Esempio n. 1
0
class PhotoStyle():
    def __init__(self, model_path=MODEL_PATH, use_cuda=True):
        self.p_wct = PhotoWCT()
        self.p_wct.load_state_dict(torch.load(model_path))
        self.use_cuda = use_cuda
        if use_cuda:
            self.p_wct.cuda(0)

    def stylize(self,
                content_image_path,
                style_image_path,
                output_image_path,
                content_seg_path=None,
                style_seg_path=None,
                smooth=True,
                verbose=False):
        process_stylization.stylization(
            p_wct=self.p_wct,
            content_image_path=content_image_path,
            style_image_path=style_image_path,
            content_seg_path=content_seg_path,
            style_seg_path=style_seg_path,
            output_image_path=output_image_path,
            cuda=self.use_cuda,
            smooth=smooth,
            verbose=verbose,
        )
Esempio n. 2
0
 def __init__(self):
     # Load model
     self.p_wct = PhotoWCT()
     self.p_wct.load_state_dict(
         torch.load('./PhotoWCTModels/photo_wct.pth'))
     self.p_wct.cuda(0)
     self.p_pro = Propagator()
Esempio n. 3
0
class CyclicPhotoWCT(nn.Module):
    def __init__(self):
        super(CyclicPhotoWCT, self).__init__()
        self.fw = PhotoWCT()
        self.bw = PhotoWCT()

    def transform(self, cont_img, styl_img, cont_seg, styl_seg):
        stylized_img = self.fw.transform(cont_img, styl_img, cont_seg, styl_seg)
        reversed_img = self.bw.transform(stylized_img, cont_img, cont_seg, styl_seg)
        return stylized_img, reversed_img

    def forward(self, *input):
        pass
Esempio n. 4
0
class StyleTransfer_Engine:
    def __init__(self):
        # Load model
        self.p_wct = PhotoWCT()
        self.p_wct.load_state_dict(
            torch.load('./PhotoWCTModels/photo_wct.pth'))
        self.p_wct.cuda(0)
        self.p_pro = Propagator()

    def run(self, content, style):
        out = stylization(stylization_module=self.p_wct,
                          smoothing_module=self.p_pro,
                          cont_img=content,
                          styl_img=style,
                          cuda=1,
                          save_intermediate=False,
                          no_post=False)
        return out
def load_model(model='./PhotoWCTModels/photo_wct.pth', fast=True, cuda=1):
    """load model, fast=lighter version"""
    # Load model
    p_wct = PhotoWCT()
    p_wct.load_state_dict(torch.load(model))

    if fast:
        from photo_gif import GIFSmoothing
        p_pro = GIFSmoothing(r=35, eps=0.001)
    else:
        from photo_smooth import Propagator
        p_pro = Propagator()
    if cuda:
        p_wct.cuda(0)

    return p_wct, p_pro
Esempio n. 6
0
def setup(opts):
    p_wct = PhotoWCT()
    p_wct.load_state_dict(torch.load(PRETRAINED_MODEL_PATH))

    if opts['propagation_mode'] == 'fast':
        from photo_gif import GIFSmoothing
        p_pro = GIFSmoothing(r=35, eps=0.001)
    else:
        from photo_smooth import Propagator
        p_pro = Propagator()
    if torch.cuda.is_available():
        p_wct.cuda(0)

    return {
        'p_wct': p_wct,
        'p_pro': p_pro,
    }
