Exemple #1
0
 def run(self, image, style, filename="test.png", intermediates=False):
     self.WCT = WCT()
     self.set_style(style)
     img = load_image(image,
                      size=self.current_style["content_size"],
                      noise=True)
     print(self.current_style["procedure"])
     i = 0
     for level, weight in json.loads(self.current_style["procedure"]):
         img = self.style(img, weight, level)
         if intermediates:
             print("---" * 20)
             ff = filename.split(".")
             fname = ".".join(ff[:-1]) + "{:02d}.".format(i) + ff[-1]
             print("saving as " + fname)
             vutils.save_image(img.data.cpu().float(), fname)
         i += 1
     if not intermediates:
         print("---" * 20)
         print("saving as " + filename)
         vutils.save_image(img.data.cpu().float(), filename)
     del self.WCT
     try:
         del self.encoder_cache
         del self.decoder_cache
     except:
         pass
     self.last_level = -1
     torch.cuda.empty_cache()
Exemple #2
0
def main():
    start = time.time()
    if len(args.style_paths) > 2:
        raise Exception('Maximum number of styles should be 2')

    # Load the WCT model
    wct_model = WCT(checkpoints=args.checkpoints,
                    relu_targets=args.relu_targets,
                    vgg_path=args.vgg_path,
                    device=args.device,
                    ss_patch_size=args.ss_patch_size,
                    ss_stride=args.ss_stride,
                    alpha=args.alpha,
                    beta=args.beta)

    os.makedirs(args.out_path, exist_ok=True)

    content_img = get_img(args.content_path, args.content_size)
    styles = [get_img(path, args.style_size) for path in args.style_paths]

    _, content_ext = os.path.splitext(args.content_path)
    output_filename = os.path.join(
        args.out_path, "result.jpg"
    )  #f'{args.content_path}_{args.style_path_a}.{content_ext}')
    output = stylize_output(wct_model, content_img, styles)
    save_img(output_filename, output)

    print("Finished stylizing in {}s".format(time.time() - start))
Exemple #3
0
def main():
    # Prepare WCT model
    vgg1 = 'models/vgg_normalised_conv1_1.pth'
    vgg2 = 'models/vgg_normalised_conv2_1.pth'
    vgg3 = 'models/vgg_normalised_conv3_1.pth'
    vgg4 = 'models/vgg_normalised_conv4_1.pth'
    vgg5 = 'models/vgg_normalised_conv5_1.pth'
    decoder1 = 'models/feature_invertor_conv1_1.pth'
    decoder2 = 'models/feature_invertor_conv2_1.pth'
    decoder3 = 'models/feature_invertor_conv3_1.pth'
    decoder4 = 'models/feature_invertor_conv4_1.pth'
    decoder5 = 'models/feature_invertor_conv5_1.pth'
    paths = vgg1, vgg2, vgg3, vgg4, vgg5, decoder1, decoder2, decoder3, decoder4, decoder5
    wct = WCT(paths)

    # Prepare images
    content_image = Image.open(args.content).resize((args.content_w, args.content_h))
    contentImg = TF.to_tensor(content_image)
    contentImg.unsqueeze_(0)
    style_image = Image.open(args.style).resize((args.style_w, args.style_h))
    styleImg = TF.to_tensor(style_image)
    styleImg.unsqueeze_(0)
    csF = torch.Tensor()
    
    cImg = Variable(contentImg, volatile=True)
    sImg = Variable(styleImg, volatile=True)
    csF = Variable(csF)

    cImg = cImg.cuda(0)
    sImg = sImg.cuda(0)
    csF = csF.cuda(0)
    wct.cuda(0)

    
    # Run style transfer
    start_time = time.time()
    styleTransfer(wct, args.alpha, cImg, sImg, csF, args.output)
    end_time = time.time()
    print('Elapsed time is: %f' % (end_time - start_time))    
Exemple #4
0
def main():
    # init style model
    # Load the WCT model
    global wct_model
    if wct_model is None:
        wct_model = WCT(checkpoints=[
            "".join([BACKEND_PATH, chkp]) for chkp in CHECKPOINTS
        ],
                        relu_targets=RELU_TARGETS,
                        vgg_path=BACKEND_PATH + VGG_PATH,
                        device=DEVICE,
                        ss_patch_size=PATCH_SIZE,
                        ss_stride=STRIDE)
    app.run(host='0.0.0.0', debug=False, port=8080)
Exemple #5
0
parser.add_argument('--width',
                    type=int,
                    help='Video frame width',
                    default=1920)
parser.add_argument('--height',
                    type=int,
                    help='Video frame heigh',
                    default=1080)
parser.add_argument('--fps', type=int, help='Video FPS', default=30)

arguments = parser.parse_args()

#load model
wct_model = WCT(checkpoints=arguments.checkpoints,
                relu_targets=arguments.relu_targets,
                vgg_path=arguments.vgg_path,
                device=arguments.device,
                ss_patch_size=arguments.ss_patch_size,
                ss_stride=arguments.ss_stride)

#load style image
style = get_img(arguments.style_path)


def changeImage(image):
    return wct_model.predict(image, style, arguments.alpha, arguments.swap5,
                             arguments.ss_alpha, arguments.adain)


camera = cv2.VideoCapture(0)  #get camera video

#build stream input
Exemple #6
0
import time

import sys, os
sys.path.append('./wct/')
from utils import preserve_colors_np
from utils import get_files, get_img, get_img_crop, save_img, resize_to, center_crop
from wct import WCT

checkpoints = ['models/wct/relu5_1', 'models/wct/relu4_1',
               'models/wct/relu3_1', 'models/wct/relu2_1', 'models/wct/relu1_1']
relu_targets = ['relu5_1', 'relu4_1', 'relu3_1', 'relu2_1', 'relu1_1']

# Load the WCT model
wct_model = WCT(checkpoints=checkpoints,
                relu_targets=relu_targets,
                vgg_path='models/wct/vgg_normalised.t7',
                device='/cpu:0',
                ss_patch_size=3,
                ss_stride=1)


