forked from GFleishman/greedypy
/
optimizer.py
221 lines (174 loc) · 7.85 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
GreedyPy - greedy weak diffeomorphic registration in python
Copyright: Greg M. Fleishman
Began: November 2019
"""
# to get rid of annoying hdf5 warning; comment out if you want
# to read the warning
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import numpy as np
import smoother
import matcher
import transformer
import inout
import time
from scipy.ndimage import zoom
from os import makedirs
from os.path import splitext, abspath, dirname
def initialize_constants(args):
"""Process command line args, read input images/mask
all constants stored in a constants_container"""
fff = np.float32 if args.precision == 'single' else np.float64
CONS = {}
CONS['dtype'] = fff
CONS['iterations'] = [int(x) for x in args.iterations.split('x')]
CONS['lcc_radius'] = int(args.lcc_radius)
CONS['gradient_step'] = fff(args.gradient_step)
CONS['tolerance'] = fff(args.optimization_tolerance)
CONS['field_abcd'] = [fff(x) for x in args.field_regularizer.split('x')]
CONS['grad_abcd'] = [fff(x) for x in args.grad_regularizer.split('x')]
CONS['outdir'] = abspath(dirname(args.output))
CONS['log'] = open(dirname(args.output)+'/greedypy.log', 'w')
makedirs(CONS['outdir'], exist_ok=True)
fixed, fspacing, fmeta = inout.read_image(args.fixed, CONS['dtype'],
args.n5_fixed_path)
moving, mspacing, mmeta = inout.read_image(args.moving, CONS['dtype'],
args.n5_moving_path)
CONS['fixed'] = fixed
CONS['fixed_meta'] = fmeta
CONS['moving'] = moving
CONS['moving_meta'] = mmeta
CONS['grid'] = fixed.shape
CONS['spacing'] = fspacing
if args.initial_transform:
if splitext(args.initial_transform)[1] == '.mat':
matrix = np.loadtxt(abspath(args.initial_transform))
matrix = fff(matrix)
CONS['initial_transform'] = matrix
if args.mask:
if isinstance(args.mask, str):
mask, _not_used_, _not_used__ = inout.read_image(args.mask)
# TODO: implement mask support
return CONS
def initialize_variables(CONS, phi, level):
"""Resample target transform and initial velocity
initialize other objects for scale level"""
# container to hold all variables
VARS = {}
# smooth and downsample the images
aaSmoother = smoother.smoother(1+level, 0, 1, 2,
CONS['spacing'], CONS['grid'], CONS['dtype'])
fix_smooth = np.copy(aaSmoother.smooth(CONS['fixed']))
mov_smooth = np.copy(aaSmoother.smooth(CONS['moving']))
VARS['fixed'] = zoom(fix_smooth, 1./2**level, mode='wrap')
VARS['moving'] = zoom(mov_smooth, 1./2**level, mode='wrap')
# initialize a few odds and ends
shape = VARS['fixed'].shape
VARS['spacing'] = CONS['spacing'] * 2**level
VARS['warped_transform'] = np.empty(shape + (len(shape),), dtype=CONS['dtype'])
# initialize or resample the deformation
if phi is None:
VARS['phi'] = np.zeros(shape + (len(shape),), dtype=CONS['dtype'])
else:
zoom_factor = np.array(shape) / np.array(phi.shape[:-1])
phi_ = [zoom(phi[..., i], zoom_factor, mode='nearest') for i in range(3)]
VARS['phi'] = np.ascontiguousarray(np.moveaxis(np.array(phi_), 0, -1))
# initialize the transformer
VARS['transformer'] = transformer.transformer(shape, VARS['spacing'], CONS['dtype'])
VARS['transformer'].set_initial_transform(CONS['initial_transform'])
# initialize the smoothers
VARS['field_smoother'] = smoother.smoother(
CONS['field_abcd'][0] * 2**level,
*CONS['field_abcd'][1:], VARS['spacing'], shape, CONS['dtype'])
VARS['grad_smoother'] = smoother.smoother(
CONS['grad_abcd'][0] * 2**level,
*CONS['grad_abcd'][1:], VARS['spacing'], shape, CONS['dtype'])
# initialize the matcher
VARS['matcher'] = matcher.matcher(VARS['fixed'], VARS['moving'], CONS['lcc_radius'])
return VARS
def register(args):
CONS = initialize_constants(args)
level = len(CONS['iterations']) - 1
print(args)
print(args, file=CONS['log'])
# record initial energy
# TODO: include initial transform in this energy calculation
ff, mm, rad = CONS['fixed'], CONS['moving'], CONS['lcc_radius']
mat = matcher.matcher(ff, mm, rad)
energy = mat.lcc(ff, mm, rad)
message = 'initial energy: ' + str(energy)
print(message)
print(message, file=CONS['log'])
# multiscale loop
start_time = time.clock()
lowest_phi = 0
for local_iterations in CONS['iterations']:
# initialize level
phi_ = None if level == len(CONS['iterations'])-1 else lowest_phi
VARS = initialize_variables(CONS, phi_, level)
iteration, backstep_count, converged = 0, 0, False
local_step = CONS['gradient_step']
lowest_energy = 0
# loop for current level
while iteration < local_iterations and not converged:
t0 = time.clock()
# compute the residual
warped = VARS['transformer'].apply_transform(VARS['moving'],
VARS['spacing'], VARS['phi'], initial_transform=True) # should check args.initial_transform
energy, residual = VARS['matcher'].lcc_grad(VARS['fixed'], warped,
CONS['lcc_radius'], VARS['spacing'])
residual = VARS['grad_smoother'].smooth(residual)
max_residual = np.linalg.norm(residual, axis=-1).max()
residual *= VARS['spacing'].min()/max_residual
# monitor the optimization
if energy > (1 - CONS['tolerance']) * lowest_energy:
VARS['phi'] = np.copy(lowest_phi)
local_step *= 0.5
backstep_count += 1
iteration -= 1
VARS['field_smoother'] = smoother.smoother(
CONS['field_abcd'][0] * 2**level / 4**backstep_count,
*CONS['field_abcd'][1:], VARS['spacing'], VARS['fixed'].shape, CONS['dtype'])
if backstep_count >= max(local_iterations//10, 5): converged = True
else:
if energy < lowest_energy:
lowest_energy, lowest_phi = energy, np.copy(VARS['phi'])
backstep_count = max(0, backstep_count-1)
# the gradient descent update
residual *= -local_step
for i in range(3):
VARS['warped_transform'][..., i] = VARS['transformer'].apply_transform(
VARS['phi'][..., i], VARS['spacing'], residual)
VARS['phi'] = VARS['warped_transform'] + residual
VARS['phi'] = VARS['field_smoother'].smooth(VARS['phi'])
iteration += 1
# record progress
message = 'it: ' + str(iteration) + \
', en: ' + str(energy) + \
', time: ' + str(time.clock() - t0) + \
', bsc: ' + str(backstep_count)
print(message)
print(message, file=CONS['log'])
level -= 1
message = 'total optimization time: ' + str(time.clock() - start_time)
print(message)
print(message, file=CONS['log'])
if args.final_lcc is not None or \
args.warped_image is not None:
warped = VARS['transformer'].apply_transform(CONS['moving'],
CONS['spacing'], lowest_phi, initial_transform=True) # should check args.initial_transform
# write the warped image
if args.warped_image is not None:
inout.write_image(warped, args.warped_image)
# write the final lcc
if args.final_lcc is not None:
final_lcc = VARS['matcher'].lcc(CONS['fixed'], warped, CONS['lcc_radius'], mean=False)
inout.write_image(final_lcc, args.final_lcc)
# write the deformation field
output = lowest_phi
if args.compose_output_with_it:
output += VARS['transformer'].Xit - VARS['transformer'].X
inout.write_image(output, args.output)