def main(domainxml, trainingsetcsv, restrictionstxt):
    restrictions = dataset.restrictions_from_text(restrictionstxt)

    cols, data = dataset.read(trainingsetcsv.read(), restrictions)
    # call train function with:
    #   `col_sets` - list of sets per column, NOT including class label
    #   `data` (list of ([train data], class))
    tree = Node("swole", ("true", Label("protein and starches")),
            ("false", Label("sugar"))) # dummy temp tree
    tree = c45.run(data, list(enumerate(cols)), 0)
    tree_xml = stringify_tree(tree)
    sys.stdout.buffer.write(tree_xml)
def cross_validate(data, attributes, manifold):
    hunks = hunk(data, manifold)
    itr = pull_each(hunks)
    actual = []
    expected = []
    actual_hunked = []
    expected_hunked = []
    for elem, rest in itr:
        tree = c45.run(elem, list(enumerate(attributes)), 0.05)
        results = [tree.classify(r[0], attributes) for r in itertools.chain(*rest)]
        correct = [r[1] for r in itertools.chain(*rest)]
        actual_hunked.append(results)
        expected_hunked.append(correct)
        actual.extend(results)
        expected.extend(correct)
    return (expected, actual, expected_hunked, actual_hunked)
    def test_basic(self):
        d = [
            ([3, "false", "traditional", "South"], "Not Visited"),
            ([3, "true", "traditional", "South"], "Visited"),
            ([3, "true", "open", "North"], "Not Visited"),
            ([3, "true", "traditional", "North"], "Not Visited"),
            ([3, "false", "open", "North"], "Not Visited"),
            ([3, "true", "traditional", "South"], "Visited"),
            ([3, "true", "open", "South"], "Not Visited"),
            ([3, "false", "traditional", "South"], "Not Visited"),
            ([4, "false", "traditional", "South"], "Visited"),
            ([4, "true", "open", "North"], "Not Visited"),
            ([4, "true", "open", "South"], "Visited"),
            ([4, "false", "traditional", "North"], "Not Visited"),
            ([4, "false", "open", "South"], "Visited"),
            ([4, "true", "open", "South"], "Visited"),
            ([4, "false", "traditional", "North"], "Not Visited"),
            ([4, "true", "open", "North"], "Not Visited"),
        ]

        attributes = ["Bedrooms", "Basement", "Floorplan", "Location"]
        threshold = 0
        result = c45.run(d, list(enumerate(attributes)), threshold)
        # print(result)
        expected = Node(
            "Location",
            ("North", Label("Not Visited")),
            (
                "South",
                Node(
                    "Bedrooms",
                    (
                        3,
                        Node(
                            "Basement",
                            (
                                "true",
                                Node("Floorplan", ("traditional", Label("Visited")), ("open", Label("Not Visited"))),
                            ),
                            ("false", Label("Not Visited")),
                        ),
                    ),
                    (4, Label("Visited")),
                ),
            ),
        )
        self.assertEqual(result, expected)
 def test_small_run(self):
     result = c45.run(self.d, self.attr, 0)
     self.assertEqual(result, Node("shape", (1, Label(2)), (2, Label(1))))