Esempio n. 7
0
def main():
    parser = argparse.ArgumentParser(
        description='Photorealistic Image Stylization')
    parser.add_argument(
        '--model',
        default='./PhotoWCTModels/photo_wct.pth',
        help=
        'Path to the PhotoWCT model. These are provided by the PhotoWCT submodule, please use `git submodule update --init --recursive` to pull.'
    )
    parser.add_argument('--content_image_path',
                        default='./images/content1.png')
    parser.add_argument('--content_seg_path', default=[])
    parser.add_argument('--style_image_path', default='./images/style1.png')
    parser.add_argument('--style_seg_path', default=[])
    parser.add_argument('--output_image_path',
                        default='./results/example1.png')
    parser.add_argument('--cuda', type=int, default=1, help='Enable CUDA.')
    args = parser.parse_args()

    # Load model
    p_wct = PhotoWCT()
    try:
        p_wct.load_state_dict(torch.load(args.model))
    except:
        print("Fail to load PhotoWCT models. PhotoWCT submodule not updated?")
        exit()

    if args.cuda:
        p_wct.cuda(0)

    process_stylization.stylization(
        p_wct=p_wct,
        content_image_path=args.content_image_path,
        style_image_path=args.style_image_path,
        content_seg_path=args.content_seg_path,
        style_seg_path=args.style_seg_path,
        output_image_path=args.output_image_path,
        cuda=args.cuda,
    )
parser.add_argument('--model', default='./PhotoWCTModels/photo_wct.pth',
                    help='Path to the PhotoWCT model. These are provided by the PhotoWCT submodule, please use `git submodule update --init --recursive` to pull.')
parser.add_argument('--cuda', type=bool, default=True, help='Enable CUDA.')
args = parser.parse_args()

folder = 'examples'
cont_img_folder = os.path.join(folder, 'content_img')
cont_seg_folder = os.path.join(folder, 'content_seg')
styl_img_folder = os.path.join(folder, 'style_img')
styl_seg_folder = os.path.join(folder, 'style_seg')
outp_img_folder = os.path.join(folder, 'results')
cont_img_list = [f for f in os.listdir(cont_img_folder) if os.path.isfile(os.path.join(cont_img_folder, f))]
cont_img_list.sort()

# Load model
p_wct = PhotoWCT()
p_wct.load_state_dict(torch.load(args.model))

for f in cont_img_list:
    print("Process " + f)
    
    content_image_path = os.path.join(cont_img_folder, f)
    content_seg_path = os.path.join(cont_seg_folder, f).replace(".png", ".pgm")
    style_image_path = os.path.join(styl_img_folder, f)
    style_seg_path = os.path.join(styl_seg_folder, f).replace(".png", ".pgm")
    output_image_path = os.path.join(outp_img_folder, f)
    
    process_stylization.stylization(
        p_wct=p_wct,
        content_image_path=content_image_path,
        style_image_path=style_image_path,
Esempio n. 9
0
                    default='./models/feature_invertor_conv4_1_mask.t7',
                    help='Path to the decoder4')
parser.add_argument('--decoder3',
                    default='./models/feature_invertor_conv3_1_mask.t7',
                    help='Path to the decoder3')
parser.add_argument('--decoder2',
                    default='./models/feature_invertor_conv2_1_mask.t7',
                    help='Path to the decoder2')
parser.add_argument('--decoder1',
                    default='./models/feature_invertor_conv1_1_mask.t7',
                    help='Path to the decoder1')
parser.add_argument('--content_image_path', default='./images/content1.png')
parser.add_argument('--content_seg_path', default=[])
parser.add_argument('--style_image_path', default='./images/style1.png')
parser.add_argument('--style_seg_path', default=[])
parser.add_argument('--output_image_path', default='./results/example1.png')
args = parser.parse_args()

# Load model
p_wct = PhotoWCT(args)
p_wct.cuda(0)

process_stylization.stylization(
    p_wct=p_wct,
    content_image_path=args.content_image_path,
    style_image_path=args.style_image_path,
    content_seg_path=args.content_seg_path,
    style_seg_path=args.style_seg_path,
    output_image_path=args.output_image_path,
)
Esempio n. 10
0
    description='Photorealistic Image Stylization')
