def shortcut_convolution(high_res_img, low_res_target, nb_channels_out):
    if img_size(low_res_target) == 1:
        kernel_size = img_size(high_res_img)
        downsampled_input = kl.TimeDistributed(
            SpectralNormalization(
                kl.Conv2D(nb_channels_out,
                          kernel_size,
                          activation=LeakyReLU(0.2))),
            name='shortcut_conv_1')(high_res_img)
    else:
        strides = int(
            tf.math.ceil(
                (2 + img_size(high_res_img)) / (img_size(low_res_target) - 1)))
        margin = 2
        padding = int(
            tf.math.ceil((strides * (img_size(low_res_target) - 1) -
                          img_size(high_res_img)) / 2) + 1 + margin)
        kernel_size = int(strides * (1 - img_size(low_res_target)) +
                          img_size(high_res_img) + 2 * padding)
        downsampled_input = kl.TimeDistributed(
            kl.ZeroPadding2D(padding=padding))(high_res_img)
        downsampled_input = kl.TimeDistributed(
            SpectralNormalization(
                kl.Conv2D(nb_channels_out,
                          kernel_size,
                          strides=strides,
                          activation=LeakyReLU(0.2))),
            name='shortcut_conv')(downsampled_input)
    downsampled_input = kl.LayerNormalization()(downsampled_input)
    return downsampled_input
 def __init__(self,
              conv,
              filters,
              kernel_size,
              stride,
              padding,
              activation,
              spectral_norm=False,
              batch_norm=False,
              name=None):
     super(CustomConvBlock, self).__init__(name=name)
     self.sn = spectral_norm
     self.bn = batch_norm
     self.conv_type = conv
     self.activation_type = activation
     if not spectral_norm:
         self.conv = self.conv_type(filters,
                                    kernel_size,
                                    stride,
                                    padding,
                                    name=name + '_conv')
     if spectral_norm:
         self.sn = SpectralNormalization(self.conv_type(filters,
                                                        kernel_size,
                                                        stride,
                                                        padding,
                                                        name=name +
                                                        '_conv'),
                                         name=name + '_conv_sn')
     if batch_norm:
         self.bn = BatchNormalization(name=name + '_bn')
     self.relu = self.activation_type(name=name + '_act')
Beispiel #3
0
def conv2d(layer_input,
           filters,
           kernel_size,
           stride,
           padding='same',
           activation=None,
           bias=True,
           sn=False):
    m = (kernel_size**2) * filters
    weights_ini = RandomNormal(mean=0., stddev=np.sqrt(2 / m))
    bias_ini = tf.constant_initializer(0.0)
    if sn:
        x = SpectralNormalization(
            Conv2D(filters=filters,
                   kernel_size=kernel_size,
                   strides=stride,
                   padding=padding,
                   activation=activation,
                   bias_initializer=bias_ini,
                   kernel_initializer=weights_ini,
                   use_bias=bias))(layer_input)
    else:
        x = Conv2D(filters=filters,
                   kernel_size=kernel_size,
                   strides=stride,
                   padding=padding,
                   activation=activation,
                   bias_initializer=bias_ini,
                   kernel_initializer=weights_ini,
                   use_bias=bias)(layer_input)

    return x
Beispiel #4
0
 def _discriminator_block(self, previous_output, filter_out: int, strides: Tuple[int, int] = (2, 2)):
     layer = Conv2D(filter_out, (4, 4), strides=strides, padding='same', kernel_initializer=self.disc_init_fn)
     if self.spec_normalization:
         layer = SpectralNormalization(layer)
     out = layer(previous_output)
     if self.batch_normalization:
         out = BatchNormalization()(out)
     return LeakyReLU(alpha=0.2)(out)
Beispiel #5
0
    def _patch_discriminator(self, shape: Tuple[int, ...], input_channels: int) -> Model:
        in_image = Input(shape=shape)
        cond_image = Input((225, 225, input_channels))
        conc_img = Concatenate()([in_image, cond_image])

        block1 = self._discriminator_block(conc_img, 64)
        block2 = self._discriminator_block(block1, 128)
        block3 = self._discriminator_block(block2, 256)
        block4 = self._discriminator_block(block3, 512)
        block5 = self._discriminator_block(block4, 512, strides=(1, 1))
        final_layer = Conv2D(1, (4, 4), padding='same', activation='sigmoid', kernel_initializer=self.disc_init_fn)
        if self.spec_normalization:
            final_layer = SpectralNormalization(final_layer)
        output = final_layer(block5)
        return Model([in_image, cond_image], output)
