def run_stats(ref_sim, ref_pair, data, data_freqs={}):
    '''
    co-factor function to md counter comparisons, deploy heatmap and calculate kmer proportion differences 
    between pairs of population.
    - ref pair: list of tuples. can't be dictionary because of repeated pops / reference tags. 
    '''
    batch = ref_sim.split('C')[0]
    sizes = [data[x[0]]['sizes'][x[1]] for x in ref_pair]
    #

    chromosomes = [ref_sim.split('.')[0].split('C')[1]]

    pop_counts = {g: data[g[0]]['counts'][g[1]] for g in ref_pair}

    num_variants = {g: data[g[0]]['Nvars'][g[1]] for g in ref_pair}

    ratio_grid, sig_cells = heatmap_v2(chromosomes,
                                       pop_counts,
                                       num_variants, {},
                                       frequency_range,
                                       exclude,
                                       p_value,
                                       muted_dir,
                                       tag='',
                                       test=test_m,
                                       output='pval')

    pop_counts = {z: s / np.sum(s) for z, s in pop_counts.items()}

    grid_diffs = pop_counts[ref_pair[0]] - pop_counts[ref_pair[1]]

    comb_stats = {
        'grids': ratio_grid,
        'sigs': sig_cells,
        'sizes': sizes,
        'batch': batch,
        'diffs': grid_diffs
    }

    if data_freqs:
        comb_stats['freqs'] = {
            ref_pair.index(x): data_freqs[x[0]][x[1]]
            for x in ref_pair
        }

    return comb_stats
