forked from raphael-group/comet
/
run_comet.py
320 lines (261 loc) · 13.1 KB
/
run_comet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
#!/usr/bin/python
# Load required modules
import sys, os, json, re, time, comet as C, resource
from math import exp
# Try loading Multi-Dendrix
try:
import multi_dendrix as multi_dendrix
importMultidendrix = True
except ImportError:
importMultidendrix = False
sys.stderr.write("Warning: The Multi-Dendrix Python module could not"\
" be found. Using only random initializations...\n")
def get_parser():
# Parse arguments
import argparse
description = 'Runs CoMEt to find the optimal set M '\
'of k genes for the weight function \Phi(M).'
parser = argparse.ArgumentParser(description=description)
# General parameters
parser.add_argument('-o', '--output_prefix', required=True,
help='Output path prefix (TSV format).')
parser.add_argument('-v', '--verbose', default=True, action="store_true",
help='Flag verbose output.')
parser.add_argument('--seed', default=int(time.time()), type=int,
help='Set the seed of the PRNG.')
# Mutation data
parser.add_argument('-m', '--mutation_matrix', required=True,
help='File name for mutation data.')
parser.add_argument('-mf', '--min_freq', type=int, default=0,
help='Minimum gene mutation frequency.')
parser.add_argument('-pf', '--patient_file', default=None,
help='File of patients to be included (optional).')
parser.add_argument('-gf', '--gene_file', default=None,
help='File of genes to be included (optional).')
# Comet
parser.add_argument('-ks', '--gene_set_sizes', nargs="*", type=int, required=True,
help='Gene set sizes (length must be t). This or -k must be set. ')
parser.add_argument('-N', '--num_iterations', type=int, default=pow(10, 3),
help='Number of iterations of MCMC.')
parser.add_argument('-NStop', '--n_stop', type=int, default=pow(10, 8),
help='Number of iterations of MCMC to stop the pipeline.')
parser.add_argument('-s', '--step_length', type=int, default=100,
help='Number of iterations between samples.')
parser.add_argument('-init', '--initial_soln', nargs="*",
help='Initial solution to use.')
parser.add_argument('-acc', '--accelerator', default=1, type=int,
help='accelerating factor for target weight')
parser.add_argument('-sub', '--subtype', default=None, help='File with a list of subtype for performing subtype-comet.')
parser.add_argument('-r', '--num_initial', default=1, type=int,
help='Number of different initial starts to use with MCMC.')
parser.add_argument('--exact_cut', default=0.001, type=float,
help='Maximum accumulated table prob. to stop exact test.')
parser.add_argument('--binom_cut', type=float, default=0.005,
help='Minumum pval cutoff for CoMEt to perform binom test.')
parser.add_argument('-nt', '--nt', default=10, type=int,
help='Maximum co-occurrence cufoff to perform exact test.')
parser.add_argument('-tv', '--total_distance_cutoff', type=float, default=0.005,
help='stop condition of convergence (total distance).')
parser.add_argument('--precomputed_scores', default=None,
help='input file with lists of pre-run results.')
return parser
def convert_solns(indexToGene, solns):
newSolns = []
for arr in solns:
arr.sort(key=lambda M: M[-2], reverse=True)
S = tuple( frozenset([indexToGene[g] for g in M[:-2] ]) for M in arr )
W = [ M[-2] for M in arr ]
F = [ M[-1] for M in arr ]
newSolns.append( (S, W, F) )
return newSolns
def comet(mutations, n, t, ks, numIters, stepLen, initialSoln,
amp, subt, nt, hybridPvalThreshold, pvalThresh, verbose):
# Convert mutation data to C-ready format
if subt: mutations = mutations + (subt, )
cMutations = C.convert_mutations_to_C_format(*mutations)
iPatientToGenes, iGeneToCases, geneToNumCases, geneToIndex, indexToGene = cMutations
initialSolnIndex = [geneToIndex[g] for g in initialSoln]
solns = C.comet(t, mutations[0], mutations[1], iPatientToGenes, geneToNumCases,
ks, numIters, stepLen, amp, nt, hybridPvalThreshold,
initialSolnIndex, len(subt), pvalThresh, verbose)
# Collate the results and sort them descending by sampling frequency
solnsWithWeights = convert_solns( indexToGene, solns )
def collection_key(collection):
return " ".join(sorted([",".join(sorted(M)) for M in collection]))
results = dict()
# store last soln of sampling for more iterations
lastSoln = list()
for gset in solnsWithWeights[-1][0]:
for g in gset:
lastSoln.append(g)
for collection, Ws, Cs in solnsWithWeights:
key = collection_key(collection)
if key in results: results[key]["freq"] += 1
else:
sets = []
for i in range(len(collection)):
M = collection[i]
W = Ws[i]
F = Cs[i]
P = exp(-W)
sets.append( dict(genes=M, W=W, num_tbls=F, prob=P) )
totalWeight = sum([ S["W"] for S in sets ])
targetWeight = exp( totalWeight ) if totalWeight < 700 else 1e1000
results[key] = dict(freq=1, sets=sets, total_weight=totalWeight,
target_weight=targetWeight)
return results, lastSoln
def iter_num (prefix, numIters, ks, acc):
if numIters >= 1e9: iterations = "%sB" % (numIters / int(1e9))
elif numIters >= 1e6: iterations = "%sM" % (numIters / int(1e6))
elif numIters >= 1e3: iterations = "%sK" % (numIters / int(1e3))
else: iterations = "%s" % numIters
prefix += ".k%s.%s.%s" % ("".join(map(str, ks)), iterations, acc)
return prefix
def call_multidendrix(mutations, k, t):
alpha, delta, lmbda = 1.0, 0, 1 # default of multidendrix
geneSetsWithWeights = multi_dendrix.ILP( mutations, t, k, k, alpha, delta, lmbda)
multiset = list()
for geneSet, W in geneSetsWithWeights:
for g in geneSet:
multiset.append(g)
return multiset
def initial_solns_generator(r, mutations, ks, assignedInitSoln, subtype):
runInit = list()
totalOut = list()
if assignedInitSoln:
if len(assignedInitSoln) == sum(ks):
print 'load init soln', "\t".join(assignedInitSoln)
runInit.append(assignedInitSoln)
elif len(assignedInitSoln) < sum(ks): # fewer initials than sampling size, randomly pick from genes
import random
rand = assignedInitSoln + random.sample(set(mutations[2])-set(assignedInitSoln), sum(ks)-len(assignedInitSoln))
print 'load init soln with random', rand
else:
sys.stderr.write('Too many initial solns for CoMEt.\n')
exit(1)
if importMultidendrix and not subtype and ks.count(ks[0])==len(ks):
md_init = call_multidendrix(mutations, ks[0], len(ks))
print ' load multi-dendrix solns', md_init
runInit.append(list(md_init))
# assign empty list to runInit as random initials
for i in range(len(runInit), r):
runInit.append(list())
for i in range(r):
totalOut.append(dict())
return runInit, totalOut
def load_precomputed_scores(infile, mutations, subt):
if subt: mutations = mutations + (subt,)
cMutations = C.convert_mutations_to_C_format(*mutations)
iPatientToGenes, iGeneToCases, geneToNumCases, geneToIndex, indexToGene = cMutations
baseI = 3 # sampling freq., total weight, target weight
setI = 3 # gene set, score, weight function
matchObj = re.match( r'.+\.k(\d+)\..+?', infile)
loadingT = len(matchObj.group(1)) # determine t:the number of gene sets.
for l in open(infile):
if not l.startswith("#"):
v = l.rstrip().split("\t")
j = 0
for i in range(loadingT):
gSet = [geneToIndex[g] for g in v[baseI + j].split(", ")]
C.load_precomputed_scores(float(v[baseI + j + 1]), len(v[baseI + j].split(", ")), int(v[baseI + j + 2]), gSet)
j += setI
def printParameters(args, ks, finaltv):
opts = vars(args)
opts['total distance'] = finaltv
prefix = iter_num(args.output_prefix + '.para', args.num_iterations, ks, args.accelerator)
with open(prefix + '.json', 'w') as outfile:
json.dump(opts, outfile)
def merge_results(convResults):
total = dict()
for results in convResults:
for key in results.keys():
if key in total:
total[key]["freq"] += results[key]["freq"]
else:
total[key] = results[key]
return total
def merge_runs(resultsPre, resultsNew):
for key in resultsNew.keys():
if key in resultsPre:
resultsPre[key]["freq"] += resultsNew[key]["freq"]
else:
resultsPre[key] = resultsNew[key]
def run( args ):
# Parse the arguments into shorter variable handles
mutationMatrix = args.mutation_matrix
geneFile = args.gene_file
patientFile = args.patient_file
minFreq = args.min_freq
rc = args.num_initial
t = len(args.gene_set_sizes) # number of pathways
ks = args.gene_set_sizes # size of each pathway
N = args.num_iterations # number of iteration
s = args.step_length # step
NStop = args.n_stop
acc = args.accelerator
nt = args.nt
hybridCutoff = args.binom_cut
NInc = 1.5 # increamental for non-converged chain
tc = 1
# Load the mutation data
mutations = C.load_mutation_data(mutationMatrix, patientFile, geneFile, minFreq)
m, n, genes, patients, geneToCases, patientToGenes = mutations
if args.subtype:
with open(args.subtype) as f:
subSet = [ l.rstrip() for l in f ]
else:
subSet = list()
if args.verbose:
print 'Mutation data: %s genes x %s patients' % (m, n)
# Precompute factorials
C.precompute_factorials(max(m, n))
C.set_random_seed(args.seed)
# stored the score of pre-computed collections into C
if args.precomputed_scores:
load_precomputed_scores(args.precomputed_scores, mutations, subSet)
# num_initial > 1, perform convergence pipeline, otherwise, perform one run only
if args.num_initial > 1:
# collect initial soln from users, multidendrix and random.
initialSolns, totalOut = initial_solns_generator(args.num_initial, mutations, ks, args.initial_soln, subSet )
runN = N
while True:
lastSolns = list()
for i in range(len(initialSolns)):
init = initialSolns[i]
outresults, lastSoln = comet(mutations, n, t, ks, runN, s, init, acc, subSet, nt, hybridCutoff, args.exact_cut, True)
print "Mem usage: ", resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1000
merge_runs(totalOut[i], outresults)
lastSolns.append(lastSoln)
finalTv = C.discrete_convergence(totalOut, int(N/s))
print finalTv, N
newN = int(N*NInc)
if newN > NStop or finalTv < args.total_distance_cutoff:
break
runN = newN - N
N = newN
initialSolns = lastSolns
runNum = len(totalOut)
results = merge_results(totalOut)
printParameters(args, ks, finalTv) # store and output parameters into .json
else:
init = list()
outresults, lastSoln = comet(mutations, n, t, ks, N, s, init, acc, subSet, nt, hybridCutoff, args.exact_cut, True)
results = outresults
runNum = 1
printParameters(args, ks, 1)
C.free_factorials()
# Output Comet results to TSV
collections = sorted(results.keys(), key=lambda S: results[S]["total_weight"], reverse=True)
header = "#Freq\tTotal Weight\tTarget Weight\t"
header += "\t".join(["Gene set %s (k=%s)\tProb %s\tWeight function %s" % (i, ks[i-1], i, i) for i in range(1, len(ks)+1)])
tbl = [header]
for S in collections:
data = results[S]
row = [ data["freq"], data["total_weight"], format(data["target_weight"], 'g') ]
for d in sorted(data["sets"], key=lambda d: d["W"]):
row += [", ".join(sorted(d["genes"])), d["prob"], d["num_tbls"] ]
tbl.append("\t".join(map(str, row)))
outputFile = "%s.tsv" % iter_num(args.output_prefix + '.sum', N*(runNum), ks, args.accelerator)
with open(outputFile, "w") as outfile: outfile.write( "\n".join(tbl) )
return [ (S, results[S]["freq"], results[S]["total_weight"]) for S in collections ]
if __name__ == "__main__": run( get_parser().parse_args(sys.argv[1:]) )