예제 #1
0
def likelihood_subsample(taxon, treatment, ntot_subsample=50, fmax_cutoff=0.8, fmin_cutoff=0.0, subsamples=10000):
    # ntot_subsample minimum number of mutations

    # Load convergence matrix
    convergence_matrix = parse_file.parse_convergence_matrix(pt.get_path() + '/data/timecourse_final/' +("%s_convergence_matrix.txt" % (treatment+taxon)))

    populations = [treatment+taxon + replicate for replicate in pt.replicates ]

    gene_parallelism_statistics = mutation_spectrum_utils.calculate_parallelism_statistics(convergence_matrix,populations, fmax_min=fmax_cutoff)

    G_subsample_list = []
    for i in range(subsamples):

        G_subsample = mutation_spectrum_utils.calculate_subsampled_total_parallelism(gene_parallelism_statistics, ntot_subsample=ntot_subsample)

        G_subsample_list.append(G_subsample)

    G_subsample_list.sort()

    G_CIs_dict = {}

    G_subsample_mean = np.mean(G_subsample_list)
    G_subsample_025 = G_subsample_list[ int( 0.025 * subsamples)  ]
    G_subsample_975 = G_subsample_list[ int( 0.975 * subsamples)  ]

    G_CIs_dict['G_mean'] = G_subsample_mean
    G_CIs_dict['G_025'] = G_subsample_025
    G_CIs_dict['G_975'] = G_subsample_975

    return G_CIs_dict
예제 #2
0
def calculate_likelihood_ratio_fmax(taxon,
                                    treatment,
                                    ntot_subsample=50,
                                    fmax_partition=0.8,
                                    subsamples=10000):

    convergence_matrix = parse_file.parse_convergence_matrix(
        pt.get_path() + '/data/timecourse_final/' +
        ("%s_convergence_matrix.txt" % (treatment + taxon)))

    populations = [
        treatment + taxon + replicate for replicate in pt.replicates
    ]

    gene_parallelism_statistics = mutation_spectrum_utils.calculate_parallelism_statistics(
        convergence_matrix, populations, fmax_min=fmax_cutoff)

    G_subsample_list = []
예제 #3
0
fmax_cutoffs = np.asarray([0, 0.2, 0.4, 0.6, 0.8])
G_dict_all = {}
taxa = ['B', 'C', 'D', 'F', 'J', 'P']
treatments = ['0', '1']
ntotal_dict = {}
for taxon in taxa:

    sys.stdout.write("Sub-sampling taxon: %s\n" % (taxon))

    G_dict_all[taxon] = {}
    if taxon == 'J':
        ntotal = 50
    else:
        # calculate ntot for all frequency cutoffs
        convergence_matrix = parse_file.parse_convergence_matrix(
            pt.get_path() + '/data/timecourse_final/' +
            ("%s_convergence_matrix.txt" % ('1' + taxon)))
        populations = ['1' + taxon + replicate for replicate in pt.replicates]
        gene_parallelism_statistics = mutation_spectrum_utils.calculate_parallelism_statistics(
            convergence_matrix, populations, fmax_min=max(fmax_cutoffs))
        ntotal = 0
        for gene_i, gene_parallelism_statistics_i in gene_parallelism_statistics.items(
        ):
            ntotal += gene_parallelism_statistics_i['observed']
    ntotal_dict[taxon] = ntotal
    for treatment in treatments:
        if treatment + taxon in pt.treatment_taxa_to_ignore:
            continue

        G_dict_all[taxon][treatment] = {}
