Пример #1
0
def _handler2(content_path,
              model_path,
              save_path=None,
              prefix=None,
              suffix=None):
    with tf.Graph().as_default(), tf.Session() as sess:
        # build the dataflow graph
        content_image = tf.placeholder(tf.float32,
                                       shape=(1, None, None, 3),
                                       name='content_image')

        output_image = itn.transform(content_image)

        # restore the trained model and run the style transferring
        saver = tf.train.Saver()
        saver.restore(sess, model_path)

        output = []
        for content in content_path:
            content_target = get_images(content)
            result = sess.run(output_image,
                              feed_dict={content_image: content_target})
            output.append(result[0])

    if save_path is not None:
        save_images(content_path,
                    output,
                    save_path,
                    prefix=prefix,
                    suffix=suffix)

    return output
Пример #2
0
def _handler1(content_path,
              model_path,
              resize_height=None,
              resize_width=None,
              save_path=None,
              prefix=None,
              suffix=None):
    # get the actual image data, output shape: (num_images, height, width, color_channels)
    content_target = get_images(content_path, resize_height, resize_width)

    with tf.Graph().as_default(), tf.Session() as sess:
        # build the dataflow graph
        content_image = tf.placeholder(tf.float32,
                                       shape=content_target.shape,
                                       name='content_image')

        output_image = itn.transform(content_image)

        # restore the trained model and run the style transferring
        saver = tf.train.Saver()
        saver.restore(sess, model_path)

        output = sess.run(output_image,
                          feed_dict={content_image: content_target})

    if save_path is not None:
        save_images(content_path,
                    output,
                    save_path,
                    prefix=prefix,
                    suffix=suffix)

    return output
