Exemplo n.º 1
0
def infer_branch_associations(path, total_strains_count, strain_fraction_branch_association):
    from sf_geneCluster_align_makeTree import load_sorted_clusters
    from sf_coreTree_json import metadata_load
    metaFile= '%s%s'%(path,'metainfo.tsv')
    data_description = '%s%s'%(path,'meta_tidy.tsv')
    association_dict = defaultdict(dict)
    metadata = Metadata(metaFile, data_description)
    metadata_dict = metadata.to_dict()

    sorted_genelist = load_sorted_clusters(path)
    ## sorted_genelist: [(clusterID, [ count_strains,[memb1,...],count_genes]),...]
    for clusterID, gene in sorted_genelist:
        if gene[-1]>=total_strains_count*strain_fraction_branch_association: # and clusterID=='GC00001136':
            print(clusterID)
            tree = Phylo.read("%s/geneCluster/%s.nwk"%(path, clusterID), 'newick')
            assoc = BranchAssociation(tree, metadata_dict)
            for col, d  in metadata.data_description.iterrows():
                if d['associate']=='yes':
                    if 'log_scale' in d and d['log_scale']=='yes':
                        t = lambda x:np.log(x)
                    else:
                        t = lambda x:x
                    assoc.calc_up_down_averages(d["meta_category"], transform = t)
                    max_assoc = assoc.calc_significance()
                    association_dict[clusterID][d["meta_category"]] = max_assoc

    write_pickle("%s/branch_association.cpk"%path, association_dict)
Exemplo n.º 2
0
def infer_presence_absence_associations(path, total_strains_count,
    min_strain_fraction_association, max_strain_fraction_association):
    from sf_geneCluster_align_makeTree import load_sorted_clusters
    from sf_coreTree_json import metadata_load
    metaFile= '%s%s'%(path,'metainfo.tsv')
    data_description = '%s%s'%(path,'meta_tidy.tsv')
    association_dict = defaultdict(dict)
    metadata = Metadata(metaFile, data_description)
    metadata_dict = metadata.to_dict()
    min_strains_association = total_strains_count*min_strain_fraction_association
    max_strains_association = total_strains_count*max_strain_fraction_association
    sorted_genelist = load_sorted_clusters(path)
    ## sorted_genelist: [(clusterID, [ count_strains,[memb1,...],count_genes]),...]
    # TODO fix vis
    tree = Phylo.read("%sgeneCluster/strain_tree.nwk"%(path), 'newick')
    assoc = PresenceAbsenceAssociation(tree, metadata_dict)
    for clusterID, gene in sorted_genelist:
        if gene[-1]>min_strains_association and gene[-1]<max_strains_association:
            print(clusterID)
            gl = load_gain_loss(path, clusterID)
            for col, d  in metadata.data_description.iterrows():
                if d['associate']=='yes':
                    if 'log_scale' in d and d['log_scale']=='yes':
                        t = lambda x:np.log(x)
                    else:
                        t = lambda x:x
                    assoc.set_gain_loss(gl)
                    score = assoc.calc_association_simple(d["meta_category"], transform = t)
                    if np.isinf(score):
                        association_dict[clusterID][d["meta_category"]] = 0.0
                    else:
                        association_dict[clusterID][d["meta_category"]] = np.abs(score)

    write_pickle("%s/presence_absence_association.cpk"%path, association_dict)
def make_gene_presence_absence_matrix(input_filepath):
    os.chdir(input_filepath)
    gene_order= ','.join([gene.rstrip() for gene, content in load_sorted_clusters('./')])

    with open('./geneCluster/genePresence.aln') as inputf,\
         open(output_filepath,'wb') as outputf:
        outputf.write('accession,%s\n'%gene_order)
        for strain, genes in read_fasta(inputf).iteritems():
            outputf.write('%s,%s\n'%(strain,','.join(genes)))
def find_and_merge_unclustered_genes(path,
                                     nstrains,
                                     window_size=5,
                                     strain_proportion=0.3,
                                     sigma_scale=3):
    """
    detect the unclustered genes and concatenate them
    params:
        nstrains: total number of strains
        window_size
        strain_proportion
        sigma_scale
    return:
        a dict with key of the merged cluster and value of
        a list of related unclustered cluster-name for deletion

    """
    file_path = '%s%s' % (path, 'geneCluster/')
    gene_clusters = load_sorted_clusters(path)
    length_to_cluster = defaultdict(list)
    length_list = []
    ## calculate cluster length distribution, link clusterIDs with their clusterLength
    for gid, (clusterID, gene) in enumerate(gene_clusters):
        # average length of the cluster in amino acids
        clusterLength = int(
            np.mean([
                len(igene)
                for igene in read_fasta(file_path + '%s%s' %
                                        (clusterID, '.fna')).values()
            ]) / 3.0)
        length_to_cluster[clusterLength].append(clusterID)
        length_list.append(clusterLength)
    cluster_length_distribution = np.bincount(length_list)

    ## calculate smoothed cluster length distribution
    window_size = 5
    window = np.ones(window_size, dtype=float) / window_size
    smoothed_length_distribution = np.convolve(cluster_length_distribution,
                                               window,
                                               mode='same')

    ## detect peaks
    peaks = (cluster_length_distribution -
             smoothed_length_distribution) > np.maximum(
                 strain_proportion * nstrains,
                 sigma_scale * np.sqrt(smoothed_length_distribution))
    position_peaks = np.where(peaks)[0]
    #cluster_len_peaks= position_peaks*3
    ## concatenate clusters with the same aver. length, return dict of these clusters
    merged_clusters_dict = defaultdict(dict)
    for index, i_peak in enumerate(position_peaks, 1):
        merged_cluster_filename, cluster_needed_deletion = concatenate_cluster_files(
            length_to_cluster[i_peak], index, file_path)
        merged_clusters_dict[merged_cluster_filename] = cluster_needed_deletion
    return merged_clusters_dict