parser.add_argument('--model', default='./PhotoWCTModels/photo_wct.pth')
parser.add_argument('--content_image_path', default='./images/content1.png')
parser.add_argument('--content_seg_path', default=[])
parser.add_argument('--style_image_path', default='./images/style1.png')
parser.add_argument('--style_seg_path', default=[])
parser.add_argument('--output_image_path', default='./results/example1.png')
parser.add_argument('--save_intermediate', action='store_true', default=False)
parser.add_argument('--fast', action='store_true', default=False)
parser.add_argument('--no_post', action='store_true', default=False)
parser.add_argument('--cuda', type=int, default=1, help='Enable CUDA.')
parser.add_argument('--device', type=int, default=0, help='CUDA device.')
args = parser.parse_args()

# Load model
p_wct = PhotoWCT(device=args.device)
p_wct.load_state_dict(torch.load(args.model))

if args.fast:
    from photo_gif import GIFSmoothing
    p_pro = GIFSmoothing(r=35, eps=0.001)
else:
    from photo_smooth import Propagator
    p_pro = Propagator()
if args.cuda:
    p_wct.cuda(args.device)

process_stylization.stylization(stylization_module=p_wct,
                                smoothing_module=p_pro,
                                content_image_path=args.content_image_path,
                                style_image_path=args.style_image_path,
Esempio n. 11
0
                        style_image_path,
                        stylized_image_path,
                        reversed_image_path,
                        cuda,
                        smoothing_module=smoothing_module,
                        do_smoothing=do_smoothing)
    print('Stylized image', stylized_image_path)
    print('Reversed image', reversed_image_path)
    print("MSE loss between content and reversed image:",
          mse_loss_images(content_image_path, reversed_image_path).item())
    print("Content loss between content and reversed image:",
          content_loss_images(content_image_path, reversed_image_path).item())
    print("=" * 15)


p_wct = PhotoWCT()
p_pro = Propagator(beta=0.7)
cuda = torch.cuda.is_available()
do_smoothing = False  # change this flag if you want to apply or no the smoothing module

if cuda:
    p_wct.cuda(0)

p_wct_paths = [
    './PhotoWCTModels/cyclic_photo_wct_mse.pth',
    './PhotoWCTModels/cyclic_photo_wct_content.pth',
    './PhotoWCTModels/cyclic_photo_wct_mse_content.pth'
]
content_image_paths = [
    './images/opernhaus.jpg', './images/forest_summer.jpg',
    './images/pyramid_egypt.jpg'
Esempio n. 12
0
                    help='Path to the decoder3')
parser.add_argument('--decoder2',
                    default='./models/feature_invertor_conv2_1_mask.t7',
                    help='Path to the decoder2')
parser.add_argument('--decoder1',
                    default='./models/feature_invertor_conv1_1_mask.t7',
                    help='Path to the decoder1')
parser.add_argument('--content_image_path', default='./images/content1.png')
parser.add_argument('--content_seg_path', default=[])
parser.add_argument('--style_image_path', default='./images/style1.png')
parser.add_argument('--style_seg_path', default=[])
parser.add_argument('--output_image_path', default='./results/example1.png')
args = parser.parse_args()

# Load model
p_wct = PhotoWCT(args)
p_pro = Propagator()
p_wct.cuda(0)

content_image_path = args.content_image_path
content_seg_path = args.content_seg_path
style_image_path = args.style_image_path
style_seg_path = args.style_seg_path
output_image_path = args.output_image_path

# Load image
cont_img = Image.open(content_image_path).convert('RGB')
styl_img = Image.open(style_image_path).convert('RGB')
try:
    cont_seg = Image.open(content_seg_path)
    styl_seg = Image.open(style_seg_path)
Esempio n. 13
0
from __future__ import print_function
from model import StyleTransferModel
from io import BytesIO
from telegram import ReplyKeyboardMarkup, ReplyKeyboardRemove
from telegram.ext.dispatcher import run_async
import torch
from PIL import Image
import process_stylization
from photo_wct import PhotoWCT
from photo_gif import GIFSmoothing