def plot_within_taxon_paralleliism(taxon, slope_null=1):

    fig = plt.figure(figsize=(12, 8))

    gene_data = parse_file.parse_gene_list(taxon)

    gene_names, gene_start_positions, gene_end_positions, promoter_start_positions, promoter_end_positions, gene_sequences, strands, genes, features, protein_ids = gene_data
    # to get the common gene names for each ID

    ax_multiplicity = plt.subplot2grid((2, 3), (0, 0), colspan=1)
    ax_mult_freq = plt.subplot2grid((2, 3), (0, 1), colspan=1)
    ax_venn = plt.subplot2grid((2, 3), (0, 2), colspan=1)

    ax_multiplicity.set_xscale('log', base=10)
    ax_multiplicity.set_yscale('log', base=10)
    ax_multiplicity.set_xlabel('Gene multiplicity, ' + r'$m$', fontsize=14)
    ax_multiplicity.set_ylabel('Fraction mutations ' + r'$\geq m$',
                               fontsize=14)
    ax_multiplicity.text(-0.1,
                         1.07,
                         pt.sub_plot_labels[0],
                         fontsize=18,
                         fontweight='bold',
                         ha='center',
                         va='center',
                         transform=ax_multiplicity.transAxes)

    ax_multiplicity.set_ylim([0.001, 1.1])
    ax_multiplicity.set_xlim([0.07, 130])

    ax_mult_freq.set_xscale('log', base=10)
    ax_mult_freq.set_yscale('log', base=10)
    ax_mult_freq.set_xlabel('Gene multiplicity, ' + r'$m$', fontsize=14)
    ax_mult_freq.set_ylabel('Mean maximum allele frequency, ' +
                            r'$\overline{f}_{max}$',
                            fontsize=11)
    ax_mult_freq.text(-0.1,
                      1.07,
                      pt.sub_plot_labels[1],
                      fontsize=18,
                      fontweight='bold',
                      ha='center',
                      va='center',
                      transform=ax_mult_freq.transAxes)

    ax_venn.axis('off')
    ax_venn.text(-0.1,
                 1.07,
                 pt.sub_plot_labels[2],
                 fontsize=18,
                 fontweight='bold',
                 ha='center',
                 va='center',
                 transform=ax_venn.transAxes)

    alpha_treatment_dict = {'0': 0.5, '1': 0.5, '2': 0.8}

    significant_multiplicity_dict = {}

    significant_multiplicity_values_dict = {}

    multiplicity_dict = {}

    g_score_p_label_dict = {}

    all_mults = []
    all_freqs = []

    treatments_in_taxon = []

    label_y_axes = [0.3, 0.2, 0.1]

    for treatment_idx, treatment in enumerate(pt.treatments):

        significan_multiplicity_taxon_path = pt.get_path(
        ) + '/data/timecourse_final/parallel_genes_%s.txt' % (treatment +
                                                              taxon)
        if os.path.exists(significan_multiplicity_taxon_path) == False:
            continue
        treatments_in_taxon.append(treatment)
        significan_multiplicity_taxon = open(
            significan_multiplicity_taxon_path, "r")

        significan_multiplicity_list = []
        for i, line in enumerate(significan_multiplicity_taxon):
            if i == 0:
                continue
            line = line.strip()
            items = line.split(",")
            significan_multiplicity_list.append(items[0])

            if items[0] not in significant_multiplicity_values_dict:
                significant_multiplicity_values_dict[items[0]] = {}
                significant_multiplicity_values_dict[
                    items[0]][treatment] = float(items[-2])
            else:
                significant_multiplicity_values_dict[
                    items[0]][treatment] = float(items[-2])

        significant_multiplicity_dict[treatment] = significan_multiplicity_list

        populations = [
            treatment + taxon + replicate for replicate in pt.replicates
        ]

        # Load convergence matrix
        convergence_matrix = parse_file.parse_convergence_matrix(
            pt.get_path() + '/data/timecourse_final/' +
            ("%s_convergence_matrix.txt" % (treatment + taxon)))
        gene_parallelism_statistics = mutation_spectrum_utils.calculate_parallelism_statistics(
            convergence_matrix, populations, Lmin=100)
        #print(gene_parallelism_statistics)
        G, pvalue = mutation_spectrum_utils.calculate_total_parallelism(
            gene_parallelism_statistics)

        sys.stdout.write("Total parallelism for %s = %g (p=%g)\n" %
                         (treatment + taxon, G, pvalue))

        predictors = []
        responses = []

        gene_hits = []
        gene_predictors = []
        mean_gene_freqs = []

        Ls = []

        ax_mult_freqs_x = []
        ax_mult_freqs_y = []

        for gene_name in convergence_matrix.keys():

            convergence_matrix[gene_name][
                'length'] < 50 and convergence_matrix[gene_name]['length']

            Ls.append(convergence_matrix[gene_name]['length'])
            m = gene_parallelism_statistics[gene_name]['multiplicity']

            if gene_name not in multiplicity_dict:
                multiplicity_dict[gene_name] = {}
                multiplicity_dict[gene_name][treatment] = m
            else:
                multiplicity_dict[gene_name][treatment] = m

            n = 0
            nfixed = 0
            freqs = []
            nf_max = 0

            for population in populations:
                for t, L, f, f_max in convergence_matrix[gene_name][
                        'mutations'][population]:
                    fixed_weight = timecourse_utils.calculate_fixed_weight(
                        L, f)

                    predictors.append(m)
                    responses.append(fixed_weight)

                    n += 1
                    nfixed += fixed_weight

                    # get freqs for regression
                    #if L == parse_file.POLYMORPHIC:
                    #freqs.append(f_max)
                    nf_max += timecourse_utils.calculate_fixed_weight(L, f_max)

            if n > 0.5:
                gene_hits.append(n)
                gene_predictors.append(m)
                #mean_gene_freqs.append(np.mean(freqs))

                if nf_max > 0:
                    ax_mult_freqs_x.append(m)
                    ax_mult_freqs_y.append(nf_max / n)

        Ls = np.asarray(Ls)
        ntot = len(predictors)
        mavg = ntot * 1.0 / len(Ls)

        predictors, responses = (np.array(x) for x in zip(
            *sorted(zip(predictors, responses), key=lambda pair: (pair[0]))))

        gene_hits, gene_predictors = (np.array(x) for x in zip(*sorted(
            zip(gene_hits, gene_predictors), key=lambda pair: (pair[0]))))

        rescaled_predictors = np.exp(np.fabs(np.log(predictors / mavg)))

        null_survival_function = mutation_spectrum_utils.NullMultiplicitySurvivalFunction.from_parallelism_statistics(
            gene_parallelism_statistics)

        # default base is 10
        theory_ms = np.logspace(-2, 2, 100)
        theory_survivals = null_survival_function(theory_ms)
        theory_survivals /= theory_survivals[0]

        sys.stderr.write("Done!\n")

        ax_multiplicity.plot(theory_ms,
                             theory_survivals,
                             lw=3,
                             color=pt.get_colors(treatment),
                             alpha=0.8,
                             ls=':',
                             zorder=1)

        ax_multiplicity.plot(
            predictors, (len(predictors) - np.arange(0, len(predictors))) *
            1.0 / len(predictors),
            lw=3,
            color=pt.get_colors(treatment),
            alpha=0.8,
            ls='--',
            label=str(int(10**int(treatment))) + '-day',
            drawstyle='steps',
            zorder=2)

        #ax_multiplicity.text(0.2, 0.3, g_score_p_label_dict['0'], fontsize=25, fontweight='bold', ha='center', va='center', transform=ax_multiplicity.transAxes)
        #ax_multiplicity.text(0.2, 0.2, g_score_p_label_dict['1'], fontsize=25, fontweight='bold', ha='center', va='center', transform=ax_multiplicity.transAxes)
        #ax_multiplicity.text(0.2, 0.1, g_score_p_label_dict['2'], fontsize=25, fontweight='bold', ha='center', va='center', transform=ax_multiplicity.transAxes)

        if pvalue < 0.001:
            pretty_pvalue = r'$\ll 0.001$'
        else:
            pretty_pvalue = '=' + str(round(pvalue, 4))

        g_score_p_label = r'$\Delta \ell_{{{}}}=$'.format(
            str(10**int(treatment))) + str(round(
                G, 3)) + ', ' + r'$P$' + pretty_pvalue

        text_color = pt.lighten_color(pt.get_colors(treatment), amount=1.3)

        ax_multiplicity.text(0.26,
                             label_y_axes[treatment_idx],
                             g_score_p_label,
                             fontsize=7,
                             ha='center',
                             va='center',
                             color='k',
                             transform=ax_multiplicity.transAxes)

        ax_mult_freq.scatter(ax_mult_freqs_x,
                             ax_mult_freqs_y,
                             color=pt.get_colors(treatment),
                             edgecolors=pt.get_colors(treatment),
                             marker=pt.plot_species_marker(taxon),
                             alpha=alpha_treatment_dict[treatment])

        all_mults.extend(ax_mult_freqs_x)
        all_freqs.extend(ax_mult_freqs_y)

        #slope, intercept, r_value, p_value, std_err = stats.linregress(np.log10(ax_mult_freqs_x), np.log10(ax_mult_freqs_y))
        #print(slope, p_value)

    # make treatment pairs
    treatments_in_taxon.sort(key=float)

    for i in range(0, len(treatments_in_taxon)):

        for j in range(i + 1, len(treatments_in_taxon)):

            ax_mult_i_j = plt.subplot2grid((2, 3), (1, i + j - 1), colspan=1)
            ax_mult_i_j.set_xscale('log', base=10)
            ax_mult_i_j.set_yscale('log', base=10)
            ax_mult_i_j.set_xlabel(str(10**int(treatments_in_taxon[i])) +
                                   '-day gene multiplicity, ' + r'$m$',
                                   fontsize=14)
            ax_mult_i_j.set_ylabel(str(10**int(treatments_in_taxon[j])) +
                                   '-day gene multiplicity, ' + r'$m$',
                                   fontsize=14)
            ax_mult_i_j.plot([0.05, 200], [0.05, 200],
                             lw=3,
                             c='grey',
                             ls='--',
                             zorder=1)
            ax_mult_i_j.set_xlim([0.05, 200])
            ax_mult_i_j.set_ylim([0.05, 200])

            ax_mult_i_j.text(-0.1,
                             1.07,
                             pt.sub_plot_labels[2 + i + j],
                             fontsize=18,
                             fontweight='bold',
                             ha='center',
                             va='center',
                             transform=ax_mult_i_j.transAxes)

            multiplicity_pair = [
                (multiplicity_dict[gene_name][treatments_in_taxon[i]],
                 multiplicity_dict[gene_name][treatments_in_taxon[j]])
                for gene_name in sorted(multiplicity_dict)
                if (multiplicity_dict[gene_name][treatments_in_taxon[i]] > 0)
                and (multiplicity_dict[gene_name][treatments_in_taxon[j]] > 0)
            ]
            significant_multiplicity_pair = [
                (significant_multiplicity_values_dict[gene_name][
                    treatments_in_taxon[i]],
                 significant_multiplicity_values_dict[gene_name][
                     treatments_in_taxon[j]])
                for gene_name in sorted(significant_multiplicity_values_dict)
                if (treatments_in_taxon[i] in
                    significant_multiplicity_values_dict[gene_name]) and (
                        treatments_in_taxon[j] in
                        significant_multiplicity_values_dict[gene_name])
            ]

            # get mean colors
            ccv = ColorConverter()

            color_1 = np.array(
                ccv.to_rgb(pt.get_colors(treatments_in_taxon[i])))
            color_2 = np.array(
                ccv.to_rgb(pt.get_colors(treatments_in_taxon[j])))

            mix_color = 0.7 * (color_1 + color_2)
            mix_color = np.min([mix_color, [1.0, 1.0, 1.0]], 0)

            if (treatments_in_taxon[i] == '0') and (treatments_in_taxon[j]
                                                    == '1'):
                #mix_color = pt.lighten_color(mix_color, amount=2.8)
                mix_color = 'gold'

            mult_i = [x[0] for x in multiplicity_pair]
            mult_j = [x[1] for x in multiplicity_pair]

            ax_mult_i_j.scatter(mult_i,
                                mult_j,
                                marker=pt.plot_species_marker(taxon),
                                facecolors=mix_color,
                                edgecolors='none',
                                alpha=0.8,
                                s=90,
                                zorder=2)

            mult_significant_i = [x[0] for x in significant_multiplicity_pair]
            mult_significant_j = [x[1] for x in significant_multiplicity_pair]
            ax_mult_i_j.scatter(mult_significant_i,
                                mult_significant_j,
                                marker=pt.plot_species_marker(taxon),
                                facecolors=mix_color,
                                edgecolors='k',
                                lw=1.5,
                                alpha=0.7,
                                s=90,
                                zorder=3)

            #slope_mult, intercept_mult, r_value_mult, p_value_mult, std_err_mult = stats.linregress(np.log10(mult_significant_i), np.log10(mult_significant_j))

            mult_ij = mult_significant_i + mult_significant_j + mult_i + mult_j

            ax_mult_i_j.set_xlim([min(mult_ij) * 0.5, max(mult_ij) * 1.5])
            ax_mult_i_j.set_ylim([min(mult_ij) * 0.5, max(mult_ij) * 1.5])

            # null slope of 1
            #ratio = (slope_mult - slope_null) / std_err_mult
            #p_value_mult_new_null = stats.t.sf(np.abs(ratio), len(mult_significant_j)-2)*2

            #if p_value_mult_new_null < 0.05:
            #    x_log10_fit_range =  np.linspace(np.log10(min(mult_i) * 0.5), np.log10(max(mult_i) * 1.5), 10000)

            #    y_fit_range = 10 ** (slope_mult*x_log10_fit_range + intercept_mult)
            #    ax_mult_i_j.plot(10**x_log10_fit_range, y_fit_range, c='k', lw=3, linestyle='--', zorder=4)

            #ax_mult_i_j.text(0.05, 0.9, r'$\beta_{1}=$'+str(round(slope_mult,3)), fontsize=12, transform=ax_mult_i_j.transAxes)
            #ax_mult_i_j.text(0.05, 0.82, r'$r^{2}=$'+str(round(r_value_mult**2,3)), fontsize=12, transform=ax_mult_i_j.transAxes)
            #ax_mult_i_j.text(0.05, 0.74, pt.get_p_value_latex(p_value_mult_new_null), fontsize=12, transform=ax_mult_i_j.transAxes)

    #if taxon == 'F':
    #    subset_tuple = (len( significant_multiplicity_dict['0']), \
    #                    len( significant_multiplicity_dict['1']), \
    #                    len(set(significant_multiplicity_dict['0']) & set(significant_multiplicity_dict['1'])))

    #    venn = venn2(subsets = subset_tuple, ax=ax_venn, set_labels=('', '', ''), set_colors=(pt.get_colors('0'), pt.get_colors('1')))
    #    c = venn2_circles(subsets=subset_tuple, ax=ax_venn, linestyle='dashed')

    subset_tuple = (len( significant_multiplicity_dict['0']), \
                    len( significant_multiplicity_dict['1']), \
                    len(set(significant_multiplicity_dict['0']) & set(significant_multiplicity_dict['1'])), \
                    len(significant_multiplicity_dict['2']), \
                    len(set(significant_multiplicity_dict['0']) & set(significant_multiplicity_dict['2'])), \
                    len(set(significant_multiplicity_dict['1']) & set(significant_multiplicity_dict['2'])),  \
                    len(set(significant_multiplicity_dict['1']) & set(significant_multiplicity_dict['1']) & set(significant_multiplicity_dict['2'])))

    venn = venn3(subsets=subset_tuple,
                 ax=ax_venn,
                 set_labels=('', '', ''),
                 set_colors=(pt.get_colors('0'), pt.get_colors('1'),
                             pt.get_colors('2')))
    c = venn3_circles(subsets=subset_tuple, ax=ax_venn, linestyle='dashed')

    ax_mult_freq.set_xlim([min(all_mults) * 0.5, max(all_mults) * 1.5])
    ax_mult_freq.set_ylim([min(all_freqs) * 0.5, max(all_freqs) * 1.5])

    fig.suptitle(pt.latex_dict[taxon], fontsize=30)

    fig.subplots_adjust(wspace=0.3)  #hspace=0.3, wspace=0.5
    fig_name = pt.get_path() + "/figs/multiplicity_%s.jpg" % taxon
    fig.savefig(fig_name,
                format='jpg',
                bbox_inches="tight",
                pad_inches=0.4,
                dpi=600)
    plt.close()
