def draw_tree_structure(tree, show_pdf=True):
    doc = p2l.Document(str(tree).replace(' ', '_'),
                       options=('varwidth', ),
                       doc_type='standalone',
                       border='1cm')
    doc.add_package('tikz')
    del doc.packages['geometry']
    doc.add_to_preamble('\\usetikzlibrary{shapes}')
    doc += tree_struct_to_tikz(tree)
    doc.build(show_pdf=show_pdf)
def draw_decision_tree(decision_tree, show_pdf=True):
    doc = p2l.Document(str(decision_tree.tree).replace(' ', '_'),
                       options=('varwidth', ),
                       doc_type='standalone',
                       border='1cm')
    doc.add_package('tikz')
    del doc.packages['geometry']
    doc.add_to_preamble('\\usetikzlibrary{shapes}')
    doc += decision_tree_to_tikz(decision_tree,
                                 decision_tree.label_encoder.labels)
    doc.build(show_pdf=show_pdf)
コード例 #3
0
dtc_sklearn = DTCsklearn(random_state=42)
dtc_sklearn.fit(X, y)
dtc_sklearn_conv = tree_from_sklearn_decision_tree(dtc_sklearn)

dtc_ours_pic = decision_tree_to_tikz(dtc_ours,
                                     dtc_ours.label_encoder.labels,
                                     min_node_distance=1.8,
                                     level_distance=1.8,
                                     show_impurity=True,
                                     show_n_examples_by_label=True)
dtc_ours_pic.body.insert(0, r'\node at (0,1) {Our implementation};')

dtc_sklearn_pic = decision_tree_to_tikz(dtc_sklearn_conv,
                                        dtc_sklearn.classes_,
                                        min_node_distance=1.8,
                                        level_distance=1.8,
                                        show_impurity=True,
                                        show_n_examples_by_label=True)
dtc_sklearn_pic.body.insert(0, r'\node at (0,1) {Scikit-learn implementation};')

doc = p2l.Document('comparison_with_sklearn', doc_type='standalone', border='1cm')
doc.add_package('tikz')
del doc.packages['geometry']
doc.add_to_preamble('\\usetikzlibrary{shapes}')

doc += dtc_ours_pic
doc += '\\hspace{2cm}'
doc += dtc_sklearn_pic

doc.build()
from partitioning_machines import Tree, decision_tree_to_tikz, gini_impurity_criterion, breiman_alpha_pruning_objective
from partitioning_machines import DecisionTreeClassifier

dataset = load_iris()
X = dataset.data
y = [dataset.target_names[i] for i in dataset.target]
dtc = DecisionTreeClassifier(gini_impurity_criterion)
dtc.fit(X, y)

sequence_of_trees = [decision_tree_to_tikz(dtc, dtc.label_encoder.labels)]

pruning_coefs = dtc.compute_pruning_coefficients(
    breiman_alpha_pruning_objective)

for pruning_coef_threshold in pruning_coefs:
    n_nodes_removed = dtc.prune_tree(pruning_coef_threshold)
    if n_nodes_removed > 0:
        sequence_of_trees.append(
            decision_tree_to_tikz(dtc, dtc.label_encoder.labels))

doc = p2l.Document('sequential_tree_pruning',
                   doc_type='standalone',
                   border='1cm')
doc.add_package('tikz')
del doc.packages['geometry']
doc.add_to_preamble('\\usetikzlibrary{shapes}')
for tree in sequence_of_trees:
    doc += tree
    doc += r'\hspace{1cm}'
