def count_dbs(lines, ksize, alls):
    #d=defaultdict(lambda:{})
    d = defaultdict(lambda: defaultdict(defaultdict))
    d2 = {}
    c = 1
    for line in lines:
        if not line: continue
        if line[0] == 'a':
            blockid = c
            d2[blockid] = 0
            c += 1
        if line[0] == 's':
            #ele=line.split()
            d2[blockid] += 1
    c = 1
    for line in lines:
        if not line: continue
        if line[0] == 'a':
            blockid = c
            c += 1
        if line[0] == 's':
            #print(d2[blockid],alls)
            if d2[blockid] > alls: continue
            ele = line.split()
            for i in range(len(ele[-1]) - ksize + 1):
                kmer = ele[-1][i:i + ksize]
                if not len(kmer) == ksize: continue
                if re.search('N', kmer): continue
                rev_kmer = seqpy.revcomp(kmer)
                d[blockid][kmer][ele[1]] = ''
                d[blockid][rev_kmer][ele[1]] = ''
    return d
示例#2
0
def build_kmer_dict(d, k):
    #dlabel={}
    print('::Build kmer dict for all clusters')
    dlabel = defaultdict(lambda: {})
    c = 1
    #dmap={}
    for l in d:
        print('::Process - ', c, '/', len(d))
        single = 0
        if len(d[l]) == 1:
            single = 1
        for g in d[l]:
            if single == 1:
                pre = get_pre(g)
                #dmap[c]=pre
                #dmap[pre]=c
            else:
                #dmap[c]='C'+str(l)
                #dmap['C'+str(l)]=c
                pre = 'C' + str(l)
            seq_dict = {rec.id: rec.seq for rec in SeqIO.parse(g, "fasta")}
            #exit()
            for cl in seq_dict:
                seq = str(seq_dict[cl])
                for i in range(len(seq) - k + 1):
                    if not special_match(seq[i:i + k]): continue
                    kmrb = encode(seq[i:i + k])
                    rev_kmrb = encode(seqpy.revcomp(seq[i:i + k]))
                    #if not special_match(seq[i:i+k]):continue
                    dlabel[kmrb][c] = ''
                    dlabel[rev_kmrb][c] = ''
        c += 1
    return dlabel
def out_block(oi, dblock, block, strain_cls):
    oi.write('a\n')
    block_seq = []
    for s in dblock[block]:
        fs = open(strain_cls[s], 'r')
        head = fs.readline()
        seq_all = fs.readline().strip()
        tem = []
        tem.append('s')
        tem.append(s)
        tem.append(str(dblock[block][s][1]))
        seq = seq_all[dblock[block][s][1]:dblock[block][s][2]]
        if dblock[block][s][0] == '-':
            seq = seqpy.revcomp(seq)
        tem.append(str(len(seq)))
        tem.append(dblock[block][s][0])
        tem.append(str(len(seq_all)))
        tem.append(seq)
        tem_s = ' '.join(tem)
        block_seq.append(tem_s)
        '''
		if s in mcgr:
			ref_match[dblock[block][s][1]]=[len(seq),dblock[block][s][0]]
			ref_match[-1]=len(seq_all)
		'''
    block_seq_out = '\n'.join(block_seq)
    oi.write(block_seq_out + '\n\n')
def build_kmer_dict(d, k):
    print('Load kmer to dict...')
    import time
    dlabel = defaultdict(lambda: {})
    c = 1
    label_match = {}
    for g in d:
        print('Process: ', c, '/', len(d))
        seq_dict = {rec.id: rec.seq for rec in SeqIO.parse(g, "fasta")}
        for cl in seq_dict:
            seq = str(seq_dict[cl])
            #rev_seq=seqpy.revcomp(seq)
            for i in range(len(seq) - k + 1):
                kmer = seq[i:i + k]
                rev_kmer = seqpy.revcomp(seq[i:i + k])
                #rev_kmer=rev_seq[i:i+k]
                dlabel[kmer][c] = ''
                dlabel[rev_kmer][c] = ''
        pre = Unique_kmer_detect_direct.get_pre(g)
        label_match[c] = pre
        c += 1
    return dlabel, label_match
示例#5
0
def build_kmer_dict(ddir, k, dsi, cls_num):
    match_1 = []
    match_2 = []
    for i in range(cls_num):
        match_1.append(i + 1)
        match_2.append('0')
    dlabel = defaultdict(lambda: {})
    #dseq=defaultdict(lambda:[]) # id -> [seq]
    #print(dsi)
    #exit()
    for pre in ddir:
        g = ddir[pre]
        seq_dict = {rec.id: rec.seq for rec in SeqIO.parse(g, "fasta")}
        for cl in seq_dict:
            seq = str(seq_dict[cl])
            #dseq[pre].append(seq)
            for i in range(len(seq) - k + 1):
                kmer = seq[i:i + k]
                rev_kmer = seqpy.revcomp(seq[i:i + k])
                dlabel[kmer][dsi[pre] - 1] = 1
                dlabel[rev_kmer][dsi[pre] - 1] = 1
    return dlabel
def count_dbs(lines, ksize):
    #d=defaultdict(lambda:{})
    d = defaultdict(lambda: defaultdict(defaultdict))
    c = 1
    for line in lines:
        if not line: continue
        if line[0] == 'a':
            if c == 1:
                blockid = c
                c += 1
            else:
                blockid = c
                c += 1
        if line[0] == 's':
            ele = line.split()
            for i in range(len(ele[-1]) - ksize + 1):
                kmer = ele[-1][i:i + ksize]
                if not len(kmer) == ksize: continue
                if re.search('N', kmer): continue
                rev_kmer = seqpy.revcomp(kmer)
                d[blockid][kmer][ele[1]] = ''
                d[blockid][rev_kmer][ele[1]] = ''
    return d
示例#7
0
def unique_kmer_out(d, k, dlabel, out_d, base_dir, if_base):
    print('::Scan unique kmer and output')
    count = 1
    for cls in d:
        print('::Process - ', count, '/', len(d))
        count += 1
        single = 0
        uk_count = 0
        resd = {}  # The dict used to record final unique kmer of each cluster
        single_cls = 0
        if len(d[cls]) == 1:
            single_cls = 1
            if not if_base == 'Y':
                o = open(out_d + '/C' + str(cls) + '.fasta', 'w+')
                pre = 'C' + str(cls)
        else:
            o = open(out_d + '/C' + str(cls) + '.fasta', 'w+')
            pre = 'C' + str(cls)
        for g in d[cls]:
            if single_cls == 1 and if_base == 'Y':
                pre = get_pre(g)
                o = open(base_dir + '/' + pre + '.fasta', 'w+')

            seq_dict = {rec.id: rec.seq for rec in SeqIO.parse(g, "fasta")}

            for cl in seq_dict:
                seq = str(seq_dict[cl])
                for i in range(len(seq) - k + 1):
                    if not special_match(seq[i:i + k]): continue
                    kmrb = encode(seq[i:i + k])
                    if len(dlabel[kmrb]) == 1:
                        resd[kmrb] = ''
                        resd[encode(seqpy.revcomp(seq[i:i + k]))] = ''
        for kmrb in resd:
            uk_count += 1
            kmr = decode(kmrb)
            o.write('>' + str(uk_count) + '\n' + str(kmr) + '\n')
