def stylize(contents_path,
            styles_path,
            output_dir,
            encoder_path,
            model_path,
            style_ratio=0.6,
            repeat_pipeline=1,
            autoencoder_levels=None):

    if isinstance(contents_path, str):
        contents_path = [contents_path]
    if isinstance(styles_path, str):
        styles_path = [styles_path]

    style_ratio = np.clip(style_ratio, 0, 1)

    with tf.Graph().as_default(), tf.Session() as sess:
        # build the dataflow graph
        content = tf.placeholder(tf.float32,
                                 shape=(1, None, None, 3),
                                 name='content_input')
        style = tf.placeholder(tf.float32,
                               shape=(1, None, None, 3),
                               name='style_input')

        stn = StyleTransferNet(encoder_path, autoencoder_levels)

        output_image = stn.transform(content, style, style_ratio,
                                     repeat_pipeline)

        sess.run(tf.global_variables_initializer())

        # restore the trained model and run the style transferring
        saver = tf.train.Saver(var_list=tf.trainable_variables())
        saver.restore(sess, model_path)

        outputs = []
        for content_path in contents_path:

            content_img = get_images(content_path)

            for style_path in styles_path:

                style_img = get_images(style_path)

                result = sess.run(output_image,
                                  feed_dict={
                                      content: content_img,
                                      style: style_img
                                  })

                outputs.append(result[0])

    save_images(outputs, contents_path, styles_path, output_dir)

    return outputs
Пример #2
0
def stylize(contents_path,
            styles_path,
            output_dir,
            encoder_path,
            model_path,
            resize_height=None,
            resize_width=None,
            suffix=None):

    if isinstance(contents_path, str):
        contents_path = [contents_path]
    if isinstance(styles_path, str):
        styles_path = [styles_path]

    with tf.Graph().as_default(), tf.Session() as sess:
        # build the dataflow graph
        content = tf.placeholder(tf.float32,
                                 shape=(1, None, None, 3),
                                 name='content')
        style = tf.placeholder(tf.float32,
                               shape=(1, None, None, 3),
                               name='style')

        stn = StyleTransferNet(encoder_path)

        output_image = stn.transform(content, style)

        sess.run(tf.global_variables_initializer())

        # restore the trained model and run the style transferring
        saver = tf.train.Saver()
        saver.restore(sess, model_path)

        outputs = []
        for content_path in contents_path:

            content_img = get_images(content_path,
                                     height=resize_height,
                                     width=resize_width)

            for style_path in styles_path:

                style_img = get_images(style_path)

                result = sess.run(output_image,
                                  feed_dict={
                                      content: content_img,
                                      style: style_img
                                  })

                outputs.append(result[0])

    save_images(outputs, contents_path, styles_path, output_dir, suffix=suffix)

    return outputs
Пример #3
0
def _handler(content_path,
             style_path,
             encoder_path,
             model_path,
             model_pre_path,
             output_path=None):

    with tf.Graph().as_default(), tf.Session() as sess:
        index = 2
        content_path = content_path + str(index) + '.jpg'
        style_path = style_path + str(index) + '.jpg'

        content_img = get_train_images(content_path)
        style_img = get_train_images(style_path)

        # build the dataflow graph
        content = tf.placeholder(tf.float32,
                                 shape=content_img.shape,
                                 name='content')
        style = tf.placeholder(tf.float32, shape=style_img.shape, name='style')

        stn = StyleTransferNet(encoder_path, model_pre_path)

        enc_c, enc_s = stn.encoder_process(content, style)

        target = tf.placeholder(tf.float32, shape=enc_c.shape, name='target')

        # output_image = stn.transform(content, style)
        output_image = stn.decoder_process(target)

        # restore the trained model and run the style transferring
        saver = tf.train.Saver()
        saver.restore(sess, model_path)

        # get the output
        enc_c, enc_s = sess.run([enc_c, enc_s],
                                feed_dict={
                                    content: content_img,
                                    style: style_img
                                })
        feature = L1_Max(enc_c, enc_s)
        # feature = enc_s
        output = sess.run(output_image, feed_dict={target: feature})

    if output_path is not None:
        save_images(content_path,
                    output,
                    output_path,
                    prefix='fused' + str(index) + '_',
                    suffix='deep')

    return output
