コード例 #1
0
        def predict():
            content_img = request.files['content']
            #filename = secure_filename(file.filename)
            content_img.save(
                os.path.join(self.app.config['UPLOAD_FOLDER'], 'content_img'))
            content_img = image.load_img(
                os.path.join(self.app.config['UPLOAD_FOLDER'], 'content_img'),
                target_size=(self.image_rows, self.image_cols))
            content_img = image.img_to_array(content_img)

            style_img = request.files['style']
            style_img.save(
                os.path.join(self.app.config['UPLOAD_FOLDER'], 'style_img'))
            style_img = image.load_img(
                os.path.join(self.app.config['UPLOAD_FOLDER'], 'style_img'),
                target_size=(self.image_rows, self.image_cols))
            style_img = image.img_to_array(style_img)

            pretrained_model_dir_path = './pretrained-model'

            vgg19_model_path = pretrained_model_dir_path + " /imagenet-vgg-verydeep-19.mat"
            ss = StyleTransfer(vgg19_model_path)

            generated_image = ss.fit_and_transform(content_img,
                                                   style_img,
                                                   output_dir_path='./static',
                                                   num_iterations=200)
            generated_image = np.reshape(generated_image, (300, 400, 3))
            generated_image = image.array_to_img(generated_image)
            return render_template('index.html',
                                   message='This is a generated image')
コード例 #2
0
def test_INT8():
    print('\nTest INT8 network')
    xml_path = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'candy_int8.xml')
    bin_path = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'candy_int8.bin')
    img_path = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'tram.jpg')
    ref_path = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'tram_candy_int8.png')
    model = StyleTransfer(xml_path, bin_path)

    img = cv.imread(img_path)
    ref = cv.imread(ref_path)
    stylized = model.process(img)

    assert(cv.PSNR(stylized, ref) >= 35)
コード例 #3
0
def test_FP32():
    print('\nTest FP32 network')
    xml_path = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'candy.xml')
    bin_path = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'candy.bin')
    img_path = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'tram.jpg')
    ref_path = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'tram_candy.png')
    model = StyleTransfer(xml_path, bin_path)

    img = cv.imread(img_path)
    ref = cv.imread(ref_path)
    stylized = model.process(img)

    assert(cv.norm(stylized, ref, cv.NORM_INF) <= 1)
コード例 #4
0
def style_transfer(source_image_path,
                   content_image_path,
                   decoder_state_path,
                   content_strength=0.5,
                   n_passes=5,
                   train=False):
    assert content_strength >= 0 and content_strength <= 1
    source_image = Image.open(source_image_path)
    content_image = Image.open(content_image_path)
    transfer = StyleTransfer(source_image, content_image, observed_layers)
    transfer.set_layer_decoders(train=train, state_dir_path=decoder_state_path)
    with torch.no_grad():
        pass_generated_images = transfer.transfer(n_passes, content_strength)
    return transfer, pass_generated_images
コード例 #5
0
def main():
    args = parse_args()
    model_file_path = args.model_path + '/' + vgg19.MODEL_FILE_NAME
    vgg_net = vgg19.VGG19(model_file_path)

    content_image = utils.load_image(args.content, max_size=args.max_size)

    style_image = []
    for style_image_path in args.style:
        style_image.append(utils.load_image(style_image_path, shape=(content_image.shape[1], content_image.shape[0])))
    style_image = np.array(style_image)

    content_mask = None
    if args.content_mask is not None:
        content_mask = utils.load_image(args.content_mask, shape=(content_image.shape[1], content_image.shape[0]))
        content_mask = content_mask/255.

    # initial guess for output
    if args.initial_type == 'content':
        init_image = content_image
    elif args.initial_type == 'style':
        init_image = style_image
    elif args.initial_type == 'random':
        init_image = np.random.normal(size=content_image.shape, scale=np.std(content_image))

    CONTENT_LAYERS = {}
    for layer, weight in zip(args.content_layers, args.content_layer_weights):
        CONTENT_LAYERS[layer] = weight

    STYLE_LAYERS = {}
    for layer, weight in zip(args.style_layers, args.style_layer_weights):
        STYLE_LAYERS[layer] = weight

    # open session
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    # build the graph
    st = StyleTransfer(session=sess, content_layer_ids=CONTENT_LAYERS, style_layer_ids=STYLE_LAYERS,
                       init_image=add_one_dim(init_image), content_image=add_one_dim(content_image),
                       style_image=style_image, net=vgg_net, num_iter=args.num_iter,
                       loss_ratios=[args.loss_ratio_c, args.loss_ratio_tv],
                       content_loss_norm_type=args.content_loss_norm_type, content_mask=content_mask)
    result_image = st.update()
    sess.close()

    shape = result_image.shape
    result_image = np.reshape(result_image, shape[1:])
    utils.save_image(result_image, args.output)