示例#8
0
def extract_kmers(fna_i, fna_path, ksize, kmer_index_dict, kmer_index, Lv,
                  spec, tree_dir, alpha_ratio, identifier):
    kmer_sta = defaultdict(int)
    for j in fna_i:
        for seq_record in SeqIO.parse(fna_path[j], "fasta"):
            temp = str(seq_record.seq)
            for k in range(0, len(temp) - ksize):
                forward = temp[k:k + ksize]
                reverse = seqpy.revcomp(forward)
                for kmer in [forward, reverse]:
                    try:
                        kmer_sta[kmer_index_dict[kmer]] += 1
                    except KeyError:
                        kmer_index_dict[kmer] = kmer_index
                        kmer_sta[kmer_index] += 1
                        kmer_index += 1
    alpha = len(fna_i) * alpha_ratio
    for x in kmer_sta:
        if (kmer_sta[x] >= alpha):
            Lv[identifier].add(x)
        else:
            spec[identifier].add(x)
    print(identifier, len(Lv[identifier]), len(spec[identifier]))
    return kmer_index
示例#9
0
def unique_kmer_out_inside_cls(d,k,dlabel,out_dir):
	print('::Scan unique kmer inside cluster and output')
	count=1
	knum=1
	kid_match={} # Kmer -> ID
	#used_kmr={} # record all unique kmers
	sid_match={} # Strain -> ID
	ids_match={} # ID -> Strain
	head=[]
	#o=open(out_dir+'/unique_kmer_all.fasta','w+')
	pre_sim_d={}
	match_1=[]
	match_2=[]
	#all_uk={} # {uk1:'',uk2:'',....}
	for s in d:
		match_1.append(int(s))
		match_2.append('0')
		#pre_sim_d[s]='0' # strain id to 0 
	#duniq_num=defaultdict(lambda:0)	
	kmatrix=defaultdict(lambda:{}) # {1:{1:0, 2:1, ....}} # Kmer id -> Strain id-> 0 or 1
	for s in d: # For each strain in the cluster -> 's' here refers to the id of strain
		head.append(s)
		count+=1
		uk_count=0
		resd={}

		for s2 in d[s]: # Only one time
			pre=Unique_kmer_detect_direct.get_pre(s2)
			sid_match[pre]=s # Name -> Strain id
			ids_match[s]=pre # ID -> Strain Name
			#o=open(out_dir+'/'+pre+'.fasta','w+')
			seq_dict = {rec.id : rec.seq for rec in SeqIO.parse(s2, "fasta")}
			for cl in seq_dict:
				seq=str(seq_dict[cl])
				#rev_seq=seqpy.revcomp(seq)
				for i in range(len(seq)-k+1):
					kmer=seq[i:i+k]
					if len(dlabel[kmer])==1:
						#duniq_num[s]+=1
						uk_count+=1
						resd[kmer]=''
						resd[seqpy.revcomp(kmer)]=''

						

		kcount=0						
		for kmr in resd:
			#uk_count+=1
			kcount+=1
			#used_kmr[kmr]=None
			if True:
			#if kcount<100001:
				kid_match[kmr]=knum
				#head.append(str(knum))
				#kmatrix[knum]=dict(zip(match_1,match_2))
				kmatrix[knum][s-1]=1
				knum+=1

	
	tem=[]
	head=sorted(head)
	for h in head:
		tem.append(str(h))
	#o2.write(','.join(tem)+'\n')
	head_out=','.join(tem)+'\n'
	'''
	for kid in sorted(kmatrix.keys()):
		outa=[kmatrix[kid][key] for key in sorted(kmatrix[kid].keys())]
		#o2.write(','.join(outa)+'\n')
	#o2.close()

	with open(out_dir+'/uk_kid.pkl','wb') as o3:
		pickle.dump(kid_match, o3, pickle.HIGHEST_PROTOCOL)
	'''
	tem=[]
	for i in sorted(ids_match.keys()):
		tem.append(ids_match[i])
	with open(out_dir+'/id2strain.pkl','wb') as o4: # The list of strain name.
		pickle.dump(tem, o4, pickle.HIGHEST_PROTOCOL)
	#print('Unique part -> kid_match:',len(kid_match),', kmatrix:',len(kmatrix))
	return match_1,match_2,kid_match,kmatrix, head_out,knum,sid_match
