forked from jpritt/fm-update
-
Notifications
You must be signed in to change notification settings - Fork 0
/
iterativeUpdate.py
executable file
·80 lines (67 loc) · 2.89 KB
/
iterativeUpdate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#! /usr/bin/env python
import bwt
import random
import time
''' Repeatedly match all possible reads to the genome, update, then repeat '''
def iterativeUpdate(fm, b, alphabet, reads, starts, errors, maxIters, threshold=False, readLen=50, genomeLen=5000):
unmatched = [1]*len(reads)
numUnmatched = len(reads)
prevSize = 2*len(reads)
currIter = 0
firstIter = True
initialAcc = 0.0
correct = 0
incorrect = 0
while numUnmatched > 0 and currIter < maxIters and float(prevSize - numUnmatched) / prevSize > 0.1:
threshold = 0.5 * numUnmatched * readLen / genomeLen
currIter += 1
prevSize = numUnmatched
# Match reads against t2
mutations = dict()
# match all reads to genome, collect mutations
for i in xrange(len(reads)):
if unmatched[i] == 1:
m = bwt.findApproximate(fm, b, alphabet, ''.join(reads[i]), errors)
if len(m) > 0:
unmatched[i] = 0
for k,edits in m.items():
for v in edits:
if v[0] == 2:
vnew = (v[0],v[1]+k)
else:
vnew = (v[0],v[1]+k,v[2])
if vnew in mutations:
mutations[vnew] += 1
else:
mutations[vnew] = 1
found = False
for j in xrange(-errors, errors+1):
if starts[i]+j in m and not found:
correct += 1
found = True
if not found:
incorrect += 1
if firstIter:
firstIter = False
initialAcc = float(correct) / len(reads)
# apply mutations to fm index
for k,v in mutations.items():
if v >= threshold:
if k[0] == 1:
fm = bwt.insert(fm, b, alphabet, k[1], k[2])
for i in xrange(len(starts)):
if starts[i] >= k[1]:
starts[i] += 1
elif k[0] == 0:
fm = bwt.substitute(fm, b, alphabet, k[1], k[2])
elif k[0] == 2:
fm = bwt.delete(fm, b, alphabet, k[1])
for i in xrange(len(starts)):
if starts[i] >= k[1]:
starts[i] -= 1
else:
print 'Error: k[0] = ' + str(k[0])
numUnmatched = sum(unmatched)
print " Iter " + str(currIter) + " - " + str(correct) + " correct, " + str(incorrect) + " incorrect, " + str(len(reads)-correct-incorrect) + ' unmatched'
#print " Accuracy: " + str(float(correct) / len(reads))
return initialAcc, float(correct) / len(reads)