def train(decoder,content_loader,style_loader): '''训练解码器''' print('Start training...') tic=time.time() fpnet=FPnet(decoder) fpnet=fpnet.to(device) fpnet.train() #使用adam优化器 optimizer = torch.optim.Adam(fpnet.decoder.parameters(), lr=args.lr) #optimizer = optim.SGD(decoder.parameters(), lr=0.001, momentum=0.9) #优化函数为随机梯度下降 with alive_bar(args.iter_times) as bar: for epoch in range(args.iter_times): bar() adjust_learning_rate(optimizer,epoch) try: content=next(content_loader).to(device) style=next(style_loader).to(device) except: continue loss,output=fpnet(content,style,lamda=args.lamda,alpha=args.alpha) optimizer.zero_grad() loss.backward() optimizer.step() #print('iters: '+str(epoch)+' loss:'+str(loss.item())) torch.save(fpnet.decoder.state_dict(), args.save_dir) toc=time.time() print('TRIANING COMPLETED.') print('Time cost: {}s.'.format(toc-tic)) print('model saved as: '+args.save_dir)
def test(contentpath, stylepath, multi=False): '''一次前传得到风格化图像''' if multi == False: content_name = (contentpath.split('/')[-1]).split('.')[0] style_name = (stylepath.split('/')[-1]).split('.')[0] else: content_name = (contentpath.split('\\')[-1]).split('.')[0] style_name = (stylepath.split('\\')[-1]).split('.')[0] transfer = test_transform(512) contentimg = Image.open(str(contentpath)).convert('RGB') styleimg = Image.open(str(stylepath)).convert('RGB') if args.preserve_color: styleimg = change_color(styleimg, contentimg) contentimg = transfer(contentimg).unsqueeze(0) styleimg = transfer(styleimg).unsqueeze(0) #if args.preserve_color: styleimg = coral(styleimg, contentimg) decoder = Decoder().to(device).eval() decoder.load_state_dict(torch.load(args.model_path)) fbnet = FPnet(decoder, True).to(device).eval() output = fbnet(contentimg, styleimg, alpha=args.alpha, lamda=args.lamda, require_loss=False) image_name = args.save_dir + '/' + content_name + '+' + style_name + '.jpg' save_image(output.cpu(), image_name) print('image saved as: ' + image_name) contentimg.detach() styleimg.detach()
def test(contentpath,stylepath,pixel,multi=False): '''一次前传得到风格化图像''' if multi==False: content_name=(contentpath.split('/')[-1]).split('.')[0] style_name=(stylepath.split('/')[-1]).split('.')[0] else: content_name=(contentpath.split('\\')[-1]).split('.')[0] style_name=(stylepath.split('\\')[-1]).split('.')[0] mytransfer=test_transform(pixel) contentfile = open(str(contentpath),'rb') stylefile = open(str(stylepath),'rb') contentimg = Image.open(contentfile).convert('RGB') styleimg = Image.open(stylefile).convert('RGB') contentfile.close() stylefile.close() if args.preserve_color: styleimg = change_color(styleimg, contentimg) #if args.preserve_color: contentimg,styleimg,contentH,contentS = lumi_only(contentimg,styleimg,mytransfer) contentimg=mytransfer(contentimg).unsqueeze(0) styleimg=mytransfer(styleimg).unsqueeze(0) decoder=Decoder().to(device).eval() decoder.load_state_dict(torch.load(args.model_path)) fbnet=FPnet(decoder,True).to(device).eval() output=fbnet(contentimg,styleimg,alpha=args.alpha,lamda=args.lamda,require_loss=False) #if args.preserve_color:output=recover_color(output,contentH,contentS) image_name=args.save_dir+'/'+content_name+'+'+style_name+'-'+str(pixel)+'-'+str(args.alpha)+'.jpg' save_image(output.cpu(),image_name) print('image saved as: '+image_name) contentimg.detach() styleimg.detach() output.detach()
def transfer(contentpath, stylepath, converted, pixel=512, model_path='static/20200522decoder100000_1.pth'): '''一次前传得到风格化图像''' mytransfer = test_transform(pixel) contentfile = open(str(contentpath), 'rb') stylefile = open(str(stylepath), 'rb') contentimg = Image.open(contentfile).convert('RGB') styleimg = Image.open(stylefile).convert('RGB') contentfile.close() stylefile.close() contentimg = mytransfer(contentimg).unsqueeze(0) styleimg = mytransfer(styleimg).unsqueeze(0) # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cpu") decoder = Decoder().to(device).eval() decoder.load_state_dict( torch.load(model_path, map_location=torch.device('cpu'))) # decoder = decoder.module # decoder.load_state_dict(torch.load(model_path)) fbnet = FPnet(decoder, True).to(device).eval() output = fbnet(contentimg, styleimg, alpha=1.0, lamda=1.0, require_loss=False) save_image(output.cpu(), converted) contentimg.detach() styleimg.detach() output.detach()
PATH = os.path.dirname(__file__) SCALE = (400,320)#帧大小 pre_style=0#当前风格 #加载模型 encoder=vgg.eval().cuda() encoder.load_state_dict(torch.load(os.path.join(PATH,'model/vgg_normalised.pth'),map_location="cpu")) decoder=testdecoder.eval().cuda() decoder.load_state_dict(torch.load(os.path.join(PATH,"model/decoder.pth"),map_location="cpu")) for param in encoder.parameters(): param.requires_grad = False for param in decoder.parameters(): param.requires_grad = False fbnet=FPnet(encoder,decoder).eval().cuda() encode_param=[int(cv2.IMWRITE_JPEG_QUALITY),90] #初始化风格特征 def test_transform(size): transform_list = [] transform_list.append(transforms.Resize(size)) transform_list.append(transforms.ToTensor()) transform = transforms.Compose(transform_list) return transform