def shortcut_convolution(high_res_img, low_res_target, nb_channels_out): if img_size(low_res_target) == 1: kernel_size = img_size(high_res_img) downsampled_input = kl.TimeDistributed( SpectralNormalization( kl.Conv2D(nb_channels_out, kernel_size, activation=LeakyReLU(0.2))), name='shortcut_conv_1')(high_res_img) else: strides = int( tf.math.ceil( (2 + img_size(high_res_img)) / (img_size(low_res_target) - 1))) margin = 2 padding = int( tf.math.ceil((strides * (img_size(low_res_target) - 1) - img_size(high_res_img)) / 2) + 1 + margin) kernel_size = int(strides * (1 - img_size(low_res_target)) + img_size(high_res_img) + 2 * padding) downsampled_input = kl.TimeDistributed( kl.ZeroPadding2D(padding=padding))(high_res_img) downsampled_input = kl.TimeDistributed( SpectralNormalization( kl.Conv2D(nb_channels_out, kernel_size, strides=strides, activation=LeakyReLU(0.2))), name='shortcut_conv')(downsampled_input) downsampled_input = kl.LayerNormalization()(downsampled_input) return downsampled_input
def __init__(self, conv, filters, kernel_size, stride, padding, activation, spectral_norm=False, batch_norm=False, name=None): super(CustomConvBlock, self).__init__(name=name) self.sn = spectral_norm self.bn = batch_norm self.conv_type = conv self.activation_type = activation if not spectral_norm: self.conv = self.conv_type(filters, kernel_size, stride, padding, name=name + '_conv') if spectral_norm: self.sn = SpectralNormalization(self.conv_type(filters, kernel_size, stride, padding, name=name + '_conv'), name=name + '_conv_sn') if batch_norm: self.bn = BatchNormalization(name=name + '_bn') self.relu = self.activation_type(name=name + '_act')
def conv2d(layer_input, filters, kernel_size, stride, padding='same', activation=None, bias=True, sn=False): m = (kernel_size**2) * filters weights_ini = RandomNormal(mean=0., stddev=np.sqrt(2 / m)) bias_ini = tf.constant_initializer(0.0) if sn: x = SpectralNormalization( Conv2D(filters=filters, kernel_size=kernel_size, strides=stride, padding=padding, activation=activation, bias_initializer=bias_ini, kernel_initializer=weights_ini, use_bias=bias))(layer_input) else: x = Conv2D(filters=filters, kernel_size=kernel_size, strides=stride, padding=padding, activation=activation, bias_initializer=bias_ini, kernel_initializer=weights_ini, use_bias=bias)(layer_input) return x
def _discriminator_block(self, previous_output, filter_out: int, strides: Tuple[int, int] = (2, 2)): layer = Conv2D(filter_out, (4, 4), strides=strides, padding='same', kernel_initializer=self.disc_init_fn) if self.spec_normalization: layer = SpectralNormalization(layer) out = layer(previous_output) if self.batch_normalization: out = BatchNormalization()(out) return LeakyReLU(alpha=0.2)(out)
def _patch_discriminator(self, shape: Tuple[int, ...], input_channels: int) -> Model: in_image = Input(shape=shape) cond_image = Input((225, 225, input_channels)) conc_img = Concatenate()([in_image, cond_image]) block1 = self._discriminator_block(conc_img, 64) block2 = self._discriminator_block(block1, 128) block3 = self._discriminator_block(block2, 256) block4 = self._discriminator_block(block3, 512) block5 = self._discriminator_block(block4, 512, strides=(1, 1)) final_layer = Conv2D(1, (4, 4), padding='same', activation='sigmoid', kernel_initializer=self.disc_init_fn) if self.spec_normalization: final_layer = SpectralNormalization(final_layer) output = final_layer(block5) return Model([in_image, cond_image], output)
def make_generator(image_size: int, in_channels: int, noise_channels: int, out_channels: int, n_timesteps: int, batch_size: int = None, feature_channels=128): # Make sure we have nice multiples everywhere assert image_size % 4 == 0 assert feature_channels % 8 == 0 total_in_channels = in_channels + noise_channels img_shape = (image_size, image_size) tshape = (n_timesteps, ) + img_shape input_image = kl.Input(shape=tshape + (in_channels, ), batch_size=batch_size, name='input_image') input_noise = kl.Input(shape=tshape + (noise_channels, ), batch_size=batch_size, name='input_noise') # Concatenate inputs x = kl.Concatenate()([input_image, input_noise]) # Add features and decrease image size - in 2 steps intermediate_features = total_in_channels * 8 if total_in_channels * 8 <= feature_channels else feature_channels x = kl.TimeDistributed(kl.ZeroPadding2D(padding=3))(x) x = kl.TimeDistributed( SpectralNormalization( kl.Conv2D(intermediate_features, (8, 8), strides=2, activation=LeakyReLU(0.2))))(x) x = kl.BatchNormalization()(x) assert tuple(x.shape) == (batch_size, n_timesteps, image_size // 2, image_size // 2, intermediate_features) res_2 = x # Keep residuals for later x = kl.TimeDistributed(kl.ZeroPadding2D(padding=1))(x) x = kl.TimeDistributed( SpectralNormalization( kl.Conv2D(feature_channels, (4, 4), strides=2, activation=LeakyReLU(0.2))))(x) x = kl.BatchNormalization()(x) assert tuple(x.shape) == (batch_size, n_timesteps, image_size // 4, image_size // 4, feature_channels) res_4 = x # Keep residuals for later # Recurrent unit x = kl.ConvLSTM2D(feature_channels, (3, 3), padding='same', return_sequences=True)(x) assert tuple(x.shape) == (batch_size, n_timesteps, image_size // 4, image_size // 4, feature_channels) # Re-increase image size and decrease features x = kl.TimeDistributed( SpectralNormalization( kl.Conv2D(feature_channels // 2, (3, 3), padding='same', activation=LeakyReLU(0.2))))(x) x = kl.BatchNormalization()(x) assert tuple(x.shape) == (batch_size, n_timesteps, image_size // 4, image_size // 4, feature_channels // 2) # Re-introduce residuals from before (skip connection) x = kl.Concatenate()([x, res_4]) x = kl.TimeDistributed( SpectralNormalization( kl.Conv2DTranspose(feature_channels / 4, (2, 2), strides=2, activation=LeakyReLU(0.2))))(x) x = kl.BatchNormalization()(x) assert tuple(x.shape) == (batch_size, n_timesteps, image_size // 2, image_size // 2, feature_channels // 4) # Skip connection 2 x = kl.Concatenate()([x, res_2]) if feature_channels / 8 >= out_channels: x = kl.TimeDistributed( kl.UpSampling2D(size=(2, 2), interpolation='bilinear'))(x) x = kl.TimeDistributed( kl.Conv2DTranspose(feature_channels // 8, (5, 5), padding='same', activation=LeakyReLU(0.2)))(x) assert tuple(x.shape) == (batch_size, n_timesteps, image_size, image_size, feature_channels // 8) else: x = kl.TimeDistributed( kl.Conv2D(out_channels, (3, 3), padding='same', activation=LeakyReLU(0.2)))(x) assert tuple(x.shape) == (batch_size, n_timesteps, image_size, image_size, out_channels) x = kl.BatchNormalization()(x) x = kl.TimeDistributed(kl.Conv2D(out_channels, (3, 3), padding='same', activation='linear'), name='predicted_image')(x) assert tuple(x.shape) == (batch_size, n_timesteps, image_size, image_size, out_channels) return Model(inputs=[input_image, input_noise], outputs=x, name='generator')
def make_discriminator(low_res_size: int, high_res_size: int, low_res_channels: int, high_res_channels: int, n_timesteps: int, batch_size: int = None, feature_channels: int = 16): low_res = kl.Input(shape=(n_timesteps, low_res_size, low_res_size, low_res_channels), batch_size=batch_size, name='low_resolution_image') high_res = kl.Input(shape=(n_timesteps, high_res_size, high_res_size, high_res_channels), batch_size=batch_size, name='high_resolution_image') if tuple(low_res.shape)[:-1] != tuple(high_res.shape)[:-1]: raise NotImplementedError( "The discriminator assumes that the low res and high res images have the same size." "Perhaps you should upsample your low res image first?") # First branch: high res only hr = kl.ConvLSTM2D(high_res_channels, (3, 3), padding='same', return_sequences=True)(high_res) hr = kl.TimeDistributed( SpectralNormalization( kl.Conv2D(feature_channels, (3, 3), padding='same', activation=LeakyReLU(0.2))))(hr) hr = kl.LayerNormalization()(hr) # Second branch: Mix both inputs mix = kl.Concatenate()([low_res, high_res]) mix = kl.ConvLSTM2D(feature_channels, (3, 3), padding='same', return_sequences=True)(mix) mix = kl.TimeDistributed( SpectralNormalization( kl.Conv2D(feature_channels, (3, 3), padding='same', activation=LeakyReLU(0.2))))(mix) mix = kl.LayerNormalization()(mix) # Merge everything together x = kl.Concatenate()([hr, mix]) assert tuple(x.shape) == (batch_size, n_timesteps, low_res_size, low_res_size, 2 * feature_channels) while img_size(x) >= 16: x = kl.TimeDistributed(kl.ZeroPadding2D())(x) x = kl.TimeDistributed(SpectralNormalization( kl.Conv2D(channels(x) * 2, (7, 7), strides=3, activation=LeakyReLU(0.2))), name=f'conv_{img_size(x)}')(x) x = kl.LayerNormalization()(x) shortcut = x while img_size(x) >= 4: x = kl.TimeDistributed(kl.ZeroPadding2D())(x) x = kl.TimeDistributed(SpectralNormalization( kl.Conv2D(channels(x) * 2, (7, 7), strides=3, activation=LeakyReLU(0.2))), name=f'conv_{img_size(x)}')(x) x = kl.LayerNormalization()(x) shortcut = shortcut_convolution(shortcut, x, channels(x)) # Split connection x = kl.add([x, shortcut]) while img_size(x) > 2: x = kl.TimeDistributed(SpectralNormalization( kl.Conv2D(channels(x) * 2, (3, 3), strides=2, activation=LeakyReLU(0.2))), name=f'conv_{img_size(x)}')(x) x = kl.LayerNormalization()(x) x = kl.TimeDistributed(kl.Flatten())(x) assert tuple(x.shape)[:-1] == (batch_size, n_timesteps ) # Unknown number of channels x = kl.TimeDistributed(kl.Dense(1, activation='linear'))(x) x = kl.GlobalAveragePooling1D(name='score')(x) return Model(inputs=[low_res, high_res], outputs=x, name='discriminator')
def __init__(self, opts, fin, fout, use_spade=True, use_spectral_norm=True, norm_layer=BatchNormalization, *args, **kwargs): super().__init__(*args, **kwargs) self.opts = opts self.fin = fin self.fout = fout self.use_spectral_norm = use_spectral_norm self.norm_layer = norm_layer fmiddle = min(fin, fout) norm_s = SPADE(opts, fin, opts.semantic_nc, name='norm_s') if fin != fout else Lambda( lambda x: x[0], trainable=False) conv_s = Conv2D(fout, kernel_size=1, use_bias=False, name='conv_s') if fin != fout else Lambda( lambda x: x, trainable=False) shortcut = [norm_s, conv_s] if use_spade: conv_0 = SPADE(opts, fin, opts.semantic_nc, norm_layer=SyncBatchNormalization, name='norm_0') else: conv_0 = Lambda(lambda x: x[0], trainable=False) conv_0 = [ conv_0, LeakyReLU(alpha=2e-1), Conv2D(fmiddle, kernel_size=3, padding='same', name='conv_0') ] if use_spade: conv_1 = SPADE(opts, fmiddle, opts.semantic_nc, norm_layer=SyncBatchNormalization, name='norm_1') else: conv_1 = Lambda(lambda x: x, trainable=False) conv_1 = [ conv_1, LeakyReLU(alpha=2e-1), Conv2D(fout, kernel_size=3, padding='same', name='conv_1') ] if use_spectral_norm: conv_0[-1] = SpectralNormalization(conv_0[-1], name='conv_0') conv_1[-1] = SpectralNormalization(conv_1[-1], name='conv_1') if fin != fout: shortcut[-1] = SpectralNormalization(conv_s, name='conv_s') self.inner_layers = [shortcut, conv_0, conv_1]
def __init__(self, units, sn=False, *args, **kwargs): super(DenselyConnected, self).__init__() if sn: self.dense = SpectralNormalization(Dense(units, *args, **kwargs)) else: self.dense = Dense(units, *args, **kwargs)
def __init__(self, filters, kernel_size, sn=False, *args, **kwargs): super(TransposedConvolution, self).__init__() if sn: self.conv = SpectralNormalization(Conv2DTranspose(filters, kernel_size, *args, **kwargs)) else: self.conv = Conv2DTranspose(filters, kernel_size, *args, **kwargs)