コード例 #6
0
ファイル: main.py プロジェクト: hhappy06/image_style_transfer
def main():
    # read image from file
    content_image = scipy.misc.imread(_CONTENT_IMAGE_PATH).astype(np.float)
    content_image = scipy.misc.imresize(content_image,
                                        [_IMAGE_SIZE, _IMAGE_SIZE])
    style_image = scipy.misc.imread(_STYLE_IMAGE_PATH).astype(np.float)
    style_image = scipy.misc.imresize(style_image, content_image.shape[:2])

    init_image = scipy.misc.imread(_SAVE_IMAGE_PATH).astype(np.float)
    init_image = scipy.misc.imresize(init_image, content_image.shape[:2])
    # style transfer model
    image_style_transfer = StyleTransfer(_PRETRAINED_VGG19_MODEL)

    generated_image = image_style_transfer.image_style_transfer(
        content_image, style_image, 0.05, 1, 15, init_image=init_image)
    scipy.misc.imsave(_SAVE_IMAGE_PATH, generated_image)
コード例 #7
0
def main():
    parser = build_parser()
    options = parser.parse_args()

    if not os.path.isfile(VGG_PATH):
        parser.error("Network %s does not exist." % VGG_PATH)

    content_image = load_image(options.content)
    style_image = load_image(options.style)

    initial = options.initial
    if initial is not None:
        initial = scipy.misc.imresize(imread(initial), content_image.shape[:2])

    if options.checkpoint_output and "%s" not in options.checkpoint_output:
        parser.error("To save intermediate images, the checkpoint output "
                     "parameter must contain `%s` (e.g. `foo%s.jpg`)")

    device = '/gpu:0' if options.use_gpu else '/cpu:0'

    style_transfer = StyleTransfer(
        vgg_path=VGG_PATH,
        content=content_image,
        style=style_image,
        content_weight=options.content_weight,
        style_weight=options.style_weight,
        tv_weight=options.style_weight,
        initial=initial,
        device=device)

    for iteration, image, losses in style_transfer.train(
        learning_rate=options.learning_rate,
        iterations=options.iterations,
        checkpoint_iterations=options.checkpoint_iterations
    ):
        print_losses(losses)

        output_file = None
        if iteration is not None:
            if options.checkpoint_output:
                output_file = options.output + (options.checkpoint_output % iteration)
            else:
                output_file = options.output
        if output_file:
            imsave(output_file, image)
コード例 #8
0
    def __init__(self, mirror=False, style=False):
        """
        mirror: Support camera mirror mode
        style: Style transfer application
        """
        self.data = None
        self.data_ready = False
        self.cam = cv2.VideoCapture(0)

        self.style = style
        self.mirror = mirror

        self.WIDTH = 640
        self.HEIGHT = 480

        # This parameter is used to know the center position when the zoom function is in use.
        self.center_x = self.WIDTH / 2
        self.center_y = self.HEIGHT / 2
        self.touched_zoom = False

        # Queue for image capture and video recording.
        self.image_queue = Queue()
        self.video_queue = Queue()

        # Button manager object for creating UI buttons.
        self.btn_manager = ButtonManager(self.WIDTH, self.HEIGHT)

        # scale is a variable that determines zoom of the screen.
        self.scale = 1

        # It is a variable to check whether it is currently recording.
        self.recording = False

        # Whether to apply the style transfer to the face only.
        self.face_transfer = False
        # An object that performs style transfers.
        self.style_transfer = StyleTransfer(self.WIDTH, self.HEIGHT)
        # It is an object that recognizes the face and segments it.
        self.image_segmentation = ImageSegmentation(self.WIDTH, self.HEIGHT)

        self.__setup()
