Esempio n. 1
0
 def __init__(self):
     # Load model
     self.p_wct = PhotoWCT()
     self.p_wct.load_state_dict(
         torch.load('./PhotoWCTModels/photo_wct.pth'))
     self.p_wct.cuda(0)
     self.p_pro = Propagator()
def process_image(content: str, style: str, output: str):
    import os
    
    class StylizationException(Exception):
        def __init__(self, msg):
            self.msg = msg
    
    while True:
        try:
            if not os.path.isdir(output):
                raise StylizationException("output path is not a folder")
            
            img_ext = ['.jpg', '.jpeg', '.png']
            if not os.path.isdir(style) and not str.lower(os.path.splitext(style)[-1]) in img_ext:
                raise StylizationException("style image doesn't exist")
            
            # Load model
            p_wct = PhotoWCT()
            p_wct.load_state_dict(torch.load('./PhotoWCTModels/photo_wct.pth'))

            from photo_smooth import Propagator
            p_pro = Propagator()
            
            if os.path.isdir(content):
                has_image = False
                for i in os.listdir(content):
                    ext = str.lower(os.path.splitext(i)[-1])
                    if ext in img_ext:
                        has_image = True
                        stylization(
                            stylization_module=p_wct,
                            smoothing_module=p_pro,
                            content_image_path=f"{content}{i}",
                            style_image_path=style,
                            content_seg_path=[],
                            style_seg_path=[],
                            output_image_path=f"{output}processed_{i}",
                            )
                        show_image(f"{output}processed_{i}")
                if not has_image:
                    raise StylizationException("no image in content folder")
            else:
                if not str.lower(os.path.splitext(content)[-1]) in img_ext:
                    raise StylizationException("content image doesn't exist")
                stylization(
                    stylization_module=p_wct,
                    smoothing_module=p_pro,
                    content_image_path=content,
                    style_image_path=style,
                    content_seg_path=[],
                    style_seg_path=[],
                    output_image_path=f"{output}processed_{str.split(content, '/')[-1]}",
                    )
            show_image(output)             
            break
        except StylizationException as se:
            print(se.msg)
def load_model(model='./PhotoWCTModels/photo_wct.pth', fast=True, cuda=1):
    """load model, fast=lighter version"""
    # Load model
    p_wct = PhotoWCT()
    p_wct.load_state_dict(torch.load(model))

    if fast:
        from photo_gif import GIFSmoothing
        p_pro = GIFSmoothing(r=35, eps=0.001)
    else:
        from photo_smooth import Propagator
        p_pro = Propagator()
    if cuda:
        p_wct.cuda(0)

    return p_wct, p_pro
Esempio n. 4
0
def setup(opts):
    p_wct = PhotoWCT()
    p_wct.load_state_dict(torch.load(PRETRAINED_MODEL_PATH))

    if opts['propagation_mode'] == 'fast':
        from photo_gif import GIFSmoothing
        p_pro = GIFSmoothing(r=35, eps=0.001)
    else:
        from photo_smooth import Propagator
        p_pro = Propagator()
    if torch.cuda.is_available():
        p_wct.cuda(0)

    return {
        'p_wct': p_wct,
        'p_pro': p_pro,
    }
Esempio n. 5
0
parser.add_argument('--decoder2',
                    default='./models/feature_invertor_conv2_1_mask.t7',
                    help='Path to the decoder2')
parser.add_argument('--decoder1',
                    default='./models/feature_invertor_conv1_1_mask.t7',
                    help='Path to the decoder1')
parser.add_argument('--content_image_path', default='./images/content1.png')
parser.add_argument('--content_seg_path', default=[])
parser.add_argument('--style_image_path', default='./images/style1.png')
parser.add_argument('--style_seg_path', default=[])
parser.add_argument('--output_image_path', default='./results/example1.png')
args = parser.parse_args()

# Load model
p_wct = PhotoWCT(args)
p_pro = Propagator()
p_wct.cuda(0)

content_image_path = args.content_image_path
content_seg_path = args.content_seg_path
style_image_path = args.style_image_path
style_seg_path = args.style_seg_path
output_image_path = args.output_image_path

# Load image
cont_img = Image.open(content_image_path).convert('RGB')
styl_img = Image.open(style_image_path).convert('RGB')
try:
    cont_seg = Image.open(content_seg_path)
    styl_seg = Image.open(style_seg_path)
except:
Esempio n. 6
0
parser.add_argument('--output_image_path', default='./results/example1.png')
parser.add_argument('--save_intermediate', action='store_true', default=False)
parser.add_argument('--fast', action='store_true', default=False)
parser.add_argument('--no_post', action='store_true', default=False)
parser.add_argument('--cuda', type=int, default=1, help='Enable CUDA.')
args = parser.parse_args()

