forked from ishansd/swg
-
Notifications
You must be signed in to change notification settings - Fork 1
/
generator.py
73 lines (63 loc) · 1.87 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import tensorflow as tf
import tensorflow.contrib.layers as layers
def generator(z, reuse=False):
"""
generator
Network to produce samples.
params:
z: Input noise [batch size, latent dimension]
returns:
x_hat: Artificial image [batch size, 64, 64, 3]
"""
batch_norm = layers.batch_norm
outputs = []
h = z
with tf.variable_scope("generator", reuse=reuse) as scope:
h = layers.fully_connected(
inputs=h,
num_outputs=4 * 4 * 1024,
activation_fn=tf.nn.relu,
normalizer_fn=batch_norm)
h = tf.reshape(h, [-1, 4, 4, 1024])
# [4,4,1024]
h = layers.conv2d_transpose(
inputs=h,
num_outputs=512,
kernel_size=4,
stride=2,
activation_fn=tf.nn.relu,
normalizer_fn=batch_norm)
# [8,8,512]
h = layers.conv2d_transpose(
inputs=h,
num_outputs=256,
kernel_size=4,
stride=2,
activation_fn=tf.nn.relu,
normalizer_fn=batch_norm)
# [16,16,256]
h = layers.conv2d_transpose(
inputs=h,
num_outputs=128,
kernel_size=4,
stride=2,
activation_fn=tf.nn.relu,
normalizer_fn=batch_norm)
# This is an extra conv layer like the WGAN folks.
h = layers.conv2d(
inputs=h,
num_outputs=128,
kernel_size=4,
stride=1,
activation_fn=tf.nn.relu,
normalizer_fn=batch_norm)
# [32,32,128]
x_hat = layers.conv2d_transpose(
inputs=h,
num_outputs=3,
kernel_size=4,
stride=2,
activation_fn=tf.nn.sigmoid,
biases_initializer=None)
# [64,64,3]
return x_hat