def md_reference_comp(data,
                      p_value=1e-5,
                      test_m='fisher',
                      individually=False,
                      Nbins=10,
                      exclude=False,
                      frequency_range=[0, 1],
                      data_freqs={},
                      extract='pval',
                      muted_dir='',
                      tag_ref='_ss'):
    '''
    Parse data dictionary.
        data: {sim: {counts:{pop:g}, Nvars:{pop:g}, sizes:{pop:g}}}
    i: use sim and pop IDs to create dictionary connecting original populations to 
    subset populations created using ind_assignment_scatter_v1.
    ii: for each pair of reference populations, launch heatmapv2. return grid pvals or proportions,
    and proportion of mutations in subset population. allows for fisher or chi2 test for pval.
    '''

    bins = np.linspace(0, 1, Nbins)
    bins = np.round(bins, 4)
    bins = [(bins[x - 1], bins[x]) for x in range(1, len(bins))]

    avail = list(data.keys())
    ref_idx = [int(tag_ref in avail[x]) for x in range(len(avail))]
    categ = {
        z: [x for x in range(len(avail)) if ref_idx[x] == z]
        for z in [0, 1]
    }

    print([len(categ[x]) for x in [0, 1]])

    ### possible combinations per simulation.
    ref_combos = {}

    for idx in categ[0]:
        ref = avail[idx]
        ref_combs = list(data[ref]['counts'].keys())
        ref_combs = it.combinations(ref_combs, 2)
        ref_combs = list(ref_combs)

        comb_dict = {x: {} for x in ref_combs}

        comb_stats = {}

        for pair in ref_combs:
            batch = ref.split('C')[0]
            pop1, pop2 = pair

            sizes = [data[ref]['sizes'][x] for x in pair]
            #

            chromosomes = [ref.split('.')[0].split('C')[1]]

            pop_counts = {x: data[ref]['counts'][x] for x in pair}

            num_variants = {z: data[ref]['Nvars'][z] for z in pair}

            ratio_grid, sig_cells = heatmap_v2(chromosomes,
                                               pop_counts,
                                               num_variants, {},
                                               frequency_range,
                                               exclude,
                                               p_value,
                                               muted_dir,
                                               tag='',
                                               test=test_m,
                                               output='pval')

            pop_counts = {
                z: pop_counts[z] / np.sum(pop_counts[z])
                for z in pop_counts.keys()
            }

            dist_prop = pop_counts[pop1] / pop_counts[pop2]
            dist_prop = np.nan_to_num(dist_prop)

            grid_diffs = pop_counts[pop1] - pop_counts[pop2]

            comb_stats[pair] = {
                'grids': ratio_grid,
                'sigs': sig_cells,
                'sizes': sizes,
                'batch': batch,
                'diffs': grid_diffs
            }

            if data_freqs[ref]:
                comb_stats[pair]['freqs'] = {
                    pop1: data_freqs[ref][pop1],
                    pop1: data_freqs[ref][pop2]
                }

        ref_combos[ref] = {
            'combs': comb_dict,
            'sizes': data[ref]['sizes'],
            'stats': comb_stats
        }

    #### population size diffs per population per simulation
    pop_asso = {avail[x]: recursively_default_dict() for x in categ[0]}

    for av in categ[1]:
        dat = [x for x in data[avail[av]]['counts'].keys() if tag_ref in x]
        dat_size = [data[avail[av]]['sizes'][x] for x in dat]
        ref_sim = avail[av].split(tag_ref)[0]
        ref_pop = [x.split('.')[0].strip(tag_ref) for x in dat]
        dat_size = [
            dat_size[x] / data[ref_sim]['sizes'][ref_pop[x]]
            for x in range(len(dat))
        ]
        dat_size = [round(x, 3) for x in dat_size]
        for p in range(len(dat)):
            pop_asso[ref_sim][ref_pop[p]][dat_size[p]][avail[av]] = dat[p]

    d = 0
    ### combine simulation combination and population size ranges.

    for ref_sim in pop_asso.keys():
        print(ref_sim)
        batch = ref.split('C')[0]

        for combo in ref_combos[ref_sim]['combs'].keys():

            pop1, pop2 = combo

            available_sizes = {
                z: sorted(list(pop_asso[ref_sim][z].keys()))
                for z in combo
            }
            #available_sizes= {
            #    z: [round(x / ref_combos[ref_sim]['sizes'][z], 3) for x in available_sizes[z]] for z in combo
            #}
            bins_dict = {
                b: {
                    z:
                    [x for x in available_sizes[z] if x > b[0] and x <= b[1]]
                    for z in combo
                }
                for b in bins
            }

            print([len(bins_dict[b][pop1]) for b in bins])
            print([len(bins_dict[b][pop2]) for b in bins])

            bins_combs = {
                b: [(x, y) for x in bins_dict[b][pop1]
                    for y in bins_dict[b][pop2]]
                for b in bins
            }

            print([len(bins_combs[c]) for c in bins_combs.keys()])

            for bend in bins_combs.keys():
                ref_combos[ref_sim]['combs'][combo][bend] = []

                for size_combo in bins_combs[bend]:
                    i, j = size_combo

                    for sub1 in pop_asso[ref_sim][pop1][i].keys():
                        for sub2 in pop_asso[ref_sim][pop2][j].keys():

                            ref_pair = {
                                sub1: pop_asso[ref_sim][pop1][i][sub1],
                                sub2: pop_asso[ref_sim][pop2][j][sub2]
                            }

                            sizes = [
                                data[x]['sizes'][g]
                                for x, g in ref_pair.items()
                            ]
                            #

                            chromosomes = [ref_sim.split('.')[0].split('C')[1]]

                            pop_counts = {
                                x: data[g]['counts'][x]
                                for g, x in ref_pair.items()
                            }

                            num_variants = {
                                x: data[z]['Nvars'][x]
                                for z, x in ref_pair.items()
                            }

                            ratio_grid, sig_cells = heatmap_v2(chromosomes,
                                                               pop_counts,
                                                               num_variants,
                                                               {},
                                                               frequency_range,
                                                               exclude,
                                                               p_value,
                                                               muted_dir,
                                                               tag='',
                                                               test=test_m,
                                                               output='pval')

                            pop_counts = {
                                z: s / np.sum(s)
                                for z, s in pop_counts.items()
                            }

                            dist_prop = pop_counts[
                                ref_pair[sub1]] / pop_counts[ref_pair[sub2]]
                            dist_prop = np.nan_to_num(dist_prop)

                            grid_diffs = pop_counts[
                                ref_pair[sub1]] - pop_counts[ref_pair[sub2]]

                            comb_stats = {
                                'grids': ratio_grid,
                                'sigs': sig_cells,
                                'sizes': sizes,
                                'batch': batch,
                                'diffs': grid_diffs
                            }

                            if data_freqs:
                                comb_stats['freqs'] = {
                                    pop1: data_freqs[sub1][ref_pair[sub1]],
                                    pop2: data_freqs[sub2][ref_pair[sub2]]
                                }

                            ref_combos[ref_sim]['combs'][pair][bend].append(
                                comb_stats)

    return pop_asso, ref_combos