Example #1
0
 def _get_model(self):
     x = tf.keras.Input(shape=(None, None, 3))
     a = InputNormalize()(x)
     #a = ReflectionPadding2D(padding=(40,40),input_shape=(img_width,img_height,3))(a)
     a = conv_bn_relu(8, 9, 9, stride=(1, 1))(a)
     a = conv_bn_relu(16, 3, 3, stride=(2, 2))(a)
     a = conv_bn_relu(32, 3, 3, stride=(2, 2))(a)
     for i in range(2):
         a = res_conv(32, 3, 3)(a)
     a = dconv_bn_nolinear(16, 3, 3)(a)
     a = dconv_bn_nolinear(8, 3, 3)(a)
     a = dconv_bn_nolinear(3, 9, 9, stride=(1, 1), activation="tanh")(a)
     # Scale output to range [0, 255] via custom Denormalize layer
     y = Denormalize(name='transform_output')(a)
     return tf.keras.Model(x, y, name="transformnet")
Example #2
0
    def create_sr_model(self, ip):

        x = Convolution2D(self.filters,
                          5,
                          5,
                          activation='linear',
                          border_mode='same',
                          name='sr_res_conv1')(ip)
        x = BatchNormalization(axis=channel_axis,
                               mode=self.mode,
                               name='sr_res_bn_1')(x)
        x = LeakyReLU(alpha=0.25, name='sr_res_lr1')(x)

        x = Convolution2D(self.filters,
                          5,
                          5,
                          activation='linear',
                          border_mode='same',
                          name='sr_res_conv2')(x)
        x = BatchNormalization(axis=channel_axis,
                               mode=self.mode,
                               name='sr_res_bn_2')(x)
        x = LeakyReLU(alpha=0.25, name='sr_res_lr2')(x)

        nb_residual = 5 if self.small_model else 15

        for i in range(nb_residual):
            x = self._residual_block(x, i + 1)

        for scale in range(self.nb_scales):
            x = self._upscale_block(x, scale + 1)

        scale = 2**self.nb_scales
        tv_regularizer = TVRegularizer(img_width=self.img_width * scale,
                                       img_height=self.img_height * scale,
                                       weight=self.tv_weight)

        x = Convolution2D(3,
                          5,
                          5,
                          activation='tanh',
                          border_mode='same',
                          activity_regularizer=tv_regularizer,
                          name='sr_res_conv_final')(x)

        x = Denormalize()(x)

        return x
Example #3
0
def image_transform_net(img_width,img_height,tv_weight=1):
    x = Input(shape=(img_width,img_height,3))
    a = InputNormalize()(x)
    a = ReflectionPadding2D(padding=(40,40),input_shape=(img_width,img_height,3))(a)
    a = conv_bn_relu(32, 9, 9, stride=(1,1))(a)
    a = conv_bn_relu(64, 9, 9, stride=(2,2))(a)
    a = conv_bn_relu(128, 3, 3, stride=(2,2))(a)
    for i in range(5):
        a = res_conv(128,3,3)(a)
    a = dconv_bn_nolinear(64,3,3)(a)
    a = dconv_bn_nolinear(32,3,3)(a)
    a = dconv_bn_nolinear(3,9,9,stride=(1,1),activation="tanh")(a)
    # Scale output to range [0, 255] via custom Denormalize layer
    y = Denormalize(name='transform_output')(a)
    
    model = Model(inputs=x, outputs=y)
    
    if tv_weight > 0:
        add_total_variation_loss(model.layers[-1],tv_weight)
        
    return model 