예제 #5
0
missed_opportunity_axis.set_xlabel('Partition time, $t^*$', fontsize=6)
missed_opportunity_axis.set_xticks(figure_utils.time_xticks)
missed_opportunity_axis.set_xticklabels(figure_utils.time_xticklabels)
missed_opportunity_axis.set_xlim([0, 55000])
#missed_opportunity_axis.set_ylim([-20,35])

#######################################
#
# Now do calculations and plot figures
#
#######################################

tstars = numpy.arange(0, 111) * 500

# Load convergence matrix
convergence_matrix = parse_file.parse_convergence_matrix(
    parse_file.data_directory + ('%s_convergence_matrix.txt' % level))

# Load significant genes
parallel_genes = parse_file.parse_parallel_genes(parse_file.data_directory +
                                                 ('parallel_%ss.txt' % level))

# Calculate gene parallelism statistics
gene_parallelism_statistics = mutation_spectrum_utils.calculate_parallelism_statistics(
    convergence_matrix, populations)

# Calculate gene name, pop, and time vectors

# All genes
all_gene_names = []
all_pops = []
all_times = []
예제 #6
0

    ###################################################################
    #
    # Gene parallelism analysis
    #
    ###################################################################

    for include_svs in [True, False]:

        if include_svs:
            convergence_matrix_filename = parse_file.data_directory + "gene_convergence_matrix.txt"
        else:
            convergence_matrix_filename = parse_file.data_directory + "gene_convergence_matrix_nosvs.txt"
            
        convergence_matrix = parse_file.parse_convergence_matrix(convergence_matrix_filename)

        # Calculate median appearance time
        pooled_appearance_times = []
        for gene_name in convergence_matrix.keys():
            for population in metapopulation_populations[metapopulation]:
                for t,L,Lclade,f in convergence_matrix[gene_name]['mutations'][population]:
                    pooled_appearance_times.append(t)
                    
        tstar = numpy.median(pooled_appearance_times)
                    
        sys.stdout.write("Median appearance time = %g\n" % tstar)

        gene_parallelism_statistics = mutation_spectrum_utils.calculate_parallelism_statistics(convergence_matrix,allowed_populations=metapopulation_populations[metapopulation],Lmin=100)
        
        G, pvalue = mutation_spectrum_utils.calculate_total_parallelism(gene_parallelism_statistics)
