def test_load_rb_com_11():
    data = load_rb_com_11(num_instances=1)
    known = {
        '_guid': 'ea022d3d-5c9e-46d7-be23-8ea718fe7816',
        '_human_cluster_label': '0',
        'component0': {
            'b': 1.0,
            'l': 0.0,
            'r': 1.0,
            't': 2.0,
            'type': 'cube0'
        },
        'component1': {
            'b': 3.0,
            'l': 2.0,
            'r': 3.0,
            't': 4.0,
            'type': 'cube0'
        },
        'component14': {
            'b': 4.0,
            'l': 1.0,
            'r': 4.0,
            't': 5.0,
            'type': 'ufoo0'
        },
        'component2': {
            'b': 1.0,
            'l': 1.0,
            'r': 4.0,
            't': 2.0,
            'type': 'plat0'
        },
        'component3': {
            'b': 2.0,
            'l': 1.0,
            'r': 4.0,
            't': 3.0,
            'type': 'plat0'
        },
        'component4': {
            'b': 0.0,
            'l': 0.0,
            'r': 5.0,
            't': 1.0,
            'type': 'rect0'
        }
    }
    assert known == data[0]
def output_json(file="forest", size=100, prune=True, seed=50, burn=1):
    random.seed(seed)
    if file == "forest":
        instances = ds.load_forest_fires()
        variables = False
    elif file == "voting":
        instances = ds.load_congressional_voting()
        variables = False
    elif file == "iris":
        instances = ds.load_iris()
        variables = False
    elif file == "mushroom":
        instances = ds.load_mushroom()
        variables = False
    elif file == "rb_com_11":
        instances = ds.load_rb_com_11()
        variables = True
    elif file == "rb_s_07":
        instances = ds.load_rb_s_07()
        variables = True
    elif file == "rb_s_13":
        instances = ds.load_rb_s_13()
        variables = True
    elif file == "rb_wb_03":
        instances = ds.load_rb_wb_03()
        variables = True
    else:
        instances = ds.load_forest_fires()
        variables = False

    random.shuffle(instances)
    pprint.pprint(instances[0])
    instances = instances[:size]
    print(len(instances))

    if variables:
        variablizer = ObjectVariablizer()
        instances = [variablizer.transform(t) for t in instances]

    tree = TrestleTree()
    tree.fit(instances, iterations=burn)

    pprint.pprint(tree.root.output_json())

    with open('output.js', 'w') as out:
        out.write("var trestle_output = ")
        out.write(json.dumps(tree.root.output_json()))
        out.write(";")
x = np.arange(len(hueristics))
width = 0.3

hueristic_names = ['AIC', 'BIC', 'CU', 'AICc']
# for i in range(len(clusters)):
#     hueristic_names[i] +=  '\nClusters='+str(len(set(clusters[i])))

b1 = plt.bar(x - width,
             calculate_aris(load_rb_wb_03()),
             width,
             color='r',
             alpha=.8,
             align='center')
b2 = plt.bar(x,
             calculate_aris(load_rb_com_11()),
             width,
             color='b',
             alpha=.8,
             align='center')
b3 = plt.bar(x + width,
             calculate_aris(load_rb_s_13()),
             width,
             color='g',
             alpha=.8,
             align='center')
plt.legend((b1[0], b2[0], b3[0]), ('wb_03', 'com_11', 's_13'))
plt.title(
    "TRESTLE Clustering Accuracy of Best Clustering by Different Hueristics")
plt.ylabel("Adjusted Rand Index (Agreement Correcting for Chance)")
plt.ylim(0, 1)