Пример #4
0
def test_autoencoder(autoencoder_levels, model_save_path):
    
    input_imgs_paths = list_images(TEST_IMG_DIR)

    with tf.Graph().as_default(), tf.Session() as sess:

        input_img = tf.placeholder(
            tf.float32, shape=(1, None, None, 3), name='input_img')

        stn = StyleTransferNet(ENCODER_WEIGHTS_PATH, autoencoder_levels)

        input_encs = [encoder.encode(input_img)[0] for encoder in stn.encoders]

        output_imgs = [decoder.decode(input_enc) for decoder, input_enc in zip(stn.decoders, input_encs)]

        sess.run(tf.global_variables_initializer())

        # restore the trained model and run the reconstruction
        saver = tf.train.Saver(var_list=tf.trainable_variables())
        saver.restore(sess, model_save_path)

        for input_img_path in input_imgs_paths:

            img = get_images(input_img_path)

            for autoencoder_id, output_img in zip(autoencoder_levels, output_imgs):

                out = sess.run(output_img, feed_dict={input_img: img})

                save_single_image(out[0], input_img_path, OUTPUT_DIR, prefix=str(autoencoder_id) + '-')
def _handler1(content_path,
              style_path,
              encoder_path,
              model_path,
              resize_height=None,
              resize_width=None,
              output_path=None,
              prefix=None,
              suffix=None):

    # get the actual image data, output shape:
    # (num_images, height, width, color_channels)
    content_img = get_images(content_path, resize_height, resize_width)
    style_img = get_images(style_path)

    with tf.Graph().as_default(), tf.Session() as sess:
        # build the dataflow graph
        content = tf.placeholder(tf.float32,
                                 shape=content_img.shape,
                                 name='content')
        style = tf.placeholder(tf.float32, shape=style_img.shape, name='style')

        stn = StyleTransferNet(encoder_path)

        output_image = stn.transform(content, style)

        # restore the trained model and run the style transferring
        saver = tf.train.Saver()
        saver.restore(sess, model_path)

        output = sess.run(output_image,
                          feed_dict={
                              content: content_img,
                              style: style_img
                          })

    if output_path is not None:
        save_images(content_path,
                    output,
                    output_path,
                    prefix=prefix,
                    suffix=suffix)

    return output
def _handler2(content_path,
              style_path,
              encoder_path,
              model_path,
              output_path=None,
              prefix=None,
              suffix=None):

    style_img = get_images(style_path)

    with tf.Graph().as_default(), tf.Session() as sess:
        # build the dataflow graph
        content = tf.placeholder(tf.float32,
                                 shape=(1, None, None, 3),
                                 name='content')
        style = tf.placeholder(tf.float32, shape=style_img.shape, name='style')

        stn = StyleTransferNet(encoder_path)

        output_image = stn.transform(content, style)

        # restore the trained model and run the style transferring
        saver = tf.train.Saver()
        saver.restore(sess, model_path)

        output = []
        for path in content_path:
            content_img = get_images(path)
            result = sess.run(output_image,
                              feed_dict={
                                  content: content_img,
                                  style: style_img
                              })
            output.append(result[0])

    if output_path is not None:
        save_images(content_path,
                    output,
                    output_path,
                    prefix=prefix,
                    suffix=suffix)

    return output
