Ejemplo n.º 1
0
def main(train_f, image_f, test_f, output_f):
    rows = tsv.get_list(train_f)
    for i in range(len(rows)):
        # Convert to numeric, then pop the Pokemon name
        rows[i] = move_tree.convert_numeric(rows[i])
        rows[i].pop(0)

    tree = dtree_build.buildtree(rows)
    dtree_draw.drawtree(tree, labels, jpeg=image_f)
    classify_pokemon(tree, test_f, output_f)
Ejemplo n.º 2
0
def main(train_f, image_f, test_f, output_f):
    data = open(train_f)
    moves = []
    rows = []

    # Create a 2D array to pass into the function which creates the tree
    for line in data:
        arr = line.rstrip().split('\t')
        moves.append(arr.pop(0))
        entry = convert_numeric(arr)  # Convert arr into integers where appropriate

        rows.append(entry)

    data.close()
    tree = dtree_build.buildtree(rows)
    dtree_draw.drawtree(tree, labels, jpeg=image_f)
    classify_moves(tree, test_f, output_f)
Ejemplo n.º 3
0
def main(col_names=None):
    # parse command-line arguments to read the name of the input csv file
    # and optional 'draw tree' parameter
    if len(sys.argv) < 2:  # input file name should be specified
        print("Please specify input csv file name")
        return

    csv_file_name = sys.argv[1]

    data = []
    with open(csv_file_name) as csvfile:
        readCSV = csv.reader(csvfile, delimiter=',')
        for row in readCSV:
            list = []
            for attribute in row:
                try:
                    list += [float(attribute)]
                except:
                    list += [attribute]
            data.append(list)

    print("Total number of records = ", len(data))
    tree = dtree_build.buildtree(data, min_gain=0.01, min_samples=5)

    dtree_build.printtree(tree, '', col_names)

    max_tree_depth = dtree_build.max_depth(tree)
    print("max number of questions=" + str(max_tree_depth))

    if len(sys.argv) > 2:  # draw option specified
        import dtree_draw
        dtree_draw.drawtree(tree, jpeg=csv_file_name + '.jpg')

    if len(sys.argv) > 3:  # create json file for d3.js visualization
        import json
        import dtree_to_json
        json_tree = dtree_to_json.dtree_to_jsontree(tree, col_names)
        print(json_tree)

        # create json data for d3.js interactive visualization
        with open(csv_file_name + ".json", "w") as write_file:
            json.dump(json_tree, write_file)
Ejemplo n.º 4
0
import regr_dtree
import sys

if __name__ == "__main__":
    # fruits with their size and color
    fruits = [[4, 'red', 'apple'], [4, 'green', 'apple'], [1, 'red', 'cherry'],
              [1, 'green', 'grape'], [5, 'red', 'apple']]

    tree = regr_dtree.buildtree(fruits)
    regr_dtree.printtree(tree, '', ["size", "color"])
    print("fruit [2, 'red'] is: ", regr_dtree.classify([2, 'red'], tree))
    print("fruit [4.5, 'red'] is: ", regr_dtree.classify([4.5, 'red'], tree))
    print("fruit [1.4, 'green'] is: ", regr_dtree.classify([1.4, 'green'],
                                                           tree))

    max_tree_depth = regr_dtree.max_depth(tree)
    print("max number of questions=" + str(max_tree_depth))
    if len(sys.argv) > 1:  # draw option specified
        import dtree_draw
        dtree_draw.drawtree(tree, jpeg='fruits_dt.jpg')