def test_vgg19_rev_save_weights():
    from vgg import vgg19_rev
    MODEL_SAVE_PATH = './trained_models/'
    dec_c_net = vgg19_rev(pretrained=False,
                          end_with='conv1_1',
                          input_depth=512,
                          name='stylized_dec')
    dec_c_net.save_weights(
        osp.join(MODEL_SAVE_PATH, 'temp_vgg19_rev_weights.h5'))
def test_vgg19_rev_load_weights():
    from vgg import vgg19_rev
    DEC_LATEST_WEIGHTS_PATH = 'pretrained_models/dec_latest_weights.h5'
    tl.logging.set_verbosity(tl.logging.DEBUG)
    dec_c_net = vgg19_rev(pretrained=False,
                          batch_norm=True,
                          end_with='conv1_1',
                          input_depth=512,
                          name='stylized_dec')
    dec_c_net.load_weights(DEC_LATEST_WEIGHTS_PATH, skip=True)
def test_vgg_rev_load_vgg_weights():
    from vgg import vgg19_rev
    VGG19_WEIGHTS_PATH = 'pretrained_models/predefined_vgg19_endwith(conv4_1)_weights.h5'
    TEMP_IMAGE_PATH = './temp_images/'
    # enc_c_net = vgg19(pretrained=False, end_with='conv4_1', name='content_enc')
    # enc_c_net.load_weights(VGG19_WEIGHTS_PATH)
    dec_c_net = vgg19_rev(pretrained=False,
                          end_with='conv1_1',
                          input_depth=512,
                          name='content_dec')
    dec_c_net.load_weights(VGG19_WEIGHTS_PATH, skip=True)
 def __init__(self, style_weight = STYLE_WEIGHT):
     super(StyleTransferModel, self).__init__()
     # NOTE: you may check on `pretrained` if you want to download complete version of vgg19 weights
     want_to_download_vgg19 = False
     self.enc_net = vgg19(pretrained=want_to_download_vgg19, end_with='conv4_1', name='content_and_style_enc')
     if not want_to_download_vgg19 and osp.exists(VGG19_PARTIAL_WEIGHTS_PATH):
         self.enc_net.load_weights(VGG19_PARTIAL_WEIGHTS_PATH, in_order=False)
         tl.logging.info(f"Encoder weights loaded from: {VGG19_PARTIAL_WEIGHTS_PATH}")
     # NOTE: batch_norm=False->True will lower quality of the generated image = may need retrain
     self.dec_net = vgg19_rev(pretrained=False, batch_norm=USE_BATCH_NORM, input_depth=512, name='stylized_dec')
     if osp.exists(DEC_LATEST_WEIGHTS_PATH):
         self.dec_net.load_weights(DEC_LATEST_WEIGHTS_PATH, skip=True)
         tl.logging.info(f"Decoder weights loaded from: {DEC_LATEST_WEIGHTS_PATH}")
     self.style_weight = style_weight
     self.content_loss, self.style_loss, self.loss = None, None, None
def test_conv_and_deconv():
    VGG19_WEIGHTS_PATH = 'pretrained_models/predefined_vgg19_endwith(conv4_1)_weights.h5'
    VGG19_REV_WEIGHTS_PATH = 'pretrained_models/dec_best_weights (before use DeConv2d).h5'
    TEMP_IMAGE_PATH = './temp_images/53154.jpg'
    # try directly decoding content features
    enc_net = vgg19(pretrained=False, end_with='conv4_1')
    dec_net = vgg19_rev(pretrained=False, end_with='conv1_1', input_depth=512)
    enc_net.load_weights(VGG19_WEIGHTS_PATH)
    dec_net.load_weights(VGG19_REV_WEIGHTS_PATH, skip=True)
    enc_net.eval()
    dec_net.eval()
    image = imread(TEMP_IMAGE_PATH, mode='RGB')
    image = imresize_square(image, long_side=512, interp='nearest')
    content_features = enc_net([image])
    generated_images = dec_net(content_features)
    imsave(TEMP_IMAGE_PATH + '!generated.jpg', generated_images[0].numpy())
Esempio n. 6
0
 def __init__(self, *args, **kwargs):
     super(StyleTransferModel, self).__init__(*args, **kwargs)
     # NOTE: you may use a vgg19 instance for both content encoder and style encoder, just as in train.py
     # self.enc_c_net = vgg19(pretrained=True, end_with='conv4_1', name='content')
     # self.enc_s_net = vgg19(pretrained=True, end_with='conv4_1', name='style')
     self.enc_net = vgg19(pretrained=False,
                          end_with='conv4_1',
                          name='content_and_style_enc')
     if os.path.exists(VGG19_PARTIAL_WEIGHTS_PATH):
         self.enc_net.load_weights(VGG19_PARTIAL_WEIGHTS_PATH,
                                   in_order=False)
     self.dec_net = vgg19_rev(pretrained=False,
                              end_with='conv1_1',
                              input_depth=512,
                              name='stylized_dec')
     if os.path.exists(DEC_BEST_WEIGHTS_PATH):
         self.dec_net.load_weights(DEC_BEST_WEIGHTS_PATH, skip=True)
