def block_e(x,res): with tf.variable_scope('Gen_Enc%dx%d' % (2**res, 2**res)): with tf.variable_scope('Conv'): x = leaky_relu(batchnorm(apply_bias(conv2d(x, fmaps = nf(res-2), kernel = 3, cf = self.channel_first), cf = self.channel_first), cf = self.channel_first)) x = downscale2d(x, cf = self.channel_first) return x
def discriminator(self, unknown, input_image): assert self.resolution == 2**self.resolution_log2 and self.resolution >= 4 def nf(stage): return min(int(self.fmap_base / (2.0**(stage * self.fmap_decay))), self.fmap_max) def fromrgb(x, res): with tf.variable_scope('Disc_FromRGB_lod%d' % (self.resolution_log2 - res)): return leaky_relu( apply_bias(conv2d(x, fmaps=nf(res - 1), kernel=1, cf=self.channel_first), cf=self.channel_first)) def block(x, res): layers = [] with tf.variable_scope('Disc_%dx%d' % (2**res, 2**res)): if res > 4: with tf.variable_scope('Conv'): x = leaky_relu( batchnorm(apply_bias(conv2d(x, fmaps=nf(res - 2), kernel=3, cf=self.channel_first), cf=self.channel_first), cf=self.channel_first)) x = downscale2d(x, cf=self.channel_first) else: with tf.variable_scope('Patch'): x = tf.sigmoid( apply_bias(conv2d(x, fmaps=1, kernel=3, cf=self.channel_first), cf=self.channel_first)) return x if self.structure == 'linear': img = tf.concat([input_image, unknown], axis=self.conc_axis) x = fromrgb(img, self.resolution_log2) for res in range(self.resolution_log2, 4, -1): #print(res, x.shape, nf(res-2)) lod = self.resolution_log2 - res x = block(x, res) img = downscale2d(img, cf=self.channel_first) y = fromrgb(img, res - 1) with tf.variable_scope('Disc_Grow_lod%d' % lod): x = lerp_clip(x, y, self.lod_in - lod) x = block(x, 4) return x
def block(x,res): layers = [] with tf.variable_scope('Disc_%dx%d' % (2**res, 2**res)): if res > 4: with tf.variable_scope('Conv'): x = leaky_relu(batchnorm(apply_bias(conv2d(x, fmaps = nf(res-2), kernel = 3, cf = self.channel_first), cf = self.channel_first), cf = self.channel_first)) x = downscale2d(x, cf = self.channel_first) else: with tf.variable_scope('Patch'): x = tf.sigmoid(apply_bias(conv2d(x, fmaps = 1, kernel = 3, cf = self.channel_first), cf = self.channel_first)) return x
def generator(self, labels_in): assert self.resolution == 2**self.resolution_log2 and self.resolution >= 4 def nf(stage): return min(int(self.fmap_base / (2.0 ** (stage * self.fmap_decay))), self.fmap_max) # First Layer from Image: A x A x 3 ==> A x A x Channels(A) # -------------------------------------------------------------------------- def fromrgb(x, res): with tf.variable_scope('Gen_Enc_FromRGB_lod%d' % (self.resolution_log2 - res)): return leaky_relu(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=1, cf = self.channel_first), cf = self.channel_first)) # Building Blocks Encoder: Input --> Convolution (+Bias) --> Batchnormalisation # --> Activation function --> Downsample by factor 2 # -------------------------------------------------------------------------- def block_e(x,res): with tf.variable_scope('Gen_Enc%dx%d' % (2**res, 2**res)): with tf.variable_scope('Conv'): x = leaky_relu(batchnorm(apply_bias(conv2d(x, fmaps = nf(res-2), kernel = 3, cf = self.channel_first), cf = self.channel_first), cf = self.channel_first)) x = downscale2d(x, cf = self.channel_first) return x # ========================================================================= # Encoder # ========================================================================= if self.structure == 'linear': skip = [] img = labels_in x = fromrgb(img, self.resolution_log2) #print(x.shape) for res in range(self.resolution_log2, 4, -1): lod = self.resolution_log2 - res x = block_e(x, res) #print(x.shape) img = downscale2d(img, cf = self.channel_first) y = fromrgb(img, res - 1) with tf.variable_scope('Gen_Enc_Grow_lod%d' % lod): x = lerp_clip(x, y, self.lod_in - lod) skip.append(x) # #print('Encoder after', 2**res, "conv", x.shape) for res in range(4,0,-1): lod = self.resolution_log2 - res x = block_e(x, res) #print(x.shape) if res >= 2: skip.append(x) #print("Encoder Output", x.shape) combo_out = x # ======================================================================== # Decoder # ======================================================================== # Last Layer to Image: A x A x Channels(A) ==> A x A x 3 # -------------------------------------------------------------------------- def torgb(x, res): # res = 2..resolution_log2 lod = self.resolution_log2 - res with tf.variable_scope('Gen_Dec_ToRGB_lod%d' % lod): return apply_bias(conv2d(x, fmaps=self.num_channels, kernel=1, cf = self.channel_first), cf = self.channel_first) # Building Blocks Encoder: Input --> Upsampling by factor 2 --> Convolution (+Bias) # --> Batchnormalisation --> Activation # -------------------------------------------------------------------------- def block_d(x,res): layers = [] with tf.variable_scope('Gen_Dec_%dx%d' % (2**res, 2**res)): x = upscale2d(x, cf = self.channel_first) with tf.variable_scope('Conv'): x = leaky_relu(batchnorm(apply_bias(conv2d(x, fmaps = nf(res-1), kernel = 3, cf = self.channel_first), cf = self.channel_first), cf = self.channel_first)) return x # Growing the Decoder # --------------------------------------------------------------------------- if self.structure == 'linear': #print('Decode Input:', x.shape) x = combo_out #print('start decoder',x.shape) x = block_d(x, 1) # 1x1x512 ==> 2x2x512 #print(x.shape) x = tf.concat([x,skip[-1]], axis=self.conc_axis) # concat ==> 2x2x1024 x = block_d(x, 2) # 2x2x1024 ==> 4x4x512 #print(x.shape) x = tf.concat([x,skip[-2]], axis=self.conc_axis) x = block_d(x, 3) # 4x4x1024 ==> 8x8x512 #print(x.shape) x = tf.concat([x,skip[-3]], axis=self.conc_axis) x = block_d(x, 4) # 8x8x1024 ==> 16x16x512 #print(x.shape) #print("-----") images_out = torgb(x,4) # Extracted Output layer 16x16 #print("Image Out:", images_out.shape) x = tf.concat([x,skip[-4]], axis=self.conc_axis) # concat ==> 16x16x1024 #print('Decode after const:', x.shape) for res in range(5,self.resolution_log2+1): lod = self.resolution_log2 - res x = block_d(x,res) #print(x.shape) img = torgb(x,res) if res < self.resolution_log2: x = tf.concat([x,skip[-res]], axis = self.conc_axis) images_out = upscale2d(images_out, cf = self.channel_first) with tf.variable_scope('Gen_Dec_Grow_lod%d' % lod): images_out = lerp_clip(img, images_out, self.lod_in - lod) #print("Images Out:", images_out.shape) #print('Decode res:', 2**res,'shape x:', x.shape, "shape img out:", images_out.shape) #print('output',images_out.shape) assert images_out.dtype == tf.as_dtype(self.dtype) images_out = tf.identity(images_out, name='images_out') return images_out