doc.build()
def process_results(exp_name='first_exp'):
    """
    Produces Tables 2 to 20 from the paper (Appendix E). Will try to call pdflatex if installed.
    
    Args:
        exp_name (str): Name of the experiment used when the experiments were run. If no experiments by that name are found, entries are set to 'nan'.
    
    Prints in the console some compiled statistics used in the paper and the tex string used to produce the tables, and will compile it if possible.
    """

    doc = p2l.Document(exp_name + '_results_by_dataset', '.')
    doc.add_package('natbib')

    tables = [build_table(dataset, exp_name) for dataset in dataset_list]

    # Other stats
    print('Some compiled statistics used in the paper:\n')

    times_leaves_cart = [
        table[3, 4].data[0][0] / table[3, 2].data[0][0] for table in tables
    ]
    print('CART tree has',
          sum(times_leaves_cart) / len(times_leaves_cart),
          'times less leaves than our model.')
    acc_gain_vs_cart = [
        table[2, 4].data[0][0] - table[2, 2].data[0][0] for table in tables
    ]
    print('Our model has a mean accuracy gain of',
          sum(acc_gain_vs_cart) / len(acc_gain_vs_cart),
          'over the CART algorithm.')
    time_ours_vs_cart = [
        table[5, 2].data[0][0] / table[5, 4].data[0][0] for table in tables
    ]
    print(
        'It took in average',
        sum(time_ours_vs_cart) / len(time_ours_vs_cart),
        'less time to prune the tree with our model than with the CART algorithm.'
    )

    times_leaves_mcart = [
        table[3, 3].data[0][0] / table[3, 2].data[0][0] for table in tables
    ]
    print('CART tree has',
          sum(times_leaves_mcart) / len(times_leaves_mcart),
          'times less leaves than the modified CART algorithm.')
    acc_gain_vs_mcart = [
        table[2, 3].data[0][0] - table[2, 2].data[0][0] for table in tables
    ]
    print('The modified CART algorithm has a mean accuracy gain of',
          sum(acc_gain_vs_mcart) / len(acc_gain_vs_mcart),
          'over the CART algorithm.')

    print('\n')

    doc.body.extend(tables)
    print(doc.build(save_to_disk=False))

    try:
        doc.build()
    except:
        pass
コード例 #6
0
    acc_ts_bound = accuracy_score(y_true=y_ts, y_pred=decision_tree.predict(X_ts))
    print(f'Accuracy score of pruned tree on train dataset: {acc_tr_bound:.3f}')
    print(f'Accuracy score of pruned tree on test dataset: {acc_ts_bound:.3f}')

    decision_tree = copy_of_tree
    n_folds = 10
    optimal_threshold = prune_with_cv(decision_tree, X_tr, y_tr, n_folds=n_folds)
    print(f'Optimal cross-validated pruning coefficient threshold: {optimal_threshold:.3f}')
    pruned_tree_with_cv = decision_tree_to_tikz(decision_tree, classes)
    acc_tr_cv = accuracy_score(y_true=y_tr, y_pred=decision_tree.predict(X_tr))
    acc_ts_cv = accuracy_score(y_true=y_ts, y_pred=decision_tree.predict(X_ts))
    print(f'Accuracy score of pruned tree on train dataset: {acc_tr_cv:.3f}')
    print(f'Accuracy score of pruned tree on test dataset: {acc_ts_cv:.3f}')


    doc = p2l.Document('tree_pruning_comparison', filepath='experiments/scripts/', doc_type='standalone', border='1cm')
    doc.add_package('tikz')
    del doc.packages['geometry']
    doc.add_to_preamble('\\usetikzlibrary{shapes}')
    
    table = doc.new(p2l.Table((5,1), as_float_env=False))
    table[0,0] = 'Iris dataset'
    table[0,0].add_rule()
    table[1,0] = f'Number of examples (total): {n_examples}'
    table[2,0] = f'Train-test split: {X_tr.shape[0]}:{X_ts.shape[0]}'
    table[3,0] = f'Number of features: {n_features}'
    table[4,0] = f'Number of fold in CV: {n_folds}'
    
    table = doc.new(p2l.Table((3,3), as_float_env=False, bottom_rule=False, top_rule=False))
    table[0,0] = 'Full tree (no pruning)'
    original_tree.kwoptions['baseline'] = '(current bounding box.north)'