def make_genepresence_alignment(path, disable_gain_loss,
                                merged_gain_loss_output):
    '''
    loop over all gene clusters and append 0/1 to strain specific
    string used as pseudo alignment of gene presence absence
    '''
    geneClusterPath = '%s%s' % (path, 'protein_fna/diamond_matches/')
    output_path = '%s%s' % (path, 'geneCluster/')

    ## load strain list and prepare for gene presence/absence
    strain_list = load_pickle('%s%s' % (path, 'strain_list.cpk'))
    set_totalStrain = set([istrain for istrain in strain_list])
    totalStrain = len(set_totalStrain)
    dt_strainGene = defaultdict(str)

    sorted_genelist = load_sorted_clusters(path)
    ## sorted_genelist: [(clusterID, [ count_strains,[memb1,...],count_genes]),...]
    for clusterID, gene in sorted_genelist:
        ## append 0/1 to each strain
        create_genePresence(dt_strainGene, totalStrain, set_totalStrain,
                            gene[1])

    with open('%s%s' % (output_path, 'genePresence.aln'),
              'wb') as presence_outfile:
        for istkey in dt_strainGene:
            write_in_fa(presence_outfile, istkey, dt_strainGene[istkey])
    write_pickle('%s%s' % (output_path, 'dt_genePresence.cpk'), dt_strainGene)

    if disable_gain_loss:
        geneEvents_dt = {i: 0 for i in range(len(sorted_genelist))}
        write_pickle('%s%s' % (output_path, 'dt_geneEvents.cpk'),
                     geneEvents_dt)
        if merged_gain_loss_output:
            gene_loss_fname = '%s%s' % (output_path, 'geneGainLossEvent.json')
            write_json(dt_strainGene, gene_loss_fname, indent=1)
        else:
            ## strainID as key, presence pattern as value (converted into np.array)
            keylist = dt_strainGene.keys()
            keylist.sort()
            strainID_keymap = {ind: k
                               for ind, k in enumerate(keylist)
                               }  # dict(zip(keylist, range(3)))
            presence_arr = np.array([
                np.array(dt_strainGene[k], 'c') for k in keylist
            ])  # 0: present, 3: absent
            presence_arr[presence_arr == '1'] = '3'
            for ind, (clusterID, gene) in enumerate(sorted_genelist):
                pattern_dt = {
                    strainID_keymap[strain_ind]: str(patt)
                    for strain_ind, patt in enumerate(presence_arr[:, ind])
                }
                pattern_fname = '%s%s_patterns.json' % (output_path, clusterID)
                write_json(pattern_dt, pattern_fname, indent=1)
def make_genepresence_alignment(path, disable_gain_loss, merged_gain_loss_output):
    '''
    loop over all gene clusters and append 0/1 to strain specific
    string used as pseudo alignment of gene presence absence
    '''
    geneClusterPath='%s%s'%(path,'protein_fna/diamond_matches/')
    output_path='%s%s'%(path,'geneCluster/');

    ## load strain list and prepare for gene presence/absence
    strain_list= load_pickle('%s%s'%(path,'strain_list.cpk'))
    set_totalStrain=set([ istrain for istrain in strain_list ])
    totalStrain=len(set_totalStrain)
    dt_strainGene= defaultdict(str)

    sorted_genelist = load_sorted_clusters(path)
    ## sorted_genelist: [(clusterID, [ count_strains,[memb1,...],count_genes]),...]
    for clusterID, gene in sorted_genelist:
        ## append 0/1 to each strain
        create_genePresence(dt_strainGene, totalStrain, set_totalStrain, gene[1])

    with open('%s%s'%(output_path,'genePresence.aln'),'wb') as presence_outfile:
        for istkey in dt_strainGene:
            write_in_fa( presence_outfile, istkey, dt_strainGene[istkey])
    write_pickle('%s%s'%(output_path,'dt_genePresence.cpk'), dt_strainGene)

    if disable_gain_loss:
        geneEvents_dt={ i:0 for i in range(len(sorted_genelist)) }
        write_pickle('%s%s'%(output_path,'dt_geneEvents.cpk'), geneEvents_dt)
        if merged_gain_loss_output:
            gene_loss_fname='%s%s'%(output_path,'geneGainLossEvent.json')
            write_json(dt_strainGene, gene_loss_fname, indent=1)
        else:
            ## strainID as key, presence pattern as value (converted into np.array)
            keylist= dt_strainGene.keys(); keylist.sort()
            strainID_keymap= {ind:k for ind, k in enumerate(keylist)} # dict(zip(keylist, range(3)))
            presence_arr= np.array([ np.array(dt_strainGene[k],'c') for k in keylist]) # 0: present, 3: absent
            presence_arr[presence_arr=='1']='3'
            for ind, (clusterID, gene) in enumerate(sorted_genelist):
                pattern_dt= { strainID_keymap[strain_ind]:str(patt) for strain_ind, patt in enumerate(presence_arr[:, ind])}
                pattern_fname='%s%s_patterns.json'%(output_path,clusterID)
                write_json(pattern_dt, pattern_fname, indent=1)