コード例 #9
0
def style_transfer():
    form = PhotoForm()

    if form.validate_on_submit():
        username = current_user.username
        image_path = ImagesPath(username, current_app.static_folder)

        target_image = form.target_image.data
        style_reference_image = form.style_reference_image.data

        target_image_name = secure_filename(target_image.filename)
        style_reference_image_name = secure_filename(
            style_reference_image.filename)

        target_image_path = image_path.buffer_image_path(target_image_name)
        style_reference_image_path = image_path.buffer_image_path(
            style_reference_image_name)

        target_image_path_abs = image_path.abs_path(target_image_path)
        style_reference_image_path_abs = image_path.abs_path(
            style_reference_image_path)

        target_image.save(target_image_path_abs)
        style_reference_image.save(style_reference_image_path_abs)

        flash('Wait for result image.')

        StyleTransfer(
            target_image_path_abs,
            style_reference_image_path_abs,
            save_path=image_path.model_buffer(absolute=True),
            iterations=current_app.config['MODEL_ITERATION']).transfer()

        result_image_path = image_path.last_file_path()

        return render_template(
            'transfer/image_transfer.html',
            target_image_path=image_path.convert(target_image_path),
            style_reference_image_path=image_path.convert(
                style_reference_image_path),
            form=form,
            result_image_path=image_path.convert(result_image_path),
            username=username)
    return render_template('transfer/image_transfer.html', form=form)
コード例 #10
0
ファイル: app.py プロジェクト: Andy1621/SaoImage
def mix():
    setup()
    content_img = request.form['content']
    style_img = request.form['style']
    train = request.form['train']
    # 指定像素尺寸
    img_width = 400
    img_height = 300
    # style transfer
    global style_transfer
    style_transfer = StyleTransfer(content_img, style_img, img_width,
                                   img_height)
    style_transfer.build()
    style_transfer.train(int(train))
    style_transfer.gif()
    print(train)
    res = dict()
    res['gif'] = style_transfer.res
    res['mix_img'] = style_transfer.mix_img
    return json.dumps(res)
コード例 #11
0
import os
import numpy as np
import cv2 as cv
import telebot
import argparse
from style_transfer import StyleTransfer

parser = argparse.ArgumentParser()
parser.add_argument('--token', help='Telegram bot token', required=True)
args = parser.parse_args()

xml_path = os.path.join(os.path.dirname(__file__), '..', '..', 'data',
                        'candy_int8.xml')
bin_path = os.path.join(os.path.dirname(__file__), '..', '..', 'data',
                        'candy_int8.bin')
model = StyleTransfer(xml_path, bin_path)

bot = telebot.TeleBot(args.token)


def get_image(message):
    fileID = message.photo[-1].file_id
    file = bot.get_file(fileID)
    data = bot.download_file(file.file_path)
    buf = np.frombuffer(data, dtype=np.uint8)
    return cv.imdecode(buf, cv.IMREAD_COLOR)


def send_image(message, img):
    _, buf = cv.imencode(".jpg", img, [cv.IMWRITE_JPEG_QUALITY, 90])
    bot.send_photo(message.chat.id, buf)
コード例 #12
0
config.learning_rate = 1.0
config.decayed_learning_rate = .03333
config.start_with_content = True
config.input_samples = 188
config.content_layer = 0
config.style_layers = (
    #(0, 5e-2), # Just first layer
    #(1, 1e-9), # Just 2nd layer
    (0, 3e-2),  # Both first and 2nd layer
    (1, 1e-9),
)
config.reg = 0.0
config.alpha = 1.45e-11
config.channels_as_filters = False

st = StyleTransfer(config)
out, out_sr = st.transfer_style(content_file, style_file)
print("shape: ", out.shape)
print("Contains nan: ", np.any(np.isnan(out)))
print("min: ", out.min())
print("max: ", out.max())
print("mean: ", out.mean())
print("std: ", out.std())

try:
    save_spectrogram_as_audio(out.T, out_sr,
                              "log/" + config.experiment_name + "_out.wav")
except Exception as e:
    print(e)
    print("Trying clipping")
    out = np.clip(out, -40, 40)
コード例 #13
0
parser.add_argument('--style_weight',
                    nargs='?',
                    default=100000000000,
                    type=int)
parser.add_argument('--content_weight', nargs='?', default=100000, type=int)
parser.add_argument('--res', nargs='?', default=256, type=int)
parser.add_argument('--lr', nargs='?', default=1e-3, type=float)
parser.add_argument('--train_epoch', nargs='?', default=60, type=int)
parser.add_argument('--test_perc', nargs='?', default=.1, type=float)
parser.add_argument('--data_perc', nargs='?', default=1, type=float)
parser.add_argument('--beta1', nargs='?', default=.5, type=float)
parser.add_argument('--beta2', nargs='?', default=.999, type=float)
parser.add_argument('--workers', nargs='?', default=4, type=int)
parser.add_argument('--save_every', nargs='?', default=5, type=int)
parser.add_argument('--save_img_every', nargs='?', default=1, type=int)
parser.add_argument('--ids', type=int, nargs='+', default=[10, 20])
parser.add_argument('--style_image', nargs='?', default='franc.jpg', type=str)
parser.add_argument('--save_root', nargs='?', default='franc_style', type=str)
parser.add_argument('--load_state', nargs='?', type=str)

