-
Notifications
You must be signed in to change notification settings - Fork 1
/
MatryoshkaModules.py
538 lines (487 loc) · 21.7 KB
/
MatryoshkaModules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
import numpy as np
import theano
import theano.tensor as T
from theano.sandbox.cuda.dnn import dnn_conv, dnn_pool
from theano.sandbox.cuda.rng_curand import CURAND_RandomStreams as RandStream
from lib import activations
from lib import updates
from lib import inits
from lib.rng import py_rng, np_rng
from lib.ops import batchnorm, conv_cond_concat, deconv, dropout
from lib.theano_utils import floatX, sharedX
relu = activations.Rectify()
sigmoid = activations.Sigmoid()
lrelu = activations.LeakyRectify()
bce = T.nnet.binary_crossentropy
#############################
# BASIC CONVOLUTIONAL LAYER #
#############################
class BasicConvModule(object):
"""
Simple convolutional layer for use anywhere?
Params:
filt_shape: filter shape, should be square and odd dim
in_chans: number of channels in input
out_chans: number of channels to produce as output
apply_bn: whether to apply batch normalization after conv
act_func: should be "relu" or "lrelu"
init_func: function for initializing module parameters
mod_name: text name to identify this module in theano graph
"""
def __init__(self, filt_shape, in_chans, out_chans,
apply_bn=True, act_func='lrelu', init_func=None,
mod_name='dm_conv'):
assert ((filt_shape[0] % 2) > 0), "filter dim should be odd (not even)"
self.filt_dim = filt_shape[0]
self.in_chans = in_chans
self.out_chans = out_chans
self.apply_bn = apply_bn
self.act_func = act_func
self.mod_name = mod_name
if init_func is None:
self.init_func = inits.Normal(scale=0.02)
else:
self.init_func = init_func
self._init_params() # initialize parameters
return
def _init_params(self):
"""
Initialize parameters for the layers in this discriminator module.
"""
self.w1 = self.init_func((self.out_chans, self.in_chans, self.filt_dim, self.filt_dim),
"{}_w1".format(self.mod_name))
self.params = [self.w1]
# make gains and biases for transforms that will get batch normed
if self.apply_bn:
gain_ifn = inits.Normal(loc=1., scale=0.02)
bias_ifn = inits.Constant(c=0.)
self.g1 = gain_ifn((self.out_chans), "{}_g1".format(self.mod_name))
self.b1 = bias_ifn((self.out_chans), "{}_b1".format(self.mod_name))
self.params.extend([self.g1, self.b1])
return
def apply(self, input):
"""
Apply this convolutional module to the given input.
"""
bm = int((self.filt_dim - 1) / 2) # use "same" mode convolutions
# apply first conv layer
h1 = dnn_conv(input, self.w1, subsample=(1, 1), border_mode=(bm, bm))
if self.apply_bn:
h1 = batchnorm(h1, g=self.g1, b=self.b1)
if self.act_func == 'lrelu':
h1 = lrelu(h1)
elif self.act_func == 'relu':
h1 = relu(h1)
else:
assert False, "Unsupported activation function."
return h1
#############################################
# DISCRIMINATOR DOUBLE CONVOLUTIONAL MODULE #
#############################################
class DiscConvModule(object):
"""
Module that does one layer of convolution with stride 1 followed by
another layer of convlution with adjustable stride.
Following the second layer of convolution, an additional convolution
is performed that produces a single "discriminator" channel.
Params:
filt_shape: filter shape, should be square and odd dim
in_chans: number of channels in input
out_chans: number of channels to produce as output
apply_bn_1: whether to apply batch normalization after first conv
apply_bn_2: whether to apply batch normalization after second conv
ds_stride: "downsampling" stride for the second convolution
use_pooling: whether to use max pooling or multi-striding
init_func: function for initializing module parameters
mod_name: text name to identify this module in theano graph
"""
def __init__(self, filt_shape, in_chans, out_chans,
apply_bn_1=True, apply_bn_2=True, ds_stride=2,
use_pooling=True, init_func=None, mod_name='dm_conv'):
assert ((filt_shape[0] % 2) > 0), "filter dim should be odd (not even)"
self.filt_dim = filt_shape[0]
self.in_chans = in_chans
self.out_chans = out_chans
self.apply_bn_1 = apply_bn_1
self.apply_bn_2 = apply_bn_2
self.ds_stride = ds_stride
self.use_pooling = use_pooling
self.mod_name = mod_name
if init_func is None:
self.init_func = inits.Normal(scale=0.02)
else:
self.init_func = init_func
self._init_params() # initialize parameters
return
def _init_params(self):
"""
Initialize parameters for the layers in this discriminator module.
"""
self.w1 = self.init_func((self.out_chans, self.in_chans, self.filt_dim, self.filt_dim),
"{}_w1".format(self.mod_name))
self.w2 = self.init_func((self.out_chans, self.out_chans, self.filt_dim, self.filt_dim),
"{}_w2".format(self.mod_name))
self.wd = self.init_func((1, self.out_chans, self.filt_dim, self.filt_dim),
"{}_wd".format(self.mod_name))
self.params = [self.w1, self.w2, self.wd]
# make gains and biases for transforms that will get batch normed
if self.apply_bn_1:
gain_ifn = inits.Normal(loc=1., scale=0.02)
bias_ifn = inits.Constant(c=0.)
self.g1 = gain_ifn((self.out_chans), "{}_g1".format(self.mod_name))
self.b1 = bias_ifn((self.out_chans), "{}_b1".format(self.mod_name))
self.params.extend([self.g1, self.b1])
if self.apply_bn_2:
gain_ifn = inits.Normal(loc=1., scale=0.02)
bias_ifn = inits.Constant(c=0.)
self.g2 = gain_ifn((self.out_chans), "{}_g2".format(self.mod_name))
self.b2 = bias_ifn((self.out_chans), "{}_b2".format(self.mod_name))
self.params.extend([self.g2, self.b2])
return
def apply(self, input):
"""
Apply this discriminator module to the given input. This produces a
collection of filter responses for feedforward and a spatial grid of
discriminator outputs.
"""
bm = int((self.filt_dim - 1) / 2) # use "same" mode convolutions
ss = self.ds_stride # stride for "learned downsampling"
# apply first conv layer
h1 = dnn_conv(input, self.w1, subsample=(1, 1), border_mode=(bm, bm))
if self.apply_bn_1:
h1 = batchnorm(h1, g=self.g1, b=self.b1)
h1 = lrelu(h1)
# apply second conv layer (may include downsampling)
if self.use_pooling:
h2 = dnn_conv(h1, self.w2, subsample=(1, 1), border_mode=(bm, bm))
if self.apply_bn_2:
h2 = batchnorm(h2, g=self.g2, b=self.b2)
h2 = lrelu(h2)
h2 = dnn_pool(h2, (ss,ss), stride=(ss, ss), mode='max', pad=(0, 0))
else:
h2 = dnn_conv(h1, self.w2, subsample=(ss, ss), border_mode=(bm, bm))
if self.apply_bn_2:
h2 = batchnorm(h2, g=self.g2, b=self.b2)
h2 = lrelu(h2)
# apply discriminator layer
y = dnn_conv(h2, self.wd, subsample=(1, 1), border_mode=(bm, bm))
y = sigmoid(T.flatten(y, 2)) # flatten to (batch_size, num_preds)
return h2, y
########################################
# DISCRIMINATOR FULLY CONNECTED MODULE #
########################################
class DiscFCModule(object):
"""
Module that feeds forward through a single fully connected hidden layer
and then produces a single scalar "discriminator" output.
Params:
fc_dim: dimension of the fully connected layer
in_dim: dimension of the inputs to the module
apply_bn: whether to apply batch normalization at fc layer
init_func: function for initializing module parameters
mod_name: text name for identifying module in theano graph
"""
def __init__(self, fc_dim, in_dim, apply_bn=True,
init_func=None, mod_name='dm_fc'):
self.fc_dim = fc_dim
self.in_dim = in_dim
self.apply_bn = apply_bn
self.mod_name = mod_name
if init_func is None:
self.init_func = inits.Normal(scale=0.02)
else:
self.init_func = init_func
self._init_params() # initialize parameters
return
def _init_params(self):
"""
Initialize parameters for the layers in this discriminator module.
"""
self.w1 = self.init_func((self.in_dim, self.fc_dim),
"{}_w1".format(self.mod_name))
self.w2 = self.init_func((self.fc_dim, 1),
"{}_w2".format(self.mod_name))
self.params = [self.w1, self.w2]
# make gains and biases for transforms that will get batch normed
if self.apply_bn:
gain_ifn = inits.Normal(loc=1., scale=0.02)
bias_ifn = inits.Constant(c=0.)
self.g1 = gain_ifn((self.fc_dim), "{}_g1".format(self.mod_name))
self.b1 = bias_ifn((self.fc_dim), "{}_b1".format(self.mod_name))
self.params.extend([self.g1, self.b1])
return
def apply(self, input):
"""
Apply this discriminator module to the given input. This produces a
scalar discriminator output for each input observation.
"""
# flatten input to 1d per example
input = T.flatten(input, 2)
# feedforward to fully connected layer
h1 = T.dot(input, self.w1)
if self.apply_bn:
h1 = batchnorm(h1, g=self.g1, b=self.b1)
h1 = lrelu(h1)
# feedforward to discriminator outputs
y = sigmoid(T.dot(h1, self.w2))
return y
#########################################
# GENERATOR DOUBLE CONVOLUTIONAL MODULE #
#########################################
class GenConvModule(object):
"""
Module of one "fractionally strided" convolution layer followed by one
regular convolution layer. Inputs to the fractionally strided convolution
can optionally be augmented with some random values.
Params:
filt_shape: shape for convolution filters -- should be square and odd
in_chans: number of channels in the inputs to module
out_chans: number of channels in the outputs from module
rand_chans: number of random channels to augment input
use_rand: flag for whether or not to augment inputs
apply_bn_1: flag for whether to batch normalize following first conv
apply_bn_2: flag for whether to batch normalize following second conv
us_stride: upsampling ratio in the fractionally strided convolution
use_pooling: whether to use unpooling or fractional striding
init_func: function for initializing module parameters
mod_name: text name for identifying module in theano graph
rand_type: whether to use Gaussian or uniform randomness
"""
def __init__(self, filt_shape, in_chans, out_chans, rand_chans,
use_rand=True, apply_bn_1=True, apply_bn_2=True,
us_stride=2, use_pooling=True,
init_func=None, mod_name='gm_conv',
rand_type='normal'):
assert ((filt_shape[0] % 2) > 0), "filter dim should be odd (not even)"
self.filt_dim = filt_shape[0]
self.in_chans = in_chans
self.out_chans = out_chans
self.rand_chans = rand_chans
self.use_rand = use_rand
self.apply_bn_1 = apply_bn_1
self.apply_bn_2 = apply_bn_2
self.us_stride = us_stride
self.use_pooling = use_pooling
self.mod_name = mod_name
self.rand_type = rand_type
self.rng = RandStream(123)
if init_func is None:
self.init_func = inits.Normal(scale=0.02)
else:
self.init_func = init_func
self._init_params() # initialize parameters
return
def _init_params(self):
"""
Initialize parameters for the layers in this generator module.
"""
if self.use_rand:
# random values will be stacked on exogenous input
self.w1 = self.init_func((self.out_chans, (self.in_chans+self.rand_chans), self.filt_dim, self.filt_dim),
"{}_w1".format(self.mod_name))
else:
# random values won't be stacked on exogenous input
self.w1 = self.init_func((self.out_chans, self.in_chans, self.filt_dim, self.filt_dim),
"{}_w1".format(self.mod_name))
self.w2 = self.init_func((self.out_chans, self.out_chans, self.filt_dim, self.filt_dim),
"{}_w2".format(self.mod_name))
self.params = [self.w1, self.w2]
# make gains and biases for transforms that will get batch normed
if self.apply_bn_1:
gain_ifn = inits.Normal(loc=1., scale=0.02)
bias_ifn = inits.Constant(c=0.)
self.g1 = gain_ifn((self.out_chans), "{}_g1".format(self.mod_name))
self.b1 = bias_ifn((self.out_chans), "{}_b1".format(self.mod_name))
self.params.extend([self.g1, self.b1])
if self.apply_bn_2:
gain_ifn = inits.Normal(loc=1., scale=0.02)
bias_ifn = inits.Constant(c=0.)
self.g2 = gain_ifn((self.out_chans), "{}_g2".format(self.mod_name))
self.b2 = bias_ifn((self.out_chans), "{}_b2".format(self.mod_name))
self.params.extend([self.g2, self.b2])
return
def apply(self, input, rand_vals=None):
"""
Apply this generator module to some input.
"""
batch_size = input.shape[0]
bm = int((self.filt_dim - 1) / 2) # use "same" mode convolutions
ss = self.us_stride # stride for "learned upsampling"
if self.use_pooling:
# "unpool" the input if desired
input = input.repeat(ss, axis=2).repeat(ss, axis=3)
# get shape for random values that will augment input
rand_shape = (batch_size, self.rand_chans, input.shape[2], input.shape[3])
if self.use_rand:
# augment input with random channels
if rand_vals is None:
if self.rand_type == 'normal':
rand_vals = self.rng.normal(size=rand_shape, avg=0.0, std=1.0, \
dtype=theano.config.floatX)
else:
rand_vals = self.rng.uniform(size=rand_shape, low=-1.0, high=1.0, \
dtype=theano.config.floatX)
rand_vals = rand_vals.reshape(rand_shape)
# stack random values on top of input
full_input = T.concatenate([rand_vals, input], axis=1)
else:
# don't augment input with random channels
full_input = input
# apply first convolution, perhaps with fractional striding
if self.use_pooling:
h1 = dnn_conv(full_input, self.w1, subsample=(1, 1), border_mode=(bm, bm))
else:
# apply first conv layer (with fractional stride for upsampling)
h1 = deconv(full_input, self.w1, subsample=(ss, ss), border_mode=(bm, bm))
if self.apply_bn_1:
h1 = batchnorm(h1, g=self.g1, b=self.b1)
h1 = relu(h1)
# apply second conv layer
h2 = dnn_conv(h1, self.w2, subsample=(1, 1), border_mode=(bm, bm))
if self.apply_bn_2:
h2 = batchnorm(h2, g=self.g2, b=self.b2)
h2 = relu(h2)
return h2
####################################
# GENERATOR FULLY CONNECTED MODULE #
####################################
class GenFCModule(object):
"""
Module that transforms random values through a single fully connected
layer, and then a linear transform (with another relu, optionally).
"""
def __init__(self, rand_dim, out_dim, fc_dim,
apply_bn_1=True, apply_bn_2=True,
init_func=None, rand_type='normal',
final_relu=True, mod_name='dm_fc'):
self.rand_dim = rand_dim
self.out_dim = out_dim
self.fc_dim = fc_dim
self.apply_bn_1 = apply_bn_1
self.apply_bn_2 = apply_bn_2
self.mod_name = mod_name
self.rand_type = rand_type
self.final_relu = final_relu
self.rng = RandStream(123)
if init_func is None:
self.init_func = inits.Normal(scale=0.02)
else:
self.init_func = init_func
self._init_params() # initialize parameters
return
def _init_params(self):
"""
Initialize parameters for the layers in this generator module.
"""
self.w1 = self.init_func((self.rand_dim, self.fc_dim),
"{}_w1".format(self.mod_name))
self.w2 = self.init_func((self.fc_dim, self.out_dim),
"{}_w2".format(self.mod_name))
self.params = [self.w1, self.w2]
# make gains and biases for transforms that will get batch normed
if self.apply_bn_1:
gain_ifn = inits.Normal(loc=1., scale=0.02)
bias_ifn = inits.Constant(c=0.)
self.g1 = gain_ifn((self.fc_dim), "{}_g1".format(self.mod_name))
self.b1 = bias_ifn((self.fc_dim), "{}_b1".format(self.mod_name))
self.params.extend([self.g1, self.b1])
if self.apply_bn_2:
gain_ifn = inits.Normal(loc=1., scale=0.02)
bias_ifn = inits.Constant(c=0.)
self.g2 = gain_ifn((self.out_dim), "{}_g2".format(self.mod_name))
self.b2 = bias_ifn((self.out_dim), "{}_b2".format(self.mod_name))
self.params.extend([self.g2, self.b2])
return
def apply(self, batch_size=None, rand_vals=None):
"""
Apply this generator module. Pass _either_ batch_size or rand_vals.
"""
assert not ((batch_size is None) and (rand_vals is None)), "need either batch_size or rand_vals"
if rand_vals is None:
rand_shape = (batch_size, self.rand_dim)
if self.rand_type == 'normal':
rand_vals = self.rng.normal(size=rand_shape, avg=0.0, std=1.0, \
dtype=theano.config.floatX)
else:
rand_vals = self.rng.uniform(size=rand_shape, low=-1.0, high=1.0, \
dtype=theano.config.floatX)
else:
rand_shape = (rand_vals.shape[0], self.rand_dim)
rand_vals = rand_vals.reshape(rand_shape)
# transform random values into fc layer
h1 = T.dot(rand_vals, self.w1)
if self.apply_bn_1:
h1 = batchnorm(h1, g=self.g1, b=self.b1)
h1 = relu(h1)
# transform from fc layer to output
h2 = T.dot(h1, self.w2)
if self.apply_bn_2:
h2 = batchnorm(h2, g=self.g2, b=self.b2)
if self.final_relu:
h2 = relu(h2)
return h2
####################################
# GENERATOR FULLY CONNECTED MODULE #
####################################
class GenUniModule(object):
"""
Module that applies a linear transform followed by an non-linearity.
"""
def __init__(self, rand_dim, out_dim,
apply_bn=True, init_func=None,
rand_type='normal', final_relu=True,
mod_name='dm_uni'):
self.rand_dim = rand_dim
self.out_dim = out_dim
self.apply_bn = apply_bn
self.mod_name = mod_name
self.rand_type = rand_type
self.final_relu = final_relu
self.rng = RandStream(123)
if init_func is None:
self.init_func = inits.Normal(scale=0.02)
else:
self.init_func = init_func
self._init_params() # initialize parameters
return
def _init_params(self):
"""
Initialize parameters for the layers in this generator module.
"""
self.w1 = self.init_func((self.rand_dim, self.out_dim),
"{}_w1".format(self.mod_name))
self.params = [ self.w1 ]
# make gains and biases for transforms that will get batch normed
if self.apply_bn:
gain_ifn = inits.Normal(loc=1., scale=0.02)
bias_ifn = inits.Constant(c=0.)
self.g1 = gain_ifn((self.out_dim), "{}_g1".format(self.mod_name))
self.b1 = bias_ifn((self.out_dim), "{}_b1".format(self.mod_name))
self.params.extend([self.g1, self.b1])
return
def apply(self, batch_size=None, rand_vals=None):
"""
Apply this generator module. Pass _either_ batch_size or rand_vals.
"""
assert not ((batch_size is None) and (rand_vals is None)), "need either batch_size or rand_vals"
if rand_vals is None:
rand_shape = (batch_size, self.rand_dim)
if self.rand_type == 'normal':
rand_vals = self.rng.normal(size=rand_shape, avg=0.0, std=1.0, \
dtype=theano.config.floatX)
else:
rand_vals = self.rng.uniform(size=rand_shape, low=-1.0, high=1.0, \
dtype=theano.config.floatX)
else:
rand_shape = (rand_vals.shape[0], self.rand_dim)
rand_vals = rand_vals.reshape(rand_shape)
# transform random values linearly
h1 = T.dot(rand_vals, self.w1)
if self.apply_bn:
h1 = batchnorm(h1, g=self.g1, b=self.b1)
if self.final_relu:
h1 = relu(h1)
return h1
##############
# EYE BUFFER #
##############