# Load model
p_wct = PhotoWCT()
p_wct.load_state_dict(torch.load(args.model))

if args.fast:
    from photo_gif import GIFSmoothing
    p_pro = GIFSmoothing(r=35, eps=0.001)
else:
    from photo_smooth import Propagator
    p_pro = Propagator()
if args.cuda:
    p_wct.cuda(0)

process_stylization.stylization(stylization_module=p_wct,
                                smoothing_module=p_pro,
                                content_image_path=args.content_image_path,
                                style_image_path=args.style_image_path,
                                content_seg_path=args.content_seg_path,
                                style_seg_path=args.style_seg_path,
                                output_image_path=args.output_image_path,
                                cuda=args.cuda,
                                save_intermediate=args.save_intermediate,
                                no_post=args.no_post)
cont_img_list = [
    f for f in os.listdir(cont_img_folder)
    if os.path.isfile(os.path.join(cont_img_folder, f))
]
cont_img_list.sort()

# Load model
p_wct = PhotoWCT(device=args.device)
p_wct.load_state_dict(torch.load(args.model))
# Load Propagator
if args.fast:
    from photo_gif import GIFSmoothing
    p_pro = GIFSmoothing(r=35, eps=0.01)
else:
    from photo_smooth import Propagator
    p_pro = Propagator(args.beta)

for f in cont_img_list:
    content_image_path = os.path.join(cont_img_folder, f)
    content_seg_path = os.path.join(cont_seg_folder,
                                    f).replace(args.cont_img_ext,
                                               args.cont_seg_ext)
    style_image_path = os.path.join(styl_img_folder, f)
    style_seg_path = os.path.join(styl_seg_folder,
                                  f).replace(args.styl_img_ext,
                                             args.styl_seg_ext)
    output_image_path = os.path.join(outp_img_folder, f)

    print("Content image: " + content_image_path)
    if os.path.isfile(content_seg_path):
        print("Content mask: " + content_seg_path)
Esempio n. 8
0
                        stylized_image_path,
                        reversed_image_path,
                        cuda,
                        smoothing_module=smoothing_module,
                        do_smoothing=do_smoothing)
    print('Stylized image', stylized_image_path)
    print('Reversed image', reversed_image_path)
    print("MSE loss between content and reversed image:",
          mse_loss_images(content_image_path, reversed_image_path).item())
    print("Content loss between content and reversed image:",
          content_loss_images(content_image_path, reversed_image_path).item())
    print("=" * 15)


p_wct = PhotoWCT()
p_pro = Propagator(beta=0.7)
cuda = torch.cuda.is_available()
do_smoothing = False  # change this flag if you want to apply or no the smoothing module

if cuda:
    p_wct.cuda(0)

p_wct_paths = [
    './PhotoWCTModels/cyclic_photo_wct_mse.pth',
    './PhotoWCTModels/cyclic_photo_wct_content.pth',
    './PhotoWCTModels/cyclic_photo_wct_mse_content.pth'
]
content_image_paths = [
    './images/opernhaus.jpg', './images/forest_summer.jpg',
    './images/pyramid_egypt.jpg'
]
Esempio n. 9
0

SEED = 1984
N_IMGS = 100
FAST = True
SPLIT = "val"
MODEL = "./PhotoWCTModels/photo_wct.pth"
STYLE_PATH = "/data/datasets/style_transfer_amos/styles_sub_10"
CONTENT_PATH = "/data/datasets/phototourism"
OUTPUT_ST_PATH = osp.join(CONTENT_PATH, "style_transfer_all")
STYLES = ["cloudy", "dusk", "mist", "night", "rainy", "snow"]
#STYLES = ["snow"]

p_wct = PhotoWCT()
p_wct.load_state_dict(torch.load(MODEL))
p_pro = GIFSmoothing(r=35, eps=0.001) if FAST else Propagator()
p_wct.cuda(0)

with open(osp.join(CONTENT_PATH, SPLIT + "_phototourism_ms.txt"), "r") as f:
    content_fnames = [line.rstrip('\n') for line in f]

for style in STYLES:
    print("Style: {:s}".format(style))
    style_fnames = [img for img in os.listdir(osp.join(STYLE_PATH, style)) if img[-3:] in ["png", "jpg"]]
    for style_fname in style_fnames:
        k_cont = 0
        for content_fname in content_fnames:
            scene = content_fname.split('/')[0]
            output_path = osp.join(CONTENT_PATH, "style_transfer_all", scene, style)
            if not osp.isdir(output_path):
                os.makedirs(output_path)