forked from uci-cbcl/tree-hmm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
do_parallel.py
164 lines (146 loc) · 6.4 KB
/
do_parallel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import re
import glob
import multiprocessing
import scipy as sp
import copy
import os
import tempfile
import time
#try:
# from ipdb import set_trace as breakpoint
#except ImportError:
# from pdb import set_trace as breakpoint
import sge
def do_parallel_inference(args):
"""Perform inference in parallel on several observations matrices with
joint parameters
"""
from histone_tree_hmm import random_params, do_inference, plot_params, plot_energy, load_params
from vb_mf import normalize_trans
args.I, _, args.L = sp.load(args.observe_matrix[0]).shape
K = args.K
L = args.L
args.T = 'all'
args.free_energy = []
args.observe = 'all.npy'
args.last_free_energy = 0
args.emit_sum = 0
args.out_dir = args.out_dir.format(timestamp=time.strftime('%x_%X').replace('/','-'), **args.__dict__)
try:
print 'making', args.out_dir
os.makedirs(args.out_dir)
except OSError:
pass
if args.warm_start:
#args.last_free_energy, args.theta, args.alpha, args.beta, args.gamma, args.emit_probs, args.emit_sum = load_params(args)
#args.warm_start = False
print '# loading previous params for warm start from %s' % args.warm_start
tmpargs = copy.deepcopy(args)
tmpargs.out_dir = args.warm_start
tmpargs.observe = 'all.npy'
args.free_energy, args.theta, args.alpha, args.beta, args.gamma, args.emit_probs, args.emit_sum = load_params(tmpargs)
try:
args.free_energy = list(args.free_energy)
except TypeError: # no previous free energy
args.free_energy = []
print 'done'
args.warm_start = False
else:
(args.theta, args.alpha, args.beta, args.gamma, args.emit_probs) = \
random_params(args.K,args.L)
for p in ['free_energy', 'theta', 'alpha', 'beta', 'gamma', 'emit_probs', 'last_free_energy', 'emit_sum']:
sp.save(os.path.join(args.out_dir, args.out_params.format(param=p, **args.__dict__)),
args.__dict__[p])
print '# setting up job arguments'
# set up new versions of args for other jobs
job_args = [copy.copy(args) for i in range(len(args.observe_matrix))]
for j, a in enumerate(job_args):
a.observe_matrix = args.observe_matrix[j]
a.observe = os.path.split(args.observe_matrix[j])[1]
a.subtask = True
a.func = None
a.iteration = 0
a.max_iterations = 1
a.quiet_mode = True
if j % 1000 == 0:
print j
if args.run_local:
pool = multiprocessing.Pool()
else:
pool = sge.SGEPool()
job_handle = pool.imap_unordered(do_inference, job_args)
converged = False
for args.iteration in range(args.max_iterations):
#import ipdb; ipdb.set_trace()
# fresh parameters-- to be aggregated after jobs are run
print 'iteration', args.iteration
total_free = 0
args.theta = sp.zeros((K,K,K), dtype=sp.float64)
args.alpha = sp.zeros((K,K), dtype=sp.float64)
args.beta = sp.zeros((K,K), dtype=sp.float64)
args.gamma = sp.zeros((K), dtype=sp.float64)
args.emit_probs = sp.zeros((K,L), dtype=sp.float64)
args.emit_sum = sp.zeros(K, sp.float64)
if args.run_local:
iterator = pool.imap_unordered(do_inference, job_args)
# wait for jobs to finish
for result in iterator:
pass
else:
jobs_handle = pool.map_async(do_inference, job_args, chunksize=100)
# wait for all jobs to finish
for j in jobs_handle:
j.wait()
# sum free energies and parameters from jobs
for a in job_args:
#print '# loading from %s' % a.observe
free_energy, theta, alpha, beta, gamma, emit_probs, emit_sum = load_params(a)
if len(free_energy) > 0:
last_free_energy = free_energy[-1]
else:
last_free_energy = 0
total_free += last_free_energy
args.theta += theta
args.alpha += alpha
args.beta += beta
args.gamma += gamma
args.emit_probs += emit_probs
args.emit_sum += emit_sum
# renormalize and plot
print 'normalize aggregation... total free energy is:', total_free
args.free_energy.append(total_free)
if len(args.free_energy) > 1 and args.free_energy[-1] != 0 and args.free_energy[-2] != 0 \
and abs(args.free_energy[-2] - args.free_energy[-1]) / args.free_energy[-2] < args.epsilon:
converged = True
normalize_trans(args.theta, args.alpha, args.beta, args.gamma)
args.emit_probs[:] = sp.dot(sp.diag(1./args.emit_sum), args.emit_probs)
for a in job_args:
a.theta, a.alpha, a.beta, a.gamma, a.emit_probs = args.theta, args.alpha, args.beta, args.gamma, args.emit_probs
for p in ['free_energy', 'theta', 'alpha', 'beta', 'gamma', 'emit_probs']:
sp.save(os.path.join(args.out_dir, args.out_params.format(param=p, **args.__dict__)),
args.__dict__[p])
plot_params(args)
plot_energy(args)
if args.save_Q >= 3:
print '# reconstructing chromosomes from *chunk*',
in_order = {}
# Q_chr16_all.trimmed.chunk*.npy => Q_chr16_all.trimmed.npy
all_chunks = glob.glob(os.path.join(args.out_dir, '*_Q_*chunk*.npy'))
for chunk in all_chunks:
print chunk
chunk_num = int(re.search(r'chunk(\d+)', chunk).groups()[0])
chrom_out = re.sub('chunk(\d+)\.', '', chunk)
if chrom_out not in in_order:
in_order[chrom_out] = {}
in_order[chrom_out][chunk_num] = sp.load(chunk)
for chrom_out in in_order:
print 'reconstructing chromosomes from', in_order[chrom_out]
if len(in_order[chrom_out]) > 1:
final_array = sp.concatenate((in_order[chrom_out][0], in_order[chrom_out][1]), axis=1)
for i in range(2, max(in_order[chrom_out])):
final_array = sp.concatenate((final_array, in_order[chrom_out][i]), axis=1)
else:
final_array = in_order[chrom_out][0]
sp.save(chrom_out, final_array)
if converged:
break