def get_trees(tree_type, tree_size, tree_path):
    if tree_type == "binary":
        tree = utils.balanced_binary(tree_size)
    elif tree_type == "catepillar":
        tree = utils.lopsided_tree(tree_size)
    elif tree_type == "birthdeath":
        tree = utils.unrooted_birth_death_tree(tree_size)
    elif tree_type == "kingman":
        tree = utils.unrooted_pure_kingman_tree(tree_size)
    elif tree_type == "path":
        tree = dendropy.Tree.get(path=args.path, schema="nexus")
    return tree
    return p_list, splits


#tree = utils.lopsided_tree(32)
tree_path = os.path.join(os.path.dirname(sys.path[0]), "data/NY_H3N2.newick")
#fasta_path = os.path.join(os.path.dirname(sys.path[0]), "data/NY_H3N2.fasta")
H3N2_tree = dendropy.Tree.get(path=tree_path, schema="newick")
#H3N2_dna = dendropy.DnaCharacterMatrix.get(file=open(fasta_path, "r"), schema="fasta")

tree_H1N1_path = os.path.join(os.path.dirname(sys.path[0]),
                              "data/NY_H1N1.newick")
#fasta_H1N1_path = os.path.join(os.path.dirname(sys.path[0]), "data/NY_H1N1.fasta")
H1N1_tree = dendropy.Tree.get(path=tree_H1N1_path, schema="newick")
#H1N1_dna = dendropy.DnaCharacterMatrix.get(file=open(fasta_H1N1_path, "r"), schema="fasta")

caterpillar_tree_512 = utils.lopsided_tree(512)
binary_tree_1024 = utils.balanced_binary(1024)
tree_list = [H1N1_tree, H3N2_tree, binary_tree_1024, caterpillar_tree_512]
tree_str_list = ['H1N1', 'H3N2', 'binary_1024', 'caterpillar_512']

B = 20
N_H3N2 = [500, 1000, 1500, 2000]
N_H1N1 = [300, 500, 700, 1000]
N_1024 = [500, 1000, 1500, 2000]
N_cat_256 = [300, 500, 700, 1000]
threshold = 128
#N_list = [N_H1N1,N_H3N2]
N = [300, 500, 700, 900, 1000]

# test params
#B = 2
N = 1000
num_taxa = 128
jc = generation.Jukes_Cantor()
mutation_rate = [jc.p2t(0.95)]
num_itr = 2  #0
# reference_tree = utils.unrooted_birth_death_tree(num_taxa, birth_rate=1)
# for x in reference_tree.preorder_edge_iter():
#     x.length = 1
merging_method_list = ['least_square', 'angle']
RF = {'least_square': [], 'angle': []}
F1 = {'least_square': [], 'angle': []}
for merge_method in merging_method_list:
    for i in range(num_itr):
        #reference_tree = utils.balanced_binary(num_taxa)
        reference_tree = utils.lopsided_tree(num_taxa)
        observations, taxa_meta = generation.simulate_sequences(
            N,
            tree_model=reference_tree,
            seq_model=jc,
            mutation_rate=mutation_rate)
        spectral_method = reconstruct_tree.SpectralTreeReconstruction(
            reconstruct_tree.NeighborJoining,
            reconstruct_tree.JC_similarity_matrix)
        tree_rec = spectral_method.deep_spectral_tree_reconstruction(
            observations,
            reconstruct_tree.JC_similarity_matrix,
            taxa_metadata=taxa_meta,
            threshhold=16,
            merge_method=merge_method)
        RF_i, F1_i = reconstruct_tree.compare_trees(tree_rec, reference_tree)