params = vars(parser.parse_args())

# if load_state arg is not used, then train model from scratch
if __name__ == '__main__':
    style_transfer = StyleTransfer(params)
    if params['load_state']:
        style_transfer.load_state(params['load_state'])
    else:
        print('Starting From Scratch')
    style_transfer.train()
コード例 #14
0
from style_transfer import StyleTransfer

if __name__ == '__main__':
    ST = StyleTransfer('input/style1.jpg', 'input/face.jpg',
                       'input/mask_style1.jpg', 'input/mask_face.jpg')
    ST.run()
コード例 #15
0
class Camera:
    # Camera class for streaming
    def __init__(self, mirror=False, style=False):
        """
        mirror: Support camera mirror mode
        style: Style transfer application
        """
        self.data = None
        self.data_ready = False
        self.cam = cv2.VideoCapture(0)

        self.style = style
        self.mirror = mirror

        self.WIDTH = 640
        self.HEIGHT = 480

        # This parameter is used to know the center position when the zoom function is in use.
        self.center_x = self.WIDTH / 2
        self.center_y = self.HEIGHT / 2
        self.touched_zoom = False

        # Queue for image capture and video recording.
        self.image_queue = Queue()
        self.video_queue = Queue()

        # Button manager object for creating UI buttons.
        self.btn_manager = ButtonManager(self.WIDTH, self.HEIGHT)

        # scale is a variable that determines zoom of the screen.
        self.scale = 1

        # It is a variable to check whether it is currently recording.
        self.recording = False

        # Whether to apply the style transfer to the face only.
        self.face_transfer = False
        # An object that performs style transfers.
        self.style_transfer = StyleTransfer(self.WIDTH, self.HEIGHT)
        # It is an object that recognizes the face and segments it.
        self.image_segmentation = ImageSegmentation(self.WIDTH, self.HEIGHT)

        self.__setup()

    def __setup(self):
        # Prepare the camera settings, button manager, and style transfer objects.
        self.cam.set(cv2.CAP_PROP_FRAME_WIDTH, self.WIDTH)
        self.cam.set(cv2.CAP_PROP_FRAME_HEIGHT, self.HEIGHT)

        self.btn_manager.button_setting()
        self.style_transfer.load()
        time.sleep(2)

    def get_location(self, x, y):
        # Specifies the center of the current screen.
        self.center_x = x
        self.center_y = y
        self.touched_zoom = True

    def stream(self):
        # Streaming thread Function
        def streaming():
            self.ret = True
            while self.ret:
                self.ret, np_image = self.cam.read()
                self.data_ready = False
                if np_image is None:
                    continue
                if self.mirror:
                    # Mirror mode function
                    np_image = cv2.flip(np_image, 1)
                if self.touched_zoom:
                    # When using the double-click zoom function,
                    np_image = self.__zoom(np_image,
                                           (self.center_x, self.center_y))
                else:
                    # When not zoomed,
                    if not self.scale == 1:
                        np_image = self.__zoom(np_image)

                if self.style:
                    # Convert image to style transfer
                    image_result = self.transform(np_image)

                    np_image = image_result

                self.data = np_image
                self.data_ready = True
                k = cv2.waitKey(1)
                if k == ord('q'):
                    self.release()
                    break

        Thread(target=streaming).start()

    def __zoom(self, img, center=None):
        # This function calculates various values ​​according to the scale of the current screen
        # and applies them to the screen.
        height, width = img.shape[:2]
        if center is None:
            # When the center is the initial value,
            center_x = int(width / 2)
            center_y = int(height / 2)
            radius_x, radius_y = int(width / 2), int(height / 2)
        else:
            # When the center is not the initial value (when the zoom function is activated)
            rate = height / width
            center_x, center_y = center

            # Calculate center value according to ratio
            if center_x < width * (1 - rate):
                center_x = width * (1 - rate)
            elif center_x > width * rate:
                center_x = width * rate
            if center_y < height * (1 - rate):
                center_y = height * (1 - rate)
            elif center_y > height * rate:
                center_y = height * rate

            center_x, center_y = int(center_x), int(center_y)
            left_x, right_x = center_x, int(width - center_x)
            up_y, down_y = int(height - center_y), center_y
            radius_x = min(left_x, right_x)
            radius_y = min(up_y, down_y)

        # Calculate position according to proportion
        radius_x, radius_y = int(self.scale * radius_x), int(self.scale *
                                                             radius_y)

        # Size calculation
        min_x, max_x = center_x - radius_x, center_x + radius_x
        min_y, max_y = center_y - radius_y, center_y + radius_y

        # Crop the image to fit the calculated size.
        cropped = img[min_y:max_y, min_x:max_x]
        # Stretch the cropped image to the original image size.
        new_cropped = cv2.resize(cropped, (width, height))

        return new_cropped

    def touch_init(self):
        # Initialize state
        self.center_x = self.WIDTH / 2
        self.center_y = self.HEIGHT / 2
        self.touched_zoom = False
        self.scale = 1

    def zoom_out(self):
        # Zoom-out by increasing the scale value
        if self.scale < 1:
            self.scale += 0.1
        if self.scale == 1:
            self.center_x = self.WIDTH
            self.center_y = self.HEIGHT
            self.touched_zoom = False

    def zoom_in(self):
        # Zoom-in function by reducing scale value
        if self.scale > 0.2:
            self.scale -= 0.1

    def zoom(self, num):
        # Zoom in & out according to index
        if num == 0:
            self.zoom_in()
        elif num == 1:
            self.zoom_out()
        elif num == 2:
            self.touch_init()

    def transform(self, img):
        # Functions that perform style transfers
        copy_img = img.copy()
        copy_img2 = img.copy()
        # 1. Get a converted image with style transfer
        style_img = self.style_transfer.predict(copy_img)
        # 2. Getting a mask with face segmentation
        seg_mask = self.image_segmentation.predict(copy_img2)
        mask = cv2.cvtColor(seg_mask, cv2.COLOR_GRAY2RGB)

        # 3. Combine face only with image converted to style transfer
        if self.face_transfer:
            image_result = np.where(mask, style_img, img)
        else:
            image_result = np.where(mask, img, style_img)
        return image_result

    def event(self, i):
        # Function to change style according to button event
        self.style = True
        self.style_transfer.change_style(i)

    def save_picture(self):
        # Save Image Function
        ret, img = self.cam.read()
        if ret:
            now = datetime.datetime.now()
            date = now.strftime('%Y%m%d')
            hour = now.strftime('%H%M%S')
            user_id = '00001'
            filename = './images/mevia_{}_{}_{}.png'.format(
                date, hour, user_id)
            cv2.imwrite(filename, img)
            self.image_queue.put_nowait(filename)

    def record_video(self):
        # Recording Function
        fc = 20.0
        record_start_time = time.time()
        now = datetime.datetime.now()
        date = now.strftime('%Y%m%d')
        t = now.strftime('%H')
        num = 1
        filename = 'videos/mevia_{}_{}_{}.avi'.format(date, t, num)
        while os.path.exists(filename):
            num += 1
            filename = 'videos/mevia_{}_{}_{}.avi'.format(date, t, num)
        codec = cv2.VideoWriter_fourcc('D', 'I', 'V', 'X')
        out = cv2.VideoWriter(filename, codec, fc,
                              (int(self.cam.get(3)), int(self.cam.get(4))))
        while self.recording:
            if time.time() - record_start_time >= 600:
                self.record_video()
                break
            ret, frame = self.cam.read()
            if ret:
                if len(os.listdir('./videos')) >= 100:
                    name = self.video_queue.get()
                    if os.path.exists(name):
                        os.remove(name)
                if self.data is not None:
                    out.write(self.data)
                else:
                    out.write(frame)
                self.video_queue.put_nowait(filename)
            k = cv2.waitKey(1)
            if k == ord('q'):
                break

    def show(self):
        # Function to show streaming screen
        # Provides various functions using keyboard keys
        """
        q: Close & Quit
        z: Zoom-in
        x: Zoom-out
        p: Save Picture
        v: Return initial State
        r: Video recording
        """
        while True:
            frame = self.data
            if frame is not None:
                self.btn_manager.draw(frame)
                cv2.imshow('Mevia', frame)
                cv2.setMouseCallback('Mevia', self.mouse_callback)
            key = cv2.waitKey(1)
            if key == ord('q'):
                # q : close
                self.release()
                cv2.destroyAllWindows()
                break

            elif key == ord('z'):
                # z : zoom - in
                self.zoom_in()

            elif key == ord('x'):
                # x : zoom - out
                self.zoom_out()

            elif key == ord('p'):
                # p : take picture and save image (image folder)
                self.save_picture()

            elif key == ord('v'):
                # v : original state
                self.touch_init()

            elif key == ord('r'):
                # r : recording
                self.recording = not self.recording
                if self.recording:
                    t = Thread(target=cam.record_video)
                    t.start()

    def release(self):
        self.cam.release()
        cv2.destroyAllWindows()

    def mouse_callback(self, event, x, y, flag, param):
        # Mouse click event handling function
        if event == cv2.EVENT_LBUTTONDOWN:
            # Left click once (or touch)
            # Determining whether a button is clicked by passing the click position to the button manager
            self.btn_manager.btn_on_click(x, y)
            if self.btn_manager.button_flag[-1] == 0:
                self.face_transfer = False
            else:
                self.face_transfer = True
            if 1 in self.btn_manager.button_flag:
                # If the button was clicked, On style transfer
                for i in range(len(self.btn_manager.button_flag)):
                    if self.btn_manager.button_flag[i] == 1 and i != len(
                            self.btn_manager.button_flag) - 1:
                        # Change the style to suit the clicked button
                        self.event(i)
                        break
            else:
                self.style = False

        elif event == cv2.EVENT_LBUTTONDBLCLK:
            # When double-clicking the left button, activate the zoom function
            self.get_location(x, y)
            self.zoom_in()
        elif event == cv2.EVENT_RBUTTONDOWN:
            # Right click, zoom out
            self.zoom_out()