def make_core_all_targz(species_name, analysis_folder):
    species_folder= analysis_folder+species_name
    cwd=os.getcwd()
    os.chdir(species_folder)

    ## packing core genes
    os.system('mkdir -p ./core_gene_alignments;')
    with open('./geneCluster/core_geneList.txt') as core_list:
     # all core gene alignments in FASTA files
        for gene in core_list:
            os.system('cp ./vis/geneCluster/'+gene.rstrip()+'.gz ./core_gene_alignments')
            os.system('cp ./vis/geneCluster/'+gene.rstrip().replace('_na','_aa')+'.gz ./core_gene_alignments')
        #os.system('gunzip ./core_gene_alignments/*')
        os.system('tar -zcf core_gene_alignments.tar.gz core_gene_alignments; rm -r ./core_gene_alignments; mv core_gene_alignments.tar.gz vis')

    ## packing all genes in pan-genome
    os.system('mkdir -p all_gene_alignments')
    for gene, content in load_sorted_clusters('./'):
        os.system('cp ./vis/geneCluster/'+gene.rstrip()+'_na_aln.fa.gz ./all_gene_alignments')
        os.system('cp ./vis/geneCluster/'+gene.rstrip()+'_aa_aln.fa.gz ./all_gene_alignments')
    #os.system('gunzip ./all_gene_alignments/*')
    os.system('tar -zcf all_gene_alignments.tar.gz all_gene_alignments; rm -r all_gene_alignments; mv all_gene_alignments.tar.gz ./vis')
def find_and_merge_unclustered_genes( path, nstrains, window_size=5, strain_proportion=0.3 , sigma_scale=3):
    """
    detect the unclustered genes and concatenate them
    params:
        nstrains: total number of strains
        window_size
        strain_proportion
        sigma_scale
    return:
        a dict with key of the merged cluster and value of
        a list of related unclustered cluster-name for deletion

    """
    file_path='%s%s'%(path,'geneCluster/')
    gene_clusters = load_sorted_clusters(path)
    length_to_cluster = defaultdict(list)
    length_list = []
    ## calculate cluster length distribution, link clusterIDs with their clusterLength
    for gid, (clusterID, gene) in enumerate(gene_clusters):
        # average length of the cluster in amino acids
        clusterLength= int(np.mean([len(igene) for igene in read_fasta(file_path+'%s%s'%(clusterID,'.fna')).values()])/3.0)
        length_to_cluster[clusterLength].append(clusterID)
        length_list.append(clusterLength)
    cluster_length_distribution = np.bincount(length_list)

    ## calculate smoothed cluster length distribution
    window_size=5
    window = np.ones(window_size, dtype=float)/window_size
    smoothed_length_distribution = np.convolve(cluster_length_distribution, window, mode='same')

    ## detect peaks
    peaks = (cluster_length_distribution - smoothed_length_distribution)> np.maximum(strain_proportion*nstrains, sigma_scale*np.sqrt(smoothed_length_distribution))
    position_peaks =np.where(peaks)[0]; #cluster_len_peaks= position_peaks*3
    ## concatenate clusters with the same aver. length, return dict of these clusters
    merged_clusters_dict=defaultdict(dict)
    for index, i_peak in enumerate(position_peaks,1):
        merged_cluster_filename, cluster_needed_deletion=concatenate_cluster_files(length_to_cluster[i_peak], index,file_path)
        merged_clusters_dict[merged_cluster_filename]=cluster_needed_deletion
    return merged_clusters_dict