def generate_kmer_match_from_uk(input_uk, ksize, out_dir, dlabel, match_1,
                                match_2, head_out, knum, kid_match, kmatrix,
                                sid_match):
    #import pickle
    f = open(input_uk, 'r')
    lines = f.read().split('\n')
    #o=open(out_dir+'/partial_kmer_addUk2.fasta','w+')
    #kmatrix=defaultdict(lambda:{}) # K-mer id -> Strain id -> '0' or '1'
    #kid_match={}
    #knum=1
    #o2=open(out_dir+'/kmer_match.txt','w+')
    c = 1
    #dk_match=defaultdict(lambda:{}) # Kmer ->  {strain1:'',strain2:''}
    dbs_count = count_dbs(lines, ksize)  # Block_ID -> {s1:'',s2:'',....}
    dtotal_kmer = {}
    for line in lines:
        if not line: continue
        if line[0] == 'a':
            if c == 1:
                dstrain = {}
                dtotal_kmer = {}
                blockid = c
                c += 1
            else:
                #if len(dtotal_kmer)>(80-ksize+1)*2: # Minimum kmer cutoff -> Length: 100bp
                if len(dtotal_kmer) > 0:
                    for k in dtotal_kmer:
                        #dk_match[k][blockid]=dstrain
                        #if len(dict(dbs_count[blockid][k]))==1:continue # Filter Unique K-mer
                        #dk_match[k]=dict(dbs_count[blockid][k])
                        if len(dlabel[k]) == 1: continue
                        if k in kid_match: continue
                        kid_match[k] = knum
                        #kmatrix[knum]=dict(zip(match_1,match_2))
                        for e in dict(dbs_count[blockid][k]):
                            kmatrix[knum][sid_match[e] - 1] = 1
                        knum += 1

                dstrain = {}
                dtotal_kmer = {}
                blockid = c
                c += 1

        if line[0] == 's':
            ele = line.split()
            dstrain[ele[1]] = ''
            for i in range(len(ele[-1]) - ksize + 1):
                kmer = ele[-1][i:i + ksize]
                if not len(kmer) == ksize: continue
                if re.search('N', kmer): continue
                rev_kmer = seqpy.revcomp(kmer)

                if len(dlabel[kmer]) == len(dbs_count[blockid][kmer]):
                    dtotal_kmer[kmer] = ''
                if len(dlabel[rev_kmer]) == len(dbs_count[blockid][rev_kmer]):
                    dtotal_kmer[rev_kmer] = ''
    #if len(dtotal_kmer)>(80-ksize+1)*2:
    if len(dtotal_kmer) > 0:
        for k in dtotal_kmer:
            #dk_match[k][blockid]=dstrain
            #if len(dict(dbs_count[blockid][k]))==1:continue
            #dk_match[k]=dict(dbs_count[blockid][k])
            if len(dlabel[k]) == 1: continue
            if k in kid_match: continue
            kid_match[k] = knum
            #kmatrix[knum]=dict(zip(match_1,match_2))
            for e in dict(dbs_count[blockid][k]):
                kmatrix[knum][sid_match[e] - 1] = 1
            knum += 1
    for k in kid_match:
        if len(dlabel[k]) == 1: continue
        del dlabel[k]
    gc.collect()
    return knum, kid_match, kmatrix
    '''
def generate_kmer_match_from_global(input_gb, ksize, out_dir, dlabel, match_1,
                                    match_2, head_out, sid_match, label_match,
                                    knum, kid_match, kmatrix, mas, cid):
    f = open(input_gb, 'r')
    lines = f.read().split('\n')
    o = open(out_dir + '/all_kmer.fasta', 'w+')
    #kmatrix=defaultdict(lambda:{}) # K-mer id -> Strain id -> '0' or '1'
    #kid_match={}
    #knum=1
    if mas == 0:
        c = 1
        dtotal_kmer = {}
        for line in lines:
            if not line: continue
            if line[0] == 'a':
                if c == 1:
                    #dstrain={}
                    dtotal_kmer = {}
                    #blockid=c
                    c += 1
                else:
                    if len(dtotal_kmer) > 0:
                        for k in dtotal_kmer:
                            if k in kid_match: continue
                            if len(dlabel[k]) == 1: continue
                            kid_match[k] = knum
                            #kmatrix[knum]=dict(zip(match_1,match_2))
                            for e in dlabel[k]:
                                kmatrix[knum][sid_match[label_match[e]] -
                                              1] = 1
                            knum += 1
                    #dstrain={}
                    dtotal_kmer = {}
                    blockid = c
                    c += 1
            if line[0] == 's':
                ele = line.split()
                #dstrain[ele[1]]=''
                for i in range(len(ele[-1]) - ksize + 1):
                    kmer = ele[-1][i:i + ksize]
                    if not len(kmer) == ksize: continue
                    if re.search('N', kmer): continue
                    rev_kmer = seqpy.revcomp(kmer)
                    if not len(dlabel[kmer]) == len(sid_match):
                        dtotal_kmer[kmer] = ''
                    if not len(dlabel[rev_kmer]) == len(sid_match):
                        dtotal_kmer[rev_kmer] = ''
        if len(dtotal_kmer) > 0:
            for k in dtotal_kmer:
                if k in kid_match: continue
                if len(dlabel[k]) == 1: continue
                kid_match[k] = knum
                #kmatrix[knum]=dict(zip(match_1,match_2))
                for e in dlabel[k]:
                    kmatrix[knum][sid_match[label_match[e]] - 1] = 1
                knum += 1
    else:
        #generate_kmer_with_mafft_con_block.extract_kmer_mugsy(input_gb,ksize,kid_match,kmatrix,knum,sid_match,dlabel,label_match)
        generate_kmer_with_sts_con_block.extract_kmer_sts(
            input_gb, ksize, kid_match, kmatrix, knum, sid_match, dlabel,
            label_match)
    # Finish k-mer searching part...
    kc = 1
    dlabel = {}
    gc.collect()
    row = len(kmatrix)
    column = len(match_1)
    for nk in kid_match:
        o.write('>' + str(kc) + '\n' + nk + '\n')
        kc += 1
    #o1=open(out_dir+'/all_strain.csv','w+')
    #o1.write(head_out)
    print(
        str(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))) +
        ' - StrainScan::build_DB: C' + str(cid) +
        '- Fill the Sparse matrix: Row: ', row, ' Column: ', column)
    mat = sp.dok_matrix((row, column), dtype=np.int8)
    for kmr in kmatrix:
        mat[kmr - 1, list(kmatrix[kmr].keys())] = 1
    mat = mat.tocsr()
    sp.save_npz(out_dir + '/all_strains.npz', mat)

    with open(out_dir + '/all_kid.pkl', 'wb') as o2:
        pickle.dump(kid_match, o2, pickle.HIGHEST_PROTOCOL)
    # Now all sets are generated, we will re-cluster these strains to remove those 1% similar case
    print(
        str(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))) +
        ' - StrainScan::build_DB: C' + str(cid) + '- Recluster matrix')
    Recls_withR_new.remove_1per(out_dir + '/all_strains.npz',
                                out_dir + '/id2strain.pkl', out_dir)
    if os.path.exists(out_dir + '/all_strains_re.npz'):
        os.system('rm ' + out_dir + '/all_strains.npz ' + out_dir +
                  '/id2strain.pkl')
    else:
        print('The Re-cluster of L2 processing is failed! Please check!')
        print('Related path: ' + out_dir)
    print(
        str(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))) +
        ' - StrainScan::build_DB: C' + str(cid) +
        '- The current cluster is finished!')
def unique_kmer_out_inside_cls(d, k, dlabel, out_dir, uknum):
    print('::Scan unique kmer inside cluster and output')
    count = 1
    knum = 1
    kid_match = {}  # Kmer -> ID
    sid_match = {}  # Strain -> ID
    ids_match = {}  # ID -> Strain
    head = []
    #o=open(out_dir+'/unique_kmer_all.fasta','w+')
    pre_sim_d = {}
    match_1 = []
    match_2 = []
    #all_uk={} # {uk1:'',uk2:'',....}
    for s in d:
        match_1.append(s)
        match_2.append('0')
        #pre_sim_d[s]='0' # strain id to 0
    duniq_num = defaultdict(lambda: 0)
    kmatrix = defaultdict(
        lambda: {})  # {1:{1:0, 2:1, ....}} # Kmer id -> Strain id-> 0 or 1
    for s in d:  # For each strain in the cluster -> 's' here refers to the id of strain
        head.append(s)
        count += 1
        uk_count = 0
        #resd={}
        resd = OrderedDict()
        for s2 in d[s]:  # Only one time
            pre = Unique_kmer_detect_direct.get_pre(s2)
            sid_match[pre] = s  # Name -> Strain id
            ids_match[s] = pre  # ID -> Strain Name
            #o=open(out_dir+'/'+pre+'.fasta','w+')
            seq_dict = {rec.id: rec.seq for rec in SeqIO.parse(s2, "fasta")}
            for cl in seq_dict:
                seq = str(seq_dict[cl])
                #rev_seq=seqpy.revcomp(seq)
                for i in range(len(seq) - k + 1):
                    kmer = seq[i:i + k]
                    if len(dlabel[kmer]) == 1:
                        if re.search('N', kmer): continue
                        #duniq_num[s]+=1
                        resd[kmer] = ''
                        resd[seqpy.revcomp(kmer)] = ''

        kcount = 0
        intervals = 0
        if len(resd) > uknum:
            # We will sample unique k-mers according to the position and given value
            if uknum == 0: continue
            intervals = math.ceil((len(resd)) / uknum)
            for i in range(0, len(resd), intervals):
                kcount += 1
                kmr = list(resd.keys())[i]
                rev_kmr = seqpy.revcomp(kmr)
                kid_match[kmr] = knum
                kmatrix[knum][s - 1] = 1
                kcount += 1
                knum += 1
                kid_match[rev_kmr] = knum
                kmatrix[knum][s - 1] = 1
                knum += 1

        else:
            for kmr in resd:
                #uk_count+=1
                kcount += 1
                kid_match[kmr] = knum
                kmatrix[knum][s - 1] = 1
                knum += 1
        print('Log:', ids_match[s], kcount, len(resd), intervals)
    tem = []
    head = sorted(head)
    for h in head:
        tem.append(str(h))
    #o2.write(','.join(tem)+'\n')
    head_out = ','.join(tem) + '\n'
    '''
	for kid in sorted(kmatrix.keys()):
		outa=[kmatrix[kid][key] for key in sorted(kmatrix[kid].keys())]
		#o2.write(','.join(outa)+'\n')
	#o2.close()

	with open(out_dir+'/uk_kid.pkl','wb') as o3:
		pickle.dump(kid_match, o3, pickle.HIGHEST_PROTOCOL)
	'''
    tem = []
    for i in sorted(ids_match.keys()):
        tem.append(ids_match[i])
    with open(out_dir + '/id2strain.pkl',
              'wb') as o4:  # The list of strain name.
        pickle.dump(tem, o4, pickle.HIGHEST_PROTOCOL)
    #print('Unique part -> kid_match:',len(kid_match),', kmatrix:',len(kmatrix))
    return match_1, match_2, kid_match, kmatrix, head_out, knum, sid_match
def generate_kmer_match_from_global(input_gb, ksize, out_dir, dlabel, match_1,
                                    match_2, head_out, sid_match, label_match,
                                    knum, kid_match, kmatrix):
    f = open(input_gb, 'r')
    lines = f.read().split('\n')
    o = open(out_dir + '/all_kmer.fasta', 'w+')
    #kmatrix=defaultdict(lambda:{}) # K-mer id -> Strain id -> '0' or '1'
    #kid_match={}
    #knum=1
    c = 1
    #dk_match=defaultdict(lambda:{}) # Kmer ->  {strain1:'',strain2:''}
    #dbs_count=count_dbs(lines,ksize)
    dtotal_kmer = {}
    for line in lines:
        if not line: continue
        if line[0] == 'a':
            if c == 1:
                #dstrain={}
                dtotal_kmer = {}
                #blockid=c
                c += 1
            else:
                if len(dtotal_kmer) > 0:
                    for k in dtotal_kmer:
                        if k in kid_match: continue
                        if len(dlabel[k]) == 1: continue
                        kid_match[k] = knum
                        #kmatrix[knum]=dict(zip(match_1,match_2))
                        for e in dlabel[k]:
                            kmatrix[knum][sid_match[label_match[e]] - 1] = 1
                        knum += 1
                #dstrain={}
                dtotal_kmer = {}
                blockid = c
                c += 1

        if line[0] == 's':
            ele = line.split()
            #dstrain[ele[1]]=''
            for i in range(len(ele[-1]) - ksize + 1):
                kmer = ele[-1][i:i + ksize]
                if not len(kmer) == ksize: continue
                if re.search('N', kmer): continue
                rev_kmer = seqpy.revcomp(kmer)
                if not len(dlabel[kmer]) == len(sid_match):
                    dtotal_kmer[kmer] = ''
                if not len(dlabel[rev_kmer]) == len(sid_match):
                    dtotal_kmer[rev_kmer] = ''
    if len(dtotal_kmer) > 0:
        for k in dtotal_kmer:
            if k in kid_match: continue
            if len(dlabel[k]) == 1: continue
            kid_match[k] = knum
            #kmatrix[knum]=dict(zip(match_1,match_2))
            for e in dlabel[k]:
                kmatrix[knum][sid_match[label_match[e]] - 1] = 1
            knum += 1
    kc = 1
    dlabel = {}
    gc.collect()
    row = len(kmatrix)
    column = len(match_1)
    for nk in kid_match:
        o.write('>' + str(kc) + '\n' + nk + '\n')
        kc += 1
    print(u'Memory usage:%.4f GB' %
          (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024))
    #o1.write(head_out)
    print('Fill the Sparse matrix: Row: ', row, ' Column: ', column)
    mat = sp.dok_matrix((row, column), dtype=np.int8)
    print(u'Memory usage:%.4f GB' %
          (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024))
    for kmr in kmatrix:
        mat[kmr - 1, list(kmatrix[kmr].keys())] = 1
    mat = mat.tocsr()
    sp.save_npz(out_dir + '/all_strains.npz', mat)

    with open(out_dir + '/all_kid.pkl', 'wb') as o2:
        pickle.dump(kid_match, o2, pickle.HIGHEST_PROTOCOL)
    # Now all sets are generated, we will re-cluster these strains to remove those 1% similar case
    Recls_withR_new.remove_1per(out_dir + '/all_strains.npz',
                                out_dir + '/id2strain.pkl', out_dir)
示例#14
0
def generate_kmer_match_from_uk(input_uk,ksize,out_dir,dlabel,match_1,match_2,head_out,knum,kid_match,kmatrix,sid_match):
	#import pickle
	f=open(input_uk,'r')
	lines=f.read().split('\n')
	#o=open(out_dir+'/partial_kmer_addUk2.fasta','w+')
	#kmatrix=defaultdict(lambda:{}) # K-mer id -> Strain id -> '0' or '1'
	#kid_match={}
	#knum=1
	#o2=open(out_dir+'/kmer_match.txt','w+')
	c=1
	#dk_match=defaultdict(lambda:{}) # Kmer ->  {strain1:'',strain2:''}
	dbs_count=count_dbs(lines,ksize)	# Block_ID -> {s1:'',s2:'',....}
	dtotal_kmer={}
	for line in lines:
		if not line:continue
		if line[0]=='a':
			if c==1:
				dstrain={}
				dtotal_kmer={}
				blockid=c
				c+=1
			else:
				#if len(dtotal_kmer)>(80-ksize+1)*2: # Minimum kmer cutoff -> Length: 100bp
				if len(dtotal_kmer)>0:
					for k in dtotal_kmer:
						#dk_match[k][blockid]=dstrain
						#if len(dict(dbs_count[blockid][k]))==1:continue # Filter Unique K-mer
						#dk_match[k]=dict(dbs_count[blockid][k])
						if len(dlabel[k])==1:continue
						if k in kid_match:continue
						kid_match[k]=knum
						#used_kmr[k]=None
						#kmatrix[knum]=dict(zip(match_1,match_2))
						for e in dict(dbs_count[blockid][k]):
							kmatrix[knum][sid_match[e]-1]=1
						knum+=1
					
							
				dstrain={}
				dtotal_kmer={}
				blockid=c
				c+=1

		if line[0]=='s':
			ele=line.split()
			dstrain[ele[1]]=''
			for i in range(len(ele[-1])-ksize+1):
				kmer=ele[-1][i:i+ksize]
				if not len(kmer)==ksize:continue
				if re.search('N',kmer):continue
				rev_kmer=seqpy.revcomp(kmer)

				if len(dlabel[kmer])==len(dbs_count[blockid][kmer]):
					dtotal_kmer[kmer]=''
				if len(dlabel[rev_kmer])==len(dbs_count[blockid][rev_kmer]):
					dtotal_kmer[rev_kmer]=''
	#if len(dtotal_kmer)>(80-ksize+1)*2:
	if len(dtotal_kmer)>0:
		for k in dtotal_kmer:
			#dk_match[k][blockid]=dstrain
			#if len(dict(dbs_count[blockid][k]))==1:continue
			#dk_match[k]=dict(dbs_count[blockid][k])
			if len(dlabel[k])==1:continue
			if k in kid_match:continue
			kid_match[k]=knum 
			#used_kmr[k]=None
			#kmatrix[knum]=dict(zip(match_1,match_2))
			for e in dict(dbs_count[blockid][k]):
				kmatrix[knum][sid_match[e]-1]=1
			knum+=1
	for k in kid_match:
		if len(dlabel[k])==1:continue
		del dlabel[k]
	gc.collect()	
	# Only for Ecoil
	'''
	dlabel={}
	o=open(out_dir+'/all_kmer.fasta','w+')
	kc=1
	for nk in kid_match:
		o.write('>'+str(kc)+'\n'+nk+'\n')
		kc+=1
	
	row=len(kmatrix)
	column=len(match_1)
	print('Fill the Sparse matrix: Row: ',row,' Column: ',column)
	mat=sp.dok_matrix((row,column),dtype=np.int8)
	for kmr in kmatrix:
		mat[kmr-1,list(kmatrix[kmr.keys()])]=1
	mat=mat.tocsr()
	sp.save_npz(out_dir+'/all_strains.npz',mat)
	with open(out_dir+'/all_kid.pkl','wb') as o2:
		pickle.dump(kid_match,o2,pickle.HIGHEST_PROTOCOL)
	'''

	return knum,kid_match,kmatrix
示例#15
0
def generate_kmer_match_from_global(input_gb,ksize,out_dir,dlabel,match_1,match_2,head_out,sid_match,label_match,knum,kid_match,kmatrix):
	f=open(input_gb,'r')
	lines=f.read().split('\n')
	o=open(out_dir+'/all_kmer.fasta','w+')
	#kmatrix=defaultdict(lambda:{}) # K-mer id -> Strain id -> '0' or '1'
	#kid_match={}
	#knum=1
	c=1
	#dk_match=defaultdict(lambda:{}) # Kmer ->  {strain1:'',strain2:''}
	#dbs_count=count_dbs(lines,ksize)
	for line in lines:
		if not line:continue
		if line[0]=='a':
			if c==1:
				#dstrain={}
				dtotal_kmer={}
				#blockid=c
				c+=1
			else:
				if len(dtotal_kmer)>0:
					for k in dtotal_kmer:
						if k in kid_match:continue
						if len(dlabel[k])==1:continue
						#if k in kid_match:continue
						kid_match[k]=knum
						#used_kmr[k]=None
						#kmatrix[knum]=dict(zip(match_1,match_2))
						for e in dlabel[k]:
							kmatrix[knum][sid_match[label_match[e]]-1]=1
						knum+=1
				#dstrain={}
				dtotal_kmer={}
				blockid=c
				c+=1
				
		if line[0]=='s':
			ele=line.split()
			#dstrain[ele[1]]=''
			for i in range(len(ele[-1])-ksize+1):
				kmer=ele[-1][i:i+ksize]
				if kmer in kid_match:continue
				if not len(kmer)==ksize:continue
				if re.search('N',kmer):continue
				rev_kmer=seqpy.revcomp(kmer)
				if rev_kmer in kid_match:continue
				if not len(dlabel[kmer])==len(sid_match): 
					dtotal_kmer[kmer]=''
				if not len(dlabel[rev_kmer])==len(sid_match):
					dtotal_kmer[rev_kmer]=''
	if len(dtotal_kmer)>0:
		for k in dtotal_kmer:
			if k in kid_match:continue
			if len(dlabel[k])==1:continue
			kid_match[k]=knum
			#used_kmr[k]=None
			#kmatrix[knum]=dict(zip(match_1,match_2))
			for e in dlabel[k]:
				kmatrix[knum][sid_match[label_match[e]]-1]=1
			knum+=1
	dlabel={}
	gc.collect()
	# Output part
	# Fill the Sparse matrix
	row=len(kmatrix)
	column=len(match_1)
	print('Fill the Sparse matrix: Row: ',row,' Column: ',column)
	#kmatrix=dict(kmatrix)
	mat = sp.dok_matrix((row,column), dtype=np.int8)
	for kmr in kmatrix:
		mat[kmr-1,list(kmatrix[kmr].keys())]=1
		'''
		for strain in kmatrix[kmr]:
			mat[kmr-1,strain-1]=1
		'''
	mat = mat.tocsr()
	sp.save_npz(out_dir+'/all_strains.npz',mat)
	
	kc=1
	for nk in kid_match:
		o.write('>'+str(kc)+'\n'+nk+'\n')
		kc+=1
	with open(out_dir+'/all_kid.pkl','wb') as o2:
		pickle.dump(kid_match,o2,pickle.HIGHEST_PROTOCOL)
def parse_sts_res(sts_res, ksize, dlabel, sid_match, kid_match, knum, kmatrix,
                  label_match):
    # Get the number of strains
    ifile = (sts_res)
    ifile2 = ('reference.fasta')
    num = 0
    dmsa = {}  # Seq_Name -> msa
    my_dict = SeqIO.to_dict(SeqIO.parse(ifile, "fasta"))
    my_dict2 = SeqIO.to_dict(SeqIO.parse(ifile2, "fasta"))

    for r in my_dict:
        num += 1
        seqn = re.sub('\..*', '', r)
        seq = str(my_dict[r].seq)
        dmsa[seqn] = seq
    for r in my_dict2:
        num += 1
        seqn = re.sub('\..*', '', r)
        seq = str(my_dict2[r].seq)
        dmsa[seqn] = seq
    '''
	o=open('msa_res.aln','w+')
	while True:
		line=f.readline().strip()
		if not line:break
		
		if line[0]=='s':
			ele=line.split()
			name=re.split('\.',ele[1])[-1]
			seq=ele[-1]
			dmsa[name]=seq
			o.write('>'+name+'\n'+seq+'\n')
		
		if re.search('>',line):
			name=re.sub('>','',line)
			dmsa[name]=''
			#num+=1
		else:
			dmsa[name]+=line
	'''

    fdir = os.path.split(os.path.realpath(__file__))[0]
    pwd = os.getcwd()
    #print('perl '+fdir+'/aln2entropy.pl '+pwd+'/msa_res.aln '+str(len(dmsa))+' > tem_column.txt')
    os.system('strainest map2snp reference.fasta ' + sts_res + ' snp.dgrp')
    #os.system('perl '+fdir+'/aln2entropy.pl '+pwd+'/msa_res.aln '+str(len(dmsa))+' > tem_column.txt')
    #exit()
    # Parse each column
    dash_cutoff = int(0.1 * num)  # 100 -> <=10 dash is ok
    #print(dash_cutoff,num)
    f2 = open('snp.dgrp', 'r')
    line = f2.readline()
    out_c = []
    while True:
        line = f2.readline().strip()
        if not line: break
        ele = re.split(',', line)
        cid = ele[0]
        #dash_num=int(re.sub('-:','',ele[2]))
        btype = {}
        dash_num = 0
        for e in ele[2:]:
            if 'N' == e:
                dash_num += 1
            else:
                btype[e] = ''
                #bn=re.sub('.*:','',e)
                #bn=int(bn)
                #if bn>0:btype+=1
        if dash_num > dash_cutoff: continue
        #print(btype)
        if len(btype) > 1:
            out_c.append(int(cid))
    #print(out_c)
    #exit()
    print('Log: Extract kmers from msa..., k-mer count:', len(kid_match),
          len(kmatrix))
    #for s in dmsa:
    #tc=0
    for c in out_c:
        tem_kmr = {}
        go = 1
        for s in dmsa:
            #print(s,c,dmsa[s][c],dmsa[s][c-1])
            if dmsa[s][c - 1] == 'N':
                #go=0
                continue
            kmr = extract_kmers(c - 1, dmsa[s], ksize)
            if re.search('N', kmr):
                continue
                #go=0
            if not len(kmr) == ksize: go = 0
            if len(dlabel[kmr]) == 0: go = 0
            rev_kmr = seqpy.revcomp(kmr)
            tem_kmr[rev_kmr] = ''
            tem_kmr[kmr] = ''
        if go == 0: continue
        if len(tem_kmr) == 0: continue
        for kmr in tem_kmr:
            if kmr in kid_match: continue
            if kmr not in dlabel: continue
            #if len(dlabel[kmr])==1:continue
            if len(dlabel[kmr]) >= len(sid_match): continue
            if len(dlabel[kmr]) == 0: continue
            kid_match[kmr] = knum
            for e in dlabel[kmr]:
                kmatrix[knum][sid_match[label_match[e]] - 1] = 1
            #print(knum,len(kmatrix),len(kid_match))
            knum += 1
            '''
			knum+=1
			kid_match[rev_kmr]=knum
			for e in dlabel[rev_kmr]:
				kmatrix[knum][sid_match[label_match[e]]-1]=1
			knum+=1
			'''
    print('Log: Extraction done..., k-mer count:', len(kid_match),
          len(kmatrix))
    os.system('rm -rf ' + os.getcwd() + '/ref ' + os.getcwd() + '/output')
示例#17
0
def build_tree(arg):
    # read parameters
    start = time.time()
    dist_matrix_file = arg[0]
    cls_file = arg[1]
    tree_dir = arg[2]
    ksize = arg[3]
    params = arg[4]
    alpha_ratio = params[0]
    minsize = params[1]
    maxsize = params[2]
    max_cls_size = params[3]

    # save genomes info
    fna_seq = bidict.bidict()  # : 1
    fna_path = {}

    # read dist matrix (represented by similarity: 1-dist)
    # output: dist, fna_path, fna_seq
    f = open(dist_matrix_file, "r")
    lines = f.readlines()
    f.close()
    index = 0
    d = lines[0].rstrip().split("\t")[1:]
    bac_label = 0
    for i in lines[0].rstrip().split("\t")[1:]:
        temp = i[i.rfind('/') + 1:].split(".")[0]
        fna_seq[temp] = index
        fna_path[index] = i
        index += 1
    dist = []
    for line in lines[1:]:
        dist.append(
            [np.array(list(map(float,
                               line.rstrip().split("\t")[1:])))])
    dist = np.concatenate(dist)

    # read initial clustering results. fna_mapping, from 1 for indexing
    f = open(cls_file, 'r')
    lines = f.readlines()
    f.close()
    fna_mapping = defaultdict(set)
    for line in lines:
        temp = line.rstrip().split("\t")
        for i in temp[2].split(","):
            fna_mapping[int(temp[0])].add(fna_seq[i])
    if (len(lines) == 1):
        tree = Tree()
        kmer_sta = defaultdict(int)
        T0 = Node(identifier=list(fna_mapping.keys())[0])
        tree.add_node(T0)
        kmer_sta = defaultdict(int)
        kmer_index_dict = bidict.bidict()
        kmer_index = 1
        alpha_ratio = 1
        Lv = set()
        for i in fna_mapping[T0.identifier]:
            for seq_record in SeqIO.parse(fna_path[i], "fasta"):
                temp = str(seq_record.seq)
                for k in range(0, len(temp) - ksize):
                    forward = temp[k:k + ksize]
                    reverse = seqpy.revcomp(forward)
                    for kmer in [forward, reverse]:
                        try:
                            kmer_sta[kmer_index_dict[kmer]] += 1
                        except KeyError:
                            kmer_index_dict[kmer] = kmer_index
                            kmer_sta[kmer_index] += 1
                            kmer_index += 1
        alpha = len(fna_mapping[T0.identifier]) * alpha_ratio
        for x in kmer_sta:
            if (kmer_sta[x] >= alpha):
                Lv.add(x)
        print(T0.identifier, len(Lv))
        # save2file
        kmerlist = set()
        pkl.dump(tree, open(tree_dir + '/tree.pkl', 'wb'))
        f = open(tree_dir + "/tree_structure.txt", "w")
        os.system("mkdir " + tree_dir + "/kmers")
        os.system("mkdir " + tree_dir + "/overlapping_info")
        f.write("%d\t" % T0.identifier)
        f.close()
        os.system(f'cp {cls_file} {tree_dir}/')
        f = open(tree_dir + "/reconstructed_nodes.txt", "w")
        f.close()
        if (len(Lv) > maxsize):
            Lv = set(random.sample(Lv, maxsize))
        kmerlist = Lv
        length = len(Lv)
        f = open(tree_dir + "/kmers/" + str(T0.identifier), "w")
        for j in Lv:
            f.write("%d " % j)
        f.close()
        f = open(tree_dir + "/node_length.txt", "w")
        f.write("%d\t%d\n" % (T0.identifier, length))
        kmer_mapping = {}
        index = 0
        f = open(tree_dir + "/kmer.fa", "w")
        for i in kmerlist:
            f.write(">1\n")
            f.write(kmer_index_dict.inv[i])
            kmer_mapping[i] = index
            index += 1
            f.write("\n")
        f.close()

        # change index
        files = os.listdir(tree_dir + "/kmers")
        for i in files:
            f = open(tree_dir + "/kmers/" + i, "r")
            lines = f.readlines()
            if (len(lines) == 0):
                continue
            d = lines[0].rstrip().split(" ")
            d = map(int, d)
            f = open(tree_dir + "/kmers/" + i, "w")
            for j in d:
                f.write("%d " % kmer_mapping[j])
            f.close()
        end = time.time()
        print(
            '- The total running time of tree-based indexing struture building is ',
            str(end - start), ' s\n')
        return
    # initially build tree
    cls_dist, mapping, tree, depths, depths_mapping = hierarchy(
        fna_mapping, dist)

    # initially extract k-mers
    kmer_index_dict = bidict.bidict()
    kmer_index = 1
    Lv = defaultdict(set)
    spec = defaultdict(set)  # k-mers <= alpha
    leaves = tree.leaves()
    for i in leaves:
        kmer_index = extract_kmers(fna_mapping[i.identifier], fna_path, ksize,
                                   kmer_index_dict, kmer_index, Lv, spec,
                                   tree_dir, alpha_ratio, i.identifier)
    end = time.time()
    print('- The total running time of k-mer extraction is ', str(end - start),
          ' s\n')
    start = time.time()

    # leaf nodes check
    recls_label = 0

    leaves_check = []
    check_waitlist = reversed(leaves)
    while (True):
        if (recls_label):
            cls_dist, mapping, tree, depths, depths_mapping = hierarchy(
                fna_mapping, dist)
            leaves = tree.leaves()
            temp = {}
            temp2 = []
            for i in check_waitlist:
                if (i in fna_mapping):
                    temp2.append(i)
            check_waitlist = temp2.copy()
            for i in check_waitlist:
                temp[tree.get_node(i)] = depths[tree.get_node(i)]
            check_waitlist = []
            a = sorted(temp.items(), key=lambda x: x[1], reverse=True)
            for i in a:
                check_waitlist.append(i[0])
            for i in fna_mapping:
                if (i not in Lv):
                    kmer_index = extract_kmers(fna_mapping[i], fna_path, ksize,
                                               kmer_index_dict, kmer_index, Lv,
                                               spec, tree_dir, alpha_ratio, i)
        higher_union = defaultdict(set)
        for i in check_waitlist:
            diff, diff_nodes = get_leaf_union(depths[i], higher_union,
                                              depths_mapping, Lv, spec, i)
            kmer_t = Lv[i.identifier] - diff
            for j in diff_nodes:
                kmer_t = kmer_t - Lv[j.identifier]
            for j in diff_nodes:
                kmer_t = kmer_t - spec[j.identifier]
            print(str(i.identifier) + " checking", end="\t")
            print(len(kmer_t))
            if (len(kmer_t) < minsize):
                leaves_check.append(i)
        if (len(leaves_check) > 0):
            recls_label = 1
        else:
            break
        # re-clustering
        check_waitlist = []
        while (recls_label == 1):
            cluster_id = max(list(fna_mapping.keys())) + 1
            check_waitlist.append(cluster_id)
            leaf_a = leaves_check[0].identifier
            row_index = mapping[leaf_a]
            column_index = cls_dist[row_index].argmax()
            leaf_b = mapping.inv[column_index]  # (leaf_a, leaf_b)
            temp2 = fna_mapping[leaf_a] | fna_mapping[leaf_b]
            print(cluster_id, leaf_a, leaf_b, temp2)
            del fna_mapping[leaf_a], fna_mapping[leaf_b]
            if (leaf_a in Lv):
                del Lv[leaf_a], spec[leaf_a]
            if (leaf_b in Lv):
                del Lv[leaf_b], spec[leaf_b]
            del leaves_check[0]
            if (tree.get_node(leaf_b) in leaves_check):
                leaves_check.remove(tree.get_node(leaf_b))
            temp1 = [
                np.concatenate([[cls_dist[row_index]],
                                [cls_dist[column_index]]]).max(axis=0)
            ]
            cls_dist = np.concatenate([cls_dist, temp1], axis=0)
            temp1 = np.append(temp1, -1)
            temp1 = np.vstack(temp1)
            cls_dist = np.concatenate([cls_dist, temp1], axis=1)
            cls_dist = np.delete(cls_dist, [row_index, column_index], axis=0)
            cls_dist = np.delete(cls_dist, [row_index, column_index], axis=1)
            # change mapping
            del mapping[leaf_a], mapping[leaf_b]
            pending = list(fna_mapping.keys())
            pending.sort()
            for i in pending:
                if (mapping[i] > min([row_index, column_index])
                        and mapping[i] < max([row_index, column_index])):
                    mapping[i] -= 1
                elif (mapping[i] > max([row_index, column_index])):
                    mapping[i] -= 2
            fna_mapping[cluster_id] = temp2
            mapping[cluster_id] = len(cls_dist) - 1
            if (len(leaves_check) == 0):
                break
    del higher_union

    # rebuild identifiers
    all_nodes = tree.all_nodes()
    all_leaves_id = set([])
    leaves = set(tree.leaves())
    for i in leaves:
        all_leaves_id.add(i.identifier)
    id_mapping = bidict.bidict()
    index = 1
    index_internal = len(leaves) + 1
    for i in all_nodes:
        if (recls_label == 0):
            id_mapping[i.identifier] = i.identifier
        elif (i in leaves):
            id_mapping[i.identifier] = index
            index += 1
        else:
            id_mapping[i.identifier] = index_internal
            index_internal += 1
    leaves_identifier = list(range(1, len(leaves) + 1))
    all_identifier = list(id_mapping.values())
    all_identifier.sort()

    # save2file
    f = open(tree_dir + "/tree_structure.txt", "w")
    os.system("mkdir " + tree_dir + "/kmers")
    os.system("mkdir " + tree_dir + "/overlapping_info")
    for nn in all_identifier:
        i = id_mapping.inv[nn]
        f.write("%d\t" % id_mapping[i])
        if (i == all_nodes[0].identifier):
            f.write("N\t")
        else:
            f.write("%d\t" % id_mapping[tree.parent(i).identifier])
        if (nn in leaves_identifier):
            f.write("N\t")
        else:
            [child_a, child_b] = tree.children(i)
            f.write("%d %d\t" % (id_mapping[child_a.identifier],
                                 id_mapping[child_b.identifier]))
        if (len(fna_mapping[i]) == 1):
            temp = list(fna_mapping[i])[0]
            temp = fna_seq.inv[temp]
            f.write("%s" % temp)
        f.write("\n")
    f.close()
    f = open(tree_dir + "/hclsMap_95_recls.txt", "w")
    for nn in leaves_identifier:
        i = id_mapping.inv[nn]
        f.write("%d\t%d\t" % (nn, len(fna_mapping[i])))
        temp1 = list(fna_mapping[i])
        for j in temp1:
            temp = fna_seq.inv[j]
            if (j == temp1[-1]):
                f.write("%s\n" % temp)
            else:
                f.write("%s," % temp)
    f.close()
    end = time.time()
    print('- The total running time of re-clustering is ', str(end - start),
          ' s\n')
    start = time.time()

    # build indexing structure
    kmerlist = set([])  # all kmers used
    length = {}
    overload_label = 0
    if (len(tree.leaves()) > max_cls_size):
        overload_label = 1
    # from bottom to top (unique k-mers)
    uniq_temp = defaultdict(set)
    rebuilt_nodes = []
    descendant = defaultdict(set)  # including itself
    ancestor = defaultdict(set)
    descendant_leaves = defaultdict(set)
    ancestor[all_nodes[0].identifier].add(all_nodes[0].identifier)
    for i in all_nodes[1:]:
        ancestor[i.identifier] = ancestor[tree.parent(
            i.identifier).identifier].copy()
        ancestor[i.identifier].add(i.identifier)
    for i in reversed(all_nodes):
        print(str(id_mapping[i.identifier]) + " k-mer removing...")
        if (i in leaves):
            uniq_temp[i.identifier] = Lv[i.identifier]
            descendant_leaves[i.identifier].add(i.identifier)
        else:
            (child_a, child_b) = tree.children(i.identifier)
            descendant[i.identifier] = descendant[
                child_a.identifier] | descendant[child_b.identifier]
            descendant_leaves[i.identifier] = descendant_leaves[
                child_a.identifier] | descendant_leaves[child_b.identifier]
            uniq_temp[i.identifier] = uniq_temp[
                child_a.identifier] & uniq_temp[child_b.identifier]
            uniq_temp[child_a.identifier] = uniq_temp[
                child_a.identifier] - uniq_temp[i.identifier]
            uniq_temp[child_b.identifier] = uniq_temp[
                child_b.identifier] - uniq_temp[i.identifier]
        descendant[i.identifier].add(i.identifier)
    all_nodes_id = set(id_mapping.keys())
    # remove overlapping
    for i in reversed(all_nodes):
        print(str(id_mapping[i.identifier]) + " k-mer set building...")
        # no difference with sibling, subtree and ancestors
        if (i == all_nodes[0]):
            kmer_t = uniq_temp[i.identifier]
        else:
            diff = {}
            temp = all_nodes_id - descendant[i.identifier] - set([
                tree.siblings(i.identifier)[0].identifier
            ]) - ancestor[i.identifier]
            for j in temp:
                diff[j] = len(uniq_temp[j])
            a = sorted(diff.items(), key=lambda x: x[1], reverse=True)
            kmer_t = uniq_temp[i.identifier]
            for j in a:
                k = j[0]
                kmer_t = kmer_t - uniq_temp[k]
            # remove special k-mers
            temp = all_leaves_id - descendant_leaves[i.identifier]
            diff = {}
            for j in temp:
                diff[j] = len(spec[j])
            a = sorted(diff.items(), key=lambda x: x[1], reverse=True)
            for j in a:
                k = j[0]
                kmer_t = kmer_t - spec[k]
        if (len(kmer_t) < minsize and overload_label == 0):
            rebuilt_nodes.append(i)
            print("%d waiting for reconstruction..." %
                  id_mapping[i.identifier])
        else:
            if (len(kmer_t) > maxsize):
                kmer_t = set(random.sample(kmer_t, maxsize))
            f = open(tree_dir + "/kmers/" + str(id_mapping[i.identifier]), "w")
            for j in kmer_t:
                f.write("%d " % j)
            f.close()
            length[i] = len(kmer_t)
            kmerlist = kmerlist | kmer_t
    del uniq_temp

    # rebuild nodes
    overlapping = defaultdict(dict)
    intersection = defaultdict(set)
    higher_union = defaultdict(set)
    del_label = {}
    for i in leaves:
        del_label[i.identifier] = [0, 0]
    for i in rebuilt_nodes:
        print(str(id_mapping[i.identifier]) + " k-mer set rebuilding...")
        kmer_t = get_intersect(intersection, descendant_leaves[i.identifier],
                               Lv, del_label, i.identifier)
        diff = get_diff(higher_union, descendant_leaves, depths, all_nodes, i,
                        Lv, spec, del_label)
        for j in diff:
            kmer_t = kmer_t - j
        lower_leaves = set([])
        for j in leaves:
            if (depths[j] < depths[i]):
                lower_leaves.add(j)
        if (len(kmer_t) > maxsize):
            kmer_overlapping_sta = defaultdict(int)
            for j in lower_leaves:
                kmer_o = Lv[j.identifier] & kmer_t
                for k in kmer_o:
                    kmer_overlapping_sta[k] += 1
            temp = sorted(kmer_overlapping_sta.items(),
                          key=lambda kv: (kv[1], kv[0]))
            kmer_t = set([])
            for j in range(0, maxsize):
                kmer_t.add(temp[j][0])
        nkmer = {}
        f = open(tree_dir + "/kmers/" + str(id_mapping[i.identifier]), "w")
        index = 0
        for j in kmer_t:
            f.write("%d " % j)
            nkmer[j] = index
            index += 1
        length[i] = len(kmer_t)
        kmerlist = kmerlist | kmer_t
        # save overlapping info
        for j in lower_leaves:
            temp = Lv[j.identifier] & kmer_t
            if (len(temp) > 0):
                ii = id_mapping[i.identifier]
                jj = id_mapping[j.identifier]
                overlapping[jj][ii] = set([])
                for k in temp:
                    overlapping[jj][ii].add(nkmer[k])
        delete(Lv, spec, del_label)

    for i in overlapping:
        f = open(tree_dir + "/overlapping_info/" + str(i), "w")
        f1 = open(tree_dir + "/overlapping_info/" + str(i) + "_supple", "w")
        count = -1
        for j in overlapping[i]:
            if (len(overlapping[i]) != 0):
                f.write("%d\n" % j)
                for k in overlapping[i][j]:
                    f.write("%d " % k)
                f.write("\n")
                count += 2
                f1.write("%d %d\n" % (j, count))
        f.close()
        f1.close()

    # final saving
    f = open(tree_dir + "/reconstructed_nodes.txt", "w")
    for i in rebuilt_nodes:
        f.write("%d\n" % id_mapping[i.identifier])
    f.close()

    f = open(tree_dir + "/node_length.txt", "w")
    for nn in all_identifier:
        i = id_mapping.inv[nn]
        f.write("%d\t%d\n" % (nn, length[tree[i]]))
    f.close()

    kmer_mapping = {}
    index = 0
    f = open(tree_dir + "/kmer.fa", "w")
    for i in kmerlist:
        f.write(">1\n")
        f.write(kmer_index_dict.inv[i])
        kmer_mapping[i] = index
        index += 1
        f.write("\n")
    f.close()

    # change index
    files = os.listdir(tree_dir + "/kmers")
    for i in files:
        f = open(tree_dir + "/kmers/" + i, "r")
        lines = f.readlines()
        if (len(lines) == 0):
            continue
        d = lines[0].rstrip().split(" ")
        d = map(int, d)
        f = open(tree_dir + "/kmers/" + i, "w")
        for j in d:
            f.write("%d " % kmer_mapping[j])
        f.close()

    end = time.time()
    print(
        '- The total running time of tree-based indexing struture building is ',
        str(end - start), ' s\n')
def generate_kmer_match_from_global(input_gb, ksize, out_dir, dlabel, match_1,
                                    match_2, head_out, sid_match, label_match,
                                    knum, kid_match, kmatrix):
    f = open(input_gb, 'r')
    lines = f.read().split('\n')
    o = open(out_dir + '/all_kmer.fasta', 'w+')
    #kmatrix=defaultdict(lambda:{}) # K-mer id -> Strain id -> '0' or '1'
    #kid_match={}
    #knum=1
    c = 1
    #dk_match=defaultdict(lambda:{}) # Kmer ->  {strain1:'',strain2:''}
    #dbs_count=count_dbs(lines,ksize)
    dtotal_kmer = {}
    for line in lines:
        if not line: continue
        if line[0] == 'a':
            if c == 1:
                #dstrain={}
                dtotal_kmer = {}
                #blockid=c
                c += 1
            else:
                if len(dtotal_kmer) > 0:
                    for k in dtotal_kmer:
                        if k in kid_match: continue
                        if len(dlabel[k]) == 1: continue
                        kid_match[k] = knum
                        #kmatrix[knum]=dict(zip(match_1,match_2))
                        for e in dlabel[k]:
                            kmatrix[knum][sid_match[label_match[e]] - 1] = 1
                        knum += 1
                #dstrain={}
                dtotal_kmer = {}
                blockid = c
                c += 1

        if line[0] == 's':
            ele = line.split()
            #dstrain[ele[1]]=''
            for i in range(len(ele[-1]) - ksize + 1):
                kmer = ele[-1][i:i + ksize]
                if not len(kmer) == ksize: continue
                if re.search('N', kmer): continue
                rev_kmer = seqpy.revcomp(kmer)
                if not len(dlabel[kmer]) == len(sid_match):
                    dtotal_kmer[kmer] = ''
                if not len(dlabel[rev_kmer]) == len(sid_match):
                    dtotal_kmer[rev_kmer] = ''
    if len(dtotal_kmer) > 0:
        for k in dtotal_kmer:
            if k in kid_match: continue
            if len(dlabel[k]) == 1: continue
            kid_match[k] = knum
            #kmatrix[knum]=dict(zip(match_1,match_2))
            for e in dlabel[k]:
                kmatrix[knum][sid_match[label_match[e]] - 1] = 1
            knum += 1
    kc = 1
    dlabel = {}
    gc.collect()
    row = len(kmatrix)
    column = len(match_1)
    for nk in kid_match:
        o.write('>' + str(kc) + '\n' + nk + '\n')
        kc += 1
    #o1=open(out_dir+'/all_strain.csv','w+')
    #o1.write(head_out)
    print('Seperate mode - Fill the Sparse matrix: Row: ', row, ' Column: ',
          column)

    def fill_matrix_seperate(kmatrix, row_num, column, bid, used_kmr):
        mat = sp.dok_matrix((row_num, column), dtype=np.int8)
        for kmr in sorted(kmatrix):
            #mat[kmr-1,list(kmatrix[kmr].keys())]=1
            if kmr <= used_kmr['already_used']: continue
            if used_kmr['current_used'] >= row_num: break
            mat[kmr - 1 - used_kmr['already_used'],
                list(kmatrix[kmr].keys())] = 1
            used_kmr['current_used'] += 1
            print('kid', kmr)
        used_kmr['already_used'] = used_kmr['already_used'] + used_kmr[
            'current_used']
        used_kmr['current_used'] = 0
        mat = mat.tocsr()
        sp.save_npz(out_dir + '/all_strains_' + str(bid) + '.npz', mat)

    def fill_loop(kmatrix, row_num, column):
        block_num = int(len(kmatrix) / row_num) + 1
        used_kmr = {}
        used_kmr['current_used'] = 0
        used_kmr['already_used'] = 0
        for b in range(block_num):
            bid = b + 1
            fill_matrix_seperate(kmatrix, row_num, column, bid, used_kmr)

    fill_loop(kmatrix, 5000000, column)
    #exit()
    with open(out_dir + '/all_kid.pkl', 'wb') as o2:
        pickle.dump(kid_match, o2, pickle.HIGHEST_PROTOCOL)
    exit()
    # Now all sets are generated, we will re-cluster these strains to remove those 1% similar case
    Recls_withR_new.remove_1per(out_dir + '/all_strains.npz',
                                out_dir + '/id2strain.pkl', out_dir)