コード例 #16
0
def transfer_runner(content_image, style_image):
    model_path = 'tfb/pre_trained_model'  #The directory where the pre-trained model was saved
    output = 'results/' + strftime(
        "%Y-%m-%d-%H:%M") + '.jpg'  #File path of output image
    loss_ratio = 1e-3  #Weight of content-loss relative to style-loss
    content_layers = ['conv4_2']  #VGG19 layers used for content loss
    style_layers = ['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1',
                    'relu5_1']  #VGG19 layers used for style loss
    content_layer_weights = [
        1.0
    ]  #Content loss for each content is multiplied by corresponding weight
    style_layer_weights = [
        .2, .2, .2, .2, .2
    ]  #Style loss for each content is multiplied by corresponding weight
    initial_type = 'content'  #choices = ['random','content','style'], The initial image for optimization (notation in the paper : x)
    max_size = 101  #512                      #The maximum width or height of input images
    content_loss_norm_type = 3  #choices=[1,2,3],  Different types of normalization for content loss
    num_iter = 400  #The number of iterations to run

    try:
        assert len(content_layers) == len(content_layer_weights)
    except:
        raise ('content layer info and weight info must be matched')
    try:
        assert len(style_layers) == len(style_layer_weights)
    except:
        raise ('style layer info and weight info must be matched')

    try:
        assert max_size > 100
    except:
        raise ('Too small size')

    model_file_path = model_path + '/' + vgg19.MODEL_FILE_NAME
    assert os.path.exists(model_file_path)
    try:
        assert os.path.exists(model_file_path)
    except:
        raise Exception('There is no %s' % model_file_path)

    try:
        size_in_KB = os.path.getsize(model_file_path)
        assert abs(size_in_KB - 534904783) < 10
    except:
        print('check file size of \'imagenet-vgg-verydeep-19.mat\'')
        print('there are some files with the same name')
        print('pre_trained_model used here can be downloaded from bellow')
        print(
            'http://www.vlfeat.org/matconvnet/models/imagenet-vgg-verydeep-19.mat'
        )
        raise ()

    # initiate VGG19 model
    model_file_path = model_path + '/' + vgg19.MODEL_FILE_NAME
    vgg_net = vgg19.VGG19(model_file_path)

    # initial guess for output
    if initial_type == 'content':
        init_image = content_image
    elif initial_type == 'style':
        init_image = style_image
    elif initial_type == 'random':
        init_image = np.random.normal(size=content_image.shape,
                                      scale=np.std(content_image))

    # check input images for style-transfer
    # utils.plot_images(content_image,style_image, init_image)

    # create a map for content layers info
    CONTENT_LAYERS = {}
    for layer, weight in zip(content_layers, content_layer_weights):
        CONTENT_LAYERS[layer] = weight

    # create a map for style layers info
    STYLE_LAYERS = {}
    for layer, weight in zip(style_layers, style_layer_weights):
        STYLE_LAYERS[layer] = weight
    with tf.Graph().as_default():
        # open session
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        # build the graph
        st = StyleTransfer(
            session=sess,
            content_layer_ids=CONTENT_LAYERS,
            style_layer_ids=STYLE_LAYERS,
            init_image=add_one_dim(init_image),
            content_image=add_one_dim(content_image),
            style_image=add_one_dim(style_image),
            net=vgg_net,
            num_iter=num_iter,
            loss_ratio=loss_ratio,
            content_loss_norm_type=content_loss_norm_type,
        )
        # launch the graph in a session
        result_image = st.update()
        # close session
        sess.close()
    # remove batch dimension
    shape = result_image.shape
    result_image = np.reshape(result_image, shape[1:])
    # save result
    #utils.save_image(result_image,get_output_filepath(content, style))
    return result_image