Exemplo n.º 9
0
def export_gain_loss(tree, path, merged_gain_loss_output):
    '''
    '''
    # write final tree with internal node names as assigned by treetime
    sep = '/'
    output_path = sep.join([path.rstrip(sep), 'geneCluster/'])
    events_dict_path = sep.join([output_path, 'dt_geneEvents.cpk'])
    gene_pattern_dict_path = sep.join([output_path, 'dt_genePattern.cpk'])

    tree_fname = sep.join([output_path, 'strain_tree.nwk'])
    Phylo.write(tree.tree, tree_fname, 'newick')

    gene_gain_loss_dict = defaultdict(str)
    preorder_strain_list = []  #store the preorder nodes as strain list
    for node in tree.tree.find_clades(
            order='preorder'):  # order does not matter much here
        if node.up is None: continue
        #print(node.name ,len(node.geneevents),node.geneevents)
        gain_loss = [
            str(int(ancestral) * 2 + int(derived)) for ancestral, derived in
            zip(node.up.genepresence, node.genepresence)
        ]
        gene_gain_loss_dict[node.name] = "".join(gain_loss)
        preorder_strain_list.append(node.name)

    gain_loss_array = np.array(
        [[i for i in gain_loss_str]
         for gain_loss_str in gene_gain_loss_dict.values()],
        dtype=int)
    # 1 and 2 are codes for gain/loss events
    events_array = ((gain_loss_array == 1) |
                    (gain_loss_array == 2)).sum(axis=0)
    events_dict = {index: event for index, event in enumerate(events_array)}

    write_pickle(events_dict_path, events_dict)

    if merged_gain_loss_output:
        ## export gene loss dict to json for visualization
        #gene_loss_fname = sep.join([ output_path, 'geneGainLossEvent.json'])
        #write_json(gene_gain_loss_dict, gene_loss_fname, indent=1)
        write_pickle(gene_pattern_dict_path, gene_gain_loss_dict)
    else:
        ## strainID as key, presence pattern as value (converted into np.array)
        sorted_genelist = load_sorted_clusters(path)
        strainID_keymap = {
            ind: k
            for ind, k in enumerate(preorder_strain_list)
        }
        #presence_arr= np.array([ np.fromstring(gene_gain_loss_dict[k], np.int8)-48 for k in preorder_strain_list])
        presence_arr = np.array([
            np.array(gene_gain_loss_dict[k], 'c') for k in preorder_strain_list
        ])
        ## if true, write pattern dict instead of pattern string in a json file
        pattern_json_flag = False
        for ind, (clusterID, gene) in enumerate(sorted_genelist):
            pattern_fname = '%s%s_patterns.json' % (output_path, clusterID)
            if pattern_json_flag:
                pattern_dt = {
                    strainID_keymap[strain_ind]: str(patt)
                    for strain_ind, patt in enumerate(presence_arr[:, ind])
                }
                write_json(pattern_dt, pattern_fname, indent=1)
            #print(preorder_strain_list,clusterID)
            #print(''.join([ str(patt) for patt in presence_arr[:, ind]]))
            with open(pattern_fname, 'w') as write_pattern:
                write_pattern.write(
                    '{"patterns":"' +
                    ''.join([str(patt)
                             for patt in presence_arr[:, ind]]) + '"}')