예제 #7
0
treatments = pt.treatments
replicates = pt.replicates

G_subsample_dict = {}

G_all_mutations_dict = {}

ntot_subsample = 200
num_bootstraps = 10000
iter = 1000

for treatment in ['1']:

    # Load convergence matrix
    convergence_matrix_B = parse_file.parse_convergence_matrix(
        pt.get_path() + '/data/timecourse_final/' +
        ("%s_convergence_matrix.txt" % (treatment + 'B')))
    convergence_matrix_S = parse_file.parse_convergence_matrix(
        pt.get_path() + '/data/timecourse_final/' +
        ("%s_convergence_matrix.txt" % (treatment + 'S')))

    populations_B = [treatment + 'B' + replicate for replicate in replicates]
    populations_S = [treatment + 'S' + replicate for replicate in replicates]

    gene_parallelism_statistics_B = mutation_spectrum_utils.calculate_parallelism_statistics(
        convergence_matrix_B, populations_B, Lmin=100)
    gene_parallelism_statistics_S = mutation_spectrum_utils.calculate_parallelism_statistics(
        convergence_matrix_S, populations_S, Lmin=100)

    #G_subsample_list = []
    #for i in range(subsamples):
예제 #8
0
axis.get_yaxis().tick_left()

axis.plot([0, 20], [1, 1], 'k-', linewidth=0.25)

