def test_product_with_single_column_table(): single_table = Table({"A": 3, "B": 4, "C": 7}, names=["Y1"]) table1 = Table(sample_1, names=["X1", "X2", "X3", "X4"]) # without common names table3 = table1 * single_table assert all(compare(table3.names, ["X1", "X2", "X3", "X4", "Y1"])) # check probabilites assert table3[("a", "y", 2, 33, "B")] == 24 table3 = single_table * table1 assert all(compare(table3.names, ["Y1", "X1", "X2", "X3", "X4"])) # check probabilites assert table3[("B", "a", "y", 2, 33)] == 24 # with common names single_table = Table({"x": 3, "y": 4}, names=["X2"]) table1 = Table(sample_1, names=["X1", "X2", "X3", "X4"]) table3 = table1 * single_table assert all(compare(table3.names, ["X1", "X2", "X3", "X4"])) # check probabilites assert table3["a", "x", 2, 44] == 12 assert table3["a", "y", 2, 44] == 32 table3 = single_table * table1 assert all(compare(table3.names, ["X2", "X1", "X3", "X4"]))
def test_table_of_table(): sample_1 = { ("a", "x", 1, 33): 1, ("a", "x", 2, 33): 2, ("a", "x", 1, 44): 3, ("a", "x", 2, 44): 4, ("a", "y", 1, 33): 5, ("a", "y", 2, 33): 6, ("a", "y", 1, 44): 7, ("a", "y", 2, 44): 8, ("b", "x", 1, 33): 9, ("b", "x", 2, 33): 10, ("b", "x", 1, 44): 11, ("b", "x", 2, 44): 12, ("b", "y", 1, 33): 13, # ("b", "y", 2, 33): 14, ("b", "y", 1, 44): 15, ("b", "y", 2, 44): 16, } t1 = Table(sample_1, ["X1", "X2", "X3", "X4"]) t2 = Table(sample_1, ["X1", "X2", "X3", "X4"]) t3 = Table(sample_1, ["X1", "X2", "X3", "X4"]) t4 = Table(sample_1, ["X1", "X2", "X3", "X4"]) sample_2 = {"t1": t1, "t2": t2, "t3": t3, "t4": t4} tt = Table(sample_2, ["t"]) assert tt["t2"] == t2
def test_conditional_on_table(): table1 = Table(samples, names=["X1", "X2", "X3", "X4"]) con_table1 = table1.condition_on("X2") assert all(compare(con_table1.names, ["X2"])) for x2 in con_table1.keys(): child_table = con_table1[x2] assert all(compare(child_table.names, ["X1", "X3", "X4"])) child_table1 = con_table1["x"] child_table2 = con_table1["y"] assert child_table1["a", 1, 33] == 1 / 52 assert child_table2["a", 1, 33] == 5 / 84 assert child_table1["a", 1, 44] == 3 / 52 assert child_table2["a", 1, 44] == 7 / 84 assert child_table1["b", 1, 33] == 9 / 52 assert child_table2["b", 1, 33] == 13 / 84 assert child_table1["b", 1, 44] == 11 / 52 assert child_table2["b", 1, 44] == 15 / 84 assert child_table1["b", 2, 44] == 12 / 52 assert child_table2["b", 2, 33] == 14 / 84 # combined indexing assert con_table1["x"]["a", 1, 33] == 1 / 52 assert con_table1["y"]["a", 1, 33] == 5 / 84
def test_total_table(): table1 = Table({"a": 3, "b": 4, "c": 5}, names=["X1"]) assert table1.total() == 12 table1 = Table( {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6}, names=["X1", "X2"], ) assert table1.total() == 20 samples = { ("a", "x", 1, 33): 1, ("a", "x", 2, 33): 2, ("a", "x", 1, 44): 3, ("a", "x", 2, 44): 4, ("a", "y", 1, 33): 5, ("a", "y", 2, 33): 6, ("a", "y", 1, 44): 7, ("a", "y", 2, 44): 8, ("b", "x", 1, 33): 9, ("b", "x", 2, 33): 10, ("b", "x", 1, 44): 11, ("b", "x", 2, 44): 12, ("b", "y", 1, 33): 13, ("b", "y", 2, 33): 14, ("b", "y", 1, 44): 15, ("b", "y", 2, 44): 16, } table1 = Table(samples, names=["X1", "X2", "X3", "X4"]) con_1 = table1.condition_on("X1", normalise=False) totals = con_1.total() assert totals[("a",)] == 36 assert totals[("b",)] == 100
def test_product_with_two_common_vars_table(): table1 = Table(sample_1, names=["X1", "X2", "X3", "X4"]) table2 = Table(sample_2, names=["X3", "X5", "X6", "X2"]) table3 = table1 * table2 assert all(compare(table3.names, ["X1", "X2", "X3", "X4", "X5", "X6"])) # check probabilites assert table3[("a", "y", 2, 33, "high", "under")] == 150 # check the case that the right does not have the common assert table3[("a", "y", 2, 33, "low", "under")] is None # check the case that the left does not have the common assert table3[("b", "y", 2, 33, "high", "under")] is None
def test_product_with_no_common_vars_table(): table1 = Table(sample_1, names=["X1", "X2", "X3", "X4"]) table2 = Table(sample_2, names=["Y1", "Y2", "Y3", "Y4"]) table3 = table1 * table2 assert all( compare(table3.names, ["X1", "X2", "X3", "X4", "Y1", "Y2", "Y3", "Y4"])) # check probabilites assert table3[("a", "x", 1, 33, 2, "high", "normal", "x")] == 10 assert table3[("b", "x", 1, 44, 1, "low", "over", "y")] == 253
def test_product_with_one_common_var_table(): table1 = Table(sample_1, names=["X1", "X2", "X3", "X4"]) table2 = Table(sample_2, names=["X3", "X5", "X6", "X7"]) table3 = table1 * table2 assert all( compare(table3.names, ["X1", "X2", "X3", "X4", "X5", "X6", "X7"])) # check probabilites assert table3["a", "x", 1, 33, "high", "normal", "x"] == 2 # check the case that the right does not have the common assert table3[("b", "y", 2, 44, "high", "over", "y")] is None # check the case that the left does not have the common assert table3[("b", "y", 2, 33, "high", "normal", "y")] is None
def test_conditional_on_conditional_table(): table1 = Table(samples, names=["X1", "X2", "X3", "X4"]) con_1 = table1.condition_on("X3") con_2 = con_1.condition_on("X1") # Note: since we first condition on X3 # and then on X1, the order of keys # is as X1, X3. Otherwise, it could be # the inverse assert con_2["a", 1]["x", 33] == approx(1 / 16) assert con_2["a", 2]["x", 33] == approx(2 / 20) assert con_2["a", 1]["x", 44] == approx(3 / 16) assert con_2["a", 2]["x", 44] == approx(4 / 20) assert con_2["b", 1]["x", 33] == approx(9 / 48) assert con_2["b", 2]["x", 33] == approx(10 / 52) assert con_2["b", 1]["y", 33] == approx(13 / 48) assert con_2["b", 2]["y", 44] == approx(16 / 52) # Try the inverse conditioning con_1 = table1.condition_on("X1") con_2 = con_1.condition_on("X3") assert con_2[1, "a"]["x", 33] == approx(1 / 16) assert con_2[2, "a"]["x", 44] == approx(4 / 20) assert con_2[1, "b"]["x", 33] == approx(9 / 48) assert con_2[2, "b"]["x", 33] == approx(10 / 52) # we can use 'get' to make it order-agnostic assert con_2.get(X1="a", X3=1)["x", 33] == approx(1 / 16) # # con_1 = table1.condition_on("X3", "X4") con_2 = con_1.condition_on("X1") assert con_2["a", 1, 33]["x"] == approx(1 / 6) assert con_2["a", 2, 33]["x"] == approx(2 / 8) assert con_2["a", 1, 44]["x"] == approx(3 / 10) assert con_2["a", 2, 44]["x"] == approx(4 / 12) assert con_2["b", 1, 44]["x"] == approx(11 / 26) assert con_2["b", 2, 44]["x"] == approx(12 / 28) assert con_2["b", 1, 44]["y"] == approx(15 / 26) assert con_2["b", 2, 44]["y"] == approx(16 / 28) con_1 = table1.condition_on("X4") con_2 = con_1.condition_on("X1", "X3") assert con_2["a", 1, 33]["x"] == approx(1 / 6) assert con_2["a", 2, 33]["x"] == approx(2 / 8) assert con_2["a", 1, 44]["x"] == approx(3 / 10) assert con_2["a", 2, 44]["x"] == approx(4 / 12) assert con_2["b", 1, 44]["x"] == approx(11 / 26) assert con_2["b", 2, 44]["x"] == approx(12 / 28) assert con_2["b", 1, 44]["y"] == approx(15 / 26) assert con_2["b", 2, 44]["y"] == approx(16 / 28) # change the order con_1 = table1.condition_on("X3") con_2 = con_1.condition_on("X1", "X4") assert con_2["a", 33, 1]["x"] == approx(1 / 6) assert con_2["a", 33, 2]["x"] == approx(2 / 8) assert con_2["a", 44, 1]["x"] == approx(3 / 10) assert con_2["a", 44, 2]["x"] == approx(4 / 12) assert con_2["b", 44, 1]["x"] == approx(11 / 26) assert con_2["b", 44, 2]["x"] == approx(12 / 28) assert con_2["b", 44, 1]["y"] == approx(15 / 26) assert con_2["b", 44, 2]["y"] == approx(16 / 28)
def test_constructor_table(): table = Table({"one": 1, "two": 2, "three": 3}, names=["Y1"]) assert table["one"] == 1 assert table["two"] == 2 assert table["three"] == 3 table = Table([("two", 2), ("one", 1), ("three", 3)]) assert table["one"] == 1 assert table["two"] == 2 assert table["three"] == 3 samples = [1, 2, 2, 3, 3, 3] counter = Counter(samples) table = Table(counter) assert table[1] == 1 assert table[2] == 2 assert table[3] == 3 table = Table(zip(["one", "two", "three"], [1, 2, 3])) assert table["one"] == 1 assert table["two"] == 2 assert table["three"] == 3 # numeric key table = Table({1: 1, 2: 2, 3: 3}) assert table[1] == 1 assert table[2] == 2 assert table[3] == 3 table = Table({1.1: 1, 2.2: 2, 3.3: 3}) assert table[1.1] == 1 assert table[2.2] == 2 assert table[3.3] == 3
def test_reduce_by_name_table(): table = Table(samples) reduced_table = table.reduce(X2="y") assert reduced_table.columns.size == 3 assert all(compare(reduced_table.names, ["X1", "X3", "X4"])) assert reduced_table[("a", 1, 33)] == 5 assert reduced_table[("b", 2, 44)] == 16 reduced_table = table.reduce(X2="y", X3=1) assert reduced_table.columns.size == 2 assert all(compare(reduced_table.names, ["X1", "X4"])) assert reduced_table[("a", 33)] == 5 assert reduced_table[("b", 44)] == 15 reduced_table = table.reduce(X1="b", X3=1, X4=44) assert reduced_table.columns.size == 1 assert all(compare(reduced_table.names, ["X2"])) assert reduced_table["x"] == 11 assert reduced_table["y"] == 15 table = Table(samples, names=["Y", "Z", "W", "X"]) reduced_table = table.reduce(Z="y") assert reduced_table.columns.size == 3 assert all(compare(reduced_table.names, ["Y", "W", "X"])) assert reduced_table[("a", 1, 33)] == 5 assert reduced_table[("b", 2, 44)] == 16 reduced_table = table.reduce(Z="y", W=1) assert reduced_table.columns.size == 2 assert all(compare(reduced_table.names, ["Y", "X"])) assert reduced_table[("a", 33)] == 5 assert reduced_table[("b", 44)] == 15 reduced_table = table.reduce(Y="b", W=1, X=44) assert reduced_table.columns.size == 1 assert all(compare(reduced_table.names, ["Z"])) assert reduced_table["x"] == 11 assert reduced_table["y"] == 15
def test_reduce_by_name_on_conditioned_table(): table = Table(samples, names=["X1", "X2", "X3", "X4"]) con_1 = table.condition_on("X1") reduced_table = con_1.reduce(X2="y") assert reduced_table.columns.size == 1 assert all(compare(reduced_table.columns.children_names, ["X3", "X4"])) assert reduced_table["a"][1, 33] == 5 / 36 assert reduced_table["b"][(2, 44)] == 16 / 100 reduced_table = con_1.reduce(X2="y", X3=1) assert reduced_table.columns.size == 1 assert all(compare(reduced_table.columns.children_names, ["X4"])) assert reduced_table["a"][33] == 5 / 36 assert reduced_table["b"][44] == 15 / 100 con_1 = table.condition_on("X1", "X3") reduced_table = con_1.reduce(X2="y") assert reduced_table.columns.size == 2 assert all(compare(reduced_table.columns.children_names, ["X4"])) assert reduced_table["a", 1][33] == 5 / 16 assert reduced_table["b", 2][44] == 16 / 52
def test_getitem_table(): # One column table = Table({"one": 1, "two": 2, "three": 3}, names=["Y1"]) assert table["one"] == 1 assert table["two"] == 2 assert table["three"] == 3 assert table.get(Y1="one") == 1 assert table.get(Y1="two") == 2 assert table.get(Y1="three") == 3 # Two columns table = Table( {("one", "Red"): 1, ("two", "Green"): 2, ("three", "Blue"): 3}, names=["Y1", "Y2"], ) # Not following the order must get None assert table["Red", "one"] is None # assert table["one", "Red"] == 1 assert table["two", "Green"] == 2 assert table["three", "Blue"] == 3 assert table.get("one", "Red") == 1 assert table.get("one", Y2="Red") == 1 assert table.get("Red", Y1="one") == 1 assert table.get(Y1="one", Y2="Red") == 1 assert table.get("two", "Green") == 2 assert table.get("two", Y2="Green") == 2 assert table.get("Green", Y1="two") == 2 assert table.get(Y1="two", Y2="Green") == 2 assert table.get("three", "Blue") == 3 assert table.get("three", Y2="Blue") == 3 assert table.get("Blue", Y1="three") == 3 assert table.get(Y1="three", Y2="Blue") == 3 # Three columns table = Table( { ("one", "Red", 11): 1, ("two", "Green", 22): 2, ("three", "Blue", 33): 3, }, names=["Y1", "Y2", "Y3"], ) # Not following the order must get None assert table["Red", "one", 11] is None assert table[11, "Red", "one"] is None assert table["one", 11, "Red"] is None # assert table["one", "Red", 11] == 1 assert table["two", "Green", 22] == 2 assert table["three", "Blue", 33] == 3 assert table.get("one", "Red", 11) == 1 assert table.get("one", Y2="Red", Y3=11) == 1 assert table.get("Red", Y1="one", Y3=11) == 1 assert table.get(Y1="one", Y2="Red", Y3=11) == 1 assert table.get(11, Y1="one", Y2="Red") == 1 assert table.get(11, Y2="Red", Y1="one") == 1 assert table.get("two", "Green", 22) == 2 assert table.get("two", Y2="Green", Y3=22) == 2 assert table.get("Green", Y1="two", Y3=22) == 2 assert table.get(Y1="two", Y2="Green", Y3=22) == 2 assert table.get("three", "Blue", 33) == 3 assert table.get("three", Y2="Blue", Y3=33) == 3 assert table.get("Blue", Y1="three", Y3=33) == 3 assert table.get(Y1="three", Y2="Blue", Y3=33) == 3
def test_normalise_table(): table1 = Table({"a": 3, "b": 4, "c": 5}, names=["X1"]) table1.normalise() assert table1["a"] == 3 / 12 assert table1["b"] == 4 / 12 table1 = Table( {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6}, names=["X1", "X2"], ) table1.normalise() assert table1["a", "x"] == 4 / 20 assert table1["a", "y"] == 4 / 20 assert table1["b", "x"] == 6 / 20 assert table1["b", "y"] == 6 / 20 samples = { ("a", "x", 1, 33): 1, ("a", "x", 2, 33): 2, ("a", "x", 1, 44): 3, ("a", "x", 2, 44): 4, ("a", "y", 1, 33): 5, ("a", "y", 2, 33): 6, ("a", "y", 1, 44): 7, ("a", "y", 2, 44): 8, ("b", "x", 1, 33): 9, ("b", "x", 2, 33): 10, ("b", "x", 1, 44): 11, ("b", "x", 2, 44): 12, ("b", "y", 1, 33): 13, ("b", "y", 2, 33): 14, ("b", "y", 1, 44): 15, ("b", "y", 2, 44): 16, } table1 = Table(samples, names=["X1", "X2", "X3", "X4"]) con_1 = table1.condition_on("X1") con_1.normalise() assert con_1["a"]["x", 1, 33] == 1 / 36 assert con_1["a"]["y", 1, 33] == 5 / 36 assert con_1["a"]["y", 2, 33] == 6 / 36 assert con_1["a"]["y", 2, 44] == 8 / 36 assert con_1["b"]["x", 1, 33] == 9 / 100 assert con_1["b"]["y", 1, 33] == 13 / 100 assert con_1["b"]["y", 2, 44] == 16 / 100 con_1 = table1.condition_on("X2") con_1.normalise() assert con_1["x"]["a", 1, 33] == 1 / 52 assert con_1["x"]["a", 2, 44] == 4 / 52 assert con_1["y"]["b", 1, 33] == 13 / 84 con_1 = table1.condition_on("X1", "X2") con_1.normalise() assert con_1["a", "x"][1, 33] == 1 / 10 assert con_1["a", "x"][2, 44] == 4 / 10 assert con_1["b", "y"][2, 33] == 14 / 58
def test_product_with_table_of_table_with_no_common_table(): table1 = Table(sample_1, names=["X1", "X2", "X3", "X4"]) table2 = Table({"one": 1, "two": 2, "three": 3}, names=["X1"]) con_1 = table1.condition_on("X1") # P(X2, X3, X4 | X1) * P(X1) -> P(X2, X3, X4, X1) product_1 = con_1 * table2 assert len(product_1) == 0 assert all(compare(product_1.names, ["X2", "X3", "X4", "X1"])) # P(X1) * P(X2, X3, X4 | X1) -> P(X1, X2, X3, X4) product_1 = table2 * con_1 assert len(product_1) == 0 assert all(compare(product_1.names, ["X1", "X2", "X3", "X4"])) table2 = Table({ ("one", "x"): 1, ("two", "x"): 2, ("three", "y"): 3 }, names=["X1", "X2"]) con_1 = table1.condition_on("X1", "X2") # P(X3, X4 | X1, X2) * P(X1, X2) -> P(X3, X4 , X1, X2) product_1 = con_1 * table2 assert len(product_1) == 0 assert all(compare(product_1.names, ["X3", "X4", "X1", "X2"])) # P(X1, X2) * P(X3, X4 | X1, X2) -> P( X1, X2, X3, X4) product_1 = table2 * con_1 assert len(product_1) == 0 assert all(compare(product_1.names, ["X1", "X2", "X3", "X4"])) con_2 = table2.condition_on("X1") # P(X3, X4 | X1, X2) * P(X2 | X1) -> P(X3, X4, X2 | X1) product_1 = con_1 * con_2 assert len(product_1) == 0 assert all(compare(product_1.names, ["X1"])) assert all(compare(product_1.columns.children_names, ["X3", "X4", "X2"])) # P(X2 | X1) * P(X3, X4 | X1, X2) -> P(X2, X3, X4 | X1) product_1 = con_2 * con_1 assert len(product_1) == 0 assert all(compare(product_1.names, ["X1"])) assert all(compare(product_1.columns.children_names, ["X2", "X3", "X4"])) table3 = Table( { ("one", "x", 1): 1, ("two", "x", 1): 2, ("three", "y", 2): 3 }, names=["X1", "X2", "X3"], ) con_1 = table1.condition_on("X1", "X2", "X3") # P(X4 | X1, X2, X3) * P(X1, X2, X3) -> P(X4, X1, X2, X3) product_1 = con_1 * table3 assert len(product_1) == 0 assert all(compare(product_1.names, ["X4", "X1", "X2", "X3"])) # P(X1, X2) * P(X3, X4 | X1, X2) -> P( X1, X2, X3, X4) product_1 = table3 * con_1 assert len(product_1) == 0 assert all(compare(product_1.names, ["X1", "X2", "X3", "X4"])) con_2 = table3.condition_on("X1") # P(X4 | X1, X2, X3) * P(X2, X3 | X1) -> P(X4, X2, X3 | X1) product_1 = con_1 * con_2 assert len(product_1) == 0 assert all(compare(product_1.names, ["X1"])) assert all(compare(product_1.columns.children_names, ["X4", "X2", "X3"])) # P(X2, X3 | X1) * P(X4 | X1, X2, X3) -> P(X2, X3, X4 | X1) product_1 = con_2 * con_1 assert len(product_1) == 0 assert all(compare(product_1.names, ["X1"])) assert all(compare(product_1.columns.children_names, ["X2", "X3", "X4"])) con_2 = table3.condition_on("X1", "X2") # P(X4 | X1, X2, X3) * P(X3 | X1, X2) -> P(X4, X3 | X1, X2) product_1 = con_1 * con_2 assert len(product_1) == 0 assert all(compare(product_1.names, ["X1", "X2"])) assert all(compare(product_1.columns.children_names, ["X4", "X3"])) # P(X3 | X1, X2) * P(X4 | X1, X2, X3) -> P(X3, X4 | X1, X2) product_1 = con_2 * con_1 assert len(product_1) == 0 assert all(compare(product_1.names, ["X1", "X2"])) assert all(compare(product_1.columns.children_names, ["X3", "X4"])) con_2 = table3.condition_on("X2", "X3") # P(X4 | X1, X2, X3) * P(X1 | X2, X3) -> P(X4, X1 | X2, X3) product_1 = con_1 * con_2 assert len(product_1) == 0 assert all(compare(product_1.names, ["X2", "X3"])) assert all(compare(product_1.columns.children_names, ["X4", "X1"])) # P(X1 | X2, X3) * P(X4 | X1, X2, X3) -> P(X1, X4 | X2, X3) product_1 = con_2 * con_1 assert len(product_1) == 0 assert all(compare(product_1.names, ["X2", "X3"])) assert all(compare(product_1.columns.children_names, ["X1", "X4"]))
def test_product_with_table_of_table_column_table(): table1 = Table(sample_3, names=["X1", "X2", "X3", "X4", "X5"]) def assert_all(table1, table2): for key1 in table1: key2_dict = table1.columns.named_key(key1) assert table1[key1] == approx(table2.get(**key2_dict)) table1.normalise() con_1 = table1.condition_on("X1") table2 = table1.marginal("X2", "X3", "X4", "X5") # P(X2, X3, X4, X5 | X1) * P(X1) -> P(X2, X3, X4, X5, X1) product_1 = con_1 * table2 assert_all(product_1, table1) # P(X1) * P(X2, X3, X4, X5 | X1) -> P(X1, X2, X3, X4, X5) product_1 = table2 * con_1 assert_all(product_1, table1) con_1 = table1.condition_on("X2") table2 = table1.marginal("X1", "X3", "X4", "X5") # P(X1, X3, X4, X5 | X2) * P(X1) -> P(X1, X3, X4, X5, X2) product_1 = con_1 * table2 assert_all(product_1, table1) # P(X1) * P(X1, X3, X4, X5 | X2) -> P(X1, X3, X4, X5, X1) product_1 = table2 * con_1 assert_all(product_1, table1) con_1 = table1.condition_on("X5") table2 = table1.marginal("X1", "X2", "X3", "X4") # P(X1, X2, X3, X4 | X5) * P(X5) -> P(X1, X2, X3, X4, X5) product_1 = con_1 * table2 assert_all(product_1, table1) # P(X5) * P(X1, X2, X3, X4 | X5) -> P(X5, X1, X2, X3, X4) product_1 = table2 * con_1 assert_all(product_1, table1) con_1 = table1.condition_on("X1", "X2") table2 = table1.marginal("X3", "X4", "X5") # P(X3, X4, X5 | X1, X2) * P(X1, X2) -> P(X3, X4, X5, X1, X2) product_1 = con_1 * table2 assert_all(product_1, table1) # P(X1, X2) * P(X3, X4, X5 | X1, X2) -> P(X1, X2, X3, X4, X5) product_1 = table2 * con_1 assert_all(product_1, table1) con_1 = table1.condition_on("X1", "X3") table2 = table1.marginal("X2", "X4", "X5") # P(X2, X4, X5 | X1, X3) * P(X1, X3) -> P(X2, X4, X5, X1, X3) product_1 = con_1 * table2 assert_all(product_1, table1) # P(X1, X3) * P(X2, X4, X5 | X1, X3) -> P(X1, X3, X2, X4, X5) product_1 = table2 * con_1 assert_all(product_1, table1) con_1 = table1.condition_on("X2", "X3") table2 = table1.marginal("X1", "X4", "X5") # P(X1, X4, X5 | X2, X3) * P(X2, X3) -> P(X1, X4, X5, X2, X3) product_1 = con_1 * table2 assert_all(product_1, table1) # P(X2, X3) * P(X1, X4, X5 | X2, X3) -> P(X2, X3, X1, X4, X5) product_1 = table2 * con_1 assert_all(product_1, table1) con_1 = table1.marginal("X4", "X5").condition_on("X2", "X3") table2 = table1.condition_on("X1", "X2", "X3") # P(X1 | X2, X3) * P(X4, X5 | X1, X2, X3) -> P(X1, X4, X5| X2, X3) product_1 = con_1 * table2 assert_all(product_1, table1) # P(X4, X5 | X1, X2, X3) * P(X1 | X2, X3) -> P(X4, X5, X1| X2, X3) product_1 = table2 * con_1 assert_all(product_1, table1)
def test_marginals_names_exception_table(): # Wrong column name with pytest.raises(ValueError): samples = {"a": 3, "b": 4, "c": 5} table = Table(samples) table.marginal("X1") # Wrong column name with pytest.raises(ValueError): samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6} table = Table(samples) table.marginal("X0") # Wrong column name with pytest.raises(ValueError): samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6} table = Table(samples) table.marginal("X3") # Wrong column name with pytest.raises(ValueError): samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6} table = Table(samples) table2 = table.marginal("X1") table2.marginal("X1") # Wrong column name with pytest.raises(ValueError): samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6} table = Table(samples, names=["Y", "Z"]) table.marginal("X1") # Wrong column name with pytest.raises(ValueError): samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6} table = Table(samples, names=["Y", "Z"]) table.marginal("X1") # Wrong column name with pytest.raises(ValueError): samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6} table = Table(samples, names=["Y", "Z"]) table2 = table.marginal("Y") table2.marginal("Y") # Marginalize over all columns with pytest.raises(ValueError): samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6} table = Table(samples, names=["Y", "Z"]) table2 = table.marginal("Y", "Z")
def test_marginals_names_table(): samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6} table = Table(samples) table2 = table.marginal("X1") assert all(compare(table2.names, ["X2"])) table2 = table.marginal("X2") assert all(compare(table2.names, ["X1"])) # table = Table(samples, names=["Y", "Z"]) table2 = table.marginal("Y") assert all(compare(table2.names, ["Z"])) table2 = table.marginal("Z") assert all(compare(table2.names, ["Y"])) # Three levels dist. samples = { ("a", "x", 1): 4, ("a", "x", 2): 4, ("a", "y", 1): 6, ("a", "y", 2): 6, ("b", "x", 1): 8, ("b", "x", 2): 8, ("b", "y", 1): 10, ("b", "y", 2): 10, } table = Table(samples) table2 = table.marginal("X1") assert all(compare(table2.names, ["X2", "X3"])) table2 = table.marginal("X2") assert all(compare(table2.names, ["X1", "X3"])) table2 = table.marginal("X3") assert all(compare(table2.names, ["X1", "X2"])) table2 = table.marginal("X1", "X3") assert all(compare(table2.names, ["X2"])) table2 = table.marginal("X2", "X3") assert all(compare(table2.names, ["X1"])) # table = Table(samples, names=["Y", "Z", "W"]) table2 = table.marginal("Y") assert all(compare(table2.names, ["Z", "W"])) table2 = table.marginal("Z") assert all(compare(table2.names, ["Y", "W"])) table2 = table.marginal("W") assert all(compare(table2.names, ["Y", "Z"])) table2 = table.marginal("Y", "W") assert all(compare(table2.names, ["Z"])) table2 = table.marginal("Z", "W") assert all(compare(table2.names, ["Y"]))
def test_marginal_if_table_of_table(): sample_1 = { ("a", "x", 1, 33): 1, ("a", "x", 2, 33): 2, ("a", "x", 1, 44): 3, ("a", "x", 2, 44): 4, ("a", "y", 1, 33): 5, ("a", "y", 2, 33): 6, ("a", "y", 1, 44): 7, ("a", "y", 2, 44): 8, ("b", "x", 1, 33): 9, ("b", "x", 2, 33): 10, ("b", "x", 1, 44): 11, ("b", "x", 2, 44): 12, ("b", "y", 1, 33): 13, # ("b", "y", 2, 33): 14, ("b", "y", 1, 44): 15, ("b", "y", 2, 44): 16, } table1 = Table(sample_1, names=["X1", "X2", "X3", "X4"]) con_1 = table1.condition_on("X1") with pytest.raises(ValueError): con_1.marginal("X1") marginal_1 = con_1.marginal("X2") assert all(compare(marginal_1.names, ["X1"])) assert all(compare(marginal_1.children_names, ["X3", "X4"])) assert marginal_1["a"][1, 33] == approx((1 + 5) / 36) assert marginal_1["a"][1, 44] == approx((3 + 7) / 36) assert marginal_1["a"][2, 33] == approx((2 + 6) / 36) assert marginal_1["a"][2, 44] == approx((4 + 8) / 36) assert marginal_1["b"][1, 33] == approx((9 + 13) / 86) assert marginal_1["b"][1, 44] == approx((11 + 15) / 86) con_1 = table1.condition_on("X1", normalise=False) marginal_1 = con_1.marginal("X2", normalise=False) assert marginal_1["a"][1, 33] == (1 + 5) assert marginal_1["a"][1, 44] == (3 + 7) assert marginal_1["a"][2, 33] == (2 + 6) assert marginal_1["a"][2, 44] == (4 + 8) assert marginal_1["b"][1, 33] == (9 + 13) assert marginal_1["b"][1, 44] == (11 + 15) marginal_1 = con_1.marginal("X2", "X3") assert all(compare(marginal_1.names, ["X1"])) assert all(compare(marginal_1.children_names, ["X4"])) assert marginal_1["a"][33] == approx((1 + 5 + 2 + 6) / 36) assert marginal_1["a"][44] == approx((3 + 7 + 4 + 8) / 36) assert marginal_1["b"][33] == approx((9 + 10 + 13) / 86) assert marginal_1["b"][44] == approx((11 + 12 + 15 + 16) / 86) marginal_1 = con_1.marginal("X2", "X4") assert all(compare(marginal_1.names, ["X1"])) assert all(compare(marginal_1.children_names, ["X3"])) assert marginal_1["a"][1] == approx((1 + 3 + 5 + 7) / 36) assert marginal_1["a"][2] == approx((2 + 4 + 6 + 8) / 36) assert marginal_1["b"][1] == approx((9 + 11 + 13 + 15) / 86) assert marginal_1["b"][2] == approx((10 + 12 + 16) / 86) con_2 = table1.condition_on("X1", "X3") with pytest.raises(ValueError): con_2.marginal("X1") with pytest.raises(ValueError): con_2.marginal("X3") with pytest.raises(ValueError): con_2.marginal("X2", "X4") marginal_2 = con_2.marginal("X2") assert all(compare(marginal_2.names, ["X1", "X3"])) assert all(compare(marginal_2.children_names, ["X4"])) assert marginal_2["a", 1][33] == approx((1 + 5) / 16) assert marginal_2["a", 1][44] == approx((3 + 7) / 16) assert marginal_2["a", 2][44] == approx((4 + 8) / 20) assert marginal_2["b", 1][33] == approx((9 + 13) / 48) assert marginal_2["b", 2][33] == approx(10 / 38) assert marginal_2["b", 2][44] == approx((12 + 16) / 38) con_3 = table1.condition_on("X1", "X3", "X4") with pytest.raises(ValueError): con_3.marginal("X2")
def test_add_table(): samples = {"a": 3, "b": 4, "c": 5} table1 = Table(samples, names=["X1"]) table2 = Table(samples, names=["X1"]) table3 = table1 + table2 assert table3["a"] == 2 * 3 assert table3["b"] == 2 * 4 assert table3["c"] == 2 * 5 samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6} table1 = Table(samples, names=["X1", "X2"]) table2 = Table(samples, names=["X1", "X2"]) table3 = table1 + table2 assert table3["a", "x"] == 2 * 4 assert table3["a", "y"] == 2 * 4 assert table3["b", "x"] == 2 * 6 assert table3["b", "y"] == 2 * 6 sample_2 = { ("a", "x", 1, 33): 1, ("a", "x", 2, 33): 2, ("a", "x", 1, 44): 3, ("a", "x", 2, 44): 4, ("a", "y", 1, 33): 5, ("a", "y", 2, 33): 6, ("a", "y", 1, 44): 7, ("a", "y", 2, 44): 8, ("b", "x", 1, 33): 9, ("b", "x", 2, 33): 10, ("b", "x", 1, 44): 11, ("b", "x", 2, 44): 12, ("b", "y", 1, 33): 13, # ("b", "y", 2, 33): 14, ("b", "y", 1, 44): 15, ("b", "y", 2, 44): 16, } table1 = Table(sample_2, names=["X1", "X2", "X3", "X4"]) table2 = Table(sample_2, names=["X1", "X2", "X3", "X4"]) table3 = table1 + table2 assert table3["a", "x", 1, 33] == 2 * 1 assert table3["a", "y", 2, 33] == 2 * 6 assert table3["b", "x", 2, 44] == 2 * 12 assert table3[("b", "y", 2, 44)] == 2 * 16 con_1 = table1.condition_on("X1") con_2 = table2.condition_on("X1") table3 = con_1 + con_2 assert table3["a"]["x", 1, 33] == (2 * 1) / 36 assert table3["a"]["x", 1, 44] == (2 * 3) / 36 assert table3["b"]["x", 1, 33] == (2 * 9) / 86 assert table3["b"]["x", 1, 44] == (2 * 11) / 86 con_1 = table1.condition_on("X1", "X3") con_2 = table2.condition_on("X1", "X3") table3 = con_1 + con_2 assert table3["a", 1]["x", 33] == (2 * 1) / 16 assert table3["a", 2]["y", 44] == (2 * 8) / 20 assert table3["b", 1]["x", 44] == (2 * 11) / 48 assert table3["b", 2]["y", 33] is None assert table3["b", 2]["y", 44] == (2 * 16) / 38
def test_statistical_independence_table(): # P(x,y,z) = P(x)P(y)P(z) # to check that, first, create a joint dist. by product # then marginalis and multiply again. The final must be equal # the joint # Note: the multi-variable distributions must be statistically # independent s_1 = { ("x", 1): 1 * 6, ("x", 2): 1 * 4, ("y", 1): 9 * 6, ("y", 2): 9 * 4, } table1 = Table(s_1, names=["X1", "X2"]) s_2 = { (1, "high", "under", "x"): 4 * 3 * 1 * 1, (1, "high", "normal", "x"): 4 * 3 * 2 * 1, (1, "high", "over", "x"): 4 * 3 * 3 * 1, (1, "high", "obese", "x"): 4 * 3 * 4 * 1, (1, "low", "under", "x"): 4 * 2 * 1 * 1, (1, "low", "normal", "x"): 4 * 2 * 2 * 1, (1, "low", "over", "x"): 4 * 2 * 3 * 1, (1, "low", "obese", "x"): 4 * 2 * 4 * 1, (2, "high", "under", "x"): 2 * 3 * 1 * 1, (2, "high", "normal", "x"): 2 * 3 * 2 * 1, (2, "high", "over", "x"): 2 * 3 * 3 * 1, (2, "high", "obese", "x"): 2 * 3 * 4 * 1, (2, "low", "under", "x"): 2 * 2 * 1 * 1, (2, "low", "normal", "x"): 2 * 2 * 2 * 1, (2, "low", "over", "x"): 2 * 2 * 3 * 1, (2, "low", "obese", "x"): 2 * 2 * 4 * 1, (1, "high", "under", "y"): 4 * 3 * 1 * 3, (1, "high", "normal", "y"): 4 * 3 * 2 * 3, (1, "high", "over", "y"): 4 * 3 * 3 * 3, (1, "high", "obese", "y"): 4 * 3 * 4 * 3, (1, "low", "under", "y"): 4 * 2 * 1 * 3, (1, "low", "normal", "y"): 4 * 2 * 2 * 3, (1, "low", "over", "y"): 4 * 2 * 3 * 3, (1, "low", "obese", "y"): 4 * 2 * 4 * 3, (2, "high", "under", "y"): 2 * 3 * 1 * 3, (2, "high", "normal", "y"): 2 * 3 * 2 * 3, (2, "high", "over", "y"): 2 * 3 * 3 * 3, (2, "high", "obese", "y"): 2 * 3 * 4 * 3, (2, "low", "under", "y"): 2 * 2 * 1 * 3, (2, "low", "normal", "y"): 2 * 2 * 2 * 3, (2, "low", "over", "y"): 2 * 2 * 3 * 3, (2, "low", "obese", "y"): 2 * 2 * 4 * 3, } table2 = Table(s_2, names=["Y1", "Y2", "Y3", "Y4"]) single_table3 = Table({11: 2, 22: 4, 33: 3}, names=["Z"]) joint_dist = table1 * table2 * single_table3 marginals = [] for name in joint_dist.names: names_except_one = list(set(joint_dist.names) - {name}) marginal = joint_dist.marginal(*names_except_one) marginals.append(marginal) joint_dist2 = np.product(marginals) # Normalise Both t1 = sum(joint_dist.values()) t2 = sum(joint_dist2.values()) joint_dist = Table({k: v / t1 for k, v in joint_dist.items()}, joint_dist.names) joint_dist2 = Table({k: v / t2 for k, v in joint_dist2.items()}, joint_dist2.names) for k1 in joint_dist: assert joint_dist[k1] == approx(joint_dist2[k1], abs=1e-16)
def test_conditional_on_exception_table(): # with pytest.raises(ValueError): table1 = Table({"a": 3, "b": 4, "c": 5}, names=["X1"]) table1.condition_on("X1") with pytest.raises(ValueError): table1 = Table( { ("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6 }, names=["X1", "X2"], ) table1.condition_on("X1", "X2") with pytest.raises(ValueError): table1 = Table(samples, names=["X1", "X2", "X3", "X4"]) table1.condition_on("X1", "X2", "X3", "X4") with pytest.raises(ValueError): table1 = Table(samples, names=["X1", "X2", "X3", "X4"]) con_1 = table1.condition_on("X1") con_1.condition_on("X1") with pytest.raises(ValueError): table1 = Table(samples, names=["X1", "X2", "X3", "X4"]) con_1 = table1.condition_on("X1", "X2") con_1.condition_on("X1", "X2") with pytest.raises(ValueError): table1 = Table(samples, names=["X1", "X2", "X3", "X4"]) con_1 = table1.condition_on("X1", "X2") con_1.condition_on("X2", "X1") with pytest.raises(ValueError): table1 = Table(samples, names=["X1", "X2", "X3", "X4"]) con_1 = table1.condition_on("X1", "X2", "X3") con_1.condition_on("X1", "X2", "X3") with pytest.raises(ValueError): table1 = Table(samples, names=["X1", "X2", "X3", "X4"]) con_1 = table1.condition_on("X3", "X2", "X2") con_1.condition_on("X1", "X3", "X1")
def test_reduce_exception_table(): with pytest.raises(ValueError): table1 = Table({"a": 3, "b": 4, "c": 5}, names=["X1"]) table1.reduce(X1="a") with pytest.raises(ValueError): table1 = Table( { ("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6 }, names=["X1", "X2"], ) table1.reduce(X1="a", X2="y") with pytest.raises(ValueError): table1 = Table(samples, names=["X1", "X2", "X3", "X4"]) table1.reduce(X1="a", X2="x", X3=1, X4=44) with pytest.raises(ValueError): table1 = Table(samples, names=["X1", "X2", "X3", "X4"]) con_1 = table1.condition_on("X1") con_1.reduce(X2="x", X3=1, X4=44) with pytest.raises(ValueError): table1 = Table(samples, names=["X1", "X2", "X3", "X4"]) con_1 = table1.condition_on("X2") con_1.reduce(X1="a", X3=1, X4=44) with pytest.raises(ValueError): table1 = Table(samples, names=["X1", "X2", "X3", "X4"]) con_1 = table1.condition_on("X3") con_1.reduce(X1="a", X2="x", X4=44) with pytest.raises(ValueError): table1 = Table(samples, names=["X1", "X2", "X3", "X4"]) con_1 = table1.condition_on("X1", "X2") con_1.reduce(X3=1, X4=44) with pytest.raises(ValueError): table1 = Table(samples, names=["X1", "X2", "X3", "X4"]) con_1 = table1.condition_on("X1", "X3") con_1.reduce(X2="x", X4=44) with pytest.raises(ValueError): table1 = Table(samples, names=["X1", "X2", "X3", "X4"]) con_1 = table1.condition_on("X1", "X3", "X4") con_1.reduce(X2="x")
def test_product_exceptions_table(): table = Table(sample_1) with pytest.raises(ValueError): table *= "ttt"
def test_marginals_table(): # Single RV dist. with pytest.raises(ValueError): table = Table({"A": 2, "B": 3, "C": 4}) table.marginal("X1") # Two levels dist. samples = {(1, 1): 4, (1, 2): 4, (2, 1): 6, (2, 2): 6} table = Table(samples) table2 = table.marginal("X1") assert all(compare(table2.keys(), [(1, ), (2, )])) assert table2[1] == 10 / 20 assert table2[2] == 10 / 20 table2 = table.marginal("X2") assert all(compare(table2.keys(), [(1, ), (2, )])) assert table2[1] == 8 / 20 assert table2[2] == 12 / 20 samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6} table = Table(samples) table2 = table.marginal("X1") assert all(compare(table2.keys(), [("x", ), ("y", )])) assert table2["x"] == 10 / 20 assert table2["y"] == 10 / 20 table2 = table.marginal("X1") assert all(compare(table2.keys(), [("x", ), ("y", )])) assert table2["x"] == 10 / 20 assert table2["y"] == 10 / 20 table2 = table.marginal("X2") assert all(compare(table2.keys(), [("a", ), ("b", )])) assert table2["a"] == 8 / 20 assert table2["b"] == 12 / 20 # Three levels dist. samples = { ("a", "x", 1): 4, ("a", "x", 2): 4, ("a", "y", 1): 6, ("a", "y", 2): 6, ("b", "x", 1): 8, ("b", "x", 2): 8, ("b", "y", 1): 10, ("b", "y", 2): 10, } table = Table(samples) table2 = table.marginal("X1") assert all(compare(table2.keys(), [("x", 1), ("x", 2), ("y", 1), ("y", 2)])) assert table2[("x", 1)] == 12 / 56 assert table2[("x", 2)] == 12 / 56 assert table2[("y", 1)] == 16 / 56 assert table2[("y", 2)] == 16 / 56 table2 = table.marginal("X2") assert all(compare(table2.keys(), [("a", 1), ("a", 2), ("b", 1), ("b", 2)])) assert table2[("a", 1)] == 10 / 56 assert table2[("a", 2)] == 10 / 56 assert table2[("b", 1)] == 18 / 56 assert table2[("b", 2)] == 18 / 56 table2 = table.marginal("X3") assert all( compare(table2.keys(), [("a", "x"), ("a", "y"), ("b", "x"), ("b", "y")])) assert table2[("a", "x")] == 8 / 56 assert table2[("a", "y")] == 12 / 56 assert table2[("b", "x")] == 16 / 56 assert table2[("b", "y")] == 20 / 56 table2 = table.marginal("X1", "X2") assert all(compare(table2.keys(), [(1, ), (2, )])) assert table2[1] == 28 / 56 assert table2[2] == 28 / 56 table2 = table.marginal("X1", "X3") assert all(compare(table2.keys(), [("x", ), ("y", )])) assert table2["x"] == 24 / 56 assert table2["y"] == 32 / 56 table2 = table.marginal("X2", "X3") assert all(compare(table2.keys(), [("a", ), ("b", )])) assert table2["a"] == 20 / 56 assert table2["b"] == 36 / 56 # Four levels dist. samples = { ("a", "x", 1, 33): 1, ("a", "x", 2, 33): 2, ("a", "x", 1, 44): 3, ("a", "x", 2, 44): 4, ("a", "y", 1, 33): 5, ("a", "y", 2, 33): 6, ("a", "y", 1, 44): 7, ("a", "y", 2, 44): 8, ("b", "x", 1, 33): 9, ("b", "x", 2, 33): 10, ("b", "x", 1, 44): 11, ("b", "x", 2, 44): 12, ("b", "y", 1, 33): 13, ("b", "y", 2, 33): 14, ("b", "y", 1, 44): 15, ("b", "y", 2, 44): 16, } table = Table(samples) table2 = table.marginal("X3") assert all( compare( table2.keys(), [ ("a", "x", 33), ("a", "x", 44), ("a", "y", 33), ("a", "y", 44), ("b", "x", 33), ("b", "x", 44), ("b", "y", 33), ("b", "y", 44), ], )) assert table2[("a", "x", 33)] == 3 / 136 assert table2[("a", "x", 44)] == 7 / 136 assert table2[("a", "y", 33)] == 11 / 136 assert table2[("a", "y", 44)] == 15 / 136 assert table2[("b", "x", 33)] == 19 / 136 assert table2[("b", "x", 44)] == 23 / 136 assert table2[("b", "y", 33)] == 27 / 136 assert table2[("b", "y", 44)] == 31 / 136 table2 = table.marginal("X4") assert all( compare( table2.keys(), [ ("a", "x", 1), ("a", "x", 2), ("a", "y", 1), ("a", "y", 2), ("b", "x", 1), ("b", "x", 2), ("b", "y", 1), ("b", "y", 2), ], )) assert table2[("a", "x", 1)] == 4 / 136 assert table2[("a", "x", 2)] == 6 / 136 assert table2[("a", "y", 1)] == 12 / 136 assert table2[("a", "y", 2)] == 14 / 136 assert table2[("b", "x", 1)] == 20 / 136 assert table2[("b", "x", 2)] == 22 / 136 assert table2[("b", "y", 1)] == 28 / 136 assert table2[("b", "y", 2)] == 30 / 136 table2 = table.marginal("X1", "X4") assert all(compare(table2.keys(), [("x", 1), ("x", 2), ("y", 1), ("y", 2)])) assert table2[("x", 1)] == 24 / 136 assert table2[("x", 2)] == 28 / 136 assert table2[("y", 1)] == 40 / 136 assert table2[("y", 2)] == 44 / 136 table2 = table.marginal("X1", "X2", "X4") assert all(compare(table2.keys(), [(1, ), (2, )])) assert table2[1] == 64 / 136 assert table2[2] == 72 / 136 # marginalize two times table2 = table.marginal("X1", "X4") table3 = table2.marginal("X2") assert all(compare(table3.keys(), [(1, ), (2, )])) assert table3[1] == 64 / 136 assert table3[2] == 72 / 136 # marginalize three times table2 = table.marginal("X4") table3 = table2.marginal("X3") table4 = table3.marginal("X2") assert all(compare(table4.keys(), [("a", ), ("b", )])) assert table4["a"] == 36 / 136 assert table4["b"] == 100 / 136
def test_marginal_by_name_table(): # Four levels dist. samples = { ("a", "x", 1, 33): 1, ("a", "x", 2, 33): 2, ("a", "x", 1, 44): 3, ("a", "x", 2, 44): 4, ("a", "y", 1, 33): 5, ("a", "y", 2, 33): 6, ("a", "y", 1, 44): 7, ("a", "y", 2, 44): 8, ("b", "x", 1, 33): 9, ("b", "x", 2, 33): 10, ("b", "x", 1, 44): 11, ("b", "x", 2, 44): 12, ("b", "y", 1, 33): 13, ("b", "y", 2, 33): 14, ("b", "y", 1, 44): 15, ("b", "y", 2, 44): 16, } table = Table(samples, names=["Age", "Sex", "Edu", "Etn"]) table2 = table.marginal("Edu") assert all( compare( table2.keys(), [ ("a", "x", 33), ("a", "x", 44), ("a", "y", 33), ("a", "y", 44), ("b", "x", 33), ("b", "x", 44), ("b", "y", 33), ("b", "y", 44), ], )) assert table2[("a", "x", 33)] == 3 / 136 assert table2[("a", "x", 44)] == 7 / 136 assert table2[("a", "y", 33)] == 11 / 136 assert table2[("a", "y", 44)] == 15 / 136 assert table2[("b", "x", 33)] == 19 / 136 assert table2[("b", "x", 44)] == 23 / 136 assert table2[("b", "y", 33)] == 27 / 136 assert table2[("b", "y", 44)] == 31 / 136 table2 = table.marginal("Edu", normalise=False) assert table2[("a", "x", 33)] == 3 assert table2[("a", "x", 44)] == 7 assert table2[("a", "y", 33)] == 11 assert table2[("a", "y", 44)] == 15 assert table2[("b", "x", 33)] == 19 assert table2[("b", "x", 44)] == 23 assert table2[("b", "y", 33)] == 27 assert table2[("b", "y", 44)] == 31 table2 = table.marginal("Etn") assert all( compare( table2.keys(), [ ("a", "x", 1), ("a", "x", 2), ("a", "y", 1), ("a", "y", 2), ("b", "x", 1), ("b", "x", 2), ("b", "y", 1), ("b", "y", 2), ], )) assert table2[("a", "x", 1)] == 4 / 136 assert table2[("a", "x", 2)] == 6 / 136 assert table2[("a", "y", 1)] == 12 / 136 assert table2[("a", "y", 2)] == 14 / 136 assert table2[("b", "x", 1)] == 20 / 136 assert table2[("b", "x", 2)] == 22 / 136 assert table2[("b", "y", 1)] == 28 / 136 assert table2[("b", "y", 2)] == 30 / 136 table2 = table.marginal("Age", "Etn") assert all(compare(table2.keys(), [("x", 1), ("x", 2), ("y", 1), ("y", 2)])) assert table2[("x", 1)] == 24 / 136 assert table2[("x", 2)] == 28 / 136 assert table2[("y", 1)] == 40 / 136 assert table2[("y", 2)] == 44 / 136 table2 = table.marginal("Age", "Sex", "Etn") assert all(compare(table2.keys(), [(1, ), (2, )])) assert table2[1] == 64 / 136 assert table2[2] == 72 / 136 # marginalize two times table2 = table.marginal("Age", "Etn") table3 = table2.marginal("Sex") assert all(compare(table3.keys(), [(1, ), (2, )])) assert table3[1] == 64 / 136 assert table3[2] == 72 / 136 # marginalize three times table2 = table.marginal("Etn") table3 = table2.marginal("Edu") table4 = table3.marginal("Sex") assert all(compare(table4.keys(), [("a", ), ("b", )])) assert table4["a"] == 36 / 136 assert table4["b"] == 100 / 136
def test_factoring_table(): s_1 = { (1, "high", "under", "x"): 4 * 3 * 1 * 1, (1, "high", "normal", "x"): 4 * 3 * 2 * 1, (1, "high", "over", "x"): 4 * 3 * 3 * 1, (1, "high", "obese", "x"): 4 * 3 * 4 * 1, (1, "low", "under", "x"): 4 * 2 * 1 * 1, (1, "low", "normal", "x"): 4 * 2 * 2 * 1, (1, "low", "over", "x"): 4 * 2 * 3 * 1, (1, "low", "obese", "x"): 4 * 2 * 4 * 1, (2, "high", "under", "x"): 2 * 3 * 1 * 1, (2, "high", "normal", "x"): 2 * 3 * 2 * 1, (2, "high", "over", "x"): 2 * 3 * 3 * 1, (2, "high", "obese", "x"): 2 * 3 * 4 * 1, (2, "low", "under", "x"): 2 * 2 * 1 * 1, (2, "low", "normal", "x"): 2 * 2 * 2 * 1, (2, "low", "over", "x"): 2 * 2 * 3 * 1, (2, "low", "obese", "x"): 2 * 2 * 4 * 1, (1, "high", "under", "y"): 4 * 3 * 1 * 3, (1, "high", "normal", "y"): 4 * 3 * 2 * 3, (1, "high", "over", "y"): 4 * 3 * 3 * 3, (1, "high", "obese", "y"): 4 * 3 * 4 * 3, (1, "low", "under", "y"): 4 * 2 * 1 * 3, (1, "low", "normal", "y"): 4 * 2 * 2 * 3, (1, "low", "over", "y"): 4 * 2 * 3 * 3, (1, "low", "obese", "y"): 4 * 2 * 4 * 3, (2, "high", "under", "y"): 2 * 3 * 1 * 3, (2, "high", "normal", "y"): 2 * 3 * 2 * 3, (2, "high", "over", "y"): 2 * 3 * 3 * 3, (2, "high", "obese", "y"): 2 * 3 * 4 * 3, (2, "low", "under", "y"): 2 * 2 * 1 * 3, (2, "low", "normal", "y"): 2 * 2 * 2 * 3, (2, "low", "over", "y"): 2 * 2 * 3 * 3, (2, "low", "obese", "y"): 2 * 2 * 4 * 3, } def assert_all(table1, table2): for key1 in table1: key2_dict = table1.columns.named_key(key1) assert table1[key1] == approx(table2.get(**key2_dict)) table1 = Table(s_1, names=["Y1", "Y2", "Y3", "Y4"]) table1.normalise() # table2 = table1.marginal("Y2", "Y3", "Y4") # table3 = table1.condition_on("Y1") # table4 = table3 * table2 assert_all(table4, table1) table5 = table2 * table3 assert_all(table5, table1) # On two columns table6 = table1.marginal("Y3", "Y4") table7 = table1.condition_on("Y1", "Y2") table8 = table6 * table7 assert_all(table8, table1) table9 = table7 * table6 assert_all(table9, table1) # On Three columns table10 = table1.marginal("Y4") table11 = table1.condition_on("Y1", "Y2", "Y3") table12 = table10 * table11 assert_all(table12, table1) table13 = table11 * table10 assert_all(table13, table1)
def test_add_exception_table(): # Wrong column name with pytest.raises(ValueError): samples = {"a": 3, "b": 4, "c": 5} table1 = Table(samples, names=["X1"]) table2 = Table(samples, names=["X2"]) table1 += table2 with pytest.raises(ValueError): samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6} table1 = Table(samples, names=["X1", "X2"]) table2 = Table(samples, names=["Y2", "Y2"]) table1 += table2 with pytest.raises(ValueError): samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6} table1 = Table(samples, names=["X1", "X2"]) table2 = Table(samples, names=["X1", "Y2"]) table1 += table2 # wrong order in names with pytest.raises(ValueError): samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6} table1 = Table(samples, names=["X1", "X2"]) table2 = Table(samples, names=["X2", "X1"]) table1 += table2 sample_2 = { ("a", "x", 1, 33): 1, ("a", "x", 2, 33): 2, ("a", "x", 1, 44): 3, ("a", "x", 2, 44): 4, ("a", "y", 1, 33): 5, ("a", "y", 2, 33): 6, ("a", "y", 1, 44): 7, ("a", "y", 2, 44): 8, ("b", "x", 1, 33): 9, ("b", "x", 2, 33): 10, ("b", "x", 1, 44): 11, ("b", "x", 2, 44): 12, ("b", "y", 1, 33): 13, # ("b", "y", 2, 33): 14, ("b", "y", 1, 44): 15, ("b", "y", 2, 44): 16, } sample_3 = { ("a", 1, 33): 1, ("a", 2, 33): 2, ("a", 1, 44): 7, ("a", 2, 44): 8, ("b", 1, 33): 9, ("b", 2, 33): 10, ("b", 2, 44): 16, } # conditional with pytest.raises(ValueError): table1 = Table(sample_2, names=["X1", "X2", "X3", "X4"]) table2 = Table(sample_2, names=["X1", "X2", "X3", "X4"]) con_1 = table1.condition_on("X1") con_2 = table2.condition_on("X2") con_1 += con_2 with pytest.raises(ValueError): table1 = Table(sample_2, names=["X1", "X2", "X3", "X4"]) table2 = Table(sample_2, names=["X1", "X2", "X3", "X4"]) con_1 = table1.condition_on("X1", "X2") con_2 = table2.condition_on("X2") con_1 += con_2 with pytest.raises(ValueError): table1 = Table(sample_2, names=["X1", "X2", "X3", "X4"]) table2 = Table(sample_3, names=["X1", "X2", "X3"]) con_1 = table1.condition_on("X1") con_2 = table2.condition_on("X1") con_1 += con_2 with pytest.raises(ValueError): table1 = Table(sample_2, names=["X1", "X2", "X3", "X4"]) table2 = Table(sample_3, names=["X1", "X2", "X3"]) con_1 = table1.condition_on("X1", "X2") con_2 = table2.condition_on("X1", "X2") con_1 += con_2