def create_core_SNP_matrix(path,
                           core_cutoff=1.0,
                           core_gene_strain_fpath=''):  #1.0
    """ create SNP matrix using core gene SNPs
        input: strain_list.cpk, core_geneList.cpk
        output: SNP_whole_matrix.aln
        core_cutoff: percentage of strains used to decide whether a gene is core
            default: 1.0 (strictly core gene, which is present in all strains)
            customized: 0.9 ( soft core, considered as core if present in 90% of strains)
    """
    import os, sys, operator
    import numpy as np
    import numpy.ma as ma
    from collections import defaultdict
    from sf_miscellaneous import read_fasta, write_pickle, load_pickle, write_in_fa

    alnFilePath = '%s%s' % (path, 'geneCluster/')
    output_path = alnFilePath

    ## create core gene list
    corelist = []
    strain_list = load_pickle(path + 'strain_list.cpk')
    totalStrain = len(strain_list)
    sorted_geneList = load_sorted_clusters(path)
    if core_gene_strain_fpath != '':
        with open(core_gene_strain_fpath, 'rb') as core_gene_strain_file:
            core_strain_set = set(
                [i.rstrip().replace('-', '_') for i in core_gene_strain_file])
    with open(output_path + 'core_geneList.txt', 'wb') as outfile:
        for clusterID, vg in sorted_geneList:
            if core_cutoff == 1.0:
                strain_core_cutoff = totalStrain
            else:
                strain_core_cutoff = int(totalStrain * core_cutoff)
            if vg[0] == vg[2] and vg[0] >= strain_core_cutoff:
                coreGeneName = '%s%s' % (clusterID, '_na_aln.fa')
                ## sequences might be discarded because of premature stops
                coreGeneName_path = alnFilePath + coreGeneName
                if os.path.exists(coreGeneName_path) and len(
                        read_fasta(coreGeneName_path)) >= strain_core_cutoff:
                    if core_gene_strain_fpath != '' and len(
                            core_strain_set -
                            set([i.split('|')[0] for i in vg[1]])) != 0:
                        continue
                    outfile.write(coreGeneName + '\n')
                    corelist.append(coreGeneName)
                else:
                    #print '%s%s%s'%('warning: ',coreGeneName_path,' is not a core gene')
                    pass

        write_pickle(output_path + 'core_geneList.cpk', corelist)

    refSeqList = load_pickle(path + 'strain_list.cpk')
    refSeqList.sort()

    snp_fre_lst = []
    snp_wh_matrix_flag = 0
    snp_pos_dt = defaultdict(list)
    snp_whole_matrix = np.array([])
    snps_by_gene = []
    for align_file in corelist:  ## core genes
        nuc_array = np.array([])  # array to store nucleotides for each gene
        gene_seq_dt = read_fasta(alnFilePath + align_file)
        if core_cutoff != 1.0:
            # set sequences for missing gene (space*gene_length)
            missing_gene_seq = ' ' * len(gene_seq_dt.values()[0])
            totalStrain_sorted_lst = sorted(strain_list)
        # build strain_seq_dt from gene_seq_dt
        strain_seq_dt = defaultdict()
        for gene, seq in gene_seq_dt.iteritems():
            strain_seq_dt[gene.split('-')[0]] = seq  # strain-locus_tag-...
        strain_seq_sorted_lst = sorted(strain_seq_dt.items(),
                                       key=lambda x: x[0])

        start_flag = 0
        if core_cutoff == 1.0:
            for ka, va in strain_seq_sorted_lst:
                if start_flag == 0:
                    nuc_array = np.array(np.fromstring(va, dtype='S1'))
                    start_flag = 1
                else:
                    nuc_array = np.vstack(
                        (nuc_array, np.fromstring(va, dtype='S1')))
            ## find SNP positions
            position_polymorphic = np.any(nuc_array != nuc_array[0, :], axis=0)
            position_has_gap = np.any(nuc_array == '-', axis=0)
            position_SNP = position_polymorphic & (~position_has_gap)
            snp_columns = nuc_array[:, position_SNP]
            snp_pos_dt[align_file] = np.where(position_SNP)[0]
        else:
            ## add '-' for missing genes when dealing with soft core genes
            core_gene_strain = [gene for gene in strain_seq_dt.keys()]
            for strain in totalStrain_sorted_lst:
                if start_flag == 0:
                    if strain in core_gene_strain:
                        nuc_array = np.array(
                            np.fromstring(strain_seq_dt[strain], dtype='S1'))
                    else:
                        print 'Soft core gene: gene absent in strain %s on cluster %s' % (
                            strain, align_file)
                        nuc_array = np.array(
                            np.fromstring(missing_gene_seq, dtype='S1'))
                    start_flag = 1
                else:
                    if strain in core_gene_strain:
                        nuc_array = np.vstack(
                            (nuc_array,
                             np.fromstring(strain_seq_dt[strain], dtype='S1')))
                    else:
                        print 'Soft core gene: gene absent in strain %s on cluster %s' % (
                            strain, align_file)
                        nuc_array = np.vstack((nuc_array,
                                               np.fromstring(missing_gene_seq,
                                                             dtype='S1')))
            ## find SNP positions
            ## mask missing genes -- determine rows that have ' ' in every column
            is_missing = np.all(nuc_array == ' ', axis=1)
            masked_non_missing_array = np.ma.masked_array(
                nuc_array, nuc_array == ' ')
            position_polymorphic = np.any(
                masked_non_missing_array != masked_non_missing_array[0, :],
                axis=0)
            position_has_gap = np.any(masked_non_missing_array == '-', axis=0)
            position_SNP = position_polymorphic & (~position_has_gap)
            # the below seems duplicated from 5 lines above??
            if is_missing.sum() > 0:  # with missing genes
                nuc_array[is_missing] = '-'
            snp_columns = nuc_array[:, position_SNP]
            snp_pos_dt[align_file] = np.where(position_SNP)[0]
            #print snp_columns

        if snp_wh_matrix_flag == 0:
            snp_whole_matrix = snp_columns
            snp_wh_matrix_flag = 1
        else:
            snp_whole_matrix = np.hstack((snp_whole_matrix, snp_columns))
    write_pickle(output_path + 'snp_pos.cpk', snp_pos_dt)

    with open(output_path + 'SNP_whole_matrix.aln', 'wb') as outfile:
        for ind, isw in enumerate(snp_whole_matrix):
            write_in_fa(outfile, refSeqList[ind], isw.tostring())
