def stylize(**kwargs): opt = Config() for k_, v_ in kwargs.items(): setattr(opt, k_, v_) # 图片处理 content_image = tv.datasets.folder.default_loader(opt.content_path) content_transform = tv.transforms.Compose([ tv.transforms.ToTensor(), tv.transforms.Lambda(lambda x: x.mul(255)) ]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0) content_image = Variable(content_image, volatile=True) # 模型 style_model = TransformerNet().eval() style_model.load_state_dict(t.load(opt.model_path, map_location=lambda _s, _: _s)) if opt.use_gpu: content_image = content_image.cuda() style_model.cuda() # 风格迁移与保存 output = style_model(content_image) output_data = output.cpu().data[0] tv.utils.save_image(((output_data / 255)).clamp(min=0, max=1), opt.result_path)
def stylize(args): device = torch.device("cuda" if args.cuda else "cpu") content_image = utils.load_image(args.content_image, scale=args.content_scale) content_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device) if args.model.endswith(".onnx"): output = stylize_onnx_caffe2(content_image, args) else: with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(args.model) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) if args.export_onnx: assert args.export_onnx.endswith(".onnx"), "Export model file should end with .onnx" output = torch.onnx._export(style_model, content_image, args.export_onnx).cpu() else: output = style_model(content_image).cpu() utils.save_image(args.output_image, output[0])
def stylize(args): content_image = utils.load_image(args.content_image, scale=args.content_scale) content_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0) if args.cuda: content_image = content_image.cuda() content_image = Variable(content_image, volatile=True) style_model = TransformerNet() style_model.load_state_dict(torch.load(args.model)) if args.cuda: style_model.cuda() output = style_model(content_image) if args.cuda: output = output.cpu() output_data = output.data[0] utils.save_image(args.output_image, output_data)
def stylize(args): device = torch.device("cuda" if args.cuda else "cpu") content_image = utils.load_image(args.content_image, scale=args.content_scale) content_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device) if args.model.endswith(".onnx"): output = stylize_onnx_caffe2(content_image, args) else: with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(args.model) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) if args.export_onnx: assert args.export_onnx.endswith( ".onnx"), "Export model file should end with .onnx" output = torch.onnx._export(style_model, content_image, args.export_onnx).cpu() else: output = style_model(content_image).cpu() utils.save_image(args.output_image, output[0])
def multi_style(path,width=320,device=device): model_iter=itertools.cycle(os.listdir(path)) model_file=next(model_iter) print(f'Using {model_file} ') model_path=os.path.join(path,model_file) model = TransformerNet() model.load_state_dict(read_state_dict(model_path)) model.to(device) vs = PiVideoStream().start() time.sleep(2.0) timer=Timer() while(True): frame=vs.read() if frame is None: frame=np.random.randint(0,255,(int(width/1.5),width,3),dtype=np.uint8) frame=cv2.flip(frame, 1) frame = resize(frame, width=width) # Style the frame img=style_frame(frame,model,device).numpy() img=np.clip(img,0,255) img=img.astype(np.uint8) img = img.transpose(1, 2, 0) img=cv2.resize(img[:,:,::-1],(640,480)) # print(img.shape) cv2.imshow("Output", img) timer() key = cv2.waitKey(1) & 0xFF if key == ord("n"): model_file=next(model_iter) print(f'Using {model_file} ') model_path=os.path.join(path,model_file) model.load_state_dict(read_state_dict(model_path)) model.to(device) elif key == ord("q"): break
def __init__(self, image_queue, output_queue, engine_id, log_flag = True): super(StyleServer, self).__init__(image_queue, output_queue, engine_id) self.log_flag = log_flag self.is_first_image = True self.dir_path = os.getcwd() self.model = self.dir_path+'/../models/the_scream.model' self.path = self.dir_path+'/../models/' print('MODEL PATH {}'.format(self.path)) # initialize model self.style_model = TransformerNet() self.style_model.load_state_dict(torch.load(self.model)) self.style_model.cuda() self.style_type = "the_scream" self.content_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) wtr_mrk4 = cv2.imread('../wtrMrk.png',-1) # The waterMark is of dimension 30x120 self.mrk,_,_,mrk_alpha = cv2.split(wtr_mrk4) # The RGB channels are equivalent self.alpha = mrk_alpha.astype(float)/255 print('FINISHED INITIALISATION')
def vectorize(args): size = args.size # vectors = np.zeros((size, size, 2), dtype=np.float32) # for y in range(size): # for x in range(size): # xx = float(x - size / 2) # yy = float(y - size / 2) # rsq = xx ** 2 + yy ** 2 # if (rsq == 0): # vectors[y, x, 0] = 1 # vectors[y, x, 1] = 1 # else: # vectors[y, x, 0] = -yy / rsq # vectors[y, x, 1] = xx / rsq # vectors = NormalizVectrs(vectors) device = torch.device("cuda" if args.cuda else "cpu") content_image = Image.open(args.content_image).convert('L') content_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0) content_image = utils.subtract_imagenet_mean_batch(content_image) content_image = content_image.to(device) with torch.no_grad(): vectorize_model = TransformerNet() state_dict = torch.load(args.saved_model) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): pdb.set_trace() del state_dict[k] vectorize_model.load_state_dict(state_dict) vectorize_model.to(device) output = vectorize_model(content_image) target = dataset.hdf5_loader(args.target_vector) target_transform = transforms.ToTensor() target = target_transform(target) target = target.unsqueeze(0).to(device) cosine_loss = torch.nn.CosineEmbeddingLoss() label = torch.ones(1, 1, args.size, args.size).to(device) loss = cosine_loss(output, target, label) print(loss.item()) output = output.cpu().clone().numpy()[0].transpose(1, 2, 0) output = NormalizVectrs(output) lic(output, "output.jpg") target = target.cpu().clone().numpy()[0].transpose(1, 2, 0) lic(target, "target.jpg")
def stylize(args): content_image = load_image_eval(args.content_image) with flow.no_grad(): style_model = TransformerNet() state_dict = flow.load(args.model) style_model.load_state_dict(state_dict) style_model.to("cuda") output = style_model( flow.Tensor(content_image).clamp(0, 255).to("cuda")) print(args.output_image) cv2.imwrite(args.output_image, recover_image(output.numpy()))
def stylize(args): device = torch.device("cuda" if args.cuda else "cpu") content_image_input = utils.load_image(args.content_image, size=args.output_size, scale=args.content_scale) content_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(256)), transforms.Lambda(lambda x: x[torch.LongTensor([2, 1, 0])]), transforms.Normalize(mean=[103.939, 116.779, 123.68], std=[1, 1, 1]) ]) content_image = content_transform(content_image_input) content_image = content_image.unsqueeze(0).to(device) if args.model.endswith(".onnx"): output = stylize_onnx_caffe2(content_image, args) else: with torch.no_grad(): style_model = torch.jit.script(TransformerNet()) state_dict = torch.load(args.model) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.eval().to(device) if args.export_onnx: assert args.export_onnx.endswith( ".onnx"), "Export model file should end with .onnx" inputs = ['images'] outputs = ['scores'] dynamic_axes = {'images': {0: 'batch'}, 'scores': {0: 'batch'}} output = torch.onnx._export(style_model, content_image, args.export_onnx, verbose=True).cpu() else: output = style_model(content_image).cpu() utils.save_image_vgg19( args.output_image, output[0], args.original_colors, content_image_input, )
def load_style(model): device = "cuda" with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(model) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) return style_model
def init(): global model #model_path = os.path.join('picasso.pth') model_path = Model.get_model_path('picasso.pth') model = TransformerNet() state_dict = torch.load(model_path) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] model.load_state_dict(state_dict) model.eval()
def get_output(trained_model, content_image): with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(trained_model) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) output = style_model(content_image).cpu() # utils.save_image(args.output_image, output[0]) return output
def stylize(args): content_image = utils.tensor_load_rgbimage(args.content_image, scale=args.content_scale) content_image = content_image.unsqueeze(0) if args.cuda: content_image = content_image.cuda() content_image = Variable(utils.preprocess_batch(content_image), volatile=True) style_model = TransformerNet() style_model.load_state_dict(torch.load(args.model)) if args.cuda: style_model.cuda() output = style_model(content_image) utils.tensor_save_bgrimage(output.data[0], args.output_image, args.cuda)
def transfer_style(self, content_img_stream): content_image = self._process_image(content_img_stream) with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load('rain_princess.pth') for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(self.device) output = self.style_model(content_image).cpu() output = numpy.array(output.squeeze(0)) output = output.transpose(1, 2, 0).astype("uint8") return misc.toimage(output)
def stylize(content_image, model, content_scale=None, cuda=0): content_image = utils.tensor_load_rgbimage(content_image, scale=content_scale) content_image = content_image.unsqueeze(0) if cuda: content_image = content_image.cuda() content_image = Variable(utils.preprocess_batch(content_image), volatile=True) style_model = TransformerNet() style_model.load_state_dict(torch.load(model)) if cuda: style_model.cuda() output = style_model(content_image) return utils.tensor_to_Image(output, cuda)
def stylize(model_path): content_image = utils.load_image(enums.content_image, scale=enums.content_scale) content_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0) if enums.cuda: content_image = content_image.cuda() content_image = Variable(content_image, volatile=True) model_state = torch.load(model_path) style_model = TransformerNet() style_model.load_state_dict(model_state['state_dict']) style_model.eval() if enums.cuda: style_model.cuda() output = style_model(content_image) if enums.cuda: output = output.cpu() output_data = output.data[0] utils.save_image(enums.output_image, output_data)
def stylize(args): device = torch.device("cuda" if args.cuda else "cpu") content_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) print('Reading bag file '+ os.getcwd() + '/' + args.content_image + ' topic ' + args.topic) readbag = rosbag.Bag(args.content_image,'r') bridge = CvBridge() cv_img = [] with rosbag.Bag(args.output_image,'w') as outbag: for topic, msg, dtime in readbag.read_messages(): if topic == args.topic: cv_img = bridge.imgmsg_to_cv2(msg, desired_encoding="rgb8") (h, w, d) = img.shape nh = h/args.content_scale nw = w/args.content_scale cv_img = cv2.resize((h, w), (int(nh),int(nw))) with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(args.model) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) if args.export_onnx: assert args.export_onnx.endswith(".onnx"), "Export model file should end with .onnx" output = torch.onnx._export(style_model, content_image, args.export_onnx).cpu() else: print('stylizing image ...') output = style_model(cv_img).cpu() else: outbag.write(topic, msg, msg.header.stamp if msg._has_header else t) readbag.close()
def stylize(args): # content_image = utils.tensor_load_rgbimage(args.content_image, scale=args.content_scale) # content_image = content_image.unsqueeze(0) content_image = np.loadtxt(args.content_image) upsample_ratio = 8 batch_size = 100 #content_image = content_image[:, :200] num_images = content_image.shape[1] num_batch = int(content_image.shape[1] / batch_size) output_model_total = [] for batch_id in range(num_batch): print('[{}]/[{}] iters '.format(batch_id + 1, num_batch)) x = content_image[:, batch_id * batch_size:(batch_id + 1) * batch_size] x = x.transpose() x = x.reshape((-1, 1, args.image_size_x, args.image_size_y)) x = torch.from_numpy(x).float() if args.cuda: x = x.cuda() x = Variable(x, volatile=True) style_model = TransformerNet() style_model.load_state_dict(torch.load(args.model)) if args.cuda: style_model.cuda() output_model = style_model(x) # output_image = output_image.numpy() output_model = output_model.data output_image = output_model.repeat(1, 3, 1, 1) output_model = output_model.cpu().numpy().astype(float) output_model = output_model.reshape( (batch_size, args.image_size_x * args.image_size_y * upsample_ratio**2)) output_model = output_model.transpose() output_model_total.append(output_model) output_model_total = np.hstack(output_model_total) np.savetxt(args.output_model, output_model_total)
def transfer_style(self, content_img_stream, model_name): device = torch.device("cpu") content_image = self.process_image(content_img_stream) with torch.no_grad(): style_model = TransformerNet() # ниже нужно указать путь до папки, где хранятся модели # base_dir = './saved_models/' base_dir = '/Users/yanadm/Documents/Style Transfer Bot/saved_models/' filename = model_name path_to_model = os.path.join(base_dir, filename) state_dict = torch.load(path_to_model) for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) output = style_model(content_image).cpu() return misc.toimage(output[0])
def stylize(args): device = torch.device("cuda" if args.cuda else "cpu") content_image = utils.load_image_local(args.content_image) content_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device) with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(args.model) for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) output = style_model(content_image).cpu() utils.save_image(args.output_image, output[0])
def stylize(args): device = torch.device("cuda" if args.cuda else "cpu") content_image = utils.load_image(args.content_image, scale=args.content_scale) content_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device) with torch.no_grad(): style_model = TransformerNet(style_num=args.style_num) state_dict = torch.load(args.model) style_model.load_state_dict(state_dict) style_model.to(device) output = style_model(content_image, style_id = [args.style_id]).cpu() utils.save_image('output/'+args.output_image+'_style'+str(args.style_id)+'.jpg', output[0])
def style_data(): style = request.form.get('style') modelname = { 'X': 'hiphop.pth', 'A': 'rain_princess.pth', 'B': 'starry-night.model', 'C': 'style6.pth', 'D': 'style8.pth', 'E': 'style9.pth' } txt = [] txt.append(session['title']) txt.append(session['comment_1'] + '\n' + session['comment_2']) pos = choosetemplate('./static/' + session["imagepath_2"], './static/' + session["imagepath_1"], txt, "./static/img/" + style + "results_notext.jpg") if style != 'Z': device = torch.device("cuda") with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load('../ST/saved_models/' + modelname[style]) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) stylize("./static/img/" + style + "results_notext.jpg", "./static/img/" + style + "results_notext_S.png", model=style_model, device=device) torch.cuda.empty_cache() # stylize("./static/img/"+style+"results_notext.jpg", "./static/img/"+style+"results_notext_S.png", model='../ST/saved_models/'+modelname[style]) addallimage("./static/img/" + style + "results_notext_S.png", pos, "./static/img/results.jpg") else: addallimage("./static/img/" + style + "results_notext.jpg", pos, "./static/img/results.jpg") torch.cuda.empty_cache() return render_template("style.html", style_img=session['style_img'])
def stylize(img_stream, style_type): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") content_image = load_image(img_stream, scale=None) content_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device) with torch.no_grad(): style_model = TransformerNet() models_dir = './saved_models/' state_dict = torch.load(models_dir + '{}'.format(style_type)) for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) output = style_model(content_image).cpu() return misc.toimage(output[0])
def stylize(input_image, style, cuda): device = torch.device("cuda" if cuda else "cpu") content_image = utils.load_image(input_image, scale=None) content_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device) with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load("saved_models/" + style + ".pth") # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) output = style_model(content_image).cpu() utils.save_image("results/" + style + '_out.jpg', output[0])
def stylize(args): content_image = utils.load_image(args.content_image, scale=args.content_scale) content_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0) if args.cuda: content_image = content_image.cuda() content_image = Variable(content_image, volatile=True) style_model = TransformerNet() style_model.load_state_dict(torch.load(args.model)) if args.cuda: style_model.cuda() output = style_model(content_image) if args.cuda: output = output.cpu() output_data = output.data[0] utils.save_image(args.output_image, output_data) writer.add_image('output', output_data)
def run(self, content_img_arr, saved_style_model_path): ''' Implmentation of Fast Neural Style Transfer Algorithm.Takes Style Image (numpy array with shape (h,w,3)) & Trained Model weights for a particular Style (.pth file) & return output image as a numpy array of shape (3,h,w) ''' content_image = Image.fromarray(np.uint8(content_img_arr)) content_image = self.content_transform(content_image) content_image = content_image.unsqueeze(0).to(self.device) with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(saved_style_model_path) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(self.device) output = style_model(content_image).cpu() return (output.numpy()[0])
def stylize(args): # 提供了一个测试,当我们训练好了模型,就可以用这个函数来帮我们生成图片了 device = torch.device("cuda" if args.cuda else "cpu") content_image = utils.load_image(args.content_image, scale=args.content_scale) content_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device) with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(args.model) # 从checkpoint删除InstanceNorm中已保存的不建议使用的running_ *keys for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) output = style_model(content_image).cpu() utils.save_image(args.output_image, output[0])
def stylize(args): if torch.cuda.is_available(): print('CUDA available, using GPU.') device = torch.device('cuda') else: print('GPU training unavailable... using CPU.') device = torch.device('cpu') content_image = utils.load_image(args.content_image, scale=args.content_scale) content_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) # NOTE: Remove UNSQUEEZE, move to TransformerNet for CoreML UIImage input... # content_image = content_transform(content_image).to(device) content_image = content_image.unsqueeze(0).to(device) if args.model.endswith(".onnx"): output = stylize_onnx_caffe2(content_image, args) else: with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(args.model) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) if args.export_onnx: assert args.export_onnx.endswith(".onnx"), "Export model file should end with .onnx" output = torch.onnx._export(style_model, content_image, args.export_onnx, input_names=['inputImage']).cpu() else: output = style_model(content_image).cpu() utils.save_image(args.output_image, output[0])
def stylize(DC): trans_gpu_id = DC.trans_gpu_id device = t.device('cuda', trans_gpu_id) if DC.use_gpu else t.device('cpu') # Image preprocessing content_img = tv.datasets.folder.default_loader(DC.low_resol_img) content_transform = T.Compose( [T.ToTensor(), T.Lambda(lambda x: x.mul(255))]) content_img = content_transform(content_img) content_img = content_img.unsqueeze(0).to(device).detach() # Load model trans_model = TransformerNet().eval() trans_model.load_state_dict( t.load(DC.trans_model, map_location=lambda storage, loc: storage)) trans_model.to(device) # Save the transformer image output = trans_model(content_img) output_data = output.cpu().data[0] tv.utils.save_image((output_data / 255).clamp(min=0, max=1), DC.result_img)
def main(args): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') content_image = Image.open(args.content) transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content = transform(content_image) content = content.unsqueeze(0).to(device) with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(args.model) # TODO: remove saved deprecated running_* keys in InstanceNorm style_model.load_state_dict(state_dict) style_model.to(device) # Forward through Image Transformation Network out = style_model(content).cpu() # Save result image save_image(out[0], args.out)
def stylize_NYU(frame, settype): size = len(frame) print("total: ", size) stylized_set = [] depth_set = [] #models = ["candy" , "mosaic", "rain_princess", "udnie"] models = ["mosaic"] for model in models: style_model = TransformerNet() modelpath = "saved_models/" + model + ".pth" state_dict = torch.load(modelpath) for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) for idx in range(size): #if idx % 1000 == 0: if idx % 1 == 0: print("iter: ", idx) image_name = frame.iloc[idx, 0] depth_name = frame.iloc[idx, 1] image = matplotlib.image.imread(image_name, format="jpg") depth = matplotlib.image.imread(depth_name) with torch.no_grad(): output = single_stylize(style_model, image) #visualize_trio(image, depth, output) #stylized_set.append(output) #depth_set.append(depth) #break save_stylized(output, depth, image_name, depth_name, settype) return stylized_set, depth_set
def stylize(frame): device = torch.device("cuda") content_image = frame content_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device) with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load('../saved_models/udnie.pth') # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) output = style_model(content_image).cpu() return output.clamp(0,255)
def stylize(args): device = torch.device("cuda" if args.cuda else "cpu") content_image = utils.load_image(args.content_image, scale=args.content_scale) content_image = content_image.convert('RGB') content_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)), ]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device) with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(args.model) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) output = style_model(content_image).cpu() utils.save_image(args.output_image, output[0])
def train(args): device = torch.device("cuda" if args.cuda else "cpu") np.random.seed(args.seed) torch.manual_seed(args.seed) transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size) transformer = TransformerNet().to(device) optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16(requires_grad=False).to(device) style_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) style = utils.load_image(args.style_image, size=args.style_size) style = style_transform(style) style = style.repeat(args.batch_size, 1, 1, 1).to(device) features_style = vgg(utils.normalize_batch(style)) gram_style = [utils.gram_matrix(y) for y in features_style] for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() x = x.to(device) y = transformer(x) y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x) content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2) style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.item() agg_style_loss += style_loss.item() if (batch_id + 1) % args.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1) ) print(mesg) if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0: transformer.eval().cpu() ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) transformer.to(device).train() # save model transformer.eval().cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(**kwargs): opt = Config() for k_, v_ in kwargs.items(): setattr(opt, k_, v_) vis = utils.Visualizer(opt.env) # 数据加载 transfroms = tv.transforms.Compose([ tv.transforms.Scale(opt.image_size), tv.transforms.CenterCrop(opt.image_size), tv.transforms.ToTensor(), tv.transforms.Lambda(lambda x: x * 255) ]) dataset = tv.datasets.ImageFolder(opt.data_root, transfroms) dataloader = data.DataLoader(dataset, opt.batch_size) # 转换网络 transformer = TransformerNet() if opt.model_path: transformer.load_state_dict(t.load(opt.model_path, map_location=lambda _s, _: _s)) # 损失网络 Vgg16 vgg = Vgg16().eval() # 优化器 optimizer = t.optim.Adam(transformer.parameters(), opt.lr) # 获取风格图片的数据 style = utils.get_style_data(opt.style_path) vis.img('style', (style[0] * 0.225 + 0.45).clamp(min=0, max=1)) if opt.use_gpu: transformer.cuda() style = style.cuda() vgg.cuda() # 风格图片的gram矩阵 style_v = Variable(style, volatile=True) features_style = vgg(style_v) gram_style = [Variable(utils.gram_matrix(y.data)) for y in features_style] # 损失统计 style_meter = tnt.meter.AverageValueMeter() content_meter = tnt.meter.AverageValueMeter() for epoch in range(opt.epoches): content_meter.reset() style_meter.reset() for ii, (x, _) in tqdm.tqdm(enumerate(dataloader)): # 训练 optimizer.zero_grad() if opt.use_gpu: x = x.cuda() x = Variable(x) y = transformer(x) y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x) # content loss content_loss = opt.content_weight * F.mse_loss(features_y.relu2_2, features_x.relu2_2) # style loss style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gram_y = utils.gram_matrix(ft_y) style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y)) style_loss *= opt.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() # 损失平滑 content_meter.add(content_loss.data[0]) style_meter.add(style_loss.data[0]) if (ii + 1) % opt.plot_every == 0: if os.path.exists(opt.debug_file): ipdb.set_trace() # 可视化 vis.plot('content_loss', content_meter.value()[0]) vis.plot('style_loss', style_meter.value()[0]) # 因为x和y经过标准化处理(utils.normalize_batch),所以需要将它们还原 vis.img('output', (y.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1)) vis.img('input', (x.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1)) # 保存visdom和模型 vis.save([opt.env]) t.save(transformer.state_dict(), 'checkpoints/%s_style.pth' % epoch)