###
#
# Do calculations
#
###

populations = parse_file.complete_nonmutator_lines

tstars = numpy.arange(10, 110) * 500

sys.stderr.write("Loading convergence matrix...\t")
gene_convergence_matrix = parse_file.parse_convergence_matrix(
    parse_file.data_directory + ("gene_convergence_matrix.txt"))
operon_convergence_matrix = parse_file.parse_convergence_matrix(
    parse_file.data_directory + ("operon_convergence_matrix.txt"))

sys.stderr.write("Done!\n")

# Calculate gene parallelism statistics
gene_parallelism_statistics = mutation_spectrum_utils.calculate_parallelism_statistics(
    gene_convergence_matrix, populations)
operon_parallelism_statistics = mutation_spectrum_utils.calculate_parallelism_statistics(
    operon_convergence_matrix, populations)

num_plotted = 0
num_off_diagonal = 0

for operon_name in operon_parallelism_statistics.keys():
예제 #9
0
def calculate_parallelism_statistics_partition(taxon, treatment, fmax_partition=0.5):

    convergence_matrix = parse_file.parse_convergence_matrix(pt.get_path() + '/data/timecourse_final/' +("%s_convergence_matrix.txt" % (treatment+taxon)))

    populations = [treatment+taxon + replicate for replicate in pt.replicates ]

    significant_genes = []

    significant_multiplicity_taxon_path = pt.get_path() + '/data/timecourse_final/parallel_genes_%s.txt' % (treatment+taxon)
    #if os.path.exists(significant_multiplicity_taxon_path) == False:
    #    continue
    significant_multiplicity_taxon = open(significant_multiplicity_taxon_path, "r")
    for i, line in enumerate( significant_multiplicity_taxon ):
        if i == 0:
            continue
        line = line.strip()
        items = line.split(",")
        if items[0] not in significant_genes:
            significant_genes.append(items[0])

    # Now calculate gene counts
    #Ltot = 0
    #Ngenes = 0
    ntot_less = 0
    ntot_greater = 0
    fmax_true_false = []
    positions = [0]
    n_greater_all = []
    n_less_all = []
    for gene_name in sorted(convergence_matrix.keys()):

        if gene_name not in significant_genes:
            continue

        #L = max([convergence_matrix[gene_name]['length'],Lmin])
        n_less = 0
        n_greater = 0
        #num_pops = 0

        for population in populations:

            # filter by cutoff for maximum allele Frequency
            convergence_matrix_mutations_population_filtered_greater = [k for k in convergence_matrix[gene_name]['mutations'][population] if (k[-1] >= fmax_partition ) ]
            convergence_matrix_mutations_population_filtered_less = [k for k in convergence_matrix[gene_name]['mutations'][population] if (k[-1] < fmax_partition ) ]

            new_muts_greater = len(convergence_matrix_mutations_population_filtered_greater)
            new_muts_less = len(convergence_matrix_mutations_population_filtered_less)

            #fmax_greater =

            if (new_muts_greater > 0) and (new_muts_less > 0):

                n_greater += new_muts_greater
                n_less += new_muts_less

                #num_pops += 1
                #n += new_muts
                #for t,l,f,f_max in convergence_matrix_mutations_population_filtered_greater:
                #    times.append(t)
                #    # get maximum allele frequency

        if (n_greater == 0) or (new_muts_less == 0):
            continue

        fmax_true_false.extend([True] * n_greater)
        fmax_true_false.extend([False] * n_less)

        ntot_less += n_less
        ntot_greater += n_greater


        n_greater_all.append(n_greater)
        n_less_all.append(n_less)


        positions.append(ntot_less+ntot_greater)


    n_less_all = np.asarray(n_less_all)
    n_greater_all = np.asarray(n_greater_all)
    n_all = n_less_all + n_greater_all

    positions = np.asarray(positions)
    fmax_true_false = np.asarray(fmax_true_false)


    ntot = ntot_less + ntot_greater

    likelihood_partition = sum((n_less_all*np.log((n_less_all*ntot)/(ntot_less*n_all) )) + (n_greater_all*np.log((n_greater_all*ntot)/(ntot_greater*n_all) )))

    print("likelihood", fmax_partition, likelihood_partition)

    null_likelihood_partition = []

    for i in range(1000):

        fmax_true_false_permute = np.random.permutation(fmax_true_false)

        n_less_permute_all = []
        n_greater_permute_all = []

        for gene_idx in range(len(positions)-1):

            gene_fmax_permute = fmax_true_false_permute[positions[gene_idx]:positions[gene_idx+1]]

            n_greater_permute = len(gene_fmax_permute[gene_fmax_permute==True])
            n_less_permute = len(gene_fmax_permute[gene_fmax_permute!=True])

            if (n_greater_permute > 0 ) and (n_less_permute > 0):

                n_greater_permute_all.append(n_greater_permute)
                n_less_permute_all.append(n_less_permute)

        n_less_permute_all = np.asarray(n_less_permute_all)
        n_greater_permute_all = np.asarray(n_greater_permute_all)

        n_permute_all = n_less_permute_all + n_greater_permute_all

        ntot_less_permute = len(n_less_permute_all)
        ntot_greater_permute = len(n_greater_permute_all)

        ntot_permute = ntot_less_permute + ntot_greater_permute


        likelihood_partition_permute = sum((n_less_permute_all*np.log((n_less_permute_all*ntot_permute)/(ntot_less_permute*n_permute_all) )) + (n_greater_permute_all*np.log((n_greater_permute_all*ntot_permute)/(ntot_greater_permute*n_permute_all) )))

        null_likelihood_partition.append(likelihood_partition_permute)

    null_likelihood_partition = np.asarray(null_likelihood_partition)