Exemplo n.º 11
0
def geneCluster_to_json(path, enable_RNA_clustering, store_locus_tag,
                        raw_locus_tag, optional_table_column):
    """
    create json file for gene cluster table visualzition
    input:  path to genecluster output
    output: geneCluster.json
    """
    # define path and make output directory
    geneCluster_path = '%s%s' % (path, 'geneCluster/')
    output_path = '%s%s' % (path, 'vis/')

    # open files
    geneClusterJSON_outfile = open(output_path + 'geneCluster.json', 'wb')
    ##store locus_tags in a separate file for large dataset
    if store_locus_tag:
        locus_tag_outfile = open(path + 'search_locus_tag.tsv', 'wb')

    ### load precomputed annotations, diversity, associations etc
    # load geneID_to_descriptions
    geneID_to_descriptions = load_pickle(path + 'geneID_to_description.cpk')

    if enable_RNA_clustering:
        # load RNAID_to_description_file
        geneID_to_descriptions.update(
            load_pickle(path + 'RNAID_to_description.cpk'))

    gene_diversity_Dt = load_pickle(geneCluster_path + 'gene_diversity.cpk')
    ## load gain/loss event count dictionary
    dt_geneEvents = load_pickle(geneCluster_path + 'dt_geneEvents.cpk')
    ## load association
    branch_associations_path = path + 'branch_association.cpk'
    if os.path.isfile(branch_associations_path):
        branch_associations = load_pickle(branch_associations_path)
    else:
        branch_associations = {}
    presence_absence_associations_path = path + 'presence_absence_association.cpk'
    if os.path.isfile(presence_absence_associations_path):
        presence_absence_associations = load_pickle(
            presence_absence_associations_path)
    else:
        presence_absence_associations = {}

    ## load list of clustered sorted by strain count
    sorted_genelist = load_sorted_clusters(path)

    geneClusterJSON_outfile.write('[')
    ## sorted_genelist: [(clusterID, [ count_strains,[memb1,...],count_genes]),...]
    for gid, (clusterID, gene) in enumerate(sorted_genelist):
        strain_count, gene_list, gene_count = gene
        # #print strain_count, gene_count
        if gid != 0:  ## begin
            geneClusterJSON_outfile.write(',\n')

        ## annotation majority
        allAnn, majority_annotation = consolidate_annotation(
            path, gene_list, geneID_to_descriptions)

        ## geneName majority
        all_geneName, majority_geneName = consolidate_geneName(
            path, gene_list, geneID_to_descriptions)

        ## extract gain/loss event count
        gene_event = dt_geneEvents[gid]

        ## average length
        seqs = read_fasta(geneCluster_path + '%s%s' %
                          (clusterID, '.fna')).values()
        geneClusterLength = int(np.mean([len(igene) for igene in seqs]))

        ## msa
        #geneCluster_aln='%s%s'%(clusterID,'_aa.aln')
        geneCluster_aln = clusterID

        ## check for duplicates
        if gene_count > strain_count:
            duplicated_state = 'yes'
            dup_list = [ig.split('|')[0] for ig in gene_list]
            # "#" to delimit (gene/gene_count)key/value ; "@" to seperate genes
            # Counter({'g1': 2, 'g2': 1})
            dup_detail = ''.join([
                '%s#%s@' % (kd, vd)
                for kd, vd in Counter(dup_list).iteritems() if vd > 1
            ])[:-1]
        else:
            duplicated_state = 'no'
            dup_detail = ''

        ## locus_tag
        if raw_locus_tag:  # make a string of all locus tags [1] in igl.split('|')
            all_locus_tags = ' '.join([igl.split('|')[1] for igl in gene_list])
        else:  # in addition to locus tag, keep strain name (but replace '|')
            all_locus_tags = ' '.join(
                [igl.replace('|', '_') for igl in gene_list])

        ## optionally store locus tags to file, remove from geneClusterJSON
        if store_locus_tag:
            locus_tag_outfile.write('%s\t%s\n' % (clusterID, all_locus_tags))
            all_locus_tags = ''

        ## default cluster json fields
        cluster_json_line = [
            '"geneId":' + str(gid + 1), '"geneLen":' + str(geneClusterLength),
            '"count":' + str(strain_count), '"dupli":"' + duplicated_state +
            '"', '"dup_detail":"' + dup_detail + '"',
            '"ann":"' + majority_annotation + '"',
            '"msa":"' + geneCluster_aln + '"',
            '"divers":"' + gene_diversity_Dt[clusterID] + '"',
            '"event":"' + str(gene_event) + '"', '"allAnn":"' + allAnn + '"',
            '"GName":"' + majority_geneName + '"',
            '"allGName":"' + all_geneName + '"',
            '"locus":"' + all_locus_tags + '"'
        ]

        if optional_table_column:
            cluster_json_line.extend(
                optional_geneCluster_properties(gene_list,
                                                optional_table_column))
        if clusterID in branch_associations:
            cluster_json_line.extend(
                geneCluster_associations(branch_associations[clusterID],
                                         suffix='BA'))
        if clusterID in presence_absence_associations:
            cluster_json_line.extend(
                geneCluster_associations(
                    presence_absence_associations[clusterID], suffix='PA'))

        #write file
        cluster_json_line = ','.join(cluster_json_line)
        geneClusterJSON_outfile.write('{' + cluster_json_line + '}')

    # close files
    geneClusterJSON_outfile.write(']')
    geneClusterJSON_outfile.close()
    if store_locus_tag: locus_tag_outfile.close()
