def discriminator(self, x, reuse=None): """ :param x: images :param y: labels :param reuse: re-usable :return: classification, probability (fake or real), network """ with tf.variable_scope("discriminator", reuse=reuse): f = self.gf_dim x = t.conv2d_alt(x, f, 4, 2, pad=1, sn=True, name='disc-conv2d-1') x = tf.nn.leaky_relu(x, alpha=0.1) for i in range(self.n_layer // 2): x = t.conv2d_alt(x, f * 2, 4, 2, pad=1, sn=True, name='disc-conv2d-%d' % (i + 2)) x = tf.nn.leaky_relu(x, alpha=0.1) f *= 2 # Self-Attention Layer x = self.attention(x, f, reuse=reuse) for i in range(self.n_layer // 2, self.n_layer): x = t.conv2d_alt(x, f * 2, 4, 2, pad=1, sn=True, name='disc-conv2d-%d' % (i + 2)) x = tf.nn.leaky_relu(x, alpha=0.1) f *= 2 x = t.flatten(x) x = t.dense_alt(x, 1, sn=True, name='disc-fc-1') return x
def attention(x, f_, reuse=None): with tf.variable_scope("attention", reuse=reuse): f = t.conv2d_alt(x, f_ // 8, 1, 1, sn=True, name='attention-conv2d-f') g = t.conv2d_alt(x, f_ // 8, 1, 1, sn=True, name='attention-conv2d-g') h = t.conv2d_alt(x, f_, 1, 1, sn=True, name='attention-conv2d-h') f, g, h = t.hw_flatten(f), t.hw_flatten(g), t.hw_flatten(h) s = tf.matmul(g, f, transpose_b=True) attention_map = tf.nn.softmax(s, axis=-1, name='attention_map') o = tf.reshape(tf.matmul(attention_map, h), shape=x.get_shape()) gamma = tf.get_variable('gamma', shape=[1], initializer=tf.zeros_initializer()) x = gamma * o + x return x
def generator(self, z, reuse=None, is_train=True): """ :param z: noise :param y: image label :param reuse: re-usable :param is_train: trainable :return: prob """ with tf.variable_scope("generator", reuse=reuse): f = self.gf_dim * 8 x = t.dense_alt(z, 4 * 4 * f, sn=True, name='gen-fc-1') x = tf.reshape(x, (-1, 4, 4, f)) for i in range(self.n_layer // 2): if self.up_sampling: x = t.up_sampling(x, interp=tf.image.ResizeMethod.NEAREST_NEIGHBOR) x = t.conv2d_alt(x, f // 2, 5, 1, pad=2, sn=True, use_bias=False, name='gen-conv2d-%d' % (i + 1)) else: x = t.deconv2d_alt(x, f // 2, 4, 2, sn=True, use_bias=False, name='gen-deconv2d-%d' % (i + 1)) x = t.batch_norm(x, is_train=is_train, name='gen-bn-%d' % i) x = tf.nn.relu(x) f //= 2 # Self-Attention Layer x = self.attention(x, f, reuse=reuse) for i in range(self.n_layer // 2, self.n_layer): if self.up_sampling: x = t.up_sampling(x, interp=tf.image.ResizeMethod.NEAREST_NEIGHBOR) x = t.conv2d_alt(x, f // 2, 5, 1, pad=2, sn=True, use_bias=False, name='gen-conv2d-%d' % (i + 1)) else: x = t.deconv2d_alt(x, f // 2, 4, 2, sn=True, use_bias=False, name='gen-deconv2d-%d' % (i + 1)) x = t.batch_norm(x, is_train=is_train, name='gen-bn-%d' % i) x = tf.nn.relu(x) f //= 2 x = t.conv2d_alt(x, self.channel, 5, 1, pad=2, sn=True, name='gen-conv2d-%d' % (self.n_layer + 1)) x = tf.nn.tanh(x) return x
def res_block(x, f, scale_type, use_bn=True, name=""): with tf.variable_scope("res_block-%s" % name): assert scale_type in ["up", "down"] scale_up = False if scale_type == "down" else True ssc = x x = t.batch_norm(x, name="bn-1") if use_bn else x x = tf.nn.relu(x) x = t.conv2d_alt(x, f, sn=True, name="conv2d-1") x = t.batch_norm(x, name="bn-2") if use_bn else x x = tf.nn.relu(x) if not scale_up: x = t.conv2d_alt(x, f, sn=True, name="conv2d-2") x = tf.layers.average_pooling2d(x, pool_size=(2, 2)) else: x = t.deconv2d_alt(x, f, sn=True, name="up-sampling") return x + ssc
def generator(self, z, reuse=None): """ :param z: noise :param reuse: re-usable :return: prob """ with tf.variable_scope("generator", reuse=reuse): # split z = tf.split(z, num_or_size_splits=4, axis=-1) # expected [None, 32] * 4 # linear projection x = t.dense_alt(z, f=4 * 4 * 16 * self.channel, sn=True, use_bias=False, name="gen-dense-1") x = tf.nn.relu(x) x = tf.reshape(x, (-1, 4, 4, 16 * self.channel)) res = x f = 16 * self.channel for i in range(4): res = self.res_block(res, f=f, scale_type="up", name="gen-res%d" % (i + 1)) f //= 2 x = self.self_attention(res, f_=f * 2) x = self.res_block(x, f=1 * self.channel, scale_type="up", name="gen-res4") x = t.batch_norm(x, name="gen-bn-last") # <- noise x = tf.nn.relu(x) x = t.conv2d_alt(x, f=self.channel, k=3, sn=True, name="gen-conv2d-last") x = tf.nn.tanh(x) return x