def train(training_imgs_paths, encoder_weights_path, model_save_path, 
    autoencoder_levels=None, debug=False, logging_period=100):

    if debug:
        from datetime import datetime
        start_time = datetime.now()

    # guarantee the number of training imgs is a multiple of BATCH_SIZE
    mod = len(training_imgs_paths) % BATCH_SIZE
    if mod > 0:
        print('Train set has been trimmed %d samples...' % mod)
        training_imgs_paths = training_imgs_paths[:-mod]

    # get the traing image shape
    HEIGHT, WIDTH, CHANNELS = TRAINING_IMAGE_SHAPE
    INPUT_SHAPE = (BATCH_SIZE, HEIGHT, WIDTH, CHANNELS)

    # create the graph
    with tf.Graph().as_default(), tf.Session() as sess:

        # create encoders & decoders through StyleTransferNet
        stn = StyleTransferNet(encoder_weights_path, autoencoder_levels)

        # initialize all the variables
        sess.run(tf.global_variables_initializer())

        # saver = tf.train.Saver()
        saver = tf.train.Saver(var_list=tf.trainable_variables())

        for index, (encoder, decoder) in enumerate(zip(stn.encoders, stn.decoders)):

            autoencoder_id = stn.autoencoder_levels[index]

            input_imgs = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='input_imgs_%d' % autoencoder_id)

            # logic: input_img -> encode() -> img_features -> decode() -> output_img
            input_encs, input_features = encoder.encode(input_imgs)
            output_imgs = decoder.decode(input_encs)
            output_encs, output_features = encoder.encode(output_imgs)

            # compute the pixel loss
            pixel_loss = tf.losses.mean_squared_error(input_imgs, output_imgs)

            # compute the feature loss
            feature_loss = tf.reduce_sum([
                tf.losses.mean_squared_error(input_feat, output_feat) for input_feat, output_feat in zip(input_features, output_features)
            ])

            # total loss
            total_loss = PIXEL_LOSS_WEIGHT * pixel_loss + FEATURE_LOSS_WEIGHT * feature_loss

            # Training step
            global_step = tf.Variable(0, trainable=False)
            learning_rate = tf.train.inverse_time_decay(LEARNING_RATE, global_step, DECAY_STEPS, LR_DECAY_RATE)
            trainer = tf.train.AdamOptimizer(learning_rate)
            train_op = trainer.minimize(total_loss, global_step=global_step)
            trainer_initializers = [var.initializer for var in trainer.variables()]
            trainer_initializers.append(global_step.initializer)

            sess.run(trainer_initializers)

            """ Start Training """
            step = 0
            n_batches = int(len(training_imgs_paths) // BATCH_SIZE)

            if debug:
                elapsed_time = datetime.now() - start_time
                print('\nElapsed time for preprocessing before actually train the Decoder_%d: %s' % (autoencoder_id, elapsed_time))
                print('Now begin to train the Decoder_%d...\n' % autoencoder_id)
                start_time = datetime.now()

            try:
                for epoch in range(EPOCHS):

                    np.random.shuffle(training_imgs_paths)

                    for batch in range(n_batches):
                        # retrive a batch of trainging images
                        img_batch_paths = training_imgs_paths[batch*BATCH_SIZE:(batch*BATCH_SIZE + BATCH_SIZE)]

                        # img_batch = get_images(img_batch_paths, height=HEIGHT, width=WIDTH)
                        img_batch = get_training_images(img_batch_paths, crop_height=HEIGHT, crop_width=WIDTH)

                        # run the training step
                        sess.run(train_op, feed_dict={input_imgs: img_batch})

                        step += 1

                        if step % 1000 == 0:
                            saver.save(sess, '%s_%d' % (model_save_path, autoencoder_id), 
                                global_step=step, write_meta_graph=False)

                        if debug:
                            is_last_step = (epoch == EPOCHS - 1) and (batch == n_batches - 1)

                            if is_last_step or step == 1 or step % logging_period == 0:
                                elapsed_time = datetime.now() - start_time
                                _pixel_loss, _feature_loss, _loss = sess.run([pixel_loss, feature_loss, total_loss], 
                                    feed_dict={input_imgs: img_batch})

                                print('step: %d,  total loss: %.3f,  elapsed time: %s' % (step, _loss, elapsed_time))
                                print('pixel   loss: %.3f' % (_pixel_loss))
                                print('feature loss: %.3f' % (_feature_loss))
                                print('total   loss: %.3f' % (_loss))
                                print('\n')

                # finish training current decoder, save the model
                saver.save(sess, '%s_%d' % (model_save_path, autoencoder_id), global_step=step)

                if debug:
                    elapsed_time = datetime.now() - start_time
                    print('>>> Successfully training decoder_%d! Elapsed time: %s\n' % (autoencoder_id, elapsed_time))

            except:
                saver.save(sess, '%s_%d' % (model_save_path, autoencoder_id), global_step=step)
                print('\nSomething wrong happens! Current model is saved with current step: %d\n' % step)

                if debug:
                    elapsed_time = datetime.now() - start_time
                    print('Elapsed time: %s\n' % elapsed_time)

                exit()

        """ Done all trainings & Save the final model """
        saver.save(sess, model_save_path + '-done')

        if debug:
            elapsed_time = datetime.now() - start_time
            print('Done training! Elapsed time: %s' % elapsed_time)
            print('Model is saved to: %s' % model_save_path + '-done')
Пример #8
0
def train(ssim_weight,
          original_imgs_path_name,
          source_a_imgs_path,
          source_b_imgs_path_name,
          encoder_path,
          save_path,
          model_pre_path,
          debug=False,
          logging_period=100):
    if debug:
        from datetime import datetime
        start_time = datetime.now()

    # num_imgs = len(source_a_imgs_path)
    num_imgs = 10000
    source_a_imgs_path = source_a_imgs_path[:num_imgs]
    mod = num_imgs % BATCH_SIZE

    print('Train images number %d.\n' % num_imgs)
    print('Train images samples %s.\n' % str(num_imgs / BATCH_SIZE))

    if mod > 0:
        print('Train set has been trimmed %d samples...\n' % mod)
        source_a_imgs_path = source_a_imgs_path[:-mod]

    # get the traing image shape
    HEIGHT, WIDTH, CHANNELS = TRAINING_IMAGE_SHAPE
    INPUT_SHAPE = (BATCH_SIZE, HEIGHT, WIDTH, CHANNELS)

    HEIGHT_OR, WIDTH_OR, CHANNELS_OR = TRAINING_IMAGE_SHAPE_OR
    INPUT_SHAPE_OR = (BATCH_SIZE, HEIGHT_OR, WIDTH_OR, CHANNELS_OR)

    # create the graph
    with tf.Graph().as_default(), tf.Session() as sess:
        original = tf.placeholder(tf.float32,
                                  shape=INPUT_SHAPE_OR,
                                  name='original')
        source_a = tf.placeholder(tf.float32,
                                  shape=INPUT_SHAPE,
                                  name='source_a')
        source_b = tf.placeholder(tf.float32,
                                  shape=INPUT_SHAPE,
                                  name='source_b')

        print('source:', source_a.shape)

        # create the style transfer net
        stn = StyleTransferNet(encoder_path, model_pre_path)

        # pass content and style to the stn, getting the generated_img, fused image
        generated_img = stn.transform(source_a, source_b)

        # # get the target feature maps which is the output of AdaIN
        # target_features = stn.target_features

        pixel_loss = tf.reduce_sum(
            tf.reduce_mean(tf.square(original - generated_img), axis=[1, 2]))
        pixel_loss = pixel_loss / (HEIGHT * WIDTH)

        # compute the SSIM loss
        ssim_loss = 1 - SSIM.tf_ssim(original, generated_img)

        # compute the total loss
        loss = pixel_loss + ssim_weight * ssim_loss

        # Training step
        train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)

        sess.run(tf.global_variables_initializer())

        # saver = tf.train.Saver()
        saver = tf.train.Saver(keep_checkpoint_every_n_hours=1)

        # ** Start Training **
        step = 0
        count_loss = 0
        n_batches = int(len(source_a_imgs_path) // BATCH_SIZE)

        if debug:
            elapsed_time = datetime.now() - start_time
            print(
                '\nElapsed time for preprocessing before actually train the model: %s'
                % elapsed_time)
            print('Now begin to train the model...\n')
            start_time = datetime.now()

        Loss_all = [i for i in range(EPOCHS * n_batches)]
        for epoch in range(EPOCHS):

            np.random.shuffle(source_a_imgs_path)

            for batch in range(n_batches):
                # retrive a batch of content and style images

                source_a_path = source_a_imgs_path[batch * BATCH_SIZE:(
                    batch * BATCH_SIZE + BATCH_SIZE)]
                source_a_str = source_a_path[0]
                name_f = source_a_str.find('\\')
                source_image_name = source_a_str[name_f + 1:]
                source_image_name_comm = source_image_name[2:]

                source_b_path = [source_b_imgs_path_name + source_image_name]
                original_path = [
                    original_imgs_path_name + source_image_name_comm
                ]

                original_batch = get_train_images(original_path,
                                                  crop_height=HEIGHT,
                                                  crop_width=WIDTH,
                                                  flag=False)
                source_a_batch = get_train_images(source_a_path,
                                                  crop_height=HEIGHT,
                                                  crop_width=WIDTH)
                source_b_batch = get_train_images(source_b_path,
                                                  crop_height=HEIGHT,
                                                  crop_width=WIDTH)

                original_batch = original_batch.reshape([1, 256, 256, 1])

                # run the training step
                sess.run(train_op,
                         feed_dict={
                             original: original_batch,
                             source_a: source_a_batch,
                             source_b: source_b_batch
                         })
                step += 1
                # if step % 1000 == 0:
                #     saver.save(sess, save_path, global_step=step)
                if debug:
                    is_last_step = (epoch == EPOCHS - 1) and (batch
                                                              == n_batches - 1)

                    if is_last_step or step % logging_period == 0:
                        elapsed_time = datetime.now() - start_time
                        _pixel_loss, _ssim_loss, _loss = sess.run(
                            [pixel_loss, ssim_loss, loss],
                            feed_dict={
                                original: original_batch,
                                source_a: source_a_batch,
                                source_b: source_b_batch
                            })
                        Loss_all[count_loss] = _loss
                        count_loss += 1
                        print(
                            'step: %d,  total loss: %.3f,  elapsed time: %s' %
                            (step, _loss, elapsed_time))
                        print('pixel loss: %.3f' % (_pixel_loss))
                        print('ssim loss : %.3f\n' % (_ssim_loss))
                        # print('pca or shape  : ', _pca_or.shape)
                        # print('pca gen shape : ', _pca_gen.shape)

        # ** Done Training & Save the model **
        saver.save(sess, save_path)

        iter_index = [i for i in range(count_loss)]
        plt.plot(iter_index, Loss_all[:count_loss])
        plt.show()

        if debug:
            elapsed_time = datetime.now() - start_time
            print('Done training! Elapsed time: %s' % elapsed_time)
            print('Model is saved to: %s' % save_path)
Пример #9
0
def train(style_weight, content_imgs_path, style_imgs_path, encoder_path, 
          model_save_path, debug=False, logging_period=100):
    if debug:
        from datetime import datetime
        start_time = datetime.now()

    # guarantee the size of content and style images to be a multiple of BATCH_SIZE
    num_imgs = min(len(content_imgs_path), len(style_imgs_path))
    content_imgs_path = content_imgs_path[:num_imgs]
    style_imgs_path   = style_imgs_path[:num_imgs]
    mod = num_imgs % BATCH_SIZE
    if mod > 0:
        print('Train set has been trimmed %d samples...\n' % mod)
        content_imgs_path = content_imgs_path[:-mod]
        style_imgs_path   = style_imgs_path[:-mod]

    # get the traing image shape
    HEIGHT, WIDTH, CHANNELS = TRAINING_IMAGE_SHAPE
    INPUT_SHAPE = (BATCH_SIZE, HEIGHT, WIDTH, CHANNELS)

    # create the graph
    with tf.Graph().as_default(), tf.Session() as sess:
        content = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='content')
        style   = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='style')

        # create the style transfer net
        stn = StyleTransferNet(encoder_path)

        # pass content and style to the stn, getting the generated_img
        generated_img = stn.transform(content, style)

        # get the target feature maps which is the output of AdaIN
        target_features = stn.target_features

        # pass the generated_img to the encoder, and use the output compute loss
        generated_img = tf.reverse(generated_img, axis=[-1])  # switch RGB to BGR
        generated_img = stn.encoder.preprocess(generated_img) # preprocess image
        enc_gen, enc_gen_layers = stn.encoder.encode(generated_img)

        # compute the content loss
        # content_loss = fft_loss(enc_gen, target_features)
        content_loss = tf.reduce_sum(tf.reduce_mean(tf.square(enc_gen - target_features), axis=[1, 2]))

        # compute the style loss
        style_layer_loss = []
        for layer in STYLE_LAYERS:
            
            enc_style_feat = stn.encoded_style_layers[layer]
            enc_gen_feat   = enc_gen_layers[layer]

            meanS, varS = tf.nn.moments(enc_style_feat, [1, 2])
            meanG, varG = tf.nn.moments(enc_gen_feat,   [1, 2])
            # fft_pred = tf.abs(tf.spectral.rfft2d(tf.transpose(enc_gen_feat, (0, 3, 1, 2))))
            # fft_true = tf.abs(tf.spectral.rfft2d(tf.transpose(enc_style_feat, (0, 3, 1, 2))))
            # meanS, varS = tf.nn.moments(fft_pred, [2, 3])
            # meanG, varG = tf.nn.moments(fft_true,   [2, 3])

            sigmaS = tf.sqrt(varS + EPSILON)
            sigmaG = tf.sqrt(varG + EPSILON)

            l2_mean  = tf.reduce_sum(tf.square(meanG - meanS))
            l2_sigma = tf.reduce_sum(tf.square(sigmaG - sigmaS))

            style_layer_loss.append(l2_mean + l2_sigma)

        style_loss = tf.reduce_sum(style_layer_loss)

        # compute the total loss
        loss = content_loss + style_weight * style_loss

        # Training step
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.inverse_time_decay(LEARNING_RATE, global_step, DECAY_STEPS, LR_DECAY_RATE)
        train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step)

        sess.run(tf.global_variables_initializer())

        # saver
        saver = tf.train.Saver(max_to_keep=10)

        ###### Start Training ######
        step = 0
        n_batches = int(len(content_imgs_path) // BATCH_SIZE)

        if debug:
            elapsed_time = datetime.now() - start_time
            start_time = datetime.now()
            print('\nElapsed time for preprocessing before actually train the model: %s' % elapsed_time)
            print('Now begin to train the model...\n')

        try:
            for epoch in range(EPOCHS):

                np.random.shuffle(content_imgs_path)
                np.random.shuffle(style_imgs_path)

                for batch in range(n_batches):
                    # retrive a batch of content and style images
                    content_batch_path = content_imgs_path[batch*BATCH_SIZE:(batch*BATCH_SIZE + BATCH_SIZE)]
                    style_batch_path   = style_imgs_path[batch*BATCH_SIZE:(batch*BATCH_SIZE + BATCH_SIZE)]

                    content_batch = get_train_images(content_batch_path, crop_height=HEIGHT, crop_width=WIDTH)
                    style_batch   = get_train_images(style_batch_path,   crop_height=HEIGHT, crop_width=WIDTH)

                    # run the training step
                    sess.run(train_op, feed_dict={content: content_batch, style: style_batch})

                    step += 1

                    if step % 1000 == 0:
                        saver.save(sess, model_save_path, global_step=step, write_meta_graph=False)

                    if debug:
                        is_last_step = (epoch == EPOCHS - 1) and (batch == n_batches - 1)

                        if is_last_step or step == 1 or step % logging_period == 0:
                            elapsed_time = datetime.now() - start_time
                            _content_loss, _style_loss, _loss = sess.run([content_loss, style_loss, loss], 
                                feed_dict={content: content_batch, style: style_batch})

                            print('step: %d,  total loss: %.3f,  elapsed time: %s' % (step, _loss, elapsed_time))
                            print('content loss: %.3f' % (_content_loss))
                            print('style loss  : %.3f,  weighted style loss: %.3f\n' % (_style_loss, style_weight * _style_loss))
        except Exception as ex:
            saver.save(sess, model_save_path, global_step=step)
            print('\nSomething wrong happens! Current model is saved to <%s>' % tmp_save_path)
            print('Error message: %s' % str(ex))

        ###### Done Training & Save the model ######
        saver.save(sess, model_save_path)

        if debug:
            elapsed_time = datetime.now() - start_time
            print('Done training! Elapsed time: %s' % elapsed_time)
            print('Model is saved to: %s' % model_save_path)
Пример #10
0
def train(style_weight,
          content_imgs_path,
          style_imgs_path,
          encoder_path,
          save_path,
          debug=False,
          logging_period=100):
    if debug:
        from datetime import datetime
        start_time = datetime.now()

    # guarantee the size of content and style images to be a multiple of BATCH_SIZE
    num_imgs = min(len(content_imgs_path), len(style_imgs_path))
    content_imgs_path = content_imgs_path[:num_imgs]
    style_imgs_path = style_imgs_path[:num_imgs]
    mod = num_imgs % BATCH_SIZE
    if mod > 0:
        print('Train set has been trimmed %d samples...\n' % mod)
        content_imgs_path = content_imgs_path[:-mod]
        style_imgs_path = style_imgs_path[:-mod]

    # get the traing image shape
    HEIGHT, WIDTH, CHANNELS = TRAINING_IMAGE_SHAPE
    INPUT_SHAPE = (BATCH_SIZE, HEIGHT, WIDTH, CHANNELS)

    # create the graph
    with tf.Graph().as_default(), tf.Session() as sess:
        content = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='content')
        style = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='style')

        # create the style transfer net
        stn = StyleTransferNet(encoder_path)

        # pass content and style to the stn, getting the generated_img
        generated_img = stn.transform(content, style)

        # get the target feature maps which is the output of AdaIN
        target_features = stn.target_features

        # pass the generated_img to the encoder, and use the output compute loss
        generated_img = tf.reverse(generated_img,
                                   axis=[-1])  # switch RGB to BGR
        generated_img = stn.encoder.preprocess(
            generated_img)  # preprocess image
        enc_gen, enc_gen_layers = stn.encoder.encode(generated_img)

        # compute the content loss
        content_loss = tf.reduce_sum(
            tf.reduce_mean(tf.square(enc_gen - target_features), axis=[1, 2]))
        # compute the style loss
        style_layer_loss = []
        for layer in STYLE_LAYERS:
            enc_style_feat = stn.encoded_style_layers[layer]
            enc_gen_feat = enc_gen_layers[layer]

            meanS, varS = tf.nn.moments(enc_style_feat, [1, 2])
            meanG, varG = tf.nn.moments(enc_gen_feat, [1, 2])

            sigmaS = tf.sqrt(varS + EPSILON)
            sigmaG = tf.sqrt(varG + EPSILON)

            l2_mean = tf.reduce_sum(tf.square(meanG - meanS))
            l2_sigma = tf.reduce_sum(tf.square(sigmaG - sigmaS))

            style_layer_loss.append(l2_mean + l2_sigma)

        style_loss = tf.reduce_sum(style_layer_loss)

        # compute the total loss
        loss = content_loss + style_weight * style_loss

        # save loss to tensorboard
        tf.summary.scalar('content_loss', content_loss)
        tf.summary.scalar('style_loss', style_loss)
        tf.summary.scalar('total_loss', loss)
        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter('runs', sess.graph)
        # Training step
        train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)

        sess.run(tf.global_variables_initializer())

        # saver = tf.train.Saver()
        saver = tf.train.Saver(keep_checkpoint_every_n_hours=1)
        """Start Training"""
        step = 0
        n_batches = int(len(content_imgs_path) // BATCH_SIZE)

        if debug:
            elapsed_time = datetime.now() - start_time
            print(
                '\nElapsed time for preprocessing before actually train the model: %s'
                % elapsed_time)
            print('Now begin to train the model...\n')
            start_time = datetime.now()

        for epoch in range(EPOCHS):

            np.random.shuffle(content_imgs_path)
            np.random.shuffle(style_imgs_path)

            for batch in range(n_batches):
                print("current step: {}/{}".format(batch, n_batches))
                print(
                    "select content images: {}~{}/{}, style images: {}~{}/{}".
                    format(batch * BATCH_SIZE, batch * BATCH_SIZE + BATCH_SIZE,
                           len(content_imgs_path),
                           batch * BATCH_SIZE, batch * BATCH_SIZE + BATCH_SIZE,
                           len(style_imgs_path)))
                # retrive a batch of content and style images
                content_batch_path = content_imgs_path[batch * BATCH_SIZE:(
                    batch * BATCH_SIZE + BATCH_SIZE)]
                style_batch_path = style_imgs_path[batch * BATCH_SIZE:(
                    batch * BATCH_SIZE + BATCH_SIZE)]

                content_batch = get_train_images(content_batch_path,
                                                 crop_height=HEIGHT,
                                                 crop_width=WIDTH)
                style_batch = get_train_images(style_batch_path,
                                               crop_height=HEIGHT,
                                               crop_width=WIDTH)

                # run the training step
                print("start training step")
                c_loss, s_loss, t_loss, summary, _ = sess.run(
                    [content_loss, style_loss, loss, merged, train_op],
                    feed_dict={
                        content: content_batch,
                        style: style_batch
                    })
                train_writer.add_summary(summary, batch + epoch * n_batches)
                print(f'content: {c_loss}, style: {s_loss}, total: {t_loss}')
                print("stop trainnig step")
                step += 1

                if step % 1000 == 0:
                    saver.save(sess, save_path, global_step=step)

                if debug:
                    is_last_step = (epoch == EPOCHS - 1) and (batch
                                                              == n_batches - 1)

                    if is_last_step or step % logging_period == 0:
                        elapsed_time = datetime.now() - start_time
                        _content_loss, _style_loss, _loss = sess.run(
                            [content_loss, style_loss, loss],
                            feed_dict={
                                content: content_batch,
                                style: style_batch
                            })

                        print(
                            'step: %d,  total loss: %.3f,  elapsed time: %s' %
                            (step, _loss, elapsed_time))
                        print('content loss: %.3f' % (_content_loss))
                        print(
                            'style loss  : %.3f,  weighted style loss: %.3f\n'
                            % (_style_loss, style_weight * _style_loss))
        """ Done training. Save the model."""
        saver.save(sess, save_path)

        if debug:
            elapsed_time = datetime.now() - start_time
            print('Done training! Elapsed time: %s' % elapsed_time)
            print('Model is saved to: %s' % save_path)
Пример #11
0
# get the traing image shape
HEIGHT, WIDTH, CHANNELS = TRAINING_IMAGE_SHAPE
INPUT_SHAPE = [None, HEIGHT, WIDTH, CHANNELS]
# create the graph
tf_config = tf.ConfigProto()
#tf_config.gpu_options.per_process_gpu_memory_fraction=0.5
tf_config.gpu_options.allow_growth = True
with tf.Graph().as_default(), tf.Session(config=tf_config) as sess:

    content = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='content')
    style = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='style')
    label = tf.placeholder(tf.int64, shape=None, name="label")
    #style   = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='style')

    # create the style transfer net
    stn = StyleTransferNet(encoder_path)

    # pass content and style to the stn, getting the gen_img
    # decoded image from normal one, adversarial image, and input
    dec_img, adv_img = stn.transform(content, style)
    img = content

    print(adv_img.shape.as_list())
    stn_vars = []

    # get the target feature maps which is the output of AdaIN
    target_features = stn.target_features

    # pass the gen_img to the encoder, and use the output compute loss
    enc_gen_adv, enc_gen_layers_adv = stn.encode(adv_img)
    enc_gen, enc_gen_layers = stn.encode(dec_img)