def process_results(exp_name='first_exp'):
    """
    Produces Table 1 from the paper (Appendix E). Will try to call pdflatex if installed.
    
    Args:
        exp_name (str): Name of the experiment used when the experiments were run. If no experiments by that name are found, entries are set to 'nan'.
    
    Prints in the console the tex string used to produce the tables, and will compile it if possible.
    """
    doc = p2l.Document(exp_name + '_all_results', '.')

    model_names = [
        'original',
        'cart',
        'm-cart',
        'ours',
    ]

    dataset_list = list(load_datasets())

    caption = """Mean test accuracy and standard deviation on 25 random splits of 19 datasets taken from the UCI Machine Learning Repository \\citep{Dua:2019}. In parenthesis is the total number of examples followed by the number of classes of the dataset. The best performances up to a $0.0025$ accuracy gap are highlighted in bold."""

    label = "results"

    alignement = r'l@{\hspace{6pt}}c@{\hspace{6pt}}c@{\hspace{6pt}}c@{\hspace{6pt}}c'
    table = doc.new(
        p2l.Table((len(dataset_list) + 2, 5),
                  float_format='.3f',
                  alignment=alignement,
                  caption=caption,
                  label=label))
    table.body.insert(0, '\\small')

    table[0:2, 0].multicell('Dataset', v_shift='-3pt')
    table[0, 1:] = 'Model'
    table[1, 1:] = ['Original', 'CART', 'M-CART', 'Ours']
    table[0, 1:].add_rule()
    table[2:, 0] = [
        d.name.replace('_', ' ').title() + f' ({d.n_examples}, {d.n_classes})'
        for d in dataset_list
    ]
    table[1].add_rule()

    models_exp_name = [exp_name] * 4

    for d, dataset in enumerate(dataset_list):
        for i, (model,
                model_exp_name) in enumerate(zip(model_names,
                                                 models_exp_name)):
            ts_acc = []
            path = './experiments/results/' + dataset.name + '/' + model_exp_name + '/'
            try:
                with open(path + model + '.csv', 'r', newline='') as file:
                    reader = csv.reader(file)
                    header = next(reader)
                    for row in reader:
                        ts_acc.append(row[3])
            except FileNotFoundError:
                ts_acc.append(np.nan)

            table[d + 2, i + 1] = MeanWithStd(np.array(ts_acc, dtype=float))

        table[d + 2, 1:].highlight_best(
            highlight=lambda content: '$\\mathbf{' + content[1:-1] + '}$',
            atol=0.0025,
            rtol=0)

    d = [dataset_list[i].load() for i in [0, 2, 3, 4, 16]]

    table[
        2,
        0] = f'BCWD\\textsuperscript{{a}} ({d[0].n_examples}, {d[0].n_classes})'
    table[
        4,
        0] = f'CMSC\\textsuperscript{{b}} ({d[1].n_examples}, {d[1].n_classes})'
    table[
        5,
        0] = f'CBS\\textsuperscript{{c}} ({d[2].n_examples}, {d[2].n_classes})'
    table[
        6,
        0] = f'DRD\\textsuperscript{{d}} ({d[3].n_examples}, {d[3].n_classes})'
    table[
        18,
        0] = f'WFR24\\textsuperscript{{e}} ({d[4].n_examples}, {d[4].n_classes})'

    table += """\n\\footnotesize \\textsuperscript{a}Breast Cancer Wisconsin Diagnostic, \\textsuperscript{b}Climate Model Simulation Crashes, \\textsuperscript{c}Connectionist Bench Sonar,\n\n\\textsuperscript{d}Diabetic Retinopathy Debrecen, \\textsuperscript{e}Wall Following Robot 24"""

    doc.add_package('natbib')

    print(doc.build(save_to_disk=False))

    try:
        doc.build()
    except:
        pass