Example #4
0
def image_transform_net(img_width, img_height, tv_weight=1):
    """
        Image tranform
        network model.
    """
    # Input layer as an RGB image
    x = Input(shape=(img_width, img_height, 3))

    # Normalize input image
    a = InputNormalize()(x)

    # Pad image
    a = ReflectionPadding2D(padding=(40, 40),
                            input_shape=(img_width, img_height, 3))(a)

    # Extract feature maps
    a = conv_bn_relu(32, 9, 9, stride=(1, 1))(a)
    a = conv_bn_relu(64, 3, 3,
                     stride=(2, 2))(a)  # The previous kernel size was 9x9
    a = conv_bn_relu(128, 3, 3, stride=(2, 2))(a)
    for _ in range(5):
        a = res_conv(128, 3, 3)(a)
    a = dconv_bn_nolinear(64, 3, 3)(a)
    a = dconv_bn_nolinear(32, 3, 3)(a)
    a = dconv_bn_nolinear(3, 9, 9, stride=(1, 1), activation="tanh")(a)

    # Scale output to range [0, 255] via custom Denormalize layer
    y = Denormalize(name='transform_output')(a)

    # Create model
    model = Model(inputs=x, outputs=y)

    # Total variation regularizer
    if tv_weight > 0:
        add_total_variation_loss(model.layers[-1], tv_weight)

    return model