# get the traing image shape
HEIGHT, WIDTH, CHANNELS = TRAINING_IMAGE_SHAPE
INPUT_SHAPE = [None, HEIGHT, WIDTH, CHANNELS]
# create the graph
tf_config = tf.ConfigProto()
#tf_config.gpu_options.per_process_gpu_memory_fraction=0.5
tf_config.gpu_options.allow_growth = True
with tf.Graph().as_default(), tf.Session(config=tf_config) as sess:

    content = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='content')
    style = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='style')
    label = tf.placeholder(tf.int64, shape=None, name="label")
    #style   = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='style')

    # create the style transfer net
    stn = StyleTransferNet(encoder_path)

    # pass content and style to the stn, getting the generated_img
    generated_img, generated_img_adv = stn.transform(content, style)
    adv_img = generated_img_adv
    img = generated_img

    print(adv_img.shape.as_list())
    stn_vars = []  #get_scope_var("transform")
    # get the target feature maps which is the output of AdaIN
    target_features = stn.target_features

    # pass the generated_img to the encoder, and use the output compute loss
    generated_img_adv = tf.reverse(generated_img_adv,
                                   axis=[-1])  # switch RGB to BGR
    adv_img_bgr = generated_img_adv
Пример #13
0
def stylize(contents_path,
            styles_path,
            output_dir,
            encoder_path,
            model_path,
            resize_height=None,
            resize_width=None,
            suffix=None):

    if isinstance(contents_path, str):
        contents_path = [contents_path]
    if isinstance(styles_path, str):
        styles_path = [styles_path]

    with tf.Graph().as_default(), tf.Session(config=tf.ConfigProto(
            log_device_placement=True)) as sess:
        # 这段代码只是用来查看 tf 的运行设备信息,没啥其他用途
        #a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
        #b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2], name='b')
        #c = tf.matmul(a, b)
        #print(sess.run(c))

        # build the dataflow graph
        content = tf.placeholder(tf.float32,
                                 shape=(1, None, None, 3),
                                 name='content')
        style = tf.placeholder(tf.float32,
                               shape=(1, None, None, 3),
                               name='style')

        stn = StyleTransferNet(encoder_path)

        output_image = stn.transform(content, style)

        print(sess.run(tf.global_variables_initializer()))

        # restore the trained model and run the style transferring
        saver = tf.train.Saver()
        saver.restore(sess, model_path)

        outputs = []
        for content_path in contents_path:

            content_img = get_images(content_path,
                                     height=resize_height,
                                     width=resize_width)

            for style_path in styles_path:

                style_img = get_images(style_path)

                print('--> processing %s with style %s' %
                      (content_path, style_path))
                result = sess.run(output_image,
                                  feed_dict={
                                      content: content_img,
                                      style: style_img
                                  })

                outputs.append(result[0])

                save_image(result[0],
                           content_path,
                           style_path,
                           output_dir,
                           suffix=suffix)

    #save_images(outputs, contents_path, styles_path, output_dir, suffix=suffix)

    return outputs
Пример #14
0
OUTPUTS_DIR = 'outputs'
ENCODER_WEIGHTS_PATH = 'vgg19_normalised.npz'
MODEL_SAVE_PATH = 'models/style_weight_2e0.ckpt'

content_img = imageio.imread('./images/content/karya.jpg')
content_img = np.expand_dims(content_img, axis=0)
style_img = imageio.imread('./images/style/mosaic.jpg')
style_img = np.expand_dims(style_img, axis=0)

sess = tf.InteractiveSession()

content = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='content')
style = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='style')

model = StyleTransferNet(ENCODER_WEIGHTS_PATH)
output = model.transform(content, style)

sess.run(tf.global_variables_initializer())

saver = tf.train.Saver()
saver.restore(sess, MODEL_SAVE_PATH)

output_img = sess.run(output,
                      feed_dict={
                          content: content_img,
                          style: style_img
                      })

output_img = output_img[0]
print('output_img shape', output_img.shape)