def create_core_SNP_matrix(path, core_cutoff=1.0, core_gene_strain_fpath=''):#1.0
    """ create SNP matrix using core gene SNPs
        input: strain_list.cpk, core_geneList.cpk
        output: SNP_whole_matrix.aln
        core_cutoff: percentage of strains used to decide whether a gene is core
            default: 1.0 (strictly core gene, which is present in all strains)
            customized: 0.9 ( soft core, considered as core if present in 90% of strains)
    """
    import os,sys,operator
    import numpy as np
    import numpy.ma as ma
    from collections import defaultdict
    from sf_miscellaneous import read_fasta, write_pickle, load_pickle, write_in_fa

    alnFilePath='%s%s'%(path,'geneCluster/')
    output_path= alnFilePath

    ## create core gene list
    corelist=[]
    strain_list=load_pickle(path+'strain_list.cpk')
    totalStrain= len(strain_list)
    sorted_geneList = load_sorted_clusters(path)
    if core_gene_strain_fpath!='':
        with open(core_gene_strain_fpath,'rb') as core_gene_strain_file:
            core_strain_set= set([i.rstrip().replace('-','_') for i in core_gene_strain_file])
    with open(output_path+'core_geneList.txt','wb') as outfile:
        for clusterID, vg in sorted_geneList:
            if core_cutoff==1.0:
                strain_core_cutoff=totalStrain
            else:
                strain_core_cutoff=int(totalStrain*core_cutoff)
            if vg[0]==vg[2] and vg[0]>=strain_core_cutoff:
                coreGeneName='%s%s'%(clusterID,'_na_aln.fa')
                ## sequences might be discarded because of premature stops
                coreGeneName_path= alnFilePath+coreGeneName
                if os.path.exists(coreGeneName_path) and len(read_fasta(coreGeneName_path)) >= strain_core_cutoff:
                    if core_gene_strain_fpath!='' and len(core_strain_set-set([i.split('|')[0] for i in vg[1]]))!=0:
                        continue
                    outfile.write(coreGeneName+'\n')
                    corelist.append(coreGeneName)
                else:
                    #print '%s%s%s'%('warning: ',coreGeneName_path,' is not a core gene')
                    pass

        write_pickle(output_path+'core_geneList.cpk',corelist)

    refSeqList=load_pickle(path+'strain_list.cpk');refSeqList.sort()

    snp_fre_lst=[]; snp_wh_matrix_flag=0
    snp_pos_dt=defaultdict(list); snp_whole_matrix=np.array([])
    snps_by_gene=[]
    for align_file in corelist:## core genes
        nuc_array=np.array([]) # array to store nucleotides for each gene
        gene_seq_dt=read_fasta(alnFilePath+align_file)
        if core_cutoff!=1.0:
            # set sequences for missing gene (space*gene_length)
            missing_gene_seq=' '*len(gene_seq_dt.values()[0])
            totalStrain_sorted_lst=sorted(strain_list)
        # build strain_seq_dt from gene_seq_dt
        strain_seq_dt=defaultdict()
        for gene, seq in gene_seq_dt.iteritems():
            strain_seq_dt[gene.split('-')[0]]=seq # strain-locus_tag-...
        strain_seq_sorted_lst=sorted(strain_seq_dt.items(), key=lambda x: x[0])

        start_flag=0
        if core_cutoff==1.0:
            for ka, va in strain_seq_sorted_lst:
                if start_flag==0:
                    nuc_array=np.array(np.fromstring(va, dtype='S1'))
                    start_flag=1
                else:
                    nuc_array=np.vstack((nuc_array,np.fromstring(va, dtype='S1')))
            ## find SNP positions
            position_polymorphic = np.any(nuc_array != nuc_array[0, :], axis = 0)
            position_has_gap = np.any(nuc_array=='-', axis=0)
            position_SNP = position_polymorphic&(~position_has_gap)
            snp_columns = nuc_array[:,position_SNP]
            snp_pos_dt[align_file]=np.where(position_SNP)[0]
        else:
        ## add '-' for missing genes when dealing with soft core genes
            core_gene_strain=[ gene for gene in strain_seq_dt.keys()]
            for strain in totalStrain_sorted_lst:
                if start_flag==0:
                    if strain in core_gene_strain:
                        nuc_array=np.array(np.fromstring(strain_seq_dt[strain], dtype='S1'))
                    else:
                        print 'Soft core gene: gene absent in strain %s on cluster %s'%(strain,align_file)
                        nuc_array=np.array(np.fromstring(missing_gene_seq, dtype='S1'))
                    start_flag=1
                else:
                    if strain in core_gene_strain:
                        nuc_array=np.vstack((nuc_array,np.fromstring(strain_seq_dt[strain], dtype='S1')))
                    else:
                        print 'Soft core gene: gene absent in strain %s on cluster %s'%(strain,align_file)
                        nuc_array=np.vstack((nuc_array,np.fromstring(missing_gene_seq, dtype='S1')))
            ## find SNP positions
            ## mask missing genes -- determine rows that have ' ' in every column
            is_missing = np.all(nuc_array==' ',axis=1)
            masked_non_missing_array= np.ma.masked_array(nuc_array, nuc_array==' ')
            position_polymorphic = np.any(masked_non_missing_array!= masked_non_missing_array[0, :],axis = 0)
            position_has_gap = np.any(masked_non_missing_array=='-',axis=0)
            position_SNP = position_polymorphic&(~position_has_gap)
            # the below seems duplicated from 5 lines above??
            if is_missing.sum()>0: # with missing genes
                nuc_array[is_missing]='-'
            snp_columns = nuc_array[:,position_SNP]
            snp_pos_dt[align_file]=np.where(position_SNP)[0]
            #print snp_columns

        if snp_wh_matrix_flag==0:
            snp_whole_matrix=snp_columns;
            snp_wh_matrix_flag=1
        else:
            snp_whole_matrix=np.hstack((snp_whole_matrix, snp_columns))
    write_pickle(output_path+'snp_pos.cpk',snp_pos_dt)

    with open(output_path+'SNP_whole_matrix.aln','wb') as outfile:
        for ind, isw in enumerate(snp_whole_matrix):
            write_in_fa( outfile, refSeqList[ind], isw.tostring() )