Example #5
0
    def create_model(self,
                     style_name=None,
                     train_mode=False,
                     style_image_path=None,
                     validation_path=None):
        '''
        Creates the FastNet model which can be used in train mode, predict mode and validation mode.
        If train_mode = True, this model appends the VGG model to the end of the FastNet model.
        In train mode, it requires style image path to be supplied.
        If train_mode = False and validation_path = None, this model is in predict mode.
        In predict mode, it requires a style_name to be supplied, whose weights it will try to load.
        If validation_path is not None, this model is in validation mode.
        In validation mode, it simply loads the weights provided by the validation_path and does not append VGG
        Args:
            style_name: Used in predict mode, used to load correct weights of the style
            train_mode: Used to activate train mode. Default is predict mode.
            style_image_path: Path to the style image. Necessary if in train mode.
            validation_path: Path to the validation weights that need to be loaded.
        Returns: FastNet Model (Prediction mode / Validation mode) or FastNet + VGG Model (Train mode)
        '''

        if train_mode and style_image_path is None:
            raise Exception(
                'Style reference path must be supplied if training mode is enabled'
            )

        self.mode = 2

        if K.image_dim_ordering() == "th":
            ip = Input(shape=(3, self.img_width, self.img_height),
                       name="X_input")
        else:
            ip = Input(shape=(self.img_width, self.img_height, 3),
                       name="X_input")

        c1 = ReflectionPadding2D((4, 4))(ip)

        c1 = Convolution2D(32,
                           9,
                           9,
                           activation='linear',
                           border_mode='valid',
                           name='conv1')(c1)
        c1_b = BatchNormalization(axis=1, mode=self.mode,
                                  name="batchnorm1")(c1)
        c1_b = Activation('relu')(c1_b)

        c2 = Convolution2D(self.features,
                           self.k,
                           self.k,
                           activation='linear',
                           border_mode='same',
                           subsample=(2, 2),
                           name='conv2')(c1_b)
        c2_b = BatchNormalization(axis=1, mode=self.mode,
                                  name="batchnorm2")(c2)
        c2_b = Activation('relu')(c2_b)

        c3 = Convolution2D(self.features,
                           self.k,
                           self.k,
                           activation='linear',
                           border_mode='same',
                           subsample=(2, 2),
                           name='conv3')(c2_b)
        x = BatchNormalization(axis=1, mode=self.mode, name="batchnorm3")(c3)
        x = Activation('relu')(x)

        if self.deep_model:
            c4 = Convolution2D(self.features,
                               self.k,
                               self.k,
                               activation='linear',
                               border_mode='same',
                               subsample=(2, 2),
                               name='conv4')(x)

            x = BatchNormalization(axis=1, mode=self.mode,
                                   name="batchnorm_4")(c4)
            x = Activation('relu')(x)

        r1 = self._residual_block(x, 1)
        r2 = self._residual_block(r1, 2)
        r3 = self._residual_block(r2, 3)
        r4 = self._residual_block(r3, 4)
        x = self._residual_block(r4, 5)

        if self.deep_model:
            d4 = Deconvolution2D(self.features,
                                 self.k,
                                 self.k,
                                 activation="linear",
                                 border_mode="same",
                                 subsample=(2, 2),
                                 output_shape=(1, self.features,
                                               self.img_width // 4,
                                               self.img_height // 4),
                                 name="deconv4")(x)

            x = BatchNormalization(axis=1,
                                   mode=self.mode,
                                   name="batchnorm_extra4")(d4)
            x = Activation('relu')(x)

        d3 = Deconvolution2D(self.features,
                             self.k,
                             self.k,
                             activation="linear",
                             border_mode="same",
                             subsample=(2, 2),
                             output_shape=(1, self.features,
                                           self.img_width // 2,
                                           self.img_height // 2),
                             name="deconv3")(x)

        d3 = BatchNormalization(axis=1, mode=self.mode, name="batchnorm4")(d3)
        d3 = Activation('relu')(d3)

        d2 = Deconvolution2D(self.features,
                             self.k,
                             self.k,
                             activation="linear",
                             border_mode="same",
                             subsample=(2, 2),
                             output_shape=(1, self.features, self.img_width,
                                           self.img_height),
                             name="deconv2")(d3)

        d2 = BatchNormalization(axis=1, mode=self.mode, name="batchnorm5")(d2)
        d2 = Activation('relu')(d2)

        d1 = ReflectionPadding2D((4, 4))(d2)
        d1 = Convolution2D(3,
                           9,
                           9,
                           activation='tanh',
                           border_mode='valid',
                           name='fastnet_conv')(d1)

        # Scale output to range [0, 255] via custom Denormalize layer
        d1 = Denormalize(name='fastnet_output')(d1)

        model = Model(ip, d1)

        if self.model_save_path is not None and self.model is None:
            model.save(self.model_save_path, overwrite=True)

        self.fastnet_outputs_dict = dict([(layer.name, layer.output)
                                          for layer in model.layers])
        fastnet_output_layer = model.layers[-1]

        if style_name is not None or validation_path is not None:
            try:
                if validation_path is not None:
                    path = validation_path
                else:
                    path = "weights/fastnet_%s.h5" % style_name

                model.load_weights(path)
                print('Fast Style Net weights loaded.')
            except:
                print(
                    'Weights for this style do not exist. Model weights not loaded.'
                )

        # Add VGG layers to Fast Style Model
        if train_mode:
            model = VGG(self.img_height, self.img_width).append_vgg_model(
                model.input, x_in=model.output, pool_type=self.pool_type)

            if self.model is None:
                self.model = model

            self.vgg_output_dict = dict([(layer.name, layer.output)
                                         for layer in model.layers[-18:]])

            vgg_layers = dict([(layer.name, layer)
                               for layer in model.layers[-18:]])

            style = img_utils.preprocess_image(style_image_path,
                                               self.img_width, self.img_height)
            print('Getting style features from VGG network.')

            self.style_layers = ['conv1_2', 'conv2_2', 'conv3_3', 'conv4_3']

            self.style_layer_outputs = []
            for layer in self.style_layers:
                self.style_layer_outputs.append(self.vgg_output_dict[layer])

            style_features = self.get_vgg_style_features(style)
            self.style_features = style_features

            # Style Reconstruction Loss
            if self.style_weight != 0.0:
                for i, layer_name in enumerate(self.style_layers):
                    layer = vgg_layers[layer_name]
                    style_loss = StyleReconstructionRegularizer(
                        style_feature_target=style_features[i][0],
                        weight=self.style_weight)(layer)

                    layer.add_loss(style_loss)

            # Feature Reconstruction Loss
            self.content_layer = 'conv4_2'
            self.content_layer_output = self.vgg_output_dict[
                self.content_layer]

            if self.content_weight != 0.0:
                layer = vgg_layers[self.content_layer]
                content_regularizer = FeatureReconstructionRegularizer(
                    weight=self.content_weight)(layer)
                layer.add_loss(content_regularizer)

        # Total Variation Regularization
        if self.tv_weight != 0.0:
            layer = fastnet_output_layer  # Fastnet Output layer
            tv_regularizer = TVRegularizer(img_width=self.img_width,
                                           img_height=self.img_height,
                                           weight=self.tv_weight)(layer)
            layer.add_loss(tv_regularizer)

        if self.model is None:
            self.model = model
        return model