Пример #3
0
def train(content_targets_path,
          style_target_path,
          content_weight,
          style_weight,
          tv_weight,
          vgg_path,
          save_path,
          debug=False,
          logging_period=100):
    if debug:
        from datetime import datetime
        start_time = datetime.now()

    # guarantee the size of content_targets is a multiple of BATCH_SIZE
    mod = len(content_targets_path) % BATCH_SIZE
    if mod > 0:
        print('Train set has been trimmed %d samples...' % mod)
        content_targets_path = content_targets_path[:-mod]

    height, width, channels = TRAINING_IMAGE_SHAPE
    input_shape = (BATCH_SIZE, height, width, channels)

    # create a pre-trained VGG network
    vgg = VGG(vgg_path)

    # retrive the style_target image
    style_target = get_images(
        style_target_path)  # shape: (1, height, width, channels)
    style_shape = style_target.shape

    # compute the style features
    style_features = {}
    with tf.Graph().as_default(), tf.Session() as sess:
        style_image = tf.placeholder(tf.float32,
                                     shape=style_shape,
                                     name='style_image')

        # pass style_image through 'pretrained VGG-19 network'
        style_img_preprocess = preprocess(style_image)
        style_net = vgg.forward(style_img_preprocess)

        for style_layer in STYLE_LAYERS:
            features = style_net[style_layer].eval(
                feed_dict={style_image: style_target})
            features = np.reshape(features, [-1, features.shape[3]])

            gram = np.matmul(features.T, features) / features.size
            style_features[style_layer] = gram

    # compute the perceptual losses
    with tf.Graph().as_default(), tf.Session() as sess:
        content_images = tf.placeholder(tf.float32,
                                        shape=input_shape,
                                        name='content_images')

        # pass content_images through 'pretrained VGG-19 network'
        content_imgs_preprocess = preprocess(content_images)
        content_net = vgg.forward(content_imgs_preprocess)

        # compute the content features
        content_features = {}
        content_features[CONTENT_LAYER] = content_net[CONTENT_LAYER]

        # pass content_images through 'Image Transform Net'
        output_images = itn.transform(content_images)

        # pass output_images through 'pretrained VGG-19 network'
        output_imgs_preprocess = preprocess(output_images)
        output_net = vgg.forward(output_imgs_preprocess)

        # ** compute the feature reconstruction loss **
        content_size = tf.size(content_features[CONTENT_LAYER])

        content_loss = 2 * tf.nn.l2_loss(
            output_net[CONTENT_LAYER] -
            content_features[CONTENT_LAYER]) / tf.to_float(content_size)

        # ** compute the style reconstruction loss **
        style_losses = []
        for style_layer in STYLE_LAYERS:
            features = output_net[style_layer]
            shape = tf.shape(features)
            num_images, height, width, num_filters = shape[0], shape[1], shape[
                2], shape[3]
            features = tf.reshape(features,
                                  [num_images, height * width, num_filters])
            grams = tf.matmul(features, features,
                              transpose_a=True) / tf.to_float(
                                  height * width * num_filters)
            style_gram = style_features[style_layer]
            layer_style_loss = 2 * tf.nn.l2_loss(grams -
                                                 style_gram) / tf.to_float(
                                                     tf.size(grams))
            style_losses.append(layer_style_loss)

        style_loss = tf.reduce_sum(tf.stack(style_losses))

        # ** compute the total variation loss **
        shape = tf.shape(output_images)
        height, width = shape[1], shape[2]
        y = tf.slice(output_images, [0, 0, 0, 0],
                     [-1, height - 1, -1, -1]) - tf.slice(
                         output_images, [0, 1, 0, 0], [-1, -1, -1, -1])
        x = tf.slice(output_images, [0, 0, 0, 0],
                     [-1, -1, width - 1, -1]) - tf.slice(
                         output_images, [0, 0, 1, 0], [-1, -1, -1, -1])

        tv_loss = tf.nn.l2_loss(x) / tf.to_float(
            tf.size(x)) + tf.nn.l2_loss(y) / tf.to_float(tf.size(y))

        # overall perceptual losses
        loss = content_weight * content_loss + style_weight * style_loss + tv_weight * tv_loss

        # Training step
        train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)

        sess.run(tf.global_variables_initializer())

        # saver = tf.train.Saver()
        saver = tf.train.Saver(keep_checkpoint_every_n_hours=1)

        # ** Start Training **
        step = 0
        n_batches = len(content_targets_path) // BATCH_SIZE

        if debug:
            elapsed_time = datetime.now() - start_time
            tf.logging.set_verbosity(tf.logging.INFO)
            tf.logging.info(
                'Elapsed time for preprocessing before actually train the model: %s'
                % elapsed_time)
            tf.logging.info('Now begin to train the model...')
            start_time = datetime.now()

        for epoch in range(EPOCHS):

            np.random.shuffle(content_targets_path)

            for batch in range(n_batches):
                # retrive a batch of content_targets images
                content_batch_path = content_targets_path[batch * BATCH_SIZE:(
                    batch * BATCH_SIZE + BATCH_SIZE)]
                content_batch = get_images(content_batch_path, input_shape[1],
                                           input_shape[2])

                # run the training step
                sess.run(train_op, feed_dict={content_images: content_batch})

                step += 1

                if step % 1000 == 0:
                    saver.save(sess, save_path, global_step=step)

                if debug:
                    is_last_step = (epoch == EPOCHS - 1) and (batch
                                                              == n_batches - 1)

                    if is_last_step or step % logging_period == 0:
                        elapsed_time = datetime.now() - start_time
                        _content_loss, _style_loss, _tv_loss, _loss = sess.run(
                            [content_loss, style_loss, tv_loss, loss],
                            feed_dict={content_images: content_batch})

                        tf.logging.info(
                            'step: %d,  total loss: %f,  elapsed time: %s' %
                            (step, _loss, elapsed_time))
                        tf.logging.info(
                            'content loss: %f,  weighted content loss: %f' %
                            (_content_loss, content_weight * _content_loss))
                        tf.logging.info(
                            'style loss  : %f,  weighted style loss  : %f' %
                            (_style_loss, style_weight * _style_loss))
                        tf.logging.info(
                            'tv loss     : %f,  weighted tv loss     : %f' %
                            (_tv_loss, tv_weight * _tv_loss))
                        tf.logging.info('\n')

        # ** Done Training & Save the model **
        saver.save(sess, save_path)

        if debug:
            elapsed_time = datetime.now() - start_time
            tf.logging.info('Done training! Elapsed time: %s' % elapsed_time)
            tf.logging.info('Model is saved to: %s' % save_path)
