コード例 #1
0
ファイル: stylize.py プロジェクト: man11sha/Imagical
def get_stylize_image(content_fullpath, style_fullpath, output_path,
                      content_size=256, style_size=256, alpha=0.6,
                      swap5=False, ss_alpha=0.6, adain=False):
    content_img = get_img(content_fullpath)
    content_img = resize_to(content_img, int(content_size))

    style_img = get_img(style_fullpath)
    style_img = resize_to(style_img,int(style_size))

    stylized_rgb = wct_model.predict(
        content_img, style_img, alpha, swap5, ss_alpha, adain)

    save_img(output_path, stylized_rgb)
    print("stylized image saved "+output_path)
コード例 #2
0
def post_params():
    myid = str(int(time.time() * 10000)) + ".jpg"
    tmp_img_post_path = RESULT_IMG_PATH + myid

    style_img_url = STYLE_IMG_PATH + request.form['style_img_url']
    content_img_url = CONTENT_IMG_PATH + request.form['content_img_url']
    alpha = float(request.form['alpha'])
    content_img = get_img(content_img_url)
    style_size = int(float(request.form['style_scale']) * 512)
    keep_colors = False

    # style_img = get_img_crop(style_img_url, resize=style_size)
    style_img = get_img_crop(style_img_url)

    if style_size > 0:
        style_img = resize_to(style_img, style_size)

    if keep_colors:
        style_img = preserve_colors_np(style_img, content_img)

    # Run the frame through the style network
    stylized_rgb = wct_model.predict(content_img, style_img, alpha, False, 0.6,
                                     False)

    save_img(tmp_img_post_path, stylized_rgb)
    return Response(myid, status=200, mimetype='application/json')
コード例 #3
0
ファイル: webcam.py プロジェクト: calvinlcchen/WCT-TF
    def set_style(self, idx=None, random=False, window='Style Controls'):
        if idx is not None:
            self.idx = idx
        if random:
            self.idx = np.random.randint(len(self.style_imgs))

        style_file = self.style_imgs[self.idx]
        print('Loading style image',style_file)
        if self.crop_size > 0:
            self.style_rgb = get_img_crop(style_file, resize=self.img_size, crop=self.crop_size)
        else:
            self.style_rgb = resize_to(get_img(style_file), self.img_size)
        self.show_style(window, self.style_rgb)
コード例 #4
0
ファイル: webcam.py プロジェクト: xmartlabs/WCT-TF
    def set_style(self, idx=None, random=False, window='Style Controls'):
        if idx is not None:
            self.idx = idx
        if random:
            self.idx = np.random.randint(len(self.style_imgs))

        style_file = self.style_imgs[self.idx]
        print('Loading style image', style_file)
        if self.crop_size > 0:
            self.style_rgb = get_img_crop(style_file,
                                          resize=self.img_size,
                                          crop=self.crop_size)
        else:
            self.style_rgb = resize_to(get_img(style_file), self.img_size)
        self.show_style(window, self.style_rgb)