model = StyleTransferModel()
p_wct = PhotoWCT()
p_wct.load_state_dict(torch.load('photo_wct.pth'))
p_pro = GIFSmoothing(r=35, eps=0.001)
first_image_file = {}
storage = []
storage1 = []


@run_async
def send_prediction_on_photo(update, context):
    global storage
    global storage1
    # Нам нужно получить две картинки, чтобы произвести перенос стиля, но каждая картинка приходит в
    # отдельном апдейте, поэтому в простейшем случае мы будем сохранять id первой картинки в память,
    # чтобы, когда уже придет вторая, мы могли загрузить в память уже сами картинки и обработать их.
    # Точно место для улучшения, я бы
    bot = context.bot
    if update.message.text == '/style':
        storage.append('1')
Esempio n. 14
0
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

root_dir = 'dataset/'

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

train_data = datasets.ImageFolder(root_dir, image_transform)
data_loader = torch.utils.data.DataLoader(train_data,
                                          batch_size=32,
                                          shuffle=True)

# define & initialize model
model = PhotoWCT()
model.load_state_dict(torch.load('./PhotoWCTModels/photo_wct.pth'))

# set encoders requires_grad=False, i.e. freeze them during training
set_parameter_requires_grad(model.e1, False)
set_parameter_requires_grad(model.e2, False)
set_parameter_requires_grad(model.e3, False)
set_parameter_requires_grad(model.e4, False)

# transfer model to GPU if you have one
model.to(device)

# set criterion to reconstruction loss & define optimizer
criterion = WeightedMseContentLoss(content_loss_weight=1, mse_loss_weight=1700)
#criterion = ContentLoss() # MSELoss() # alternatives to the above dissimilarity constraint
Esempio n. 15
0
 def __init__(self, model_path=MODEL_PATH, use_cuda=True):
     self.p_wct = PhotoWCT()
     self.p_wct.load_state_dict(torch.load(model_path))
     self.use_cuda = use_cuda
     if use_cuda:
         self.p_wct.cuda(0)
Esempio n. 16
0
import six
import torch
from PIL import Image
from flask import Flask, request
from flask import jsonify

import process_stylization
from image_utils import base64_png_image_to_pillow_image, get_apt_image_size, get_temp_png_file_path
from photo_wct import PhotoWCT

app = Flask(__name__)

MAX_NUM_PIXELS = 1024 * 512

# Load model
p_wct = PhotoWCT()
try:
    p_wct.load_state_dict(
        torch.load(
            os.path.join(os.path.dirname(__file__), 'PhotoWCTModels',
                         'photo_wct.pth')))
except:
    print("Fail to load PhotoWCT models. PhotoWCT submodule not updated?")
    exit()


@app.route("/stylize/", methods=['POST'])
def stylize():
    content_image_base64 = request.json.get('content_image_base64', None)
    if content_image_base64 is None:
        raise Exception('content_image_base64 cannot be None')
Esempio n. 17
0
            'conv2_1': 9,
            'conv2_2': 12,
            'conv3_1': 16,
            'conv3_2': 19,
            'conv3_3': 22,
            'conv3_4': 25,
            'conv4_1': 29,
        })
    torch.save(e4.state_dict(), 'pth_models/vgg_normalised_conv4.pth')

    ## VGGDecoder4
    inv4 = load_lua('models/feature_invertor_conv4_1_mask.t7')
    d4 = VGGDecoder(4)
    weight_assign(
        inv4, d4, {
            'conv4_1': 1,
            'conv3_4': 5,
            'conv3_3': 8,
            'conv3_2': 11,
            'conv3_1': 14,
            'conv2_2': 18,
            'conv2_1': 21,
            'conv1_2': 25,
            'conv1_1': 28,
        })
    torch.save(d4.state_dict(), 'pth_models/feature_invertor_conv4.pth')

    p_wct = PhotoWCT()
    photo_wct_loader(p_wct)
    torch.save(p_wct.state_dict(), 'PhotoWCTModels/photo_wct.pth')