def make_generator(image_size: int,
                   in_channels: int,
                   noise_channels: int,
                   out_channels: int,
                   n_timesteps: int,
                   batch_size: int = None,
                   feature_channels=128):
    # Make sure we have nice multiples everywhere
    assert image_size % 4 == 0
    assert feature_channels % 8 == 0
    total_in_channels = in_channels + noise_channels
    img_shape = (image_size, image_size)
    tshape = (n_timesteps, ) + img_shape
    input_image = kl.Input(shape=tshape + (in_channels, ),
                           batch_size=batch_size,
                           name='input_image')
    input_noise = kl.Input(shape=tshape + (noise_channels, ),
                           batch_size=batch_size,
                           name='input_noise')

    # Concatenate inputs
    x = kl.Concatenate()([input_image, input_noise])

    # Add features and decrease image size - in 2 steps
    intermediate_features = total_in_channels * 8 if total_in_channels * 8 <= feature_channels else feature_channels
    x = kl.TimeDistributed(kl.ZeroPadding2D(padding=3))(x)
    x = kl.TimeDistributed(
        SpectralNormalization(
            kl.Conv2D(intermediate_features, (8, 8),
                      strides=2,
                      activation=LeakyReLU(0.2))))(x)
    x = kl.BatchNormalization()(x)
    assert tuple(x.shape) == (batch_size, n_timesteps, image_size // 2,
                              image_size // 2, intermediate_features)
    res_2 = x  # Keep residuals for later

    x = kl.TimeDistributed(kl.ZeroPadding2D(padding=1))(x)
    x = kl.TimeDistributed(
        SpectralNormalization(
            kl.Conv2D(feature_channels, (4, 4),
                      strides=2,
                      activation=LeakyReLU(0.2))))(x)
    x = kl.BatchNormalization()(x)
    assert tuple(x.shape) == (batch_size, n_timesteps, image_size // 4,
                              image_size // 4, feature_channels)
    res_4 = x  # Keep residuals for later

    # Recurrent unit
    x = kl.ConvLSTM2D(feature_channels, (3, 3),
                      padding='same',
                      return_sequences=True)(x)
    assert tuple(x.shape) == (batch_size, n_timesteps, image_size // 4,
                              image_size // 4, feature_channels)

    # Re-increase image size and decrease features
    x = kl.TimeDistributed(
        SpectralNormalization(
            kl.Conv2D(feature_channels // 2, (3, 3),
                      padding='same',
                      activation=LeakyReLU(0.2))))(x)
    x = kl.BatchNormalization()(x)
    assert tuple(x.shape) == (batch_size, n_timesteps, image_size // 4,
                              image_size // 4, feature_channels // 2)

    # Re-introduce residuals from before (skip connection)
    x = kl.Concatenate()([x, res_4])
    x = kl.TimeDistributed(
        SpectralNormalization(
            kl.Conv2DTranspose(feature_channels / 4, (2, 2),
                               strides=2,
                               activation=LeakyReLU(0.2))))(x)
    x = kl.BatchNormalization()(x)
    assert tuple(x.shape) == (batch_size, n_timesteps, image_size // 2,
                              image_size // 2, feature_channels // 4)

    # Skip connection 2
    x = kl.Concatenate()([x, res_2])
    if feature_channels / 8 >= out_channels:
        x = kl.TimeDistributed(
            kl.UpSampling2D(size=(2, 2), interpolation='bilinear'))(x)
        x = kl.TimeDistributed(
            kl.Conv2DTranspose(feature_channels // 8, (5, 5),
                               padding='same',
                               activation=LeakyReLU(0.2)))(x)
        assert tuple(x.shape) == (batch_size, n_timesteps, image_size,
                                  image_size, feature_channels // 8)
    else:
        x = kl.TimeDistributed(
            kl.Conv2D(out_channels, (3, 3),
                      padding='same',
                      activation=LeakyReLU(0.2)))(x)
        assert tuple(x.shape) == (batch_size, n_timesteps, image_size,
                                  image_size, out_channels)
    x = kl.BatchNormalization()(x)
    x = kl.TimeDistributed(kl.Conv2D(out_channels, (3, 3),
                                     padding='same',
                                     activation='linear'),
                           name='predicted_image')(x)
    assert tuple(x.shape) == (batch_size, n_timesteps, image_size, image_size,
                              out_channels)
    return Model(inputs=[input_image, input_noise],
                 outputs=x,
                 name='generator')
def make_discriminator(low_res_size: int,
                       high_res_size: int,
                       low_res_channels: int,
                       high_res_channels: int,
                       n_timesteps: int,
                       batch_size: int = None,
                       feature_channels: int = 16):
    low_res = kl.Input(shape=(n_timesteps, low_res_size, low_res_size,
                              low_res_channels),
                       batch_size=batch_size,
                       name='low_resolution_image')
    high_res = kl.Input(shape=(n_timesteps, high_res_size, high_res_size,
                               high_res_channels),
                        batch_size=batch_size,
                        name='high_resolution_image')
    if tuple(low_res.shape)[:-1] != tuple(high_res.shape)[:-1]:
        raise NotImplementedError(
            "The discriminator assumes that the low res and high res images have the same size."
            "Perhaps you should upsample your low res image first?")
    # First branch: high res only
    hr = kl.ConvLSTM2D(high_res_channels, (3, 3),
                       padding='same',
                       return_sequences=True)(high_res)
    hr = kl.TimeDistributed(
        SpectralNormalization(
            kl.Conv2D(feature_channels, (3, 3),
                      padding='same',
                      activation=LeakyReLU(0.2))))(hr)
    hr = kl.LayerNormalization()(hr)

    # Second branch: Mix both inputs
    mix = kl.Concatenate()([low_res, high_res])
    mix = kl.ConvLSTM2D(feature_channels, (3, 3),
                        padding='same',
                        return_sequences=True)(mix)
    mix = kl.TimeDistributed(
        SpectralNormalization(
            kl.Conv2D(feature_channels, (3, 3),
                      padding='same',
                      activation=LeakyReLU(0.2))))(mix)
    mix = kl.LayerNormalization()(mix)

    # Merge everything together
    x = kl.Concatenate()([hr, mix])
    assert tuple(x.shape) == (batch_size, n_timesteps, low_res_size,
                              low_res_size, 2 * feature_channels)

    while img_size(x) >= 16:
        x = kl.TimeDistributed(kl.ZeroPadding2D())(x)
        x = kl.TimeDistributed(SpectralNormalization(
            kl.Conv2D(channels(x) * 2, (7, 7),
                      strides=3,
                      activation=LeakyReLU(0.2))),
                               name=f'conv_{img_size(x)}')(x)
        x = kl.LayerNormalization()(x)

    shortcut = x
    while img_size(x) >= 4:
        x = kl.TimeDistributed(kl.ZeroPadding2D())(x)
        x = kl.TimeDistributed(SpectralNormalization(
            kl.Conv2D(channels(x) * 2, (7, 7),
                      strides=3,
                      activation=LeakyReLU(0.2))),
                               name=f'conv_{img_size(x)}')(x)
        x = kl.LayerNormalization()(x)
    shortcut = shortcut_convolution(shortcut, x, channels(x))
    # Split connection
    x = kl.add([x, shortcut])

    while img_size(x) > 2:
        x = kl.TimeDistributed(SpectralNormalization(
            kl.Conv2D(channels(x) * 2, (3, 3),
                      strides=2,
                      activation=LeakyReLU(0.2))),
                               name=f'conv_{img_size(x)}')(x)
        x = kl.LayerNormalization()(x)
    x = kl.TimeDistributed(kl.Flatten())(x)
    assert tuple(x.shape)[:-1] == (batch_size, n_timesteps
                                   )  # Unknown number of channels
    x = kl.TimeDistributed(kl.Dense(1, activation='linear'))(x)
    x = kl.GlobalAveragePooling1D(name='score')(x)

    return Model(inputs=[low_res, high_res], outputs=x, name='discriminator')
Beispiel #8
0
    def __init__(self,
                 opts,
                 fin,
                 fout,
                 use_spade=True,
                 use_spectral_norm=True,
                 norm_layer=BatchNormalization,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)

        self.opts = opts
        self.fin = fin
        self.fout = fout
        self.use_spectral_norm = use_spectral_norm
        self.norm_layer = norm_layer

        fmiddle = min(fin, fout)
        norm_s = SPADE(opts, fin, opts.semantic_nc,
                       name='norm_s') if fin != fout else Lambda(
                           lambda x: x[0], trainable=False)
        conv_s = Conv2D(fout, kernel_size=1, use_bias=False,
                        name='conv_s') if fin != fout else Lambda(
                            lambda x: x, trainable=False)
        shortcut = [norm_s, conv_s]

        if use_spade:
            conv_0 = SPADE(opts,
                           fin,
                           opts.semantic_nc,
                           norm_layer=SyncBatchNormalization,
                           name='norm_0')
        else:
            conv_0 = Lambda(lambda x: x[0], trainable=False)

        conv_0 = [
            conv_0,
            LeakyReLU(alpha=2e-1),
            Conv2D(fmiddle, kernel_size=3, padding='same', name='conv_0')
        ]

        if use_spade:
            conv_1 = SPADE(opts,
                           fmiddle,
                           opts.semantic_nc,
                           norm_layer=SyncBatchNormalization,
                           name='norm_1')
        else:
            conv_1 = Lambda(lambda x: x, trainable=False)

        conv_1 = [
            conv_1,
            LeakyReLU(alpha=2e-1),
            Conv2D(fout, kernel_size=3, padding='same', name='conv_1')
        ]

        if use_spectral_norm:
            conv_0[-1] = SpectralNormalization(conv_0[-1], name='conv_0')
            conv_1[-1] = SpectralNormalization(conv_1[-1], name='conv_1')
            if fin != fout:
                shortcut[-1] = SpectralNormalization(conv_s, name='conv_s')

        self.inner_layers = [shortcut, conv_0, conv_1]
Beispiel #9
0
 def __init__(self, units, sn=False, *args, **kwargs):
     super(DenselyConnected, self).__init__()
     if sn:
         self.dense = SpectralNormalization(Dense(units, *args, **kwargs))
     else: 
         self.dense = Dense(units, *args, **kwargs)
Beispiel #10
0
 def __init__(self, filters, kernel_size, sn=False, *args, **kwargs):
     super(TransposedConvolution, self).__init__()
     if sn:
         self.conv = SpectralNormalization(Conv2DTranspose(filters, kernel_size, *args, **kwargs))
     else: 
         self.conv = Conv2DTranspose(filters, kernel_size, *args, **kwargs)