def stylize(args): if args.content_type == "pi": #camera.start_preview() #sleep(5) #camera.capture('/home/pi/Desktop/image.jpg') #camera.stop_preview() content_image = '/home/pi/Desktop/image.jpg' else: content_image = utils.load_image(args.content_image, scale=args.content_scale) tstart = time.time() content_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0) with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(args.model) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) output = style_model(content_image) utils.save_image(args.output_image, output[0]) tstop = time.time() print("Inference time : " + str(1000 * (tstop - tstart)) + " ms")
def stylize(args): device = torch.device("cuda" if args.cuda else "cpu") content_image = utils.load_image(args.content_image, scale=args.content_scale) content_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device) if args.model.endswith(".onnx"): output = stylize_onnx_caffe2(content_image, args) else: with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(args.model) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) if args.export_onnx: assert args.export_onnx.endswith( ".onnx"), "Export model file should end with .onnx" output = torch.onnx._export(style_model, content_image, args.export_onnx).cpu() else: output = style_model(content_image).cpu() utils.save_image(args.output_image, output[0])
def convert(self, orig_image): #pil_img = Image.fromarray(orig_image) pil_img = orig_image content_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content_image = content_transform(pil_img) content_image = content_image.unsqueeze(0).to(self.device) with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(self.model) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(self.device) rospy.loginfo('stylizing image ...') output = style_model(content_image).cpu() img = output[0].clone().clamp(0, 255).numpy() img = img.transpose(1, 2, 0).astype("uint8") img = cv2.addWeighted( orig_image, self.alpha, img[0:orig_image.shape[0], 0:orig_image.shape[1]], 1 - self.alpha, 0.0) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) return img
def stylize_one(style_model_path, target_image): content_image = utils.load_image(target_image, scale=4) print('content_image', content_image) content_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0) with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(style_model_path) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint # for k in list(state_dict.keys()): # if re.search(r'in\d+\.running_(mean|var)$', k): # del state_dict[k] style_model.load_state_dict(state_dict) output = style_model(content_image) data = output[0] torchvision.utils.save_image(data, './1.png', normalize=True) img = data.clone().clamp(0, 255).numpy() img = img.transpose(1, 2, 0).astype("uint8") img = Image.fromarray(img) return img
def stylize_one(style_model_path, target_image): content_image = utils.load_image(target_image) # print('content_image', content_image) content_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device) with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(style_model_path) for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) output = style_model(content_image) data = output[0].clamp(0, 255) # torchvision.utils.save_image(data, './1.png', normalize=True) img = data.cpu().clone().clamp(0, 255).numpy() img = img.transpose(1, 2, 0).astype("uint8") img = Image.fromarray(img) return img, data
def stylize(args): device = torch.device("cuda" if args.cuda else "cpu") content_image = utils.load_image(args.content_image, scale=args.content_scale) content_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device) if args.model.endswith(".onnx"): output = stylize_onnx_caffe2(content_image, args) else: with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(args.model) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) if args.export_onnx: assert args.export_onnx.endswith(".onnx"), "Export model file should end with .onnx" output = torch.onnx._export(style_model, content_image, args.export_onnx).cpu() else: output = style_model(content_image).cpu() utils.save_image(args.output_image, output[0])
def stylize(**kwargs): ''' generate the picture use the style of the style_picture.jpg ''' for k_, v_, in kwargs.items(): setattr(opt, k_, v_) content_image = tv.datasets.folder.default_loader(opt.content_path) content_transfrom = tv.transforms.Compose([ tv.transforms.ToTensor(), #change value to (0,1) tv.transforms.Lambda(lambda x: x * 255) ]) #change value to (0,255) content_image = content_transfrom(content_image) content_image = Variable(content_image.unsqueeze(0), volatile=True) style_mode = TransformerNet().eval() # change to eval model style_mode.load_state_dict( t.load(opt.model_path, map_location=lambda _s, _: _s)) if opt.use_gpu == True: content_image = content_image.cuda() style_mode.cuda() output = style_mode(content_image) output_data = output.cpu().data[0] tv.utils.save_image((output_data / 255).clamp(min=0, max=1), opt.result_path)
def stylize(**kwargs): opt = Config() for k_, v_ in kwargs.items(): setattr(opt, k_, v_) # 图片处理 content_image = tv.datasets.folder.default_loader(opt.content_path) content_transform = tv.transforms.Compose([ # 转为[0,1] tv.transforms.ToTensor(), # 转为[0,255] tv.transforms.Lambda(lambda x: x.mul(255)) ]) # 图片转化 content_image = content_transform(content_image) # 扩充第0维 content_image = content_image.unsqueeze(0) content_image = Variable(content_image, volatile=True) # 模型 style_model = TransformerNet().eval() style_model.load_state_dict( t.load(opt.model_path, map_location=lambda _s, _: _s)) if opt.use_gpu: content_image = content_image.cuda() style_model.cuda() # 风格迁移与保存 通过风格转化网络预测图片 output = style_model(content_image) output_data = output.cpu().data[0] # 转化为0-1 保存图像 tv.utils.save_image(((output_data / 255)).clamp(min=0, max=1), opt.result_path)
def stylize(args): content_image = utils.load_image(args.content_image, scale=args.content_scale) content_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0) if args.cuda: content_image = content_image.cuda() with torch.no_grad(): content_image = Variable(content_image) # hotfix PyTorch >0,4,0 model_dict = torch.load(args.model) model_dict_clone = model_dict.copy() # We can't mutate while iterating for key, value in model_dict_clone.items(): if key.endswith(('running_mean', 'running_var')): del model_dict[key] style_model = TransformerNet() style_model.load_state_dict(model_dict, False) # style_model.load_state_dict(torch.load(args.model)) # remove the original code if args.cuda: style_model.cuda() output = style_model(content_image) if args.cuda: output = output.cpu() output_data = output.data[0] utils.save_image(args.output_image, output_data)
def stylize(args): img = None content_image = utils.tensor_load_rgbimage(args.content_image, scale=args.content_scale) content_image = content_image.unsqueeze(0) style_model = TransformerNet() style_model.load_state_dict(torch.load(args.model)) cam = cv2.VideoCapture(0) for x in range(0, 150): ret_val, img13 = cam.read() content_image = utils.tensor_load_rgbimage_cam( img13, scale=args.content_scale) content_image = content_image.unsqueeze(0) if args.cuda: content_image = content_image.cuda() content_image2 = Variable(utils.preprocess_batch(content_image), volatile=True) if args.cuda: style_model.cuda() output = style_model(content_image2) im = utils.tensor_ret_bgrimage(output.data[0], args.output_image, args.cuda) if img is None: img = pl.imshow(im) else: img.set_data(im) pl.pause(.1) pl.draw()
def stylize(**kwargs): opt = Config() for k_, v_ in kwargs.items(): setattr(opt, k_, v_) # 图片处理 content_image = tv.datasets.folder.default_loader(opt.content_path) content_transform = tv.transforms.Compose([ tv.transforms.ToTensor(), tv.transforms.Lambda(lambda x: x.mul(255)) ]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0) content_image = Variable(content_image, volatile=True) # 模型 style_model = TransformerNet().eval() style_model.load_state_dict(t.load(opt.model_path, map_location=lambda _s, _: _s)) if opt.use_gpu: content_image = content_image.cuda() style_model.cuda() # 风格迁移与保存 output = style_model(content_image) output_data = output.cpu().data[0] tv.utils.save_image(((output_data / 255)).clamp(min=0, max=1), opt.result_path)
def stylize(content_image_path, pathout, model): r""" :param content_image_path: the path of content image model: path of the model, default:./saved_models/starry-night.model :return: saved stylize_image """ device = torch.device("cpu") # args.content_image='../images/content-images/test1.jpg' content_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content_image = Image.open(content_image_path) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device) with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(model) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) output = style_model(content_image).cpu() img = output[0].clone().clamp(0, 255).numpy() img = img.transpose(1, 2, 0).astype("uint8") img = Image.fromarray(img) img.save(pathout)
def run_feedforward_texture_transfer(args): print('running feedforward neural style transfer...') content_image = load_image(args.content_image, mask=False, size=args.image_size, scale=None, square=False) content_image = preprocess(content_image) input_image = content_image in_channels = 3 stylizing_net = TransformerNet(in_channels) state_dict = torch.load(args.style_model) for k in list(state_dict.keys()): if re.search(r'in\d+\/running_(mean|var)$', k): del state_dict[k] del k stylizing_net.load_state_dict(state_dict) stylizing_net = stylizing_net.to(device) output = stylizing_net(input_image) if args.original_colors == 1: output = original_colors(content_image.cpu(), output) save_image(filename=args.output_image, data=output.detach())
def stylize(args): device = torch.device("cuda" if args.is_cuda else "cpu") content_image = utils.load_image(args.content_image, scale=args.content_scale) # print(content_image.size) # ss('stop') content_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device) print(content_image.shape) # ss('stop') with torch.no_grad(): print(1) style_model = TransformerNet() print(2) state_dict = torch.load(args.model) print(3) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint # for k in list(state_dict.keys()): # if re.search(r'in\d+\.running_(mean|var)$', k): # del state_dict[k] print(4) style_model.load_state_dict(state_dict) print(5) style_model.to(device) print(6) output = style_model(content_image).cpu() print(output.shape) # ss('s') utils.save_image(args.output_image, output[0])
def stylize(args): device = torch.device("cpu") content_image = utils.load_image(args['content_image'], scale=None) content_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device) if args['model'].endswith(".onnx"): output = stylize_onnx_caffe2(content_image, args) else: with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(args['model']) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) output = style_model(content_image).cpu() utils.save_image(args["output_image"], output[0])
def generate_stylized_image(args): """ Creates stylized image based on the arguments passed :param args: (content_image, style_image, model, etc.) :return: void """ device = torch.device("cuda" if args.cuda else "cpu") content_image = utils.load_image(args.content_image, scale=content_scale) content_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device) with torch.no_grad(): style_model = TransformerNet(style_number=args.style_num) state_dict = torch.load(args.model) style_model.load_state_dict(state_dict, strict=False) style_model.to(device) output = style_model(content_image, style_id=[args.style_id]).cpu() utils.save_image( 'output/' + args.output_image + '_style' + str(args.style_id) + '.jpg', output[0])
def evaluate(args): # device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu') model = TransformerNet() state_dict = torch.load(args.model) if args.gpus is not None: model = nn.DataParallel(model, device_ids=args.gpus) else: model = nn.DataParallel(model) model.load_state_dict(state_dict) if args.cuda: model.cuda() with torch.no_grad(): for root, dirs, filenames in os.walk(args.input_dir): for filename in filenames: if utils.is_image_file(filename): impath = osp.join(root, filename) img = utils.load_image(impath) img = img.unsqueeze(0) if args.cuda: img.cuda() rec_img = model(img) if args.cuda: rec_img = rec_img.cpu() img = img.cpu() save_path = osp.join(args.output_dir, filename) # utils.save_image(rec_img[0], save_path) utils.save_image_preserv_length(rec_img[0], img[0], save_path)
def st_fns(): tstart = time.time() content_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content_image = content_transform(load_img()) content_image = content_image.unsqueeze(0) with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load("../saved_models/" + combo2.value + ".pth") # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) output = style_model(content_image) shot_time = time.strftime("-%Y%m%d-%H%M%S") utilsIm.save_image(input_box1.value + "/image_st_" + shot_time + ".jpg", output[0]) image_st = cv2.imread(input_box1.value + "/image_st_" + shot_time + ".jpg") window_name = "image_st_" + shot_time + ".jpg" cv2.imshow(window_name, image_st) tstop = time.time() print("Inference time : " + str(1000 * (tstop - tstart)) + " ms")
def stylize(**kwargs): opt = Config() for k_, v_ in kwargs.items(): setattr(opt, k_, v_) device = t.device('cuda') if opt.use_gpu else t.device('cpu') # 图片处理 content_image = tv.datasets.folder.default_loader(opt.content_path) content_transform = tv.transforms.Compose( [tv.transforms.ToTensor(), tv.transforms.Lambda(lambda x: x.mul(255))]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device).detach() # 模型 style_model = TransformerNet().eval() style_model.load_state_dict( t.load(opt.model_path, map_location=lambda _s, _: _s)) style_model.to(device) # 风格迁移与保存 output = style_model(content_image) output_data = output.cpu().data[0] tv.utils.save_image(((output_data / 255)).clamp(min=0, max=1), opt.result_path)
def stylize(): device = torch.device("cpu") input_img = request.args.get('input_img') model_get = request.args.get('style') print(model_get) content_image = utils.load_image(input_img) content_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device) model_get = str(model_get) with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(model_get) for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) output = style_model(content_image).cpu() a = random.randint(1, 101) img_path = str("static/images/output_{}.jpg".format(a)) utils.save_image(img_path, output[0]) image_k = str("output_{}.jpg".format(a)) get_image = os.path.join(app.config['UPLOAD_FOLDER'], image_k) print("Done") return render_template("index.html", get_image=get_image)
def stylize(args): device = torch.device("cuda" if args.cuda else "cpu") content_image = utils.load_image(args.content_image, scale=args.content_scale) content_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device) style_id = torch.LongTensor([args.style_id]).to(device) with torch.no_grad(): import time start = time.time() style_model = TransformerNet(style_num=args.style_num) state_dict = torch.load(args.model) style_model.load_state_dict(state_dict) style_model.to(device) output = style_model([content_image, style_id]).cpu() end = time.time() print('Time={}'.format(end - start)) if args.export_onnx: assert args.export_onnx.endswith( ".onnx"), "Export model file should end with .onnx" output = torch.onnx._export(style_model, [content_image_t, style_t], args.export_onnx, input_names=['input_image', 'style_index'], output_names=['output_image']).cpu() utils.save_image( 'output/' + args.output_image + '_style' + str(args.style_id) + '.jpg', output[0])
def stylize(args): if args.model.endswith(".onnx"): return stylize_onnx_caffe2(args) content_image = utils.tensor_load_rgbimage(args.content_image, scale=args.content_scale) content_image = content_image.unsqueeze(0) if args.cuda: content_image = content_image.cuda() content_image = Variable(utils.preprocess_batch(content_image), requires_grad=False) style_model = TransformerNet() state_dict = torch.load(args.model) # removed_modules = ['in2'] in_names = [ "in1.scale", "in1.shift", "in2.scale", "in2.shift", "in3.scale", "in3.shift", "res1.in1.scale", "res1.in1.shift", "res1.in2.scale", "res1.in2.shift", "res2.in1.scale", "res2.in1.shift", "res2.in2.scale", "res2.in2.shift", "res3.in1.scale", "res3.in1.shift", "res3.in2.scale", "res3.in2.shift", "res4.in1.scale", "res4.in1.shift", "res4.in2.scale", "res4.in2.shift", "res5.in1.scale", "res5.in1.shift", "res5.in2.scale", "res5.in2.shift", "in4.scale", "in4.shift", "in5.scale", "in5.shift" ] # kl = list(state_dict.keys()) # for k in kl: for k in in_names: state_dict[k.replace("scale", "weight").replace("shift", "bias")] = state_dict.pop(k) style_model.load_state_dict(state_dict) if args.cuda: style_model.cuda() if args.half: style_model.half() content_image = content_image.half() if args.export_onnx: assert args.export_onnx.endswith( ".onnx"), "Export model file should end with .onnx" output = torch.onnx._export(style_model, content_image, args.export_onnx) else: output = style_model(content_image) if args.half: output = output.float() utils.tensor_save_bgrimage(output.data[0], args.output_image, args.cuda)
def vectorize(args): size = args.size # vectors = np.zeros((size, size, 2), dtype=np.float32) # for y in range(size): # for x in range(size): # xx = float(x - size / 2) # yy = float(y - size / 2) # rsq = xx ** 2 + yy ** 2 # if (rsq == 0): # vectors[y, x, 0] = 1 # vectors[y, x, 1] = 1 # else: # vectors[y, x, 0] = -yy / rsq # vectors[y, x, 1] = xx / rsq # vectors = NormalizVectrs(vectors) device = torch.device("cuda" if args.cuda else "cpu") content_image = Image.open(args.content_image).convert('L') content_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0) content_image = utils.subtract_imagenet_mean_batch(content_image) content_image = content_image.to(device) with torch.no_grad(): vectorize_model = TransformerNet() state_dict = torch.load(args.saved_model) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): pdb.set_trace() del state_dict[k] vectorize_model.load_state_dict(state_dict) vectorize_model.to(device) output = vectorize_model(content_image) target = dataset.hdf5_loader(args.target_vector) target_transform = transforms.ToTensor() target = target_transform(target) target = target.unsqueeze(0).to(device) cosine_loss = torch.nn.CosineEmbeddingLoss() label = torch.ones(1, 1, args.size, args.size).to(device) loss = cosine_loss(output, target, label) print(loss.item()) output = output.cpu().clone().numpy()[0].transpose(1, 2, 0) output = NormalizVectrs(output) lic(output, "output.jpg") target = target.cpu().clone().numpy()[0].transpose(1, 2, 0) lic(target, "target.jpg")
class Tracer(): def __init__(self): self.model = 'input-trained.model' self.style_model = TransformerNet() self.style_model.load_state_dict(torch.load(self.model)) def trace(self): nn_input = torch.rand(1, 3, 224, 224) traced_script_module = torch.jit.trace(self.style_model, nn_input) traced_script_module.save('serialized.pt') print('traced model saved to serialized.pt')
def stylize(args): content_image = load_image_eval(args.content_image) with flow.no_grad(): style_model = TransformerNet() state_dict = flow.load(args.model) style_model.load_state_dict(state_dict) style_model.to("cuda") output = style_model( flow.Tensor(content_image).clamp(0, 255).to("cuda")) print(args.output_image) cv2.imwrite(args.output_image, recover_image(output.numpy()))
def stylize(**kwargs): opt = Config() for k_, v_ in kwargs.items(): setattr(opt, k_, v_) style_model = TransformerNet().cuda() style_model.load_state_dict(t.load(opt.model_path, )) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) video = cv2.VideoCapture(opt.video_path) frames = list() # 从文件读取视频内容 # 视频每秒传输帧数 fps = video.get(cv2.CAP_PROP_FPS) # 视频图像的宽度 frame_width = int(640) # 视频图像的长度 frame_height = int(360) # 视频帧数 frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') out = cv2.VideoWriter('./ablation_4_16_.mp4', fourcc, fps, (frame_width, frame_height)) n = 0 while video.isOpened(): ret, frame = video.read() if ret == False: break n += 1 frame = cv2.resize(frame, (640, 360)) # print(ret,frame.shape) cv2.imwrite('./ablation/ori/temp%d.jpg' % (n), frame) content_image = tv.datasets.folder.default_loader( './ablation/ori/temp%d.jpg' % (n)) content_image = transform(content_image) content_image = content_image.unsqueeze(0).cuda() output = style_model(content_image) # output = utils.normalize_batch(output) tv.utils.save_image((output.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1), './ablation/4/temp%d.jpg' % (n)) image = cv2.imread('./ablation/4/temp%d.jpg' % (n)) out.write(image) sys.stdout.write('\r>> Converting image %d/%d' % (n, frame_count)) sys.stdout.flush() video.release() cv2.destroyAllWindows()
def multi_style(path, width=320, device=device): model_iter = itertools.cycle(os.listdir(path)) model_file = next(model_iter) print(f'Using {model_file} ') model_path = os.path.join(path, model_file) model = TransformerNet() model.load_state_dict(read_state_dict(model_path)) model.to(device) monitor_width = get_monitors()[0].width monitor_height = get_monitors()[0].height vs = VideoStream(src=0).start() time.sleep(2.0) timer = Timer() last_update = int(time.time()) while (True): frame = vs.read() if frame is None: frame = np.random.randint(0, 255, (int(width / 1.5), width, 3), dtype=np.uint8) frame = cv2.flip(frame, 1) frame = resize(frame, width=width) # Style the frame img = style_frame(frame, model, device).numpy() img = np.clip(img, 0, 255) img = img.astype(np.uint8) img = img.transpose(1, 2, 0) img = cv2.resize(img[:, :, ::-1], (monitor_width, monitor_height)) # print(img.shape) cv2.imshow("Output", img) timer() #Determine if n key has been selected or if the time since last rotation #is greater than defined rotation constraint key = cv2.waitKey(1) & 0xFF time_since_last_update = int(time.time()) - last_update rotate_by_key = key == ord("n") rotate_by_time = (time_since_last_update >= rotate_time) if rotate_by_time or rotate_by_key: model_file = next(model_iter) print(f'Using {model_file} ') model_path = os.path.join(path, model_file) model.load_state_dict(read_state_dict(model_path)) model.to(device) last_update = int(time.time()) elif key == ord("q"): break
def get_model(model_name): """Returns model Args: model_name: name of the model to load Returns: pytorch model """ style_model = TransformerNet() state_dict = torch.load(MODEL_PATH[model_name]) style_model.load_state_dict(state_dict) return style_model.eval().to(device)
def load_model(model_path): with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(model_path) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) style_model.eval() return style_model
def get_output(trained_model, content_image): with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(trained_model) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) output = style_model(content_image).cpu() # utils.save_image(args.output_image, output[0]) return output
def init(): global model #model_path = os.path.join('picasso.pth') model_path = Model.get_model_path('picasso.pth') model = TransformerNet() state_dict = torch.load(model_path) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] model.load_state_dict(state_dict) model.eval()
def load_model(model_path): print('cargando modelo') with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(model_path) # eliminamos las claves guardadas 'running_*' que estan decrapeadas en # InstanceNorm del checkpoint for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) style_model.eval() return style_model
def stylize(args): content_image = utils.load_image(args.content_image, scale=args.content_scale) content_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0) if args.cuda: content_image = content_image.cuda() content_image = Variable(content_image, volatile=True) style_model = TransformerNet() style_model.load_state_dict(torch.load(args.model)) if args.cuda: style_model.cuda() output = style_model(content_image) if args.cuda: output = output.cpu() output_data = output.data[0] utils.save_image(args.output_image, output_data)
def train(**kwargs): opt = Config() for k_, v_ in kwargs.items(): setattr(opt, k_, v_) vis = utils.Visualizer(opt.env) # 数据加载 transfroms = tv.transforms.Compose([ tv.transforms.Scale(opt.image_size), tv.transforms.CenterCrop(opt.image_size), tv.transforms.ToTensor(), tv.transforms.Lambda(lambda x: x * 255) ]) dataset = tv.datasets.ImageFolder(opt.data_root, transfroms) dataloader = data.DataLoader(dataset, opt.batch_size) # 转换网络 transformer = TransformerNet() if opt.model_path: transformer.load_state_dict(t.load(opt.model_path, map_location=lambda _s, _: _s)) # 损失网络 Vgg16 vgg = Vgg16().eval() # 优化器 optimizer = t.optim.Adam(transformer.parameters(), opt.lr) # 获取风格图片的数据 style = utils.get_style_data(opt.style_path) vis.img('style', (style[0] * 0.225 + 0.45).clamp(min=0, max=1)) if opt.use_gpu: transformer.cuda() style = style.cuda() vgg.cuda() # 风格图片的gram矩阵 style_v = Variable(style, volatile=True) features_style = vgg(style_v) gram_style = [Variable(utils.gram_matrix(y.data)) for y in features_style] # 损失统计 style_meter = tnt.meter.AverageValueMeter() content_meter = tnt.meter.AverageValueMeter() for epoch in range(opt.epoches): content_meter.reset() style_meter.reset() for ii, (x, _) in tqdm.tqdm(enumerate(dataloader)): # 训练 optimizer.zero_grad() if opt.use_gpu: x = x.cuda() x = Variable(x) y = transformer(x) y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x) # content loss content_loss = opt.content_weight * F.mse_loss(features_y.relu2_2, features_x.relu2_2) # style loss style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gram_y = utils.gram_matrix(ft_y) style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y)) style_loss *= opt.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() # 损失平滑 content_meter.add(content_loss.data[0]) style_meter.add(style_loss.data[0]) if (ii + 1) % opt.plot_every == 0: if os.path.exists(opt.debug_file): ipdb.set_trace() # 可视化 vis.plot('content_loss', content_meter.value()[0]) vis.plot('style_loss', style_meter.value()[0]) # 因为x和y经过标准化处理(utils.normalize_batch),所以需要将它们还原 vis.img('output', (y.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1)) vis.img('input', (x.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1)) # 保存visdom和模型 vis.save([opt.env]) t.save(transformer.state_dict(), 'checkpoints/%s_style.pth' % epoch)