コード例 #17
0
from style_transfer import Config, StyleTransfer

content_file = "inputs/test/content.wav"
style_file = "/inputs/test/style.wav"
config = Config()
config.experiment_name = "first-test"

st = StyleTransfer(config)
st.transfer_style(content_file, style_file)
コード例 #18
0

def to_tensor(image):
    image = Image.open(SAVING_PATH + image, 'r')
    image = loader(image).unsqueeze(0)
    return image


logging.basicConfig(level=logging.INFO)

bot = Bot(token=API_TOKEN, proxy=PROXY_URL,
          proxy_auth=PROXY_AUTH)  #объявляем бота
dp = Dispatcher(bot, storage=MemoryStorage())
dp.middleware.setup(LoggingMiddleware())

st = StyleTransfer()


@dp.message_handler(commands=['start', 'help'])
async def send_welcome(message: types.Message):
    state = dp.current_state(user=message.from_user.id)
    await state.set_state(States.all()[0])
    await message.reply("Отправьте боту "
                        "/style_transfer"
                        ",чтобы начать перенос стиля ")


@dp.message_handler()
async def bad_message(message: types.Message):
    await message.reply('Напишите боту "/start" или "/help" ')
コード例 #19
0
## PREPROCESS ##
################

cont_img = utils.load_image(cont_path, img_shape)
styl_img = utils.load_image(styl_path, img_shape)
init_img = utils.load_init_image(cont_img, styl_img, img_shape, choice="rand_uni")