def get_stylize_image(content_fullpath, style_fullpath, output_path,
                      content_size=256, style_size=256, alpha=0.6,
                      swap5=False, ss_alpha=0.6, adain=False):
    content_img = get_img(content_fullpath)
    content_img = resize_to(content_img, int(content_size))

    style_img = get_img(style_fullpath)
    style_img = resize_to(style_img,int(style_size))

    stylized_rgb = wct_model.predict(
        content_img, style_img, alpha, swap5, ss_alpha, adain)
Exemple #7
0
def main():
    # Load the WCT model
    wct_model = WCT(checkpoints=args.checkpoints, 
                                relu_targets=args.relu_targets,
                                vgg_path=args.vgg_path, 
                                device=args.device,
                                ss_patch_size=args.ss_patch_size, 
                                ss_stride=args.ss_stride)

    # Create needed dirs
    in_dir = os.path.join(args.tmp_dir, 'input')
    out_dir = os.path.join(args.tmp_dir, 'sytlized')
    if not os.path.exists(in_dir):
        os.makedirs(in_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    if os.path.isdir(args.in_path):
        in_path = get_files(args.in_path)
    else: # Single image file
        in_path = [args.in_path]

    if os.path.isdir(args.style_path):
        style_files = get_files(args.style_path)
    else: # Single image file
        style_files = [args.style_path]

    print(style_files)
    import time
    # time.sleep(999)

    in_args = [
        'ffmpeg',
        '-i', args.in_path,
        '%s/frame_%%d.png' % in_dir
    ]

    subprocess.call(" ".join(in_args), shell=True)
    base_names = os.listdir(in_dir)
    in_files = [os.path.join(in_dir, x) for x in base_names]
    out_files = [os.path.join(out_dir, x) for x in base_names]

    


    s = time.time()
    for content_fullpath in in_path:
        content_prefix, content_ext = os.path.splitext(content_fullpath)
        content_prefix = os.path.basename(content_prefix)


        try:

            for style_fullpath in style_files:
                style_img = get_img(style_fullpath)
                if args.style_size > 0:
                    style_img = resize_to(style_img, args.style_size)
                if args.crop_size > 0:
                    style_img = center_crop(style_img, args.crop_size)

                style_prefix, _ = os.path.splitext(style_fullpath)
                style_prefix = os.path.basename(style_prefix)

                # print("ARRAY:  ", style_img)
                out_v = os.path.join(args.out_path, '{}_{}{}'.format(content_prefix, style_prefix, content_ext))
                print("OUT:",out_v)
                if os.path.isfile(out_v):
                    print("SKIP" , out_v)
                    continue
                
                for in_f, out_f in zip(in_files, out_files):
                    print('{} -> {}'.format(in_f, out_f))
                    content_img = get_img(in_f)

                    if args.keep_colors:
                        style_rgb = preserve_colors_np(style_img, content_img)
                    else:
                        style_rgb = style_img

                    stylized = wct_model.predict(content_img, style_rgb, args.alpha, args.swap5, args.ss_alpha)

                    if args.passes > 1:
                        for _ in range(args.passes-1):
                            stylized = wct_model.predict(stylized, style_rgb, args.alpha)

                    # Stitch the style + stylized output together, but only if there's one style image
                    if args.concat:
                        # Resize style img to same height as frame
                        style_img_resized = scipy.misc.imresize(style_rgb, (stylized.shape[0], stylized.shape[0]))
                        stylized = np.hstack([style_img_resized, stylized])

                    save_img(out_f, stylized)

                fr = 30
                out_args = [
                    'ffmpeg',
                    '-i', '%s/frame_%%d.png' % out_dir,
                    '-f', 'mp4',
                    '-q:v', '0',
                    '-vcodec', 'mpeg4',
                    '-r', str(fr),
                    '"' + out_v + '"'
                ]
                print(out_args)

                subprocess.call(" ".join(out_args), shell=True)
                print('Video at: %s' % out_v)

                if args.keep_tmp is True or len(style_files) > 1:
                    continue
                else:
                    shutil.rmtree(args.tmp_dir)
                print('Processed in:',(time.time() - s))

            print('Processed in:',(time.time() - s))
 
        except Exception as e:
            print("EXCEPTION: ",e)
Exemple #8
0
def main():
    start = time.time()

    # Load the WCT model
    wct_model = WCT(checkpoints=args.checkpoints, 
                                relu_targets=args.relu_targets,
                                vgg_path=args.vgg_path, 
                                device=args.device,
                                ss_patch_size=args.ss_patch_size, 
                                ss_stride=args.ss_stride)

    # Get content & style full paths
    if os.path.isdir(args.content_path):
        content_files = get_files(args.content_path)
    else: # Single image file
        content_files = [args.content_path]
    if os.path.isdir(args.style_path):
        style_files = get_files(args.style_path)
        if args.random > 0:
            style_files = np.random.choice(style_files, args.random)
    else: # Single image file
        style_files = [args.style_path]

    os.makedirs(args.out_path, exist_ok=True)

    count = 0

    ### Apply each style to each content image
    for content_fullpath in content_files:
        content_prefix, content_ext = os.path.splitext(content_fullpath)
        content_prefix = os.path.basename(content_prefix)  # Extract filename prefix without ext

        content_img = get_img(content_fullpath)
        if args.content_size > 0:
            content_img = resize_to(content_img, args.content_size)
        
        for style_fullpath in style_files: 
            style_prefix, _ = os.path.splitext(style_fullpath)
            style_prefix = os.path.basename(style_prefix)  # Extract filename prefix without ext

            # style_img = get_img_crop(style_fullpath, resize=args.style_size, crop=args.crop_size)
            # style_img = resize_to(get_img(style_fullpath), content_img.shape[0])

            style_img = get_img(style_fullpath)

            if args.style_size > 0:
                style_img = resize_to(style_img, args.style_size)
            if args.crop_size > 0:
                style_img = center_crop(style_img, args.crop_size)

            if args.keep_colors:
                style_img = preserve_colors_np(style_img, content_img)

            # if args.noise:  # Generate textures from noise instead of images
            #     frame_resize = np.random.randint(0, 256, frame_resize.shape, np.uint8)
            #     frame_resize = gaussian_filter(frame_resize, sigma=0.5)

            # Run the frame through the style network
            stylized_rgb = wct_model.predict(content_img, style_img, args.alpha, args.swap5, args.ss_alpha, args.adain)

            if args.passes > 1:
                for _ in range(args.passes-1):
                    stylized_rgb = wct_model.predict(stylized_rgb, style_img, args.alpha, args.swap5, args.ss_alpha, args.adain)

            # Stitch the style + stylized output together, but only if there's one style image
            if args.concat:
                # Resize style img to same height as frame
                style_img_resized = scipy.misc.imresize(style_img, (stylized_rgb.shape[0], stylized_rgb.shape[0]))
                # margin = np.ones((style_img_resized.shape[0], 10, 3)) * 255
                stylized_rgb = np.hstack([style_img_resized, stylized_rgb])

            # Format for out filename: {out_path}/{content_prefix}_{style_prefix}.{content_ext}
            out_f = os.path.join(args.out_path, '{}_{}{}'.format(content_prefix, style_prefix, content_ext))
            # out_f = f'{content_prefix}_{style_prefix}.{content_ext}'
            
            save_img(out_f, stylized_rgb)

            count += 1
            print("{}: Wrote stylized output image to {}".format(count, out_f))

    print("Finished stylizing {} outputs in {}s".format(count, time.time() - start))
Exemple #9
0
class Styletransfer:
    def __init__(self):
        self.h = 0
        self._style_settings = None
        self.net = None
        self.current_style = None
        self.style_vars = None
        self.current_cache_key = None
        self.encoder_cache = None
        self.decoder_cache = None
        self.last_level = -1

    def run(self, image, style, filename="test.png", intermediates=False):
        self.WCT = WCT()
        self.set_style(style)
        img = load_image(image,
                         size=self.current_style["content_size"],
                         noise=True)
        print(self.current_style["procedure"])
        i = 0
        for level, weight in json.loads(self.current_style["procedure"]):
            img = self.style(img, weight, level)
            if intermediates:
                print("---" * 20)
                ff = filename.split(".")
                fname = ".".join(ff[:-1]) + "{:02d}.".format(i) + ff[-1]
                print("saving as " + fname)
                vutils.save_image(img.data.cpu().float(), fname)
            i += 1
        if not intermediates:
            print("---" * 20)
            print("saving as " + filename)
            vutils.save_image(img.data.cpu().float(), filename)
        del self.WCT
        try:
            del self.encoder_cache
            del self.decoder_cache
        except:
            pass
        self.last_level = -1
        torch.cuda.empty_cache()

    def style(self, image, weight, level):
        print("loading encoders & decoders for level " + str(level))
        with torch.no_grad():
            if (self.last_level
                    == level) or ((min(level, self.last_level) == 5) and
                                  (max(self.last_level, level) == 6)):
                encoder = self.encoder_cache
                decoder = self.decoder_cache
            else:
                try:
                    del self.encoder_cache
                    del self.decoder_cache
                    torch.cuda.empty_cache()
                except:
                    pass
                if level == 1:
                    encoder = level_1_encoder()
                    decoder = level_1_decoder()
                elif level == 2:
                    encoder = level_2_encoder()
                    decoder = level_2_decoder()
                elif level == 3:
                    encoder = level_3_encoder()
                    decoder = level_3_decoder()
                elif level == 4:
                    encoder = level_4_encoder()
                    decoder = level_4_decoder()
                elif level in [5, 6]:
                    encoder = level_5_encoder()
                    decoder = level_5_decoder()
            level_str = str(level)
            if level == 6:
                osize = image.size()
                image = self.lvl6_downscale(image)
                level_str = "10"
            style_params = self.style_vars[level_str]
            print("encoding content features")
            content_feature = encoder(image)
            print(content_feature.size())
            print("coloring")
            colored = self.WCT.content_coloring(content_feature, style_params,
                                                weight)
            print("reconstructing image")
            ret = decoder(colored)
            print("ret size: " + str(ret.size()))
            self.encoder_cache = encoder
            self.decoder_cache = decoder
            self.last_level = level
        if level == 6:
            ret = self.lvl6_upscale(ret, osize)
        return ret

    def lvl6_downscale(self, image, scale=224):
        i_size = image.size()
        if i_size[2] > i_size[3]:
            new_2 = int(i_size[2] / i_size[3] * scale)
            new_3 = scale
        else:
            new_3 = int(i_size[3] / i_size[2] * scale)
            new_2 = scale
        ret = torch.nn.functional.interpolate(image,
                                              size=(new_2, new_3),
                                              mode="bilinear",
                                              align_corners=True)
        return ret

    def lvl6_upscale(self, image, osize):
        return torch.nn.functional.interpolate(image,
                                               size=(osize[2], osize[3]),
                                               mode="bilinear",
                                               align_corners=True)

    def set_style(self, style_key):
        if self._style_settings is None:
            self._style_settings = json.load(open("style_settings.json", "r"))
        self.current_style = self._style_settings[style_key]
        cache_key = self.get_style_cache_key(style_key)
        if not os.path.isfile(cache_key):
            print("baking style cache")
            self.make_style_cache(style_key, cache_key)
        elif self.current_cache_key != cache_key:
            print("loading style from disk")
            self.style_vars = torch.load(cache_key)
            print("done")
            self.current_cache_key = cache_key

    def get_style_cache_key(self, key):
        if self._style_settings is None:
            self._style_settings = json.load(open("style_settings.json", "r"))
        style = self._style_settings[key]
        cache_key = str(
            hashlib.sha1((str(style["style_size"]) +
                          style["image"]).encode("utf-8")).hexdigest())
        return "Resources/Styles/cache/" + cache_key + ".pt"

    def make_style_cache(self, key, cache_key):
        self.style_vars = None
        torch.cuda.empty_cache()
        image = load_image(self._style_settings[key]["image"],
                           size=self._style_settings[key]["style_size"])
        net = style_features().cuda()
        net_output = net(image)
        ret = {}
        i = 1
        for style_feature in net_output:
            ret[str(i)] = self.WCT.style_params(style_feature)
            i += 1
        #scale the input image accordingly to the smaller size
        ss = self._style_settings[key]
        image = self.lvl6_downscale(
            image, int(224 / (ss["content_size"] / ss["style_size"])))
        net_output = net(image)
        for style_feature in net_output:
            ret[str(i)] = self.WCT.style_params(style_feature)
            print(i)
            i += 1
        torch.save(ret, cache_key)
        del image
        del net
        torch.cuda.empty_cache()
        self.style_vars = ret
Exemple #10
0
def main():
    # start = time.time()
    # Load the WCT model
    wct_model = WCT(checkpoints=args.checkpoints,
                    relu_targets=args.relu_targets,
                    vgg_path=args.vgg_path,
                    device=args.device,
                    ss_patch_size=args.ss_patch_size,
                    ss_stride=args.ss_stride)

    # Get content & style full paths
    if os.path.isdir(args.content_path):
        content_files = get_files(args.content_path)
    else:  # Single image file
        content_files = [args.content_path]
    if os.path.isdir(args.style_path):
        style_files = get_files(args.style_path)
        if args.random > 0:
            style_files = np.random.choice(style_files, args.random)
    else:  # Single image file
        style_files = [args.style_path]

    os.makedirs(args.out_path, exist_ok=True)

    thetotal = len(content_files) * len(style_files)

    count = 1

    ### Apply each style to each content image
    for content_fullpath in content_files:
        content_prefix, content_ext = os.path.splitext(content_fullpath)
        content_prefix = os.path.basename(
            content_prefix)  # Extract filename prefix without ext

        content_img = get_img(content_fullpath)
        if args.content_size > 0:
            content_img = resize_to(content_img, args.content_size)

        for style_fullpath in style_files:

            style_prefix, _ = os.path.splitext(style_fullpath)
            style_prefix = os.path.basename(
                style_prefix)  # Extract filename prefix without ext

            if args.mask:
                mask_prefix_, _ = os.path.splitext(args.mask)
                mask_prefix = os.path.basename(mask_prefix_)

            if args.keep_colors:
                style_prefix = "KPT_" + style_prefix
            if args.concat:
                style_prefix = "CON_" + style_prefix
            if args.adain:
                style_prefix = "ADA_" + style_prefix
            if args.swap5:
                style_prefix = "SWP_" + style_prefix
            if args.mask:
                style_prefix = "MSK_" + mask_prefix + '_' + style_prefix
            if args.remaster:
                style_prefix = style_prefix + '_REMASTERED'

            out_f = os.path.join(
                args.out_path, '{}_{}{}'.format(content_prefix, style_prefix,
                                                content_ext))

            if os.path.isfile(out_f):
                print("SKIP", out_f)
                count += 1
                continue

            # style_img = get_img_crop(style_fullpath, resize=args.style_size, crop=args.crop_size)
            # style_img = resize_to(get_img(style_fullpath), content_img.shape[0])
            style_img = get_img(style_fullpath)
            if style_img == ("IMAGE IS BROKEN"):
                continue

            if args.style_size > 0:
                style_img = resize_to(style_img, args.style_size)
            if args.crop_size > 0:
                style_img = center_crop(style_img, args.crop_size)
            if args.keep_colors:
                style_img = preserve_colors_np(style_img, content_img)

            # if args.noise:  # Generate textures from noise instead of images
            #     frame_resize = np.random.randint(0, 256, frame_resize.shape, np.uint8)
            #     frame_resize = gaussian_filter(frame_resize, sigma=0.5)

            # Run the frame through the style network
            stylized_rgb = wct_model.predict(content_img, style_img,
                                             args.alpha, args.swap5,
                                             args.ss_alpha, args.adain)

            if args.passes > 1:
                for _ in range(args.passes - 1):
                    stylized_rgb = wct_model.predict(stylized_rgb, style_img,
                                                     args.alpha, args.swap5,
                                                     args.ss_alpha, args.adain)

            if args.mask:
                import cv2
                cv2.imwrite('./tmp.png', stylized_rgb)
                stylized_rgb = cv2.imread('./tmp.png')

                mask = cv2.imread(args.mask, cv2.IMREAD_GRAYSCALE)

                # from scipy.misc import bytescale
                # mask = bytescale(mask)
                # mask = scipy.ndimage.imread(args.mask,flatten=True,mode='L')
                height, width = stylized_rgb.shape[:2]
                # print(height, width)

                # Resize the mask to fit the image.
                mask = scipy.misc.imresize(mask, (height, width),
                                           interp='bilinear')
                stylized_rgb = cv2.bitwise_and(stylized_rgb,
                                               stylized_rgb,
                                               mask=mask)

            # Stitch the style + stylized output together, but only if there's one style image
            if args.concat:
                # Resize style img to same height as frame
                # style_prefix = style_prefix + "CON_"
                # content_img_resized = scipy.misc.imresize(content_img, (stylized_rgb.shape[0], stylized_rgb.shape[0]))
                style_img_resized = scipy.misc.imresize(
                    style_img, (stylized_rgb.shape[0], stylized_rgb.shape[0]))
                # margin = np.ones((style_img_resized.shape[0], 10, 3)) * 255
                stylized_rgb = np.hstack([style_img_resized, stylized_rgb])

            # Format for out filename: {out_path}/{content_prefix}_{style_prefix}.{content_ext}
            # out_f = f'{content_prefix}_{style_prefix}.{content_ext}'
            out_f = os.path.join(
                args.out_path, '{}_{}{}'.format(content_prefix, style_prefix,
                                                content_ext))

            if args.remaster:
                # outf = os.path.join(args.out_path, '{}_{}_REMASTERED{}'.format(content_prefix, style_prefix, content_ext))
                stylized_rgb = remaster_pic(stylized_rgb)
            save_img(out_f, stylized_rgb)
            totalfiles = len([
                name for name in os.listdir(args.out_path)
                if os.path.isfile(os.path.join(args.out_path, name))
            ])
            # percent = math.floor(float(totalfiles/thetotal))
            print("{}/{} TOTAL FILES".format(count, thetotal))
            count += 1
            print("{}: Wrote stylized output image to {}".format(count, out_f))
Exemple #11
0
def main():
    start = time.time()

    session_conf = tf.ConfigProto(
        allow_soft_placement=args.allow_soft_placement,
        log_device_placement=args.log_device_placement)
    session_conf.gpu_options.per_process_gpu_memory_fraction = args.gpu_fraction

    with tf.Graph().as_default():
        # with tf.Session(config=session_conf) as sess:
        # Load the WCT model
        wct_model = WCT(checkpoints=args.checkpoints,
                        relu_targets=args.relu_targets,
                        vgg_path=args.vgg_path,
                        device=args.device,
                        ss_patch_size=args.ss_patch_size,
                        ss_stride=args.ss_stride)

        with wct_model.sess as sess:
            # 训练的时候不需要style_img,所以在inference的时候重新保存一次checkpoint
            saver = tf.train.Saver()
            log_path = args.log_path if args.log_path is not None else os.path.join(
                args.checkpoint, 'log')
            summary_writer = tf.summary.FileWriter(log_path, sess.graph)

            # 部分node没有保存在checkpoint中,需要重新初始化
            # sess.run(tf.global_variables_initializer())

            # Get content & style full paths
            if os.path.isdir(args.content_path):
                content_files = get_files(args.content_path)
            else:  # Single image file
                content_files = [args.content_path]
            if os.path.isdir(args.style_path):
                style_files = get_files(args.style_path)
                if args.random > 0:
                    style_files = np.random.choice(style_files, args.random)
            else:  # Single image file
                style_files = [args.style_path]

            os.makedirs(args.out_path, exist_ok=True)

            count = 0

            # Apply each style to each content image
            for content_fullpath in content_files:
                content_prefix, content_ext = os.path.splitext(
                    content_fullpath)
                content_prefix = os.path.basename(
                    content_prefix)  # Extract filename prefix without ext

                content_img = get_img(content_fullpath)
                if args.content_size > 0:
                    content_img = resize_to(content_img, args.content_size)

                for style_fullpath in style_files:
                    style_prefix, _ = os.path.splitext(style_fullpath)
                    style_prefix = os.path.basename(
                        style_prefix)  # Extract filename prefix without ext

                    # style_img = get_img_crop(style_fullpath, resize=args.style_size, crop=args.crop_size)
                    # style_img = resize_to(get_img(style_fullpath), content_img.shape[0])

                    style_img = get_img(style_fullpath)

                    if args.style_size > 0:
                        style_img = resize_to(style_img, args.style_size)
                    if args.crop_size > 0:
                        style_img = center_crop(style_img, args.crop_size)

                    if args.keep_colors:
                        style_img = preserve_colors_np(style_img, content_img)

                    # if args.noise:  # Generate textures from noise instead of images
                    #     frame_resize = np.random.randint(0, 256, frame_resize.shape, np.uint8)
                    #     frame_resize = gaussian_filter(frame_resize, sigma=0.5)

                    # Run the frame through the style network
                    stylized_rgb = wct_model.predict(content_img, style_img,
                                                     args.alpha, args.swap5,
                                                     args.ss_alpha, args.adain)

                    if args.passes > 1:
                        for _ in range(args.passes - 1):
                            stylized_rgb = wct_model.predict(
                                stylized_rgb, style_img, args.alpha,
                                args.swap5, args.ss_alpha, args.adain)

                    save_path = saver.save(
                        sess, os.path.join(args.checkpoint, 'model.ckpt'))
                    print("Model saved in file: %s" % save_path)

                    # Stitch the style + stylized output together, but only if there's one style image
                    if args.concat:
                        # Resize style img to same height as frame
                        style_img_resized = scipy.misc.imresize(
                            style_img,
                            (stylized_rgb.shape[0], stylized_rgb.shape[0]))
                        # margin = np.ones((style_img_resized.shape[0], 10, 3)) * 255
                        stylized_rgb = np.hstack(
                            [style_img_resized, stylized_rgb])

                    # Format for out filename: {out_path}/{content_prefix}_{style_prefix}.{content_ext}
                    out_f = os.path.join(
                        args.out_path,
                        '{}_{}{}'.format(content_prefix, style_prefix,
                                         content_ext))
                    # out_f = f'{content_prefix}_{style_prefix}.{content_ext}'

                    # print(stylized_rgb, stylized_rgb.shape, type(stylized_rgb))
                    # print(out_f)
                    save_img(out_f, stylized_rgb)

                    count += 1
                    print("{}: Wrote stylized output image to {} at {}".format(
                        count, out_f, time.time()))
                    print("breaking...")
                    break
                break

            print("Finished stylizing {} outputs in {}s".format(
                count,
                time.time() - start))
Exemple #12
0
def style_transfer(args, content_img, style_img, imname, csF):
    wct = WCT(args)
    if (args.cuda):
        wct = wct.cuda(args.gpu)
    sF5 = wct.e5(style_img)
    cF5 = wct.e5(content_img)
    sF5 = sF5.data.cpu().squeeze(0)
    cF5 = cF5.data.cpu().squeeze(0)
    csF5 = wct.transform(cF5, sF5, csF, args.alpha)
    Im5 = wct.d5(csF5)

    sF4 = wct.e4(style_img)
    cF4 = wct.e4(Im5)
    sF4 = sF4.data.cpu().squeeze(0)
    cF4 = cF4.data.cpu().squeeze(0)
    csF4 = wct.transform(cF4, sF4, csF, args.alpha)
    Im4 = wct.d4(csF4)

    sF3 = wct.e3(style_img)
    cF3 = wct.e3(Im4)
    sF3 = sF3.data.cpu().squeeze(0)
    cF3 = cF3.data.cpu().squeeze(0)
    csF3 = wct.transform(cF3, sF3, csF, args.alpha)
    Im3 = wct.d3(csF3)

    sF2 = wct.e2(style_img)
    cF2 = wct.e2(Im3)
    sF2 = sF2.data.cpu().squeeze(0)
    cF2 = cF2.data.cpu().squeeze(0)
    csF2 = wct.transform(cF2, sF2, csF, args.alpha)
    Im2 = wct.d2(csF2)

    sF1 = wct.e1(style_img)
    cF1 = wct.e1(Im2)
    sF1 = sF1.data.cpu().squeeze(0)
    cF1 = cF1.data.cpu().squeeze(0)
    csF1 = wct.transform(cF1, sF1, csF, args.alpha)
    Im1 = wct.d1(csF1)
    # save_image has this wired design to pad images with 4 pixels at default.
    vutils.save_image(Im1.data.cpu().float(),
                      os.path.join(args.output_dir, imname))
    return
Exemple #13
0
def main():
    # Load the WCT model
    wct_model = WCT(checkpoints=args.checkpoints,
                    relu_targets=args.relu_targets,
                    vgg_path=args.vgg_path,
                    device=args.device,
                    ss_patch_size=args.ss_patch_size,
                    ss_stride=args.ss_stride)

    # Create needed dirs
    in_dir = os.path.join(args.tmp_dir, 'input')
    out_dir = os.path.join(args.tmp_dir, 'sytlized')
    if not os.path.exists(in_dir):
        os.makedirs(in_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    if os.path.isdir(args.in_path):
        in_path = get_files(args.in_path)
    else:  # Single image file
        in_path = [args.in_path]

    if os.path.isdir(args.style_path):
        style_files = get_files(args.style_path)
    else:  # Single image file
        style_files = [args.style_path]

    print(style_files)
    # time.sleep(999)

    in_args = [
        'ffmpeg',
        '-i', args.in_path,
        '%s/frame_%%d.png' % in_dir
    ]

    subprocess.call(" ".join(in_args), shell=True)
    base_names = os.listdir(in_dir)
    in_files = [os.path.join(in_dir, x) for x in base_names]
    out_files = [os.path.join(out_dir, x) for x in base_names]

    s = time.time()
    for content_fullpath in in_path:
        content_prefix, content_ext = os.path.splitext(content_fullpath)
        content_prefix = os.path.basename(content_prefix)

        try:

            for style_fullpath in style_files:
                style_img = get_img(style_fullpath)
                if args.style_size > 0:
                    style_img = resize_to(style_img, args.style_size)
                if args.crop_size > 0:
                    style_img = center_crop(style_img, args.crop_size)

                style_prefix, _ = os.path.splitext(style_fullpath)
                style_prefix = os.path.basename(style_prefix)

                # print("ARRAY:  ", style_img)
                out_v = os.path.join(args.out_path, '{}_{}{}'.format(content_prefix, style_prefix, content_ext))
                print("OUT:", out_v)
                if os.path.isfile(out_v):
                    print("SKIP", out_v)
                    continue

                for in_f, out_f in zip(in_files, out_files):
                    print('{} -> {}'.format(in_f, out_f))
                    content_img = get_img(in_f)

                    if args.keep_colors:
                        style_rgb = preserve_colors_np(style_img, content_img)
                    else:
                        style_rgb = style_img

                    stylized = wct_model.predict(content_img, style_rgb, args.alpha, args.swap5, args.ss_alpha)

                    if args.passes > 1:
                        for _ in range(args.passes-1):
                            stylized = wct_model.predict(stylized, style_rgb, args.alpha)

                    # Stitch the style + stylized output together, but only if there's one style image
                    if args.concat:
                        # Resize style img to same height as frame
                        style_img_resized = scipy.misc.imresize(style_rgb, (stylized.shape[0], stylized.shape[0]))
                        stylized = np.hstack([style_img_resized, stylized])

                    save_img(out_f, stylized)

                fr = 30
                out_args = [
                    'ffmpeg',
                    '-i', '%s/frame_%%d.png' % out_dir,
                    '-f', 'mp4',
                    '-q:v', '0',
                    '-vcodec', 'mpeg4',
                    '-r', str(fr),
                    '"' + out_v + '"'
                ]
                print(out_args)

                subprocess.call(" ".join(out_args), shell=True)
                print('Video at: %s' % out_v)

                if args.keep_tmp is True or len(style_files) > 1:
                    continue
                else:
                    shutil.rmtree(args.tmp_dir)
                print('Processed in:', (time.time() - s))

            print('Processed in:', (time.time() - s))

        except Exception as e:
            print("EXCEPTION: ", e)
def main():
    start = time.time()

    # Load the WCT model
    wct_model = WCT(checkpoints=args.checkpoints,
                    relu_targets=args.relu_targets,
                    vgg_path=args.vgg_path,
                    device=args.device,
                    ss_patch_size=args.ss_patch_size,
                    ss_stride=args.ss_stride)

    # Get content & style full paths
    if os.path.isdir(args.content_path):
        content_files = get_files(args.content_path)
    else:  # Single image file
        content_files = [args.content_path]
    if os.path.isdir(args.style_path):
        style_files = get_files(args.style_path)
        if args.random > 0:
            style_files = np.random.choice(style_files, args.random)
    else:  # Single image file
        style_files = [args.style_path]

    os.makedirs(args.out_path, exist_ok=True)

    count = 0

    # Apply each style to each content image
    for content_fullpath in content_files:
        content_prefix, content_ext = os.path.splitext(content_fullpath)
        content_prefix = os.path.basename(
            content_prefix)  # Extract filename prefix without ext

        content_img = get_img(content_fullpath)
        if args.content_size > 0:
            content_img = resize_to(content_img, args.content_size)

        for style_fullpath in style_files:
            style_prefix, _ = os.path.splitext(style_fullpath)
            style_prefix = os.path.basename(
                style_prefix)  # Extract filename prefix without ext

            # style_img = get_img_crop(style_fullpath, resize=args.style_size, crop=args.crop_size)
            # style_img = resize_to(get_img(style_fullpath), content_img.shape[0])

            style_img = get_img(style_fullpath)

            if args.style_size > 0:
                style_img = resize_to(style_img, args.style_size)
            if args.crop_size > 0:
                style_img = center_crop(style_img, args.crop_size)

            if args.keep_colors:
                style_img = preserve_colors_np(style_img, content_img)

            # if args.noise:  # Generate textures from noise instead of images
            #     frame_resize = np.random.randint(0, 256, frame_resize.shape, np.uint8)
            #     frame_resize = gaussian_filter(frame_resize, sigma=0.5)

            # Run the frame through the style network
            stylized_rgb = wct_model.predict(content_img, style_img,
                                             args.alpha, args.swap5,
                                             args.ss_alpha, args.adain)

            if args.passes > 1:
                for _ in range(args.passes - 1):
                    stylized_rgb = wct_model.predict(stylized_rgb, style_img,
                                                     args.alpha, args.swap5,
                                                     args.ss_alpha, args.adain)

            # Stitch the style + stylized output together, but only if there's one style image
            if args.concat:
                # Resize style img to same height as frame
                style_img_resized = scipy.misc.imresize(
                    style_img, (stylized_rgb.shape[0], stylized_rgb.shape[0]))
                # margin = np.ones((style_img_resized.shape[0], 10, 3)) * 255
                stylized_rgb = np.hstack([style_img_resized, stylized_rgb])

            # Format for out filename: {out_path}/{content_prefix}_{style_prefix}.{content_ext}
            out_f = os.path.join(
                args.out_path, '{}_{}{}'.format(content_prefix, style_prefix,
                                                content_ext))
            # out_f = f'{content_prefix}_{style_prefix}.{content_ext}'

            save_img(out_f, stylized_rgb)

            count += 1
            print("{}: Wrote stylized output image to {}".format(count, out_f))

    print("Finished stylizing {} outputs in {}s".format(
        count,
        time.time() - start))
Exemple #15
0
def main():
    # Load the WCT model
    wct_model = WCT(checkpoints=args.checkpoints,
                    relu_targets=args.relu_targets,
                    vgg_path=args.vgg_path,
                    device=args.device)

    # Load a panel to control style settings
    style_window = StyleWindow(args.style_path, args.style_size,
                               args.crop_size, args.scale, args.alpha)

    # Start the webcam stream
    cap = WebcamVideoStream(args.video_source, args.width, args.height).start()

    _, frame = cap.read()

    # Grab a sample frame to calculate frame size
    frame_resize = cv2.resize(frame, None, fx=args.scale, fy=args.scale)
    img_shape = frame_resize.shape

    # Setup video out writer
    if args.video_out is not None:
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        if args.concat:
            out_shape = (img_shape[1] + img_shape[0], img_shape[0]
                         )  # Make room for the style img
        else:
            out_shape = (img_shape[1], img_shape[0])
        print('Video Out Shape:', out_shape)
        video_writer = cv2.VideoWriter(args.video_out, fourcc, args.fps,
                                       out_shape)

    fps = FPS().start()  # Track FPS processing speed

    keep_colors = args.keep_colors

    count = 0

    while (True):
        ret, frame = cap.read()

        if ret is True:
            frame_resize = cv2.resize(frame,
                                      None,
                                      fx=style_window.scale,
                                      fy=style_window.scale)

            if args.noise:  # Generate textures from noise instead of images
                frame_resize = np.random.randint(0, 256, frame_resize.shape,
                                                 np.uint8)

            count += 1
            print("Frame:", count, "Orig shape:", frame.shape, "New shape",
                  frame_resize.shape)

            content_rgb = cv2.cvtColor(
                frame_resize,
                cv2.COLOR_BGR2RGB)  # OpenCV uses BGR, we need RGB

            if args.random > 0 and count % args.random == 0:
                style_window.set_style(random=True)

            if keep_colors:
                style_rgb = preserve_colors_np(style_window.style_rgb,
                                               content_rgb)
            else:
                style_rgb = style_window.style_rgb

            # Run the frame through the style network
            stylized_rgb = wct_model.predict(content_rgb, style_rgb,
                                             style_window.alpha)

            if args.passes > 1:
                for i in range(args.passes - 1):
                    stylized_rgb = wct_model.predict(stylized_rgb, style_rgb,
                                                     style_window.alpha)

            # Stitch the style + stylized output together, but only if there's one style image
            if args.concat:
                # Resize style img to same height as frame
                style_rgb_resized = cv2.resize(
                    style_rgb, (stylized_rgb.shape[0], stylized_rgb.shape[0]))
                stylized_rgb = np.hstack([style_rgb_resized, stylized_rgb])

            stylized_bgr = cv2.cvtColor(stylized_rgb, cv2.COLOR_RGB2BGR)

            if args.video_out is not None:
                stylized_bgr = cv2.resize(
                    stylized_bgr,
                    out_shape)  # Make sure frame matches video size
                video_writer.write(stylized_bgr)

            cv2.imshow('WCT Universal Style Transfer', stylized_bgr)

            fps.update()

            key = cv2.waitKey(10)
            if key & 0xFF == ord('r'):  # Load new random style
                style_window.set_style(random=True)
            elif key & 0xFF == ord('c'):  # Toggle color preservation
                keep_colors = not keep_colors
                print('Switching to keep_colors', keep_colors)
            elif key & 0xFF == ord('s'):  # Save stylized frame
                out_f = "{}.png".format(time.time())
                save_img(out_f, stylized_rgb)
                print('Saved image to', out_f)
            elif key & 0xFF == ord('q'):  # Quit gracefully
                break
        else:
            break

    fps.stop()
    print('[INFO] elapsed time (total): {:.2f}'.format(fps.elapsed()))
    print('[INFO] approx. FPS: {:.2f}'.format(fps.fps()))

    cap.stop()

    if args.video_out is not None:
        video_writer.release()

    cv2.destroyAllWindows()
def main():
    # Load the WCT model
    wct_model = WCT(checkpoints=args.checkpoints,
                    relu_targets=args.relu_targets,
                    vgg_path=args.vgg_path,
                    device=args.device,
                    ss_patch_size=args.ss_patch_size,
                    ss_stride=args.ss_stride)

    # Create needed dirs
    in_dir = os.path.join(args.tmp_dir, 'input')
    out_dir = os.path.join(args.tmp_dir, 'sytlized')
    if not os.path.exists(in_dir):
        os.makedirs(in_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    in_args = ['ffmpeg', '-i', args.in_path, '%s/frame_%%d.png' % in_dir]

    subprocess.call(" ".join(in_args), shell=True)
    base_names = os.listdir(in_dir)
    in_files = [os.path.join(in_dir, x) for x in base_names]
    out_files = [os.path.join(out_dir, x) for x in base_names]

    style_img = get_img(args.style_path)

    if args.style_size > 0:
        style_img = resize_to(style_img, args.style_size)
    if args.crop_size > 0:
        style_img = center_crop(style_img, args.crop_size)

    s = time.time()
    for in_f, out_f in zip(in_files, out_files):
        print('{} -> {}'.format(in_f, out_f))
        content_img = get_img(in_f)

        if args.keep_colors:
            style_rgb = preserve_colors_np(style_img, content_img)
        else:
            style_rgb = style_img

        stylized = wct_model.predict(content_img, style_rgb, args.alpha,
                                     args.swap5, args.ss_alpha)

        if args.passes > 1:
            for _ in range(args.passes - 1):
                stylized = wct_model.predict(stylized, style_rgb, args.alpha)

        # Stitch the style + stylized output together, but only if there's one style image
        if args.concat:
            # Resize style img to same height as frame
            style_img_resized = scipy.misc.imresize(
                style_rgb, (stylized.shape[0], stylized.shape[0]))
            stylized = np.hstack([style_img_resized, stylized])

        save_img(out_f, stylized)

    fr = 30
    out_args = [
        'ffmpeg', '-i',
        '%s/frame_%%d.png' % out_dir, '-f', 'mp4', '-q:v', '0', '-vcodec',
        'mpeg4', '-r',
        str(fr), args.out_path
    ]

    subprocess.call(" ".join(out_args), shell=True)
    print('Video at: %s' % args.out_path)

    if args.keep_tmp is False:
        shutil.rmtree(args.tmp_dir)

    print('Processed in:', (time.time() - s))
Exemple #17
0
def main():
    # Load the WCT model
    wct_model = WCT(checkpoints=args.checkpoints, 
                    relu_targets=args.relu_targets,
                    vgg_path=args.vgg_path, 
                    device=args.device,
                    ss_patch_size=args.ss_patch_size, 
                    ss_stride=args.ss_stride)

    # Load a panel to control style settings
    style_window = StyleWindow(args.style_path, 
                               args.style_size, 
                               args.crop_size, 
                               args.scale, 
                               args.alpha, 
                               args.swap5, 
                               args.ss_alpha,
                               args.passes)

    # Start the webcam stream
    cap = WebcamVideoStream(args.video_source, args.width, args.height).start()

    _, frame = cap.read()

    # Grab a sample frame to calculate frame size
    frame_resize = cv2.resize(frame, None, fx=args.scale, fy=args.scale)
    img_shape = frame_resize.shape

    # Setup video out writer
    if args.video_out is not None:
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        if args.concat:
            out_shape = (img_shape[1]+img_shape[0],img_shape[0]) # Make room for the style img
        else:
            out_shape = (img_shape[1],img_shape[0])
        print('Video Out Shape:', out_shape)
        video_writer = cv2.VideoWriter(args.video_out, fourcc, args.fps, out_shape)
    
    fps = FPS().start() # Track FPS processing speed

    # Toggles changed with kb shortcuts
    keep_colors = args.keep_colors
    swap_style = args.swap5
    use_adain = args.adain

    count = 0

    while(True):
        ret, frame = cap.read()

        if ret is True:       
            frame_resize = cv2.resize(frame, None, fx=style_window.scale, fy=style_window.scale)

            if args.noise:  # Generate textures from noise instead of images
                frame_resize = np.random.randint(0, 256, frame_resize.shape, np.uint8)

            count += 1
            print("Frame:",count,"Orig shape:",frame.shape,"New shape",frame_resize.shape)

            content_rgb = cv2.cvtColor(frame_resize, cv2.COLOR_BGR2RGB)  # OpenCV uses BGR, we need RGB

            if args.random > 0 and count % args.random == 0:
                style_window.set_style(random=True)

            if keep_colors:
                style_rgb = preserve_colors_np(style_window.style_rgb, content_rgb)
            else:
                style_rgb = style_window.style_rgb

            # Run the frame through the style network
            stylized_rgb = wct_model.predict(content_rgb, style_rgb, style_window.alpha, swap_style, style_window.ss_alpha, use_adain)

            # Repeat stylization pipeline
            if style_window.passes > 1:
                for i in range(style_window.passes-1):
                    stylized_rgb = wct_model.predict(stylized_rgb, style_rgb, style_window.alpha, swap_style, style_window.ss_alpha, use_adain)

            # Stitch the style + stylized output together
            if args.concat:
                # Resize style img to same height as frame
                style_rgb_resized = cv2.resize(style_rgb, (stylized_rgb.shape[0], stylized_rgb.shape[0]))
                stylized_rgb = np.hstack([style_rgb_resized, stylized_rgb])
            
            stylized_bgr = cv2.cvtColor(stylized_rgb, cv2.COLOR_RGB2BGR)
                
            if args.video_out is not None:
                stylized_bgr = cv2.resize(stylized_bgr, out_shape) # Make sure frame matches video size
                video_writer.write(stylized_bgr)

            cv2.imshow('WCT Universal Style Transfer', stylized_bgr)

            fps.update()

            key = cv2.waitKey(10) 
            if key & 0xFF == ord('r'):   # Load new random style
                style_window.set_style(random=True)
            elif key & 0xFF == ord('c'): # Toggle color preservation
                keep_colors = not keep_colors
                print('Switching to keep_colors:',keep_colors)
            elif key & 0xFF == ord('s'): # Toggle style swap
                swap_style = not swap_style
                print('New value for flag swap_style:',swap_style)
            elif key & 0xFF == ord('a'): # Toggle AdaIN
                use_adain = not use_adain
                print('New value for flag use_adain:',use_adain)
            elif key & 0xFF == ord('w'): # Write stylized frame
                out_f = "{}.png".format(time.time())
                save_img(out_f, stylized_rgb)
                print('Saved image to:',out_f)
            elif key & 0xFF == ord('q'): # Quit gracefully
                break
        else:
            break

    fps.stop()
    print('[INFO] elapsed time (total): {:.2f}'.format(fps.elapsed()))
    print('[INFO] approx. FPS: {:.2f}'.format(fps.fps()))

    cap.stop()
    
    if args.video_out is not None:
        video_writer.release()
    
    cv2.destroyAllWindows()