Esempio n. 1
0
def train(decoder,content_loader,style_loader):
    '''训练解码器'''
    print('Start training...')
    tic=time.time()

    fpnet=FPnet(decoder)
    fpnet=fpnet.to(device)
    fpnet.train()
    #使用adam优化器
    optimizer = torch.optim.Adam(fpnet.decoder.parameters(), lr=args.lr)
    #optimizer = optim.SGD(decoder.parameters(), lr=0.001, momentum=0.9) #优化函数为随机梯度下降
    with alive_bar(args.iter_times) as bar:
        for epoch in range(args.iter_times):
            bar()
            adjust_learning_rate(optimizer,epoch)

            try:
                content=next(content_loader).to(device)
                style=next(style_loader).to(device)
            except: continue

            loss,output=fpnet(content,style,lamda=args.lamda,alpha=args.alpha)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            #print('iters: '+str(epoch)+'  loss:'+str(loss.item()))

        torch.save(fpnet.decoder.state_dict(), args.save_dir)
        toc=time.time()

    print('TRIANING COMPLETED.')
    print('Time cost: {}s.'.format(toc-tic))
    print('model saved as:  '+args.save_dir)
Esempio n. 2
0
def test(contentpath, stylepath, multi=False):
    '''一次前传得到风格化图像'''
    if multi == False:
        content_name = (contentpath.split('/')[-1]).split('.')[0]
        style_name = (stylepath.split('/')[-1]).split('.')[0]
    else:
        content_name = (contentpath.split('\\')[-1]).split('.')[0]
        style_name = (stylepath.split('\\')[-1]).split('.')[0]

    transfer = test_transform(512)

    contentimg = Image.open(str(contentpath)).convert('RGB')
    styleimg = Image.open(str(stylepath)).convert('RGB')
    if args.preserve_color: styleimg = change_color(styleimg, contentimg)
    contentimg = transfer(contentimg).unsqueeze(0)
    styleimg = transfer(styleimg).unsqueeze(0)

    #if args.preserve_color: styleimg = coral(styleimg, contentimg)

    decoder = Decoder().to(device).eval()
    decoder.load_state_dict(torch.load(args.model_path))

    fbnet = FPnet(decoder, True).to(device).eval()
    output = fbnet(contentimg,
                   styleimg,
                   alpha=args.alpha,
                   lamda=args.lamda,
                   require_loss=False)

    image_name = args.save_dir + '/' + content_name + '+' + style_name + '.jpg'
    save_image(output.cpu(), image_name)
    print('image saved  as:  ' + image_name)
    contentimg.detach()
    styleimg.detach()
Esempio n. 3
0
def test(contentpath,stylepath,pixel,multi=False):
    '''一次前传得到风格化图像'''
    if multi==False:
        content_name=(contentpath.split('/')[-1]).split('.')[0]
        style_name=(stylepath.split('/')[-1]).split('.')[0]
    else:
        content_name=(contentpath.split('\\')[-1]).split('.')[0]
        style_name=(stylepath.split('\\')[-1]).split('.')[0]

    mytransfer=test_transform(pixel)

    contentfile = open(str(contentpath),'rb')
    stylefile = open(str(stylepath),'rb')

    contentimg = Image.open(contentfile).convert('RGB')
    styleimg = Image.open(stylefile).convert('RGB')

    contentfile.close()
    stylefile.close()
    
    if args.preserve_color: styleimg = change_color(styleimg, contentimg)
    #if args.preserve_color: contentimg,styleimg,contentH,contentS = lumi_only(contentimg,styleimg,mytransfer)

    contentimg=mytransfer(contentimg).unsqueeze(0)
    styleimg=mytransfer(styleimg).unsqueeze(0)


    decoder=Decoder().to(device).eval()
    decoder.load_state_dict(torch.load(args.model_path))

    fbnet=FPnet(decoder,True).to(device).eval()
    output=fbnet(contentimg,styleimg,alpha=args.alpha,lamda=args.lamda,require_loss=False)

    #if args.preserve_color:output=recover_color(output,contentH,contentS)

    image_name=args.save_dir+'/'+content_name+'+'+style_name+'-'+str(pixel)+'-'+str(args.alpha)+'.jpg'

    save_image(output.cpu(),image_name)
    print('image saved  as:  '+image_name)

    contentimg.detach()
    styleimg.detach()
    output.detach()
Esempio n. 4
0
def transfer(contentpath,
             stylepath,
             converted,
             pixel=512,
             model_path='static/20200522decoder100000_1.pth'):
    '''一次前传得到风格化图像'''
    mytransfer = test_transform(pixel)

    contentfile = open(str(contentpath), 'rb')
    stylefile = open(str(stylepath), 'rb')

    contentimg = Image.open(contentfile).convert('RGB')
    styleimg = Image.open(stylefile).convert('RGB')

    contentfile.close()
    stylefile.close()

    contentimg = mytransfer(contentimg).unsqueeze(0)
    styleimg = mytransfer(styleimg).unsqueeze(0)

    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device("cpu")

    decoder = Decoder().to(device).eval()
    decoder.load_state_dict(
        torch.load(model_path, map_location=torch.device('cpu')))
    # decoder = decoder.module
    # decoder.load_state_dict(torch.load(model_path))

    fbnet = FPnet(decoder, True).to(device).eval()
    output = fbnet(contentimg,
                   styleimg,
                   alpha=1.0,
                   lamda=1.0,
                   require_loss=False)

    save_image(output.cpu(), converted)
    contentimg.detach()
    styleimg.detach()
    output.detach()
Esempio n. 5
0
PATH = os.path.dirname(__file__)
SCALE = (400,320)#帧大小
pre_style=0#当前风格


#加载模型
encoder=vgg.eval().cuda()
encoder.load_state_dict(torch.load(os.path.join(PATH,'model/vgg_normalised.pth'),map_location="cpu"))
decoder=testdecoder.eval().cuda()
decoder.load_state_dict(torch.load(os.path.join(PATH,"model/decoder.pth"),map_location="cpu"))
for param in encoder.parameters():
    param.requires_grad = False
for param in decoder.parameters():
    param.requires_grad = False

fbnet=FPnet(encoder,decoder).eval().cuda()

encode_param=[int(cv2.IMWRITE_JPEG_QUALITY),90]



#初始化风格特征
def test_transform(size):
    transform_list = []
    transform_list.append(transforms.Resize(size))
    transform_list.append(transforms.ToTensor())
    transform = transforms.Compose(transform_list)
    return transform