コード例 #5
0
def main():
    # Load the WCT model
    wct_model = WCT(checkpoints=args.checkpoints,
                    relu_targets=args.relu_targets,
                    vgg_path=args.vgg_path,
                    device=args.device,
                    ss_patch_size=args.ss_patch_size,
                    ss_stride=args.ss_stride)

    # Create needed dirs
    in_dir = os.path.join(args.tmp_dir, 'input')
    out_dir = os.path.join(args.tmp_dir, 'sytlized')
    if not os.path.exists(in_dir):
        os.makedirs(in_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    in_args = ['ffmpeg', '-i', args.in_path, '%s/frame_%%d.png' % in_dir]

    subprocess.call(" ".join(in_args), shell=True)
    base_names = os.listdir(in_dir)
    in_files = [os.path.join(in_dir, x) for x in base_names]
    out_files = [os.path.join(out_dir, x) for x in base_names]

    style_img = get_img(args.style_path)

    if args.style_size > 0:
        style_img = resize_to(style_img, args.style_size)
    if args.crop_size > 0:
        style_img = center_crop(style_img, args.crop_size)

    s = time.time()
    for in_f, out_f in zip(in_files, out_files):
        print('{} -> {}'.format(in_f, out_f))
        content_img = get_img(in_f)

        if args.keep_colors:
            style_rgb = preserve_colors_np(style_img, content_img)
        else:
            style_rgb = style_img

        stylized = wct_model.predict(content_img, style_rgb, args.alpha,
                                     args.swap5, args.ss_alpha)

        if args.passes > 1:
            for _ in range(args.passes - 1):
                stylized = wct_model.predict(stylized, style_rgb, args.alpha)

        # Stitch the style + stylized output together, but only if there's one style image
        if args.concat:
            # Resize style img to same height as frame
            style_img_resized = scipy.misc.imresize(
                style_rgb, (stylized.shape[0], stylized.shape[0]))
            stylized = np.hstack([style_img_resized, stylized])

        save_img(out_f, stylized)

    fr = 30
    out_args = [
        'ffmpeg', '-i',
        '%s/frame_%%d.png' % out_dir, '-f', 'mp4', '-q:v', '0', '-vcodec',
        'mpeg4', '-r',
        str(fr), args.out_path
    ]

    subprocess.call(" ".join(out_args), shell=True)
    print('Video at: %s' % args.out_path)

    if args.keep_tmp is False:
        shutil.rmtree(args.tmp_dir)

    print('Processed in:', (time.time() - s))
コード例 #6
0
def main():
    start = time.time()

    # Load the WCT model
    wct_model = WCT(checkpoints=args.checkpoints,
                    relu_targets=args.relu_targets,
                    vgg_path=args.vgg_path,
                    device=args.device,
                    ss_patch_size=args.ss_patch_size,
                    ss_stride=args.ss_stride)

    # Get content & style full paths
    if os.path.isdir(args.content_path):
        content_files = get_files(args.content_path)
    else:  # Single image file
        content_files = [args.content_path]
    if os.path.isdir(args.style_path):
        style_files = get_files(args.style_path)
        if args.random > 0:
            style_files = np.random.choice(style_files, args.random)
    else:  # Single image file
        style_files = [args.style_path]

    os.makedirs(args.out_path, exist_ok=True)

    count = 0

    # Apply each style to each content image
    for content_fullpath in content_files:
        content_prefix, content_ext = os.path.splitext(content_fullpath)
        content_prefix = os.path.basename(
            content_prefix)  # Extract filename prefix without ext

        content_img = get_img(content_fullpath)
        if args.content_size > 0:
            content_img = resize_to(content_img, args.content_size)

        for style_fullpath in style_files:
            style_prefix, _ = os.path.splitext(style_fullpath)
            style_prefix = os.path.basename(
                style_prefix)  # Extract filename prefix without ext

            # style_img = get_img_crop(style_fullpath, resize=args.style_size, crop=args.crop_size)
            # style_img = resize_to(get_img(style_fullpath), content_img.shape[0])

            style_img = get_img(style_fullpath)

            if args.style_size > 0:
                style_img = resize_to(style_img, args.style_size)
            if args.crop_size > 0:
                style_img = center_crop(style_img, args.crop_size)

            if args.keep_colors:
                style_img = preserve_colors_np(style_img, content_img)

            # if args.noise:  # Generate textures from noise instead of images
            #     frame_resize = np.random.randint(0, 256, frame_resize.shape, np.uint8)
            #     frame_resize = gaussian_filter(frame_resize, sigma=0.5)

            # Run the frame through the style network
            stylized_rgb = wct_model.predict(content_img, style_img,
                                             args.alpha, args.swap5,
                                             args.ss_alpha, args.adain)

            if args.passes > 1:
                for _ in range(args.passes - 1):
                    stylized_rgb = wct_model.predict(stylized_rgb, style_img,
                                                     args.alpha, args.swap5,
                                                     args.ss_alpha, args.adain)

            # Stitch the style + stylized output together, but only if there's one style image
            if args.concat:
                # Resize style img to same height as frame
                style_img_resized = scipy.misc.imresize(
                    style_img, (stylized_rgb.shape[0], stylized_rgb.shape[0]))
                # margin = np.ones((style_img_resized.shape[0], 10, 3)) * 255
                stylized_rgb = np.hstack([style_img_resized, stylized_rgb])

            # Format for out filename: {out_path}/{content_prefix}_{style_prefix}.{content_ext}
            out_f = os.path.join(
                args.out_path, '{}_{}{}'.format(content_prefix, style_prefix,
                                                content_ext))
            # out_f = f'{content_prefix}_{style_prefix}.{content_ext}'

            save_img(out_f, stylized_rgb)

            count += 1
            print("{}: Wrote stylized output image to {}".format(count, out_f))

    print("Finished stylizing {} outputs in {}s".format(
        count,
        time.time() - start))
コード例 #7
0
ファイル: stylize_video.py プロジェクト: calvinlcchen/WCT-TF
def main():
    # Load the WCT model
    wct_model = WCT(checkpoints=args.checkpoints, 
                                relu_targets=args.relu_targets,
                                vgg_path=args.vgg_path, 
                                device=args.device,
                                ss_patch_size=args.ss_patch_size, 
                                ss_stride=args.ss_stride)

    # Create needed dirs
    in_dir = os.path.join(args.tmp_dir, 'input')
    out_dir = os.path.join(args.tmp_dir, 'sytlized')
    if not os.path.exists(in_dir):
        os.makedirs(in_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    if os.path.isdir(args.in_path):
        in_path = get_files(args.in_path)
    else: # Single image file
        in_path = [args.in_path]

    if os.path.isdir(args.style_path):
        style_files = get_files(args.style_path)
    else: # Single image file
        style_files = [args.style_path]

    print(style_files)
    import time
    # time.sleep(999)

    in_args = [
        'ffmpeg',
        '-i', args.in_path,
        '%s/frame_%%d.png' % in_dir
    ]

    subprocess.call(" ".join(in_args), shell=True)
    base_names = os.listdir(in_dir)
    in_files = [os.path.join(in_dir, x) for x in base_names]
    out_files = [os.path.join(out_dir, x) for x in base_names]

    


    s = time.time()
    for content_fullpath in in_path:
        content_prefix, content_ext = os.path.splitext(content_fullpath)
        content_prefix = os.path.basename(content_prefix)


        try:

            for style_fullpath in style_files:
                style_img = get_img(style_fullpath)
                if args.style_size > 0:
                    style_img = resize_to(style_img, args.style_size)
                if args.crop_size > 0:
                    style_img = center_crop(style_img, args.crop_size)

                style_prefix, _ = os.path.splitext(style_fullpath)
                style_prefix = os.path.basename(style_prefix)

                # print("ARRAY:  ", style_img)
                out_v = os.path.join(args.out_path, '{}_{}{}'.format(content_prefix, style_prefix, content_ext))
                print("OUT:",out_v)
                if os.path.isfile(out_v):
                    print("SKIP" , out_v)
                    continue
                
                for in_f, out_f in zip(in_files, out_files):
                    print('{} -> {}'.format(in_f, out_f))
                    content_img = get_img(in_f)

                    if args.keep_colors:
                        style_rgb = preserve_colors_np(style_img, content_img)
                    else:
                        style_rgb = style_img

                    stylized = wct_model.predict(content_img, style_rgb, args.alpha, args.swap5, args.ss_alpha)

                    if args.passes > 1:
                        for _ in range(args.passes-1):
                            stylized = wct_model.predict(stylized, style_rgb, args.alpha)

                    # Stitch the style + stylized output together, but only if there's one style image
                    if args.concat:
                        # Resize style img to same height as frame
                        style_img_resized = scipy.misc.imresize(style_rgb, (stylized.shape[0], stylized.shape[0]))
                        stylized = np.hstack([style_img_resized, stylized])

                    save_img(out_f, stylized)

                fr = 30
                out_args = [
                    'ffmpeg',
                    '-i', '%s/frame_%%d.png' % out_dir,
                    '-f', 'mp4',
                    '-q:v', '0',
                    '-vcodec', 'mpeg4',
                    '-r', str(fr),
                    '"' + out_v + '"'
                ]
                print(out_args)

                subprocess.call(" ".join(out_args), shell=True)
                print('Video at: %s' % out_v)

                if args.keep_tmp is True or len(style_files) > 1:
                    continue
                else:
                    shutil.rmtree(args.tmp_dir)
                print('Processed in:',(time.time() - s))

            print('Processed in:',(time.time() - s))
 
        except Exception as e:
            print("EXCEPTION: ",e)
コード例 #8
0
ファイル: stylize.py プロジェクト: calvinlcchen/WCT-TF
def main():
    start = time.time()

    # Load the WCT model
    wct_model = WCT(checkpoints=args.checkpoints, 
                                relu_targets=args.relu_targets,
                                vgg_path=args.vgg_path, 
                                device=args.device,
                                ss_patch_size=args.ss_patch_size, 
                                ss_stride=args.ss_stride)

    # Get content & style full paths
    if os.path.isdir(args.content_path):
        content_files = get_files(args.content_path)
    else: # Single image file
        content_files = [args.content_path]
    if os.path.isdir(args.style_path):
        style_files = get_files(args.style_path)
        if args.random > 0:
            style_files = np.random.choice(style_files, args.random)
    else: # Single image file
        style_files = [args.style_path]

    os.makedirs(args.out_path, exist_ok=True)

    count = 0

    ### Apply each style to each content image
    for content_fullpath in content_files:
        content_prefix, content_ext = os.path.splitext(content_fullpath)
        content_prefix = os.path.basename(content_prefix)  # Extract filename prefix without ext

        content_img = get_img(content_fullpath)
        if args.content_size > 0:
            content_img = resize_to(content_img, args.content_size)
        
        for style_fullpath in style_files: 
            style_prefix, _ = os.path.splitext(style_fullpath)
            style_prefix = os.path.basename(style_prefix)  # Extract filename prefix without ext

            # style_img = get_img_crop(style_fullpath, resize=args.style_size, crop=args.crop_size)
            # style_img = resize_to(get_img(style_fullpath), content_img.shape[0])

            style_img = get_img(style_fullpath)

            if args.style_size > 0:
                style_img = resize_to(style_img, args.style_size)
            if args.crop_size > 0:
                style_img = center_crop(style_img, args.crop_size)

            if args.keep_colors:
                style_img = preserve_colors_np(style_img, content_img)

            # if args.noise:  # Generate textures from noise instead of images
            #     frame_resize = np.random.randint(0, 256, frame_resize.shape, np.uint8)
            #     frame_resize = gaussian_filter(frame_resize, sigma=0.5)

            # Run the frame through the style network
            stylized_rgb = wct_model.predict(content_img, style_img, args.alpha, args.swap5, args.ss_alpha, args.adain)

            if args.passes > 1:
                for _ in range(args.passes-1):
                    stylized_rgb = wct_model.predict(stylized_rgb, style_img, args.alpha, args.swap5, args.ss_alpha, args.adain)

            # Stitch the style + stylized output together, but only if there's one style image
            if args.concat:
                # Resize style img to same height as frame
                style_img_resized = scipy.misc.imresize(style_img, (stylized_rgb.shape[0], stylized_rgb.shape[0]))
                # margin = np.ones((style_img_resized.shape[0], 10, 3)) * 255
                stylized_rgb = np.hstack([style_img_resized, stylized_rgb])

            # Format for out filename: {out_path}/{content_prefix}_{style_prefix}.{content_ext}
            out_f = os.path.join(args.out_path, '{}_{}{}'.format(content_prefix, style_prefix, content_ext))
            # out_f = f'{content_prefix}_{style_prefix}.{content_ext}'
            
            save_img(out_f, stylized_rgb)

            count += 1
            print("{}: Wrote stylized output image to {}".format(count, out_f))

    print("Finished stylizing {} outputs in {}s".format(count, time.time() - start))
コード例 #9
0
def main():
    # start = time.time()
    # Load the WCT model
    wct_model = WCT(checkpoints=args.checkpoints,
                    relu_targets=args.relu_targets,
                    vgg_path=args.vgg_path,
                    device=args.device,
                    ss_patch_size=args.ss_patch_size,
                    ss_stride=args.ss_stride)

    # Get content & style full paths
    if os.path.isdir(args.content_path):
        content_files = get_files(args.content_path)
    else:  # Single image file
        content_files = [args.content_path]
    if os.path.isdir(args.style_path):
        style_files = get_files(args.style_path)
        if args.random > 0:
            style_files = np.random.choice(style_files, args.random)
    else:  # Single image file
        style_files = [args.style_path]

    os.makedirs(args.out_path, exist_ok=True)

    thetotal = len(content_files) * len(style_files)

    count = 1

    ### Apply each style to each content image
    for content_fullpath in content_files:
        content_prefix, content_ext = os.path.splitext(content_fullpath)
        content_prefix = os.path.basename(
            content_prefix)  # Extract filename prefix without ext

        content_img = get_img(content_fullpath)
        if args.content_size > 0:
            content_img = resize_to(content_img, args.content_size)

        for style_fullpath in style_files:

            style_prefix, _ = os.path.splitext(style_fullpath)
            style_prefix = os.path.basename(
                style_prefix)  # Extract filename prefix without ext

            if args.mask:
                mask_prefix_, _ = os.path.splitext(args.mask)
                mask_prefix = os.path.basename(mask_prefix_)

            if args.keep_colors:
                style_prefix = "KPT_" + style_prefix
            if args.concat:
                style_prefix = "CON_" + style_prefix
            if args.adain:
                style_prefix = "ADA_" + style_prefix
            if args.swap5:
                style_prefix = "SWP_" + style_prefix
            if args.mask:
                style_prefix = "MSK_" + mask_prefix + '_' + style_prefix
            if args.remaster:
                style_prefix = style_prefix + '_REMASTERED'

            out_f = os.path.join(
                args.out_path, '{}_{}{}'.format(content_prefix, style_prefix,
                                                content_ext))

            if os.path.isfile(out_f):
                print("SKIP", out_f)
                count += 1
                continue

            # style_img = get_img_crop(style_fullpath, resize=args.style_size, crop=args.crop_size)
            # style_img = resize_to(get_img(style_fullpath), content_img.shape[0])
            style_img = get_img(style_fullpath)
            if style_img == ("IMAGE IS BROKEN"):
                continue

            if args.style_size > 0:
                style_img = resize_to(style_img, args.style_size)
            if args.crop_size > 0:
                style_img = center_crop(style_img, args.crop_size)
            if args.keep_colors:
                style_img = preserve_colors_np(style_img, content_img)

            # if args.noise:  # Generate textures from noise instead of images
            #     frame_resize = np.random.randint(0, 256, frame_resize.shape, np.uint8)
            #     frame_resize = gaussian_filter(frame_resize, sigma=0.5)

            # Run the frame through the style network
            stylized_rgb = wct_model.predict(content_img, style_img,
                                             args.alpha, args.swap5,
                                             args.ss_alpha, args.adain)

            if args.passes > 1:
                for _ in range(args.passes - 1):
                    stylized_rgb = wct_model.predict(stylized_rgb, style_img,
                                                     args.alpha, args.swap5,
                                                     args.ss_alpha, args.adain)

            if args.mask:
                import cv2
                cv2.imwrite('./tmp.png', stylized_rgb)
                stylized_rgb = cv2.imread('./tmp.png')

                mask = cv2.imread(args.mask, cv2.IMREAD_GRAYSCALE)

                # from scipy.misc import bytescale
                # mask = bytescale(mask)
                # mask = scipy.ndimage.imread(args.mask,flatten=True,mode='L')
                height, width = stylized_rgb.shape[:2]
                # print(height, width)

                # Resize the mask to fit the image.
                mask = scipy.misc.imresize(mask, (height, width),
                                           interp='bilinear')
                stylized_rgb = cv2.bitwise_and(stylized_rgb,
                                               stylized_rgb,
                                               mask=mask)

            # Stitch the style + stylized output together, but only if there's one style image
            if args.concat:
                # Resize style img to same height as frame
                # style_prefix = style_prefix + "CON_"
                # content_img_resized = scipy.misc.imresize(content_img, (stylized_rgb.shape[0], stylized_rgb.shape[0]))
                style_img_resized = scipy.misc.imresize(
                    style_img, (stylized_rgb.shape[0], stylized_rgb.shape[0]))
                # margin = np.ones((style_img_resized.shape[0], 10, 3)) * 255
                stylized_rgb = np.hstack([style_img_resized, stylized_rgb])

            # Format for out filename: {out_path}/{content_prefix}_{style_prefix}.{content_ext}
            # out_f = f'{content_prefix}_{style_prefix}.{content_ext}'
            out_f = os.path.join(
                args.out_path, '{}_{}{}'.format(content_prefix, style_prefix,
                                                content_ext))

            if args.remaster:
                # outf = os.path.join(args.out_path, '{}_{}_REMASTERED{}'.format(content_prefix, style_prefix, content_ext))
                stylized_rgb = remaster_pic(stylized_rgb)
            save_img(out_f, stylized_rgb)
            totalfiles = len([
                name for name in os.listdir(args.out_path)
                if os.path.isfile(os.path.join(args.out_path, name))
            ])
            # percent = math.floor(float(totalfiles/thetotal))
            print("{}/{} TOTAL FILES".format(count, thetotal))
            count += 1
            print("{}: Wrote stylized output image to {}".format(count, out_f))
コード例 #10
0
ファイル: stylize_saveCkpt.py プロジェクト: graceGuor/WCT-TF
def main():
    start = time.time()

    session_conf = tf.ConfigProto(
        allow_soft_placement=args.allow_soft_placement,
        log_device_placement=args.log_device_placement)
    session_conf.gpu_options.per_process_gpu_memory_fraction = args.gpu_fraction

    with tf.Graph().as_default():
        # with tf.Session(config=session_conf) as sess:
        # Load the WCT model
        wct_model = WCT(checkpoints=args.checkpoints,
                        relu_targets=args.relu_targets,
                        vgg_path=args.vgg_path,
                        device=args.device,
                        ss_patch_size=args.ss_patch_size,
                        ss_stride=args.ss_stride)

        with wct_model.sess as sess:
            # 训练的时候不需要style_img,所以在inference的时候重新保存一次checkpoint
            saver = tf.train.Saver()
            log_path = args.log_path if args.log_path is not None else os.path.join(
                args.checkpoint, 'log')
            summary_writer = tf.summary.FileWriter(log_path, sess.graph)

            # 部分node没有保存在checkpoint中,需要重新初始化
            # sess.run(tf.global_variables_initializer())

            # Get content & style full paths
            if os.path.isdir(args.content_path):
                content_files = get_files(args.content_path)
            else:  # Single image file
                content_files = [args.content_path]
            if os.path.isdir(args.style_path):
                style_files = get_files(args.style_path)
                if args.random > 0:
                    style_files = np.random.choice(style_files, args.random)
            else:  # Single image file
                style_files = [args.style_path]

            os.makedirs(args.out_path, exist_ok=True)

            count = 0

            # Apply each style to each content image
            for content_fullpath in content_files:
                content_prefix, content_ext = os.path.splitext(
                    content_fullpath)
                content_prefix = os.path.basename(
                    content_prefix)  # Extract filename prefix without ext

                content_img = get_img(content_fullpath)
                if args.content_size > 0:
                    content_img = resize_to(content_img, args.content_size)

                for style_fullpath in style_files:
                    style_prefix, _ = os.path.splitext(style_fullpath)
                    style_prefix = os.path.basename(
                        style_prefix)  # Extract filename prefix without ext

                    # style_img = get_img_crop(style_fullpath, resize=args.style_size, crop=args.crop_size)
                    # style_img = resize_to(get_img(style_fullpath), content_img.shape[0])

                    style_img = get_img(style_fullpath)

                    if args.style_size > 0:
                        style_img = resize_to(style_img, args.style_size)
                    if args.crop_size > 0:
                        style_img = center_crop(style_img, args.crop_size)

                    if args.keep_colors:
                        style_img = preserve_colors_np(style_img, content_img)

                    # if args.noise:  # Generate textures from noise instead of images
                    #     frame_resize = np.random.randint(0, 256, frame_resize.shape, np.uint8)
                    #     frame_resize = gaussian_filter(frame_resize, sigma=0.5)

                    # Run the frame through the style network
                    stylized_rgb = wct_model.predict(content_img, style_img,
                                                     args.alpha, args.swap5,
                                                     args.ss_alpha, args.adain)

                    if args.passes > 1:
                        for _ in range(args.passes - 1):
                            stylized_rgb = wct_model.predict(
                                stylized_rgb, style_img, args.alpha,
                                args.swap5, args.ss_alpha, args.adain)

                    save_path = saver.save(
                        sess, os.path.join(args.checkpoint, 'model.ckpt'))
                    print("Model saved in file: %s" % save_path)

                    # Stitch the style + stylized output together, but only if there's one style image
                    if args.concat:
                        # Resize style img to same height as frame
                        style_img_resized = scipy.misc.imresize(
                            style_img,
                            (stylized_rgb.shape[0], stylized_rgb.shape[0]))
                        # margin = np.ones((style_img_resized.shape[0], 10, 3)) * 255
                        stylized_rgb = np.hstack(
                            [style_img_resized, stylized_rgb])

                    # Format for out filename: {out_path}/{content_prefix}_{style_prefix}.{content_ext}
                    out_f = os.path.join(
                        args.out_path,
                        '{}_{}{}'.format(content_prefix, style_prefix,
                                         content_ext))
                    # out_f = f'{content_prefix}_{style_prefix}.{content_ext}'

                    # print(stylized_rgb, stylized_rgb.shape, type(stylized_rgb))
                    # print(out_f)
                    save_img(out_f, stylized_rgb)

                    count += 1
                    print("{}: Wrote stylized output image to {} at {}".format(
                        count, out_f, time.time()))
                    print("breaking...")
                    break
                break

            print("Finished stylizing {} outputs in {}s".format(
                count,
                time.time() - start))
コード例 #11
0
ファイル: stylize.py プロジェクト: svjack/PhotoWCT
def main():
    start = time.time()

    # Load the WCT model
    wct_model = WCT(input_checkpoint=args.input_checkpoint,
                    relu_targets=args.relu_targets,
                    vgg_path=args.vgg_path,
                    device=args.device,
                    ss_patch_size=args.ss_patch_size,
                    ss_stride=args.ss_stride)

    print("model construct end !")

    # Get content & style full paths
    if os.path.isdir(args.content_path):
        content_files = get_files(args.content_path)
    else: # Single image file
        content_files = [args.content_path]

    content_seg_files = list(map(lambda x: x.replace("2017", "2017_seg").replace("jpg", "png"), content_files))
    assert reduce(lambda a, b : a + b, map(lambda x: int(os.path.exists(x)),content_seg_files + content_files)) > 0

    if os.path.isdir(args.style_path):
        style_files = get_files(args.style_path)
        if args.random > 0:
            style_files = np.random.choice(style_files, args.random)
    else: # Single image file
        style_files = [args.style_path]

    style_seg_files = list(map(lambda x: x.replace("2017", "2017_seg").replace("jpg", "png"), style_files))
    assert reduce(lambda a, b : a + b, map(lambda x: int(os.path.exists(x)),style_seg_files + style_files)) > 0

    os.makedirs(args.out_path, exist_ok=True)

    count = 0

    ### Apply each style to each content image
    for i in range(len(content_files)):
        content_fullpath = content_files[i]
        content_prefix, content_ext = os.path.splitext(content_fullpath)
        content_prefix = os.path.basename(content_prefix)  # Extract filename prefix without ext

        content_img = get_img(content_fullpath)
        if args.content_size > 0:
            content_img = np.asarray(Image.fromarray(content_img.astype(np.uint8)).resize((args.content_size, args.content_size))).astype(np.float32)
        content_seg = get_img_ori(content_seg_files[i])
        if args.content_size > 0:
            content_seg = np.asarray(Image.fromarray(content_seg.astype(np.uint8)).resize((args.content_size, args.content_size))).astype(np.float32)
            content_seg = image_to_pixel_map(content_seg)[...,np.newaxis]
        content_img = np.concatenate([content_img, content_seg], axis=-1)

        for i in range(len(style_files)):
            style_fullpath = style_files[i]
            style_prefix, _ = os.path.splitext(style_fullpath)
            style_prefix = os.path.basename(style_prefix)  # Extract filename prefix without ext

            style_img = get_img(style_fullpath)
            style_seg = get_img_ori(style_seg_files[i])

            if args.style_size > 0:
                style_img = resize_to(style_img, args.style_size)
                style_seg = resize_to(style_seg, args.style_size)

            if args.crop_size > 0:
                style_img = center_crop(style_img, args.crop_size)
                style_seg = center_crop(style_seg, args.crop_size)

            style_seg = image_to_pixel_map(style_seg)[...,np.newaxis]
            style_img = np.concatenate([style_img, style_seg], axis=-1)

            assert not args.keep_colors
            if args.keep_colors:
                style_img = preserve_colors_np(style_img, content_img)

            # if args.noise:  # Generate textures from noise instead of images
            #     frame_resize = np.random.randint(0, 256, frame_resize.shape, np.uint8)
            #     frame_resize = gaussian_filter(frame_resize, sigma=0.5)

            # Run the frame through the style network
            stylized_rgb = wct_model.predict(content_img, style_img, args.alpha, args.swap5, args.ss_alpha, args.adain)


            if args.passes > 1:
                for _ in range(args.passes-1):
                    stylized_rgb = np.concatenate([stylized_rgb ,content_seg], axis=-1)
                    stylized_rgb = wct_model.predict(stylized_rgb, style_img, args.alpha, args.swap5, args.ss_alpha, args.adain)

            # Stitch the style + stylized output together, but only if there's one style image
            assert not args.concat
            if args.concat:
                # Resize style img to same height as frame
                style_img_resized = scipy.misc.imresize(style_img, (stylized_rgb.shape[0], stylized_rgb.shape[0]))
                # margin = np.ones((style_img_resized.shape[0], 10, 3)) * 255
                stylized_rgb = np.hstack([style_img_resized, stylized_rgb])

            ####+++++++++++++++++++++++++++++++++++

            # Format for out filename: {out_path}/{content_prefix}_{style_prefix}.{content_ext}
            out_f = os.path.join(args.out_path, '{}_{}{}'.format(content_prefix, style_prefix, content_ext))
            save_img(out_f, stylized_rgb)

            count += 1
            print("{}: Wrote stylized output image to {}".format(count, out_f))

    print("Finished stylizing {} outputs in {}s".format(count, time.time() - start))
コード例 #12
0
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            if tf.train.latest_checkpoint('{}/{}'.format(
                    checkpoint_dir, decoder_name)) is None:
                raise ValueError(
                    'No checkpoint found for {}'.format(decoder_name))
            saver.restore(
                sess,
                tf.train.latest_checkpoint('{}/{}'.format(
                    checkpoint_dir, decoder_name)))

            encoded_image = sess.run(encoded, feed_dict={inp: img})
            encoded_style = sess.run(encoded, feed_dict={inp: style})

            styled_feature = sess.run(WCT(encoded_image, encoded_style, alpha))

            img = sess.run(decoded, feed_dict={encoded: styled_feature})

    return img[0]


if __name__ == '__main__':
    img = resize_to(get_img(args.content_path), 512)
    style = resize_to(get_img(args.style_path), 512)
    styled_img = stylize(img,
                         style,
                         args.target_layers,
                         args.alpha,
                         checkpoint_dir=args.checkpoint_dir)
    save_image(styled_img, '{}/output.jpg'.format(args.out_path))
コード例 #13
0
def freeze_graph_test(pb_file, content_path, style_path):
    '''
    :param pb_path:pb文件的路径
    :param image_path:测试图片的路径
    :return:
    '''
    content_prefix, content_ext = os.path.splitext(content_path)
    content_prefix = os.path.basename(
        content_prefix)  # Extract filename prefix without ext
    style_prefix, _ = os.path.splitext(style_path)
    style_prefix = os.path.basename(
        style_prefix)  # Extract filename prefix without ext

    # 读取测试图片content_img和style_img
    content_img = get_img(content_path)
    if args.content_size > 0:
        content_img = resize_to(content_img, args.content_size)

    style_img = get_img(style_path)

    if args.style_size > 0:
        style_img = resize_to(style_img, args.style_size)
    if args.crop_size > 0:
        style_img = center_crop(style_img, args.crop_size)

    # if args.keep_colors:
    #     style_img = preserve_colors_np(style_img, content_img)

    content_img = preprocess(content_img)
    style_img = preprocess(style_img)

    # 打印检查点所有的变量
    chkp.print_tensors_in_checkpoint_file("models/inference/model.ckpt",
                                          tensor_name='',
                                          all_tensors=False)

    # for i in range(5):
    #     print("=" * 10, i + 1)
    #     chkp.print_tensors_in_checkpoint_file("models/relu" + str(i+1) + "_1/model.ckpt-15002", tensor_name='', all_tensors=True)

    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()

        # with open(pb_file, "rb") as f:
        #     x = f.read()
        #     # print(x)
        #     output_graph_def.ParseFromString(x)
        #     output_graph_def.ParseFromString(f.read())
        #     tf.import_graph_def(output_graph_def, name="")
        #     for i, n in enumerate(output_graph_def.node):
        #         print("Name of the node - %s" % n.name)
        #         print(n)

        with tf.Session() as sess:
            # sess.run(tf.global_variables_initializer())
            # tf.saved_model.loader.load(sess, [tag_constants.TRAINING], export_dir)

            # 定义输入的张量名称,对应网络结构的输入张量
            content_input_tensor = sess.graph.get_tensor_by_name(
                "encoder_decoder_relu5_1/content_encoder_relu5_1/content_imgs:0"
            )
            style_input_tensor = sess.graph.get_tensor_by_name("style_img:0")

            # 定义输出的张量名称
            output_tensor_name = sess.graph.get_tensor_by_name(
                "encoder_decoder_relu5_1/decoder_relu5_1/decoder_model_relu5_1/relu5_1_16/relu5_1_16/BiasAdd:0"
            )

            # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
            stylized_rgb = sess.run(output_tensor_name,
                                    feed_dict={
                                        content_input_tensor: content_img,
                                        style_input_tensor: style_img
                                    })

            # Stitch the style + stylized output together, but only if there's one style image
            if args.concat:
                # Resize style img to same height as frame
                style_img_resized = scipy.misc.imresize(
                    style_img, (stylized_rgb.shape[0], stylized_rgb.shape[0]))
                stylized_rgb = np.hstack([style_img_resized, stylized_rgb])

            # Format for out filename: {out_path}/{content_prefix}_{style_prefix}.{content_ext}
            out_f = os.path.join(
                args.out_path, '{}_{}{}'.format(content_prefix, style_prefix,
                                                content_ext))

            save_img(out_f, stylized_rgb)
コード例 #14
0
def main():
    # Load the WCT model
    wct_model = WCT(checkpoints=args.checkpoints,
                    relu_targets=args.relu_targets,
                    vgg_path=args.vgg_path,
                    device=args.device,
                    ss_patch_size=args.ss_patch_size,
                    ss_stride=args.ss_stride)

    # Create needed dirs
    in_dir = os.path.join(args.tmp_dir, 'input')
    out_dir = os.path.join(args.tmp_dir, 'sytlized')
    if not os.path.exists(in_dir):
        os.makedirs(in_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    if os.path.isdir(args.in_path):
        in_path = get_files(args.in_path)
    else:  # Single image file
        in_path = [args.in_path]

    if os.path.isdir(args.style_path):
        style_files = get_files(args.style_path)
    else:  # Single image file
        style_files = [args.style_path]

    print(style_files)
    # time.sleep(999)

    in_args = [
        'ffmpeg',
        '-i', args.in_path,
        '%s/frame_%%d.png' % in_dir
    ]

    subprocess.call(" ".join(in_args), shell=True)
    base_names = os.listdir(in_dir)
    in_files = [os.path.join(in_dir, x) for x in base_names]
    out_files = [os.path.join(out_dir, x) for x in base_names]

    s = time.time()
    for content_fullpath in in_path:
        content_prefix, content_ext = os.path.splitext(content_fullpath)
        content_prefix = os.path.basename(content_prefix)

        try:

            for style_fullpath in style_files:
                style_img = get_img(style_fullpath)
                if args.style_size > 0:
                    style_img = resize_to(style_img, args.style_size)
                if args.crop_size > 0:
                    style_img = center_crop(style_img, args.crop_size)

                style_prefix, _ = os.path.splitext(style_fullpath)
                style_prefix = os.path.basename(style_prefix)

                # print("ARRAY:  ", style_img)
                out_v = os.path.join(args.out_path, '{}_{}{}'.format(content_prefix, style_prefix, content_ext))
                print("OUT:", out_v)
                if os.path.isfile(out_v):
                    print("SKIP", out_v)
                    continue

                for in_f, out_f in zip(in_files, out_files):
                    print('{} -> {}'.format(in_f, out_f))
                    content_img = get_img(in_f)

                    if args.keep_colors:
                        style_rgb = preserve_colors_np(style_img, content_img)
                    else:
                        style_rgb = style_img

                    stylized = wct_model.predict(content_img, style_rgb, args.alpha, args.swap5, args.ss_alpha)

                    if args.passes > 1:
                        for _ in range(args.passes-1):
                            stylized = wct_model.predict(stylized, style_rgb, args.alpha)

                    # Stitch the style + stylized output together, but only if there's one style image
                    if args.concat:
                        # Resize style img to same height as frame
                        style_img_resized = scipy.misc.imresize(style_rgb, (stylized.shape[0], stylized.shape[0]))
                        stylized = np.hstack([style_img_resized, stylized])

                    save_img(out_f, stylized)

                fr = 30
                out_args = [
                    'ffmpeg',
                    '-i', '%s/frame_%%d.png' % out_dir,
                    '-f', 'mp4',
                    '-q:v', '0',
                    '-vcodec', 'mpeg4',
                    '-r', str(fr),
                    '"' + out_v + '"'
                ]
                print(out_args)

                subprocess.call(" ".join(out_args), shell=True)
                print('Video at: %s' % out_v)

                if args.keep_tmp is True or len(style_files) > 1:
                    continue
                else:
                    shutil.rmtree(args.tmp_dir)
                print('Processed in:', (time.time() - s))

            print('Processed in:', (time.time() - s))

        except Exception as e:
            print("EXCEPTION: ", e)