def geneCluster_to_json(path, enable_RNA_clustering, store_locus_tag,
                        raw_locus_tag, optional_table_column):
    """
    create json file for gene cluster table visualzition
    input:  path to genecluster output
    output: geneCluster.json
    """
    # define path and make output directory
    geneCluster_path='%s%s'%(path,'geneCluster/')
    output_path='%s%s'%(path,'vis/')

    # open files
    geneClusterJSON_outfile=open(output_path+'geneCluster.json', 'wb')
    ##store locus_tags in a separate file for large dataset
    if store_locus_tag:
        locus_tag_outfile=open(path+'search_locus_tag.tsv', 'wb')


    ### load precomputed annotations, diversity, associations etc
    # load geneID_to_descriptions
    geneID_to_descriptions=load_pickle(path+'geneID_to_description.cpk')

    if enable_RNA_clustering:
        # load RNAID_to_description_file
        geneID_to_descriptions.update(load_pickle(path+'RNAID_to_description.cpk'))

    gene_diversity_Dt = load_pickle(geneCluster_path+'gene_diversity.cpk')
    ## load gain/loss event count dictionary
    dt_geneEvents = load_pickle(geneCluster_path+'dt_geneEvents.cpk')
    ## load association
    branch_associations_path = path+'branch_association.cpk'
    if os.path.isfile(branch_associations_path):
        branch_associations = load_pickle(branch_associations_path)
    else:
        branch_associations={}
    presence_absence_associations_path = path+'presence_absence_association.cpk'
    if os.path.isfile(presence_absence_associations_path):
        presence_absence_associations = load_pickle(presence_absence_associations_path)
    else:
        presence_absence_associations={}

    ## load list of clustered sorted by strain count
    sorted_genelist = load_sorted_clusters(path)

    geneClusterJSON_outfile.write('[')
    ## sorted_genelist: [(clusterID, [ count_strains,[memb1,...],count_genes]),...]
    for gid, (clusterID, gene) in enumerate(sorted_genelist):
        strain_count, gene_list, gene_count = gene
        # #print strain_count, gene_count
        if gid!=0: ## begin
            geneClusterJSON_outfile.write(',\n')

        ## annotation majority
        allAnn, majority_annotation = consolidate_annotation(path, gene_list, geneID_to_descriptions)

        ## geneName majority
        all_geneName, majority_geneName =  consolidate_geneName(path, gene_list, geneID_to_descriptions)

        ## extract gain/loss event count
        gene_event= dt_geneEvents[gid]

        ## average length
        seqs = read_fasta(geneCluster_path+'%s%s'%(clusterID,'.fna')).values()
        geneClusterLength = int(np.mean([ len(igene) for igene in seqs]))

        ## msa
        #geneCluster_aln='%s%s'%(clusterID,'_aa.aln')
        geneCluster_aln=clusterID

        ## check for duplicates
        if gene_count>strain_count:
            duplicated_state='yes'
            dup_list=[ ig.split('|')[0] for ig in gene_list]
            # "#" to delimit (gene/gene_count)key/value ; "@" to seperate genes
            # Counter({'g1': 2, 'g2': 1})
            dup_detail=''.join(['%s#%s@'%(kd,vd) for kd, vd in Counter(dup_list).iteritems() if vd>1 ])[:-1]
        else:
            duplicated_state='no';dup_detail=''

        ## locus_tag
        if raw_locus_tag: # make a string of all locus tags [1] in igl.split('|')
            all_locus_tags=' '.join([ igl.split('|')[1] for igl in gene_list ])
        else: # in addition to locus tag, keep strain name (but replace '|')
            all_locus_tags=' '.join([ igl.replace('|','_') for igl in gene_list ])

        ## optionally store locus tags to file, remove from geneClusterJSON
        if store_locus_tag:
            locus_tag_outfile.write('%s\t%s\n'%(clusterID,all_locus_tags))
            all_locus_tags=''

        ## default cluster json fields
        cluster_json_line=['"geneId":'+str(gid+1),
                            '"geneLen":'+str(geneClusterLength),
                            '"count":'+str(strain_count),
                            '"dupli":"'+duplicated_state+'"',
                            '"dup_detail":"'+dup_detail+'"',
                            '"ann":"'+majority_annotation+'"',
                            '"msa":"'+geneCluster_aln+'"',
                            '"divers":"'+gene_diversity_Dt[clusterID]+'"',
                            '"event":"'+str(gene_event)+'"',
                            '"allAnn":"'+allAnn+'"',
                            '"GName":"'+majority_geneName+'"',
                            '"allGName":"'+all_geneName+'"',
                            '"locus":"'+all_locus_tags+'"'
                            ]

        if optional_table_column:
            cluster_json_line.extend(optional_geneCluster_properties(gene_list,optional_table_column))
        if clusterID in branch_associations:
            cluster_json_line.extend(geneCluster_associations(branch_associations[clusterID], suffix='BA'))
        if clusterID in presence_absence_associations:
            cluster_json_line.extend(geneCluster_associations(presence_absence_associations[clusterID], suffix='PA'))

        #write file
        cluster_json_line=','.join(cluster_json_line)
        geneClusterJSON_outfile.write('{'+cluster_json_line+'}')

    # close files
    geneClusterJSON_outfile.write(']')
    geneClusterJSON_outfile.close()
    if store_locus_tag: locus_tag_outfile.close()