def test_test_model_single_call():
    from vgg import vgg19, vgg19_rev
    import os.path as osp
    import tensorlayer as tl
    VGG19_PARTIAL_WEIGHTS_PATH = 'pretrained_models/predefined_vgg19_endwith(conv4_1)_weights.h5'
    DEC_BEST_WEIGHTS_PATH = 'pretrained_models/dec_best_weights.h5'
    CONTENT_DATA_PATH = './test_images/content'
    STYLE_DATA_PATH = './test_images/style'
    test_content_filenames = ['brad_pitt_01.jpg']
    test_style_filenames = ['cat.jpg']
    TEST_INPUT_CONSTRAINTED_SIZE = 800
    TEST_OUTPUT_PATH = './test_images/output'

    tl.logging.set_verbosity(tl.logging.DEBUG)
    enc_net = vgg19(pretrained=False, end_with='conv4_1')
    enc_net.load_weights(VGG19_PARTIAL_WEIGHTS_PATH, in_order=False)
    dec_net = vgg19_rev(pretrained=False, batch_norm=False, input_depth=512)
    dec_net.load_weights(DEC_BEST_WEIGHTS_PATH, skip=True)

    i = 0  # only test 1 pair of input
    test_content = utils.imread(
        osp.join(CONTENT_DATA_PATH, test_content_filenames[i]))
    test_style = utils.imread(
        osp.join(STYLE_DATA_PATH, test_style_filenames[i]))
    # import cv2
    # test_content = cv2.cvtColor(test_content, cv2.COLOR_BGR2RGB)  # <- moved to utils.imread
    # test_style = cv2.cvtColor(test_style, cv2.COLOR_BGR2RGB)      # <- moved to utils.imread

    content_features = enc_net(test_content, is_train=False)
    style_features = enc_net(test_style, is_train=False)
    target_features = utils.AdaIN(content_features, style_features, alpha=1)
    del content_features, style_features
    generated = dec_net(target_features, is_train=False)
    import tensorflow as tf
    if isinstance(generated, tf.Tensor):
        if generated.dtype == tf.float32:
            generated = tf.cast(generated, tf.uint8)
        generated = generated[0].numpy()
    saved_path = f"{osp.splitext(test_style_filenames[i])[0]}+{osp.splitext(test_content_filenames[i])[0]}"
    saved_path = osp.join(TEST_OUTPUT_PATH, f"{saved_path}.jpg")
    # generated = cv2.cvtColor(generated, cv2.COLOR_RGB2BGR)  # <- moved to utils.imsave
    utils.imsave(saved_path, generated)
    tl.logging.info(f"saved_path = {saved_path}")
    tl.logging.info(f"generated.shape = {generated.shape}")
def test_test_arbitrary_sized_inputs():
    from vgg import vgg19, vgg19_rev
    import os.path as osp
    import tensorlayer as tl
    DEC_LATEST_WEIGHTS_PATH = 'pretrained_models/dec_latest_weights.h5'
    STYLE_LAYERS = ('conv1_1', 'conv2_1', 'conv3_1', 'conv4_1')  # for Encoders
    CONTENT_DATA_PATH = './dataset/content_samples'  # COCO_train_2014/'
    STYLE_DATA_PATH = './dataset/style_samples'  # wiki_all_images/'
    test_content_filenames = ['000000532397.jpg'
                              ]  #, '000000048289.jpg', '000000526781.jpg']
    test_style_filenames = ['53154.jpg']  #, '2821.jpg', '216.jpg']
    TEST_INPUT_CONSTRAINTED_SIZE = 800
    TEMP_IMAGE_PATH = './temp_images/'

    tl.logging.set_verbosity(tl.logging.DEBUG)
    enc_net = vgg19(pretrained=True, end_with='conv4_1')
    # NOTE: batch_norm=True will lower quality of the generated image = need retrain
    dec_net = vgg19_rev(pretrained=False, batch_norm=False, input_depth=512)
    if osp.exists(DEC_LATEST_WEIGHTS_PATH):
        dec_net.load_weights(DEC_LATEST_WEIGHTS_PATH, skip=True)

    enc_net.eval()
    dec_net.eval()
    for epoch in range(1):  # for test generator validity
        # Note: generator need reset for reuse
        test_inputs_gen = utils.single_inputs_generator(
            list(zip(test_content_filenames, test_style_filenames)),
            CONTENT_DATA_PATH, STYLE_DATA_PATH, TEST_INPUT_CONSTRAINTED_SIZE)
        for i, (test_content, test_style) in enumerate(test_inputs_gen):
            # shape=[1, w, h, c], so as to feed arbitrary sized test images one by one
            content_features = enc_net(test_content)
            style_features = enc_net(test_style, )
            target_features = utils.AdaIN(content_features,
                                          style_features,
                                          alpha=1)
            del content_features, style_features
            generated_images = dec_net(target_features)
            paired_name = f"{osp.splitext(test_style_filenames[i])[0]}+{osp.splitext(test_content_filenames[i])[0]}"
            utils.imsave(
                osp.join(TEMP_IMAGE_PATH,
                         f"temp_{paired_name}_epoch{epoch}.jpg"),
                generated_images[0].numpy())