def train(content_path, style_path, content_weight, style_weight, tv_weight,
          vgg_path, save_path):
    height, width, channels = IMG_SHAPE
    input_shape = (batch_size, height, width, channels)

    start_time = datetime.now()

    vgg = VGG(vgg_path)

    style_target = get_images(style_path, height, width)
    style_shape = style_target.shape

    style_features = {}
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    # style net
    with tf.Session(config=config) as sess:
        style_img = tf.placeholder(tf.float32,
                                   shape=style_shape,
                                   name='style_image')
        style_net = vgg.forward(preprocess(style_img))

        for layer in STYLE:
            features = style_net[layer].eval(
                feed_dict={style_img: style_target})
            features = np.reshape(features, [-1, features.shape[3]])

            gram = np.matmul(features.T, features) / features.size
            style_features[layer] = gram

    # content net
    with tf.Session(config=config) as sess:
        content_img = tf.placeholder(tf.float32,
                                     shape=input_shape,
                                     name='content_img')

        content_net = vgg.forward(preprocess(content_img))
        content_features = content_net[CONTENT]

        trans_images = itn.transform(content_img)
        output_net = vgg.forward(preprocess(trans_images))

        # reconstruction loss
        content_size = tf.size(content_features)
        content_loss = tf.nn.l2_loss(output_net[CONTENT] - content_features
                                     ) * 2 / tf.to_float(content_size)

        # style reconstruction loss
        style_losses = []
        for layer in STYLE:
            features = output_net[layer]
            shape = tf.shape(features)
            num_images, height, width, num_filters = shape[0], shape[1], shape[
                2], shape[3]

            features = tf.reshape(features,
                                  [num_images, height * width, num_filters])

            grams = tf.matmul(features, features,
                              transpose_a=True) / tf.to_float(
                                  height * width * num_filters)
            style_gram = style_features[layer]

            layer_style_loss = tf.nn.l2_loss(grams -
                                             style_gram) * 2 / tf.to_float(
                                                 tf.size(grams))
            style_losses.append(layer_style_loss)

        style_loss = tf.reduce_sum(tf.stack(style_losses))

        # total variation loss
        shape = tf.shape(trans_images)
        height, width = shape[1], shape[2]
        y = tf.slice(trans_images, [0, 0, 0, 0],
                     [-1, height - 1, -1, -1]) - tf.slice(
                         trans_images, [0, 1, 0, 0], [-1, -1, -1, -1])
        x = tf.slice(trans_images, [0, 0, 0, 0],
                     [-1, -1, width - 1, -1]) - tf.slice(
                         trans_images, [0, 0, 1, 0], [-1, -1, -1, -1])

        tv_loss = tf.nn.l2_loss(x) / tf.to_float(
            tf.size(x)) + tf.nn.l2_loss(y) / tf.to_float(tf.size(y))

        # overall perceptual losses
        loss = content_weight * content_loss + style_weight * style_loss + tv_weight * tv_loss

        # Training step
        train_op = tf.train.AdamOptimizer(lr).minimize(loss)

        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(keep_checkpoint_every_n_hours=1)

        step = 0
        n_batches = len(content_path) // batch_size

        elapsed_time = datetime.now() - start_time
        tf.logging.set_verbosity(tf.logging.INFO)
        tf.logging.info(
            'Elapsed time for preprocessing before actually train the model: %s'
            % elapsed_time)
        tf.logging.info('Now begin to train the model...')
        start_time = datetime.now()

        c_loss = []
        s_loss = []
        tv = []
        total_loss = []
        for epoch in range(EPOCHS):

            np.random.shuffle(content_path)

            for batch in range(n_batches):
                # retrive a batch of content_targets images
                content_batch_path = content_path[batch * batch_size:(
                    batch * batch_size + batch_size)]
                content_batch = get_images(content_batch_path, input_shape[1],
                                           input_shape[2])

                # run the training step
                sess.run(train_op, feed_dict={content_img: content_batch})

                step += 1

                if step % 1000 == 0:
                    saver.save(sess, save_path, global_step=step)

                is_last_step = (epoch == EPOCHS - 1) and (batch
                                                          == n_batches - 1)

                if is_last_step or step % 100 == 0:
                    elapsed_time = datetime.now() - start_time
                    _content_loss, _style_loss, _tv_loss, _loss = sess.run(
                        [content_loss, style_loss, tv_loss, loss],
                        feed_dict={content_img: content_batch})

                    tf.logging.info(
                        'step: %d,  total loss: %f,  elapsed time: %s' %
                        (step, _loss, elapsed_time))
                    tf.logging.info(
                        'content loss: %f,  weighted content loss: %f' %
                        (_content_loss, content_weight * _content_loss))
                    tf.logging.info(
                        'style loss  : %f,  weighted style loss  : %f' %
                        (_style_loss, style_weight * _style_loss))
                    tf.logging.info(
                        'tv loss     : %f,  weighted tv loss     : %f' %
                        (_tv_loss, tv_weight * _tv_loss))
                    tf.logging.info('\n')
                    c_loss.append(_content_loss)
                    s_loss.append(_style_loss)
                    tv.append(_tv_loss)
                    total_loss.append(_loss)

        saver.save(sess, save_path)
        return c_loss, s_loss, tv, total_loss