cont_img = utils.img_preprocess(cont_img)
styl_img = utils.img_preprocess(styl_img)
init_img = utils.img_preprocess(init_img)

model = StyleTransfer(init_img,
                      cont_img,
                      styl_img,
                      cont_layers,
                      styl_layers,
                      cont_weights,
                      styl_weights,
                      alpha,
                      beta)

##############
## TRAINING ##
##############

with tf.Session(graph=model.graph) as sess:
    sess.run(tf.global_variables_initializer())

    optimizer = ScipyOptimizerInterface(model.total_loss, method="L-BFGS-B", options={'maxiter': num_steps})
    optimizer.minimize(sess,
                       fetches=[model.styl_loss, model.cont_loss, model.total_loss,
コード例 #20
0
import tensorflow as tf
from style_transfer import StyleTransfer
import utils

# enable eager execution
tf.enable_eager_execution()
print('Eager execution: {}'.format(tf.executing_eagerly()))

painter = StyleTransfer()
threshold = painter.miss_percentage_threshold
while True:
    best_img, miss_percentage = painter.run()
    if best_img is not None:
        utils.remove_earlier_checkpoints('./checkpoints')
    if miss_percentage > threshold:
        painter.learning_rate = painter.learning_rate * painter.lr_decay_rate
コード例 #21
0
ファイル: test.py プロジェクト: abhishekvasu94/Style-Transfer
    noisy_img = np.float32(noisy_img)

    vgg16 = VGG(path)

    #Play around with these parameters
    alpha = 5e-4
    beta = 1

    with tf.Session() as sess:

        init = tf.global_variables_initializer()
        sess.run(init)

        st_transfer = StyleTransfer(content_img,
                                    stylized_img,
                                    noisy_img,
                                    alpha,
                                    beta,
                                    lr=5,
                                    n_epochs=2000,
                                    session=sess,
                                    model=vgg16)

        st_transfer.build()
        final_img = st_transfer.final_img

    sess.close()

    cv2.imwrite('final_img.jpg', final_img[0])
コード例 #22
0
ファイル: app.py プロジェクト: Andy1621/SaoImage
#! env python3
# -*- coding: UTF-8 -*-

from flask import Flask, render_template, request
from style_transfer import setup, StyleTransfer
import json

app = Flask(__name__)

style_transfer = StyleTransfer()

prog = 0


@app.route('/')
@app.route('/ist')
def ist():
    return render_template('ist.html')


@app.route('/about')
def about():
    return render_template('about.html')


@app.route('/api/v1/mix', methods=['POST'])
def mix():
    setup()
    content_img = request.form['content']
    style_img = request.form['style']
    train = request.form['train']
コード例 #23
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--input_dir',
                        type=str,
                        default="output",
                        help='directory of checkpoint files')
    parser.add_argument('--output',
                        type=str,
                        default=DEFAULT_MODEL,
                        help='exported file')
    parser.add_argument('--image_h',
                        type=int,
                        default=-1,
                        help='weight for texture loss vs content loss')
    parser.add_argument('--image_w',
                        type=int,
                        default=-1,
                        help='weight for texture loss vs content loss')

    parser.add_argument('--noise',
                        type=float,
                        default=0.,
                        help='noise magnitude')

    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s:%(message)s',
                        level=logging.INFO,
                        datefmt='%I:%M:%S')

    args = parser.parse_args()
    tmp_dir = os.path.join(args.input_dir, 'tmp')
    if not os.path.exists(tmp_dir):
        os.mkdir(tmp_dir)
    ckpt_dir = os.path.join(tmp_dir, 'ckpt')
    if not os.path.exists(ckpt_dir):
        os.mkdir(ckpt_dir)

    args.save_model = os.path.join(ckpt_dir, 'model')

    with open(os.path.join(args.input_dir, 'result.json'), 'r') as f:
        result = json.load(f)

    model_name = result['model_name']
    best_model_full = result['best_model']
    best_model_arr = best_model_full.split('/')
    best_model_arr[0] = args.input_dir
    best_model = os.path.join(*best_model_arr)

    if args.image_w < 0:
        if 'image_w' in result:
            args.image_w = result['image_w']
        else:
            args.image_w = vgg.DEFAULT_SIZE
    if args.image_h < 0:
        if 'image_h' in result:
            args.image_h = result['image_h']
        else:
            args.image_h = vgg.DEFAULT_SIZE

    if args.output == DEFAULT_MODEL:
        args.output = model_name + ".pb"

    logging.info("loading best model from %s" % best_model)

    graph = tf.Graph()
    with graph.as_default():
        with tf.name_scope(model_name):
            model = StyleTransfer(is_training=False,
                                  batch_size=1,
                                  image_h=args.image_h,
                                  image_w=args.image_w,
                                  inf_noise=args.noise)
        model_saver = tf.train.Saver(name='saver', sharded=True)
    try:
        with tf.Session(graph=graph) as session:

            logging.info("Loading model")
            model_saver.restore(session, best_model)

            logging.info("Verify model")
            batch_gen_valid = BatchGenerator(1,
                                             args.image_h,
                                             args.image_w,
                                             valid=True)
            _, _, _, test_out, _ = model.run_epoch(session,
                                                   tf.no_op(),
                                                   None,
                                                   batch_gen_valid,
                                                   num_iterations=1)

            utils.write_image(
                os.path.join(args.input_dir, 'export_verify.png'), test_out)

            logging.info("Exporting model")
            best_model = model_saver.save(session, args.save_model)
            # Save graph def
            tf.train.write_graph(session.graph_def, tmp_dir, "temp_model.pb",
                                 False)

            saver_def = model_saver.as_saver_def()
            input_graph_path = os.path.join(tmp_dir, "temp_model.pb")
            input_saver_def_path = ""  # we dont have this
            input_binary = True
            input_checkpoint_path = args.save_model
            output_node_names = model_name + "/output"
            restore_op_name = saver_def.restore_op_name
            filename_tensor_name = saver_def.filename_tensor_name
            output_graph_path = os.path.join(args.input_dir, args.output)
            clear_devices = False

            freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
                                      input_binary, input_checkpoint_path,
                                      output_node_names, restore_op_name,
                                      filename_tensor_name, output_graph_path,
                                      clear_devices, None)
            shutil.rmtree(tmp_dir)
    except:
        print("Unexpected error:", sys.exc_info()[0])
        raise
コード例 #24
0
ファイル: stylize.py プロジェクト: vonum/style-transfer
SAVE_IT = args["save_it"]
SAVE_IT_DIR = args["save_it_dir"]

net = vgg19.VGG19(MODEL_PATH)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))

st = StyleTransfer(sess,
                   net,
                   ITERATIONS,
                   CONTENT_LAYERS,
                   STYLE_LAYERS,
                   content_image,
                   style_image,
                   CONTENT_LAYER_WEIGHTS,
                   STYLE_LAYER_WEIGHTS,
                   CONTENT_LOSS_WEIGHT,
                   STYLE_LOSS_WEIGHT,
                   TV_LOSS_WEIGHT,
                   OPTIMIZER,
                   learning_rate=LEARNING_RATE,
                   init_img_type=INIT_TYPE,
                   preserve_colors=PRESERVE_COLORS,
                   cvt_type=CVT_TYPE,
                   content_factor_type=CONTENT_FACTOR_TYPE,
                   save_it=SAVE_IT,
                   save_it_dir=SAVE_IT_DIR)

mixed_image = st.run()
summary = st.loss_summary()

sess.close()