예제 #1
0
def test(contentpath, stylepath, multi=False):
    '''一次前传得到风格化图像'''
    if multi == False:
        content_name = (contentpath.split('/')[-1]).split('.')[0]
        style_name = (stylepath.split('/')[-1]).split('.')[0]
    else:
        content_name = (contentpath.split('\\')[-1]).split('.')[0]
        style_name = (stylepath.split('\\')[-1]).split('.')[0]

    transfer = test_transform(512)

    contentimg = Image.open(str(contentpath)).convert('RGB')
    styleimg = Image.open(str(stylepath)).convert('RGB')
    if args.preserve_color: styleimg = change_color(styleimg, contentimg)
    contentimg = transfer(contentimg).unsqueeze(0)
    styleimg = transfer(styleimg).unsqueeze(0)

    #if args.preserve_color: styleimg = coral(styleimg, contentimg)

    decoder = Decoder().to(device).eval()
    decoder.load_state_dict(torch.load(args.model_path))

    fbnet = FPnet(decoder, True).to(device).eval()
    output = fbnet(contentimg,
                   styleimg,
                   alpha=args.alpha,
                   lamda=args.lamda,
                   require_loss=False)

    image_name = args.save_dir + '/' + content_name + '+' + style_name + '.jpg'
    save_image(output.cpu(), image_name)
    print('image saved  as:  ' + image_name)
    contentimg.detach()
    styleimg.detach()
예제 #2
0
def test(contentpath,stylepath,pixel,multi=False):
    '''一次前传得到风格化图像'''
    if multi==False:
        content_name=(contentpath.split('/')[-1]).split('.')[0]
        style_name=(stylepath.split('/')[-1]).split('.')[0]
    else:
        content_name=(contentpath.split('\\')[-1]).split('.')[0]
        style_name=(stylepath.split('\\')[-1]).split('.')[0]

    mytransfer=test_transform(pixel)

    contentfile = open(str(contentpath),'rb')
    stylefile = open(str(stylepath),'rb')

    contentimg = Image.open(contentfile).convert('RGB')
    styleimg = Image.open(stylefile).convert('RGB')

    contentfile.close()
    stylefile.close()
    
    if args.preserve_color: styleimg = change_color(styleimg, contentimg)
    #if args.preserve_color: contentimg,styleimg,contentH,contentS = lumi_only(contentimg,styleimg,mytransfer)

    contentimg=mytransfer(contentimg).unsqueeze(0)
    styleimg=mytransfer(styleimg).unsqueeze(0)


    decoder=Decoder().to(device).eval()
    decoder.load_state_dict(torch.load(args.model_path))

    fbnet=FPnet(decoder,True).to(device).eval()
    output=fbnet(contentimg,styleimg,alpha=args.alpha,lamda=args.lamda,require_loss=False)

    #if args.preserve_color:output=recover_color(output,contentH,contentS)

    image_name=args.save_dir+'/'+content_name+'+'+style_name+'-'+str(pixel)+'-'+str(args.alpha)+'.jpg'

    save_image(output.cpu(),image_name)
    print('image saved  as:  '+image_name)

    contentimg.detach()
    styleimg.detach()
    output.detach()
예제 #3
0
def transfer(contentpath,
             stylepath,
             converted,
             pixel=512,
             model_path='static/20200522decoder100000_1.pth'):
    '''一次前传得到风格化图像'''
    mytransfer = test_transform(pixel)

    contentfile = open(str(contentpath), 'rb')
    stylefile = open(str(stylepath), 'rb')

    contentimg = Image.open(contentfile).convert('RGB')
    styleimg = Image.open(stylefile).convert('RGB')

    contentfile.close()
    stylefile.close()

    contentimg = mytransfer(contentimg).unsqueeze(0)
    styleimg = mytransfer(styleimg).unsqueeze(0)

    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device("cpu")

    decoder = Decoder().to(device).eval()
    decoder.load_state_dict(
        torch.load(model_path, map_location=torch.device('cpu')))
    # decoder = decoder.module
    # decoder.load_state_dict(torch.load(model_path))

    fbnet = FPnet(decoder, True).to(device).eval()
    output = fbnet(contentimg,
                   styleimg,
                   alpha=1.0,
                   lamda=1.0,
                   require_loss=False)

    save_image(output.cpu(), converted)
    contentimg.detach()
    styleimg.detach()
    output.detach()
예제 #4
0
def main():

    #定义数据加载器
    transform=Transform()

    content_set = CSDataset(transform=transform, root=args.content_dir)
    style_set = CSDataset(transform=transform, root=args.style_dir)

    content_loader = iter(data.DataLoader( content_set, batch_size=args.batch_size,sampler=RecurrentSampler(content_set), \
        num_workers=args.n_threads,pin_memory=True,drop_last=True))

    style_loader = iter(data.DataLoader( style_set, batch_size=args.batch_size,sampler=RecurrentSampler(style_set),\
         num_workers=args.n_threads,pin_memory=True,drop_last=True))

    #初始化模型
    decoder=Decoder()
    decoder=decoder.to(device)
    decoder.load_state_dict(torch.load('model/decoder.pth'))
    #decoder.zero_grad()

    #训练网络
    train(decoder,content_loader,style_loader)
예제 #5
0
from net import Decoder
import torch
import matplotlib.pyplot as plt

input_data = torch.tensor([48.2976, -2.0373, -29.1018, 12.2312])
# input_data = torch.tensor([23.9500, -11.1840, 36.5084, -33.9963])
# input_data = torch.tensor([9.9927, -0.2395, -1.1840, -2.4094])
# input_data = torch.tensor([20.2314, -10.3747, 11.7729, -8.3415])
# input_data = torch.tensor([6.8456, 5.0325, -4.7677, -0.9231])
# input_data = torch.tensor([12.7765, -8.3813, 9.4535, -3.7266])
# input_data = torch.tensor([5.8932, 0.2757, -5.0638, 6.0578])
# input_data = torch.tensor([7.7216, 11.7913, 1.8647, -13.9440])
# input_data = torch.tensor([8.0073, -0.2398, 7.3321, -9.5089])
# input_data = torch.tensor([2.0922, 4.1933, -0.1578, -3.3220])
# input_data = (torch.randn(4))*50

print(input_data)
decoder = Decoder().eval()
model_dict = decoder.state_dict()
pretrained_dict = torch.load('model/net.pth')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
decoder.load_state_dict(model_dict)
with torch.no_grad():
    output = decoder(input_data)
    plt.imshow(output[0], cmap='gray')
    plt.show()