def main():
    base_path = '/data/infinity-mirror/'
    input_path = '/home/dgonza26/infinity-mirror/input'
    dataset = 'flights'
    model = 'GraphRNN'

    output_path = os.path.join(base_path, 'stats', 'lambda')
    mkdir_output(output_path)

    R = [(root, generation) for root, generation in load_data(os.path.join(base_path, 'cleaned'), dataset)]
    R.sort(key=lambda x: x[1])
    R = [root for (root, generation) in R]

    if dataset == 'clique-ring-500-4':
        G = nx.ring_of_cliques(500, 4)
    else:
        G = init(os.path.join(input_path, f'{dataset}.g'))

    # transpose the list
    roots = [list(r) for r in zip(*R)]

    cols = ['model', 'gen', 'abs']
    rows = {col: [] for col in cols}

    gs0 = GraphStats(graph=G, run_id=1)
    for i, chain in enumerate(roots, 1):
        print(f'chain: {i}')
        for idx, graph in enumerate(chain, 1):
            print(f'\tgen: {idx} ... ', end='', flush=True)
            comparator = GraphPairCompare(gs0, GraphStats(graph=graph, run_id=1))
            try:
                rows['abs'].append(comparator.lambda_dist())
            except Exception as e:
                print(f'ERROR\n{e}')
            else:
                rows['model'].append('GraphRNN')
                rows['gen'].append(idx)
                print('done')

    df = pd.DataFrame(rows)
    print(df.head())
    df.to_csv(f'{output_path}/{dataset}_{model}_lambda.csv', float_format='%.7f', sep='\t', index=False, na_rep='nan')
    print(f'wrote {output_path}/{dataset}_{model}_lambda.csv')

    return
def main(base_path, dataset, models):
    if 'GraphRNN' in models:
        #path = os.path.join(base_path, 'GraphRNN')
        #for subdir, dirs, files in os.walk(path):
        #    if dataset == subdir.split('/')[-1].split('_')[0]:
        #        print(subdir)
        #        for filename in files:
        #            print(filename)
        models.remove('GraphRNN')
    for model in models:
        path = os.path.join(base_path, dataset, model)
        for subdir, dirs, files in os.walk(path):
            for filename in files:
                if 'seq' not in filename:
                    run_id = int(filename.split('.')[0].split('_')[-1])
                    string = subdir.split('/')[-2:]
                    file = os.path.join(subdir, filename)
                    newfile = file.split('.')[0]
                    if 'rob' in file:
                        newfile += '_seq_rob.pkl.gz'
                    else:
                        newfile += '_seq.pkl.gz'
                    print(f'starting\t{string[-2]}\t{string[-1]}\t{filename} ... ', end='', flush=True)
                    root = load_pickle(file)
                    node = root
                    try:
                        node.stats_seq
                    except AttributeError:
                        if type(node) is LightTreeNode:
                            node_graph_stats = GraphStats(run_id=run_id, graph=node.graph)
                            comparator = GraphPairCompare(GraphStats(graph=root.graph, run_id=run_id), \
                                                          GraphStats(graph=root.graph, run_id=run_id))
                            stats = {}
                            stats['lambda_dist'] = comparator.lambda_dist()
                            stats['node_diff'] = comparator.node_diff()
                            stats['edge_diff'] = comparator.edge_diff()
                            stats['pgd_pearson'] = comparator.pgd_pearson()
                            stats['pgd_spearman'] = comparator.pgd_spearman()
                            stats['deltacon0'] = comparator.deltacon0()
                            stats['degree_cvm'] = comparator.cvm_degree()
                            stats['pagerank_cvm'] = comparator.cvm_pagerank()
                            node = TreeNode(name=node.name, graph=node.graph, stats=stats, stats_seq={}, parent=node.parent, children=node.children)
                        elif type(node) is TreeNode:
                            node = TreeNode(name=node.name, graph=node.graph, stats=node.stats, stats_seq={}, parent=node.parent, children=node.children)
                        else:
                            print(f'node has unknown type: {type(node)}')
                            exit()
                    if node.stats_seq is None or node.stats_seq == {}:
                        node.stats_seq = {}
                        while len(node.children) > 0:
                            child = node.children[0]
                            try:
                                child.stats_seq = {}
                            except AttributeError:
                                child = TreeNode(name=child.name, graph=child.graph, stats=child.stats, stats_seq={}, parent=child.parent, children=child.children)
                            comparator = GraphPairCompare(GraphStats(graph=node.graph, run_id=run_id), \
                                                          GraphStats(graph=child.graph, run_id=run_id))
                            child.stats_seq['lambda_dist'] = comparator.lambda_dist()
                            child.stats_seq['node_diff'] = comparator.node_diff()
                            child.stats_seq['edge_diff'] = comparator.edge_diff()
                            child.stats_seq['pgd_pearson'] = comparator.pgd_pearson()
                            child.stats_seq['pgd_spearman'] = comparator.pgd_spearman()
                            child.stats_seq['deltacon0'] = comparator.deltacon0()
                            child.stats_seq['degree_cvm'] = comparator.cvm_degree()
                            child.stats_seq['pagerank_cvm'] = comparator.cvm_pagerank()
                            node = child
                    with open(newfile, 'wb') as f:
                        pickle.dump(root, f)
                    print(f'\tdone')