Esempio n. 18
0
        'conv1_2': 5,
        'conv2_1': 9,
        'conv2_2': 12,
        'conv3_1': 16,
        'conv3_2': 19,
        'conv3_3': 22,
        'conv3_4': 25,
        'conv4_1': 29,
    })
    torch.save(e4.state_dict(), 'pth_models/vgg_normalised_conv4.pth')
    
    ## VGGDecoder4
    inv4 = load_lua('models/feature_invertor_conv4_1_mask.t7')
    d4 = VGGDecoder(4)
    weight_assign(inv4, d4, {
        'conv4_1': 1,
        'conv3_4': 5,
        'conv3_3': 8,
        'conv3_2': 11,
        'conv3_1': 14,
        'conv2_2': 18,
        'conv2_1': 21,
        'conv1_2': 25,
        'conv1_1': 28,
    })
    torch.save(d4.state_dict(), 'pth_models/feature_invertor_conv4.pth')
    
    p_wct = PhotoWCT()
    photo_wct_loader(p_wct)
    torch.save(p_wct.state_dict(), 'PhotoWCTModels/photo_wct.pth')
Esempio n. 19
0
 def __init__(self):
     super(CyclicPhotoWCT, self).__init__()
     self.fw = PhotoWCT()
     self.bw = PhotoWCT()
Esempio n. 20
0
args = parser.parse_args()

folder = 'examples'
cont_img_folder = os.path.join(folder, 'content_img')
cont_seg_folder = os.path.join(folder, 'content_seg')
styl_img_folder = os.path.join(folder, 'style_img')
styl_seg_folder = os.path.join(folder, 'style_seg')
outp_img_folder = os.path.join(folder, 'results')
cont_img_list = [
    f for f in os.listdir(cont_img_folder)
    if os.path.isfile(os.path.join(cont_img_folder, f))
]
cont_img_list.sort()

# Load model
p_wct = PhotoWCT()
p_wct.load_state_dict(torch.load(args.model))
p_wct.cuda(0)

for f in cont_img_list:
    print("Process " + f)

    content_image_path = os.path.join(cont_img_folder, f)
    content_seg_path = os.path.join(cont_seg_folder, f).replace(".png", ".pgm")
    style_image_path = os.path.join(styl_img_folder, f)
    style_seg_path = os.path.join(styl_seg_folder, f).replace(".png", ".pgm")
    output_image_path = os.path.join(outp_img_folder, f)

    process_stylization.stylization(
        p_wct=p_wct,
        content_image_path=content_image_path,
Esempio n. 21
0
import os
from os import path as osp


SEED = 1984
N_IMGS = 100
FAST = True
SPLIT = "val"
MODEL = "./PhotoWCTModels/photo_wct.pth"
STYLE_PATH = "/data/datasets/style_transfer_amos/styles_sub_10"
CONTENT_PATH = "/data/datasets/phototourism"
OUTPUT_ST_PATH = osp.join(CONTENT_PATH, "style_transfer_all")
STYLES = ["cloudy", "dusk", "mist", "night", "rainy", "snow"]
#STYLES = ["snow"]

p_wct = PhotoWCT()
p_wct.load_state_dict(torch.load(MODEL))
p_pro = GIFSmoothing(r=35, eps=0.001) if FAST else Propagator()
p_wct.cuda(0)

with open(osp.join(CONTENT_PATH, SPLIT + "_phototourism_ms.txt"), "r") as f:
    content_fnames = [line.rstrip('\n') for line in f]

for style in STYLES:
    print("Style: {:s}".format(style))
    style_fnames = [img for img in os.listdir(osp.join(STYLE_PATH, style)) if img[-3:] in ["png", "jpg"]]
    for style_fname in style_fnames:
        k_cont = 0
        for content_fname in content_fnames:
            scene = content_fname.split('/')[0]
            output_path = osp.join(CONTENT_PATH, "style_transfer_all", scene, style)