Exemplo n.º 1
0
    def create_model(self,
                     style_name=None,
                     train_mode=False,
                     style_image_path=None,
                     validation_path=None):
        '''
        Creates the FastNet model which can be used in train mode, predict mode and validation mode.
        If train_mode = True, this model appends the VGG model to the end of the FastNet model.
        In train mode, it requires style image path to be supplied.
        If train_mode = False and validation_path = None, this model is in predict mode.
        In predict mode, it requires a style_name to be supplied, whose weights it will try to load.
        If validation_path is not None, this model is in validation mode.
        In validation mode, it simply loads the weights provided by the validation_path and does not append VGG
        Args:
            style_name: Used in predict mode, used to load correct weights of the style
            train_mode: Used to activate train mode. Default is predict mode.
            style_image_path: Path to the style image. Necessary if in train mode.
            validation_path: Path to the validation weights that need to be loaded.
        Returns: FastNet Model (Prediction mode / Validation mode) or FastNet + VGG Model (Train mode)
        '''

        if train_mode and style_image_path is None:
            raise Exception(
                'Style reference path must be supplied if training mode is enabled'
            )

        self.mode = 2

        if K.image_dim_ordering() == "th":
            ip = Input(shape=(3, self.img_width, self.img_height),
                       name="X_input")
        else:
            ip = Input(shape=(self.img_width, self.img_height, 3),
                       name="X_input")

        c1 = ReflectionPadding2D((4, 4))(ip)

        c1 = Convolution2D(32,
                           9,
                           9,
                           activation='linear',
                           border_mode='valid',
                           name='conv1')(c1)
        c1_b = BatchNormalization(axis=1, mode=self.mode,
                                  name="batchnorm1")(c1)
        c1_b = Activation('relu')(c1_b)

        c2 = Convolution2D(self.features,
                           self.k,
                           self.k,
                           activation='linear',
                           border_mode='same',
                           subsample=(2, 2),
                           name='conv2')(c1_b)
        c2_b = BatchNormalization(axis=1, mode=self.mode,
                                  name="batchnorm2")(c2)
        c2_b = Activation('relu')(c2_b)

        c3 = Convolution2D(self.features,
                           self.k,
                           self.k,
                           activation='linear',
                           border_mode='same',
                           subsample=(2, 2),
                           name='conv3')(c2_b)
        x = BatchNormalization(axis=1, mode=self.mode, name="batchnorm3")(c3)
        x = Activation('relu')(x)

        if self.deep_model:
            c4 = Convolution2D(self.features,
                               self.k,
                               self.k,
                               activation='linear',
                               border_mode='same',
                               subsample=(2, 2),
                               name='conv4')(x)

            x = BatchNormalization(axis=1, mode=self.mode,
                                   name="batchnorm_4")(c4)
            x = Activation('relu')(x)

        r1 = self._residual_block(x, 1)
        r2 = self._residual_block(r1, 2)
        r3 = self._residual_block(r2, 3)
        r4 = self._residual_block(r3, 4)
        x = self._residual_block(r4, 5)

        if self.deep_model:
            d4 = Deconvolution2D(self.features,
                                 self.k,
                                 self.k,
                                 activation="linear",
                                 border_mode="same",
                                 subsample=(2, 2),
                                 output_shape=(1, self.features,
                                               self.img_width // 4,
                                               self.img_height // 4),
                                 name="deconv4")(x)

            x = BatchNormalization(axis=1,
                                   mode=self.mode,
                                   name="batchnorm_extra4")(d4)
            x = Activation('relu')(x)

        d3 = Deconvolution2D(self.features,
                             self.k,
                             self.k,
                             activation="linear",
                             border_mode="same",
                             subsample=(2, 2),
                             output_shape=(1, self.features,
                                           self.img_width // 2,
                                           self.img_height // 2),
                             name="deconv3")(x)

        d3 = BatchNormalization(axis=1, mode=self.mode, name="batchnorm4")(d3)
        d3 = Activation('relu')(d3)

        d2 = Deconvolution2D(self.features,
                             self.k,
                             self.k,
                             activation="linear",
                             border_mode="same",
                             subsample=(2, 2),
                             output_shape=(1, self.features, self.img_width,
                                           self.img_height),
                             name="deconv2")(d3)

        d2 = BatchNormalization(axis=1, mode=self.mode, name="batchnorm5")(d2)
        d2 = Activation('relu')(d2)

        d1 = ReflectionPadding2D((4, 4))(d2)
        d1 = Convolution2D(3,
                           9,
                           9,
                           activation='tanh',
                           border_mode='valid',
                           name='fastnet_conv')(d1)

        # Scale output to range [0, 255] via custom Denormalize layer
        d1 = Denormalize(name='fastnet_output')(d1)

        model = Model(ip, d1)

        if self.model_save_path is not None and self.model is None:
            model.save(self.model_save_path, overwrite=True)

        self.fastnet_outputs_dict = dict([(layer.name, layer.output)
                                          for layer in model.layers])
        fastnet_output_layer = model.layers[-1]

        if style_name is not None or validation_path is not None:
            try:
                if validation_path is not None:
                    path = validation_path
                else:
                    path = "weights/fastnet_%s.h5" % style_name

                model.load_weights(path)
                print('Fast Style Net weights loaded.')
            except:
                print(
                    'Weights for this style do not exist. Model weights not loaded.'
                )

        # Add VGG layers to Fast Style Model
        if train_mode:
            model = VGG(self.img_height, self.img_width).append_vgg_model(
                model.input, x_in=model.output, pool_type=self.pool_type)

            if self.model is None:
                self.model = model

            self.vgg_output_dict = dict([(layer.name, layer.output)
                                         for layer in model.layers[-18:]])

            vgg_layers = dict([(layer.name, layer)
                               for layer in model.layers[-18:]])

            style = img_utils.preprocess_image(style_image_path,
                                               self.img_width, self.img_height)
            print('Getting style features from VGG network.')

            self.style_layers = ['conv1_2', 'conv2_2', 'conv3_3', 'conv4_3']

            self.style_layer_outputs = []
            for layer in self.style_layers:
                self.style_layer_outputs.append(self.vgg_output_dict[layer])

            style_features = self.get_vgg_style_features(style)
            self.style_features = style_features

            # Style Reconstruction Loss
            if self.style_weight != 0.0:
                for i, layer_name in enumerate(self.style_layers):
                    layer = vgg_layers[layer_name]
                    style_loss = StyleReconstructionRegularizer(
                        style_feature_target=style_features[i][0],
                        weight=self.style_weight)(layer)

                    layer.add_loss(style_loss)

            # Feature Reconstruction Loss
            self.content_layer = 'conv4_2'
            self.content_layer_output = self.vgg_output_dict[
                self.content_layer]

            if self.content_weight != 0.0:
                layer = vgg_layers[self.content_layer]
                content_regularizer = FeatureReconstructionRegularizer(
                    weight=self.content_weight)(layer)
                layer.add_loss(content_regularizer)

        # Total Variation Regularization
        if self.tv_weight != 0.0:
            layer = fastnet_output_layer  # Fastnet Output layer
            tv_regularizer = TVRegularizer(img_width=self.img_width,
                                           img_height=self.img_height,
                                           weight=self.tv_weight)(layer)
            layer.add_loss(tv_regularizer)

        if self.model is None:
            self.model = model
        return model
Exemplo n.º 2
0
weights_path = "weights/fastnet_%s.h5" % style_name

with open(model_path, "r") as f:
    string = f.read()
    model = load_model(model_path)
    model.compile("adadelta", dummy_loss)

    model.load_weights(weights_path)

size_multiple = 4 if len(
    model.layers
) == 69 else 8  # 69 layers in shallow model, 73 in deeper model

img = img_utils.preprocess_image(content_path,
                                 load_dims=True,
                                 resize=True,
                                 img_width=-1,
                                 img_height=-1,
                                 size_multiple=size_multiple)
img /= 255.
width, height = img.shape[2], img.shape[3]

t1 = time.time()
output = model.predict_on_batch(img)
t2 = time.time()

print("Saved image : %s" % output_image)
print("Prediction time : %0.2f seconds" % (t2 - t1))

img = output[0, :, :, :]
img = img_utils.deprocess_image(img)
Exemplo n.º 3
0
        x = scipy.ndimage.zoom(x, (1, zoom_ratio, zoom_ratio), order=1)
        img_height, img_width = x.shape[-2:]

    if a_scale_mode == 'match':
        a_img_width = img_width
        a_img_height = img_height
    elif a_scale_mode == 'none':
        a_img_width = full_a_image.shape[1] * scale_factor
        a_img_height = full_a_image.shape[0] * scale_factor
    else:  # should just be 'ratio'
        a_img_width = full_a_image.shape[1] * scale_factor * b_scale_ratio_width
        a_img_height = full_a_image.shape[0] * scale_factor * b_scale_ratio_height
    a_img_width = int(round(args.a_scale * a_img_width))
    a_img_height = int(round(args.a_scale * a_img_height))

    a_image = img_utils.preprocess_image(full_a_image, a_img_width, a_img_height)
    ap_image = img_utils.preprocess_image(full_ap_image, a_img_width, a_img_height)
    b_image = img_utils.preprocess_image(full_b_image, img_width, img_height)

    print('Scale factor %s "A" shape %s "B" shape %s' % (scale_factor, a_image.shape, b_image.shape))

    # build the VGG16 network. It seems this needs to be rebuilt at each scale
    # or CPU users get an error from conv2d :(
    model = vgg16.get_model(img_width, img_height, weights_path=weights_path, pool_mode=args.pool_mode)
    first_layer = model.layers[0]
    vgg_input = first_layer.input
    print('Model loaded.')

    # get the symbolic outputs of each "key" layer (we gave them unique names).
    outputs_dict = dict([(layer.name, layer.get_output()) for layer in model.layers])
            if prev_improvement == -1:
                prev_improvement = loss

            improvement = (prev_improvement - loss) / prev_improvement * 100
            prev_improvement = loss

            t2 = time.time()
            print("Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | Loss : %d" %
                 (iteration, num_iter, improvement, t2 - t1, loss))

            if iteration % val_checkpoint == 0:
                print("Producing validation image...")

                # This ensures that image height and width is an even number
                x = img_utils.preprocess_image(validation_img_path, resize=False)
                x /= 255.

                width, height = x.shape[2], x.shape[3]

                iter_path = style_name + "_epoch_%d_at_iteration_%d" % (i + 1, iteration)
                FastNet.save_fastnet_weights(iter_path, directory="val_weights/")

                path = "val_weights/fastnet_" + iter_path + ".h5"

                if validation_fastnet is None:
                    validation_fastnet = models.FastStyleNet(width, height, kernel_size, pool_type,
                                                             model_width=model_width, model_depth=model_depth)
                    validation_fastnet.create_model(validation_path=path)
                    validation_fastnet.model.compile(optimizer, dummy_loss)
                else: