示例#1
0
def stylize(**kwargs):
    opt = Config()

    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    # 图片处理
    content_image = tv.datasets.folder.default_loader(opt.content_path)
    content_transform = tv.transforms.Compose([
        tv.transforms.ToTensor(),
        tv.transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0)
    content_image = Variable(content_image, volatile=True)

    # 模型
    style_model = TransformerNet().eval()
    style_model.load_state_dict(t.load(opt.model_path, map_location=lambda _s, _: _s))

    if opt.use_gpu:
        content_image = content_image.cuda()
        style_model.cuda()

    # 风格迁移与保存
    output = style_model(content_image)
    output_data = output.cpu().data[0]
    tv.utils.save_image(((output_data / 255)).clamp(min=0, max=1), opt.result_path)
示例#2
0
def stylize(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    content_image = utils.load_image(args.content_image, scale=args.content_scale)
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    if args.model.endswith(".onnx"):
        output = stylize_onnx_caffe2(content_image, args)
    else:
        with torch.no_grad():
            style_model = TransformerNet()
            state_dict = torch.load(args.model)
            # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
            for k in list(state_dict.keys()):
                if re.search(r'in\d+\.running_(mean|var)$', k):
                    del state_dict[k]
            style_model.load_state_dict(state_dict)
            style_model.to(device)
            if args.export_onnx:
                assert args.export_onnx.endswith(".onnx"), "Export model file should end with .onnx"
                output = torch.onnx._export(style_model, content_image, args.export_onnx).cpu()
            else:
                output = style_model(content_image).cpu()
    utils.save_image(args.output_image, output[0])
示例#3
0
def stylize(args):
    content_image = utils.load_image(args.content_image, scale=args.content_scale)
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0)
    if args.cuda:
        content_image = content_image.cuda()
    content_image = Variable(content_image, volatile=True)

    style_model = TransformerNet()
    style_model.load_state_dict(torch.load(args.model))
    if args.cuda:
        style_model.cuda()
    output = style_model(content_image)
    if args.cuda:
        output = output.cpu()
    output_data = output.data[0]
    utils.save_image(args.output_image, output_data)
示例#4
0
def stylize(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    content_image = utils.load_image(args.content_image,
                                     scale=args.content_scale)
    content_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    if args.model.endswith(".onnx"):
        output = stylize_onnx_caffe2(content_image, args)
    else:
        with torch.no_grad():
            style_model = TransformerNet()
            state_dict = torch.load(args.model)
            # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
            for k in list(state_dict.keys()):
                if re.search(r'in\d+\.running_(mean|var)$', k):
                    del state_dict[k]
            style_model.load_state_dict(state_dict)
            style_model.to(device)
            if args.export_onnx:
                assert args.export_onnx.endswith(
                    ".onnx"), "Export model file should end with .onnx"
                output = torch.onnx._export(style_model, content_image,
                                            args.export_onnx).cpu()
            else:
                output = style_model(content_image).cpu()
    utils.save_image(args.output_image, output[0])
示例#5
0
def multi_style(path,width=320,device=device):
    model_iter=itertools.cycle(os.listdir(path))
    model_file=next(model_iter)
    print(f'Using {model_file} ')
    model_path=os.path.join(path,model_file)
    model = TransformerNet()
    model.load_state_dict(read_state_dict(model_path))
    model.to(device)
    
    vs = PiVideoStream().start()
    time.sleep(2.0)
    timer=Timer()
    while(True):
        frame=vs.read()
        if frame is None:
            frame=np.random.randint(0,255,(int(width/1.5),width,3),dtype=np.uint8)
        frame=cv2.flip(frame, 1)
        frame = resize(frame, width=width)

        # Style the frame
        img=style_frame(frame,model,device).numpy()
        img=np.clip(img,0,255)
        img=img.astype(np.uint8)
        
        img = img.transpose(1, 2, 0)
        img=cv2.resize(img[:,:,::-1],(640,480))

        # print(img.shape)
        cv2.imshow("Output", img)
        timer()
        key = cv2.waitKey(1) & 0xFF
        if key == ord("n"):
            model_file=next(model_iter)
            print(f'Using {model_file} ')
            model_path=os.path.join(path,model_file)
            model.load_state_dict(read_state_dict(model_path))
            model.to(device)
        elif key == ord("q"):
            break
示例#6
0
    def __init__(self, image_queue, output_queue, engine_id, log_flag = True):
        super(StyleServer, self).__init__(image_queue, output_queue, engine_id)
        self.log_flag = log_flag
        self.is_first_image = True
        self.dir_path = os.getcwd()
        self.model = self.dir_path+'/../models/the_scream.model'
        self.path = self.dir_path+'/../models/'
        print('MODEL PATH {}'.format(self.path))  

        # initialize model
        self.style_model = TransformerNet()
        self.style_model.load_state_dict(torch.load(self.model))
        self.style_model.cuda()
        self.style_type = "the_scream"
        self.content_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))
        ])
        wtr_mrk4 = cv2.imread('../wtrMrk.png',-1) # The waterMark is of dimension 30x120
        self.mrk,_,_,mrk_alpha = cv2.split(wtr_mrk4) # The RGB channels are equivalent
        self.alpha = mrk_alpha.astype(float)/255
        print('FINISHED INITIALISATION')
示例#7
0
def vectorize(args):
    size = args.size
    # vectors = np.zeros((size, size, 2), dtype=np.float32)
    # for y in range(size):
    #     for x in range(size):
    #         xx = float(x - size / 2)
    #         yy = float(y - size / 2)
    #         rsq = xx ** 2 + yy ** 2
    #         if (rsq == 0):
    #             vectors[y, x, 0] = 1
    #             vectors[y, x, 1] = 1
    #         else:
    #             vectors[y, x, 0] = -yy / rsq
    #             vectors[y, x, 1] = xx / rsq
    # vectors = NormalizVectrs(vectors)

    device = torch.device("cuda" if args.cuda else "cpu")

    content_image = Image.open(args.content_image).convert('L')
    content_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0)
    content_image = utils.subtract_imagenet_mean_batch(content_image)
    content_image = content_image.to(device)

    with torch.no_grad():
        vectorize_model = TransformerNet()
        state_dict = torch.load(args.saved_model)
        # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                pdb.set_trace()
                del state_dict[k]
        vectorize_model.load_state_dict(state_dict)
        vectorize_model.to(device)
        output = vectorize_model(content_image)

    target = dataset.hdf5_loader(args.target_vector)
    target_transform = transforms.ToTensor()
    target = target_transform(target)
    target = target.unsqueeze(0).to(device)

    cosine_loss = torch.nn.CosineEmbeddingLoss()
    label = torch.ones(1, 1, args.size, args.size).to(device)
    loss = cosine_loss(output, target, label)
    print(loss.item())

    output = output.cpu().clone().numpy()[0].transpose(1, 2, 0)
    output = NormalizVectrs(output)
    lic(output, "output.jpg")

    target = target.cpu().clone().numpy()[0].transpose(1, 2, 0)
    lic(target, "target.jpg")
示例#8
0
def stylize(args):
    content_image = load_image_eval(args.content_image)
    with flow.no_grad():
        style_model = TransformerNet()
        state_dict = flow.load(args.model)
        style_model.load_state_dict(state_dict)
        style_model.to("cuda")
        output = style_model(
            flow.Tensor(content_image).clamp(0, 255).to("cuda"))
    print(args.output_image)
    cv2.imwrite(args.output_image, recover_image(output.numpy()))
def stylize(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    content_image_input = utils.load_image(args.content_image,
                                           size=args.output_size,
                                           scale=args.content_scale)
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(256)),
        transforms.Lambda(lambda x: x[torch.LongTensor([2, 1, 0])]),
        transforms.Normalize(mean=[103.939, 116.779, 123.68], std=[1, 1, 1])
    ])
    content_image = content_transform(content_image_input)
    content_image = content_image.unsqueeze(0).to(device)

    if args.model.endswith(".onnx"):
        output = stylize_onnx_caffe2(content_image, args)
    else:
        with torch.no_grad():
            style_model = torch.jit.script(TransformerNet())
            state_dict = torch.load(args.model)
            # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
            for k in list(state_dict.keys()):
                if re.search(r'in\d+\.running_(mean|var)$', k):
                    del state_dict[k]
            style_model.load_state_dict(state_dict)
            style_model.eval().to(device)
            if args.export_onnx:
                assert args.export_onnx.endswith(
                    ".onnx"), "Export model file should end with .onnx"
                inputs = ['images']
                outputs = ['scores']
                dynamic_axes = {'images': {0: 'batch'}, 'scores': {0: 'batch'}}
                output = torch.onnx._export(style_model,
                                            content_image,
                                            args.export_onnx,
                                            verbose=True).cpu()
            else:
                output = style_model(content_image).cpu()
    utils.save_image_vgg19(
        args.output_image,
        output[0],
        args.original_colors,
        content_image_input,
    )
示例#10
0
def load_style(model):
    device = "cuda"

    with torch.no_grad():
        style_model = TransformerNet()
        state_dict = torch.load(model)
        # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model.to(device)
    return style_model
示例#11
0
def init():
    global model
    #model_path = os.path.join('picasso.pth')
    model_path = Model.get_model_path('picasso.pth')

    model = TransformerNet()
    state_dict = torch.load(model_path)
    # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
    for k in list(state_dict.keys()):
        if re.search(r'in\d+\.running_(mean|var)$', k):
            del state_dict[k]
    model.load_state_dict(state_dict)
    model.eval()
示例#12
0
def get_output(trained_model, content_image):
    with torch.no_grad():
        style_model = TransformerNet()
        state_dict = torch.load(trained_model)
        # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model.to(device)
        output = style_model(content_image).cpu()
        # utils.save_image(args.output_image, output[0])
    return output
示例#13
0
def stylize(args):
    content_image = utils.tensor_load_rgbimage(args.content_image, scale=args.content_scale)
    content_image = content_image.unsqueeze(0)

    if args.cuda:
        content_image = content_image.cuda()
    content_image = Variable(utils.preprocess_batch(content_image), volatile=True)
    style_model = TransformerNet()
    style_model.load_state_dict(torch.load(args.model))

    if args.cuda:
        style_model.cuda()

    output = style_model(content_image)
    utils.tensor_save_bgrimage(output.data[0], args.output_image, args.cuda)
示例#14
0
    def transfer_style(self, content_img_stream):
        content_image = self._process_image(content_img_stream)

        with torch.no_grad():
            style_model = TransformerNet()
            state_dict = torch.load('rain_princess.pth')
            for k in list(state_dict.keys()):
                if re.search(r'in\d+\.running_(mean|var)$', k):
                    del state_dict[k]
            style_model.load_state_dict(state_dict)
            style_model.to(self.device)
            output = self.style_model(content_image).cpu()
            output = numpy.array(output.squeeze(0))
            output = output.transpose(1, 2, 0).astype("uint8")
            return misc.toimage(output)
def stylize(content_image, model, content_scale=None, cuda=0):
    content_image = utils.tensor_load_rgbimage(content_image,
                                               scale=content_scale)
    content_image = content_image.unsqueeze(0)

    if cuda:
        content_image = content_image.cuda()
    content_image = Variable(utils.preprocess_batch(content_image),
                             volatile=True)
    style_model = TransformerNet()
    style_model.load_state_dict(torch.load(model))

    if cuda:
        style_model.cuda()

    output = style_model(content_image)
    return utils.tensor_to_Image(output, cuda)
示例#16
0
def stylize(model_path):
    content_image = utils.load_image(enums.content_image,
                                     scale=enums.content_scale)
    content_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0)
    if enums.cuda:
        content_image = content_image.cuda()
    content_image = Variable(content_image, volatile=True)

    model_state = torch.load(model_path)
    style_model = TransformerNet()
    style_model.load_state_dict(model_state['state_dict'])
    style_model.eval()

    if enums.cuda:
        style_model.cuda()
    output = style_model(content_image)
    if enums.cuda:
        output = output.cpu()
    output_data = output.data[0]
    utils.save_image(enums.output_image, output_data)
示例#17
0
def stylize(args):
    device = torch.device("cuda" if args.cuda else "cpu")
    content_transform = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Lambda(lambda x: x.mul(255))
                        ])


    print('Reading bag file '+ os.getcwd() + '/' + args.content_image
          + ' topic ' + args.topic)
    readbag = rosbag.Bag(args.content_image,'r')
    bridge = CvBridge()
    cv_img = []

    with rosbag.Bag(args.output_image,'w') as outbag:
        for topic, msg, dtime in readbag.read_messages():

            if topic == args.topic:
                cv_img = bridge.imgmsg_to_cv2(msg, desired_encoding="rgb8")
                (h, w, d) = img.shape

                nh = h/args.content_scale
                nw = w/args.content_scale
                cv_img = cv2.resize((h, w), (int(nh),int(nw)))

                with torch.no_grad():
                    style_model = TransformerNet()
                    state_dict = torch.load(args.model)
                    # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
                    for k in list(state_dict.keys()):
                        if re.search(r'in\d+\.running_(mean|var)$', k):
                            del state_dict[k]
                    style_model.load_state_dict(state_dict)
                    style_model.to(device)
                    if args.export_onnx:
                        assert args.export_onnx.endswith(".onnx"), "Export model file should end with .onnx"
                        output = torch.onnx._export(style_model, content_image, args.export_onnx).cpu()
                    else:
                        print('stylizing image ...')
                        output = style_model(cv_img).cpu()
            else:
                outbag.write(topic, msg, msg.header.stamp if msg._has_header else t)

    readbag.close()
示例#18
0
def stylize(args):
    # content_image = utils.tensor_load_rgbimage(args.content_image, scale=args.content_scale)
    # content_image = content_image.unsqueeze(0)
    content_image = np.loadtxt(args.content_image)
    upsample_ratio = 8
    batch_size = 100
    #content_image = content_image[:, :200]
    num_images = content_image.shape[1]

    num_batch = int(content_image.shape[1] / batch_size)

    output_model_total = []
    for batch_id in range(num_batch):
        print('[{}]/[{}] iters '.format(batch_id + 1, num_batch))
        x = content_image[:, batch_id * batch_size:(batch_id + 1) * batch_size]

        x = x.transpose()
        x = x.reshape((-1, 1, args.image_size_x, args.image_size_y))
        x = torch.from_numpy(x).float()
        if args.cuda:
            x = x.cuda()
        x = Variable(x, volatile=True)
        style_model = TransformerNet()
        style_model.load_state_dict(torch.load(args.model))

        if args.cuda:
            style_model.cuda()

        output_model = style_model(x)
        # output_image = output_image.numpy()
        output_model = output_model.data
        output_image = output_model.repeat(1, 3, 1, 1)

        output_model = output_model.cpu().numpy().astype(float)
        output_model = output_model.reshape(
            (batch_size,
             args.image_size_x * args.image_size_y * upsample_ratio**2))
        output_model = output_model.transpose()

        output_model_total.append(output_model)
    output_model_total = np.hstack(output_model_total)

    np.savetxt(args.output_model, output_model_total)
示例#19
0
    def transfer_style(self, content_img_stream, model_name):
        device = torch.device("cpu")

        content_image = self.process_image(content_img_stream)

        with torch.no_grad():
            style_model = TransformerNet()
            # ниже нужно указать путь до папки, где хранятся модели
            # base_dir = './saved_models/'
            base_dir = '/Users/yanadm/Documents/Style Transfer Bot/saved_models/'
            filename = model_name
            path_to_model = os.path.join(base_dir, filename)
            state_dict = torch.load(path_to_model)
            for k in list(state_dict.keys()):
                if re.search(r'in\d+\.running_(mean|var)$', k):
                    del state_dict[k]
            style_model.load_state_dict(state_dict)
            style_model.to(device)
            output = style_model(content_image).cpu()    
        return misc.toimage(output[0])
示例#20
0
def stylize(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    content_image = utils.load_image_local(args.content_image)
    content_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    with torch.no_grad():
        style_model = TransformerNet()
        state_dict = torch.load(args.model)
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model.to(device)
        output = style_model(content_image).cpu()
    utils.save_image(args.output_image, output[0])
示例#21
0
def stylize(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    content_image = utils.load_image(args.content_image, scale=args.content_scale)
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)


    with torch.no_grad():
        style_model = TransformerNet(style_num=args.style_num)
        state_dict = torch.load(args.model)
        style_model.load_state_dict(state_dict)
        style_model.to(device)
        output = style_model(content_image, style_id = [args.style_id]).cpu()

    utils.save_image('output/'+args.output_image+'_style'+str(args.style_id)+'.jpg', output[0])
示例#22
0
def style_data():
    style = request.form.get('style')
    modelname = {
        'X': 'hiphop.pth',
        'A': 'rain_princess.pth',
        'B': 'starry-night.model',
        'C': 'style6.pth',
        'D': 'style8.pth',
        'E': 'style9.pth'
    }
    txt = []
    txt.append(session['title'])
    txt.append(session['comment_1'] + '\n' + session['comment_2'])

    pos = choosetemplate('./static/' + session["imagepath_2"],
                         './static/' + session["imagepath_1"], txt,
                         "./static/img/" + style + "results_notext.jpg")
    if style != 'Z':
        device = torch.device("cuda")
        with torch.no_grad():
            style_model = TransformerNet()
            state_dict = torch.load('../ST/saved_models/' + modelname[style])
            # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
            for k in list(state_dict.keys()):
                if re.search(r'in\d+\.running_(mean|var)$', k):
                    del state_dict[k]
            style_model.load_state_dict(state_dict)
            style_model.to(device)
        stylize("./static/img/" + style + "results_notext.jpg",
                "./static/img/" + style + "results_notext_S.png",
                model=style_model,
                device=device)
        torch.cuda.empty_cache()
        # stylize("./static/img/"+style+"results_notext.jpg", "./static/img/"+style+"results_notext_S.png", model='../ST/saved_models/'+modelname[style])
        addallimage("./static/img/" + style + "results_notext_S.png", pos,
                    "./static/img/results.jpg")
    else:
        addallimage("./static/img/" + style + "results_notext.jpg", pos,
                    "./static/img/results.jpg")
    torch.cuda.empty_cache()
    return render_template("style.html", style_img=session['style_img'])
示例#23
0
def stylize(img_stream, style_type):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    content_image = load_image(img_stream, scale=None)
    content_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    with torch.no_grad():
        style_model = TransformerNet()
        models_dir = './saved_models/'
        state_dict = torch.load(models_dir + '{}'.format(style_type))
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model.to(device)
        output = style_model(content_image).cpu()
    return misc.toimage(output[0])
示例#24
0
def stylize(input_image, style, cuda):
    device = torch.device("cuda" if cuda else "cpu")

    content_image = utils.load_image(input_image, scale=None)
    content_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    with torch.no_grad():
        style_model = TransformerNet()
        state_dict = torch.load("saved_models/" + style + ".pth")
        # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model.to(device)
        output = style_model(content_image).cpu()

    utils.save_image("results/" + style + '_out.jpg', output[0])
def stylize(args):
    content_image = utils.load_image(args.content_image,
                                     scale=args.content_scale)
    content_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0)
    if args.cuda:
        content_image = content_image.cuda()
    content_image = Variable(content_image, volatile=True)

    style_model = TransformerNet()
    style_model.load_state_dict(torch.load(args.model))
    if args.cuda:
        style_model.cuda()
    output = style_model(content_image)
    if args.cuda:
        output = output.cpu()
    output_data = output.data[0]
    utils.save_image(args.output_image, output_data)
    writer.add_image('output', output_data)
示例#26
0
    def run(self, content_img_arr, saved_style_model_path):
        '''
		Implmentation of Fast Neural Style Transfer Algorithm.Takes Style Image (numpy array with shape (h,w,3)) & 
		Trained Model weights for a particular Style (.pth file) & return output image as a numpy array of shape (3,h,w)
		'''

        content_image = Image.fromarray(np.uint8(content_img_arr))

        content_image = self.content_transform(content_image)
        content_image = content_image.unsqueeze(0).to(self.device)

        with torch.no_grad():
            style_model = TransformerNet()
            state_dict = torch.load(saved_style_model_path)
            # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
            for k in list(state_dict.keys()):
                if re.search(r'in\d+\.running_(mean|var)$', k):
                    del state_dict[k]
            style_model.load_state_dict(state_dict)
            style_model.to(self.device)
            output = style_model(content_image).cpu()
        return (output.numpy()[0])
示例#27
0
def stylize(args):  # 提供了一个测试,当我们训练好了模型,就可以用这个函数来帮我们生成图片了
    device = torch.device("cuda" if args.cuda else "cpu")

    content_image = utils.load_image(args.content_image,
                                     scale=args.content_scale)
    content_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    with torch.no_grad():
        style_model = TransformerNet()
        state_dict = torch.load(args.model)
        # 从checkpoint删除InstanceNorm中已保存的不建议使用的running_ *keys
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model.to(device)
        output = style_model(content_image).cpu()
        utils.save_image(args.output_image, output[0])
示例#28
0
def stylize(args):
    if torch.cuda.is_available():
        print('CUDA available, using GPU.')
        device = torch.device('cuda')
    else:
        print('GPU training unavailable... using CPU.')
        device = torch.device('cpu')

    content_image = utils.load_image(args.content_image, scale=args.content_scale)
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])

    # NOTE: Remove UNSQUEEZE, move to TransformerNet for CoreML UIImage input...
    # content_image = content_transform(content_image).to(device)
    content_image = content_image.unsqueeze(0).to(device)

    if args.model.endswith(".onnx"):
        output = stylize_onnx_caffe2(content_image, args)
    else:
        with torch.no_grad():
            style_model = TransformerNet()
            state_dict = torch.load(args.model)
            # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
            for k in list(state_dict.keys()):
                if re.search(r'in\d+\.running_(mean|var)$', k):
                    del state_dict[k]
            style_model.load_state_dict(state_dict)
            style_model.to(device)
            if args.export_onnx:
                assert args.export_onnx.endswith(".onnx"), "Export model file should end with .onnx"
                output = torch.onnx._export(style_model,
                                            content_image,
                                            args.export_onnx,
                                            input_names=['inputImage']).cpu()
            else:
                output = style_model(content_image).cpu()
    utils.save_image(args.output_image, output[0])
示例#29
0
文件: trans.py 项目: jklhj222/bin
def stylize(DC):
    trans_gpu_id = DC.trans_gpu_id
    device = t.device('cuda', trans_gpu_id) if DC.use_gpu else t.device('cpu')

    # Image preprocessing
    content_img = tv.datasets.folder.default_loader(DC.low_resol_img)
    content_transform = T.Compose(
        [T.ToTensor(), T.Lambda(lambda x: x.mul(255))])

    content_img = content_transform(content_img)
    content_img = content_img.unsqueeze(0).to(device).detach()

    # Load model
    trans_model = TransformerNet().eval()
    trans_model.load_state_dict(
        t.load(DC.trans_model, map_location=lambda storage, loc: storage))

    trans_model.to(device)

    # Save the transformer image
    output = trans_model(content_img)
    output_data = output.cpu().data[0]
    tv.utils.save_image((output_data / 255).clamp(min=0, max=1), DC.result_img)
def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    content_image = Image.open(args.content)
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    content = transform(content_image)
    content = content.unsqueeze(0).to(device)

    with torch.no_grad():
        style_model = TransformerNet()
        state_dict = torch.load(args.model)

        # TODO: remove saved deprecated running_* keys in InstanceNorm
        style_model.load_state_dict(state_dict)
        style_model.to(device)

        # Forward through Image Transformation Network
        out = style_model(content).cpu()

    # Save result image
    save_image(out[0], args.out)
示例#31
0
def stylize_NYU(frame, settype):
    size = len(frame)
    print("total: ", size)
    stylized_set = []
    depth_set = []
    #models = ["candy" , "mosaic", "rain_princess", "udnie"]
    models = ["mosaic"]

    for model in models:
        style_model = TransformerNet()
        modelpath = "saved_models/" + model + ".pth"
        state_dict = torch.load(modelpath)
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model.to(device)

        for idx in range(size):
            #if idx % 1000 == 0:
            if idx % 1 == 0:
                print("iter: ", idx)
            image_name = frame.iloc[idx, 0]
            depth_name = frame.iloc[idx, 1]

            image = matplotlib.image.imread(image_name, format="jpg")
            depth = matplotlib.image.imread(depth_name)

            with torch.no_grad():
                output = single_stylize(style_model, image)
            #visualize_trio(image, depth, output)
            #stylized_set.append(output)
            #depth_set.append(depth)
            #break
            save_stylized(output, depth, image_name, depth_name, settype)
    return stylized_set, depth_set
示例#32
0
def stylize(frame):
    device = torch.device("cuda")

    content_image = frame
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    

    with torch.no_grad():
        style_model = TransformerNet()
        state_dict = torch.load('../saved_models/udnie.pth')
        # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model.to(device)
        output = style_model(content_image).cpu()
    return output.clamp(0,255)
def stylize(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    content_image = utils.load_image(args.content_image,
                                     scale=args.content_scale)
    content_image = content_image.convert('RGB')
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255)),
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    with torch.no_grad():
        style_model = TransformerNet()
        state_dict = torch.load(args.model)
        # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model.to(device)
        output = style_model(content_image).cpu()
    utils.save_image(args.output_image, output[0])
示例#34
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                  agg_content_loss / (batch_id + 1),
                                  agg_style_loss / (batch_id + 1),
                                  (agg_content_loss + agg_style_loss) / (batch_id + 1)
                )
                print(mesg)

            if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
示例#35
0
def train(**kwargs):
    opt = Config()
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    vis = utils.Visualizer(opt.env)

    # 数据加载
    transfroms = tv.transforms.Compose([
        tv.transforms.Scale(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Lambda(lambda x: x * 255)
    ])
    dataset = tv.datasets.ImageFolder(opt.data_root, transfroms)
    dataloader = data.DataLoader(dataset, opt.batch_size)

    # 转换网络
    transformer = TransformerNet()
    if opt.model_path:
        transformer.load_state_dict(t.load(opt.model_path, map_location=lambda _s, _: _s))

    # 损失网络 Vgg16
    vgg = Vgg16().eval()

    # 优化器
    optimizer = t.optim.Adam(transformer.parameters(), opt.lr)

    # 获取风格图片的数据
    style = utils.get_style_data(opt.style_path)
    vis.img('style', (style[0] * 0.225 + 0.45).clamp(min=0, max=1))

    if opt.use_gpu:
        transformer.cuda()
        style = style.cuda()
        vgg.cuda()

    # 风格图片的gram矩阵
    style_v = Variable(style, volatile=True)
    features_style = vgg(style_v)
    gram_style = [Variable(utils.gram_matrix(y.data)) for y in features_style]

    # 损失统计
    style_meter = tnt.meter.AverageValueMeter()
    content_meter = tnt.meter.AverageValueMeter()

    for epoch in range(opt.epoches):
        content_meter.reset()
        style_meter.reset()

        for ii, (x, _) in tqdm.tqdm(enumerate(dataloader)):

            # 训练
            optimizer.zero_grad()
            if opt.use_gpu:
                x = x.cuda()
            x = Variable(x)
            y = transformer(x)
            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)
            features_y = vgg(y)
            features_x = vgg(x)

            # content loss
            content_loss = opt.content_weight * F.mse_loss(features_y.relu2_2, features_x.relu2_2)

            # style loss
            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gram_y = utils.gram_matrix(ft_y)
                style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y))
            style_loss *= opt.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            # 损失平滑
            content_meter.add(content_loss.data[0])
            style_meter.add(style_loss.data[0])

            if (ii + 1) % opt.plot_every == 0:
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                # 可视化
                vis.plot('content_loss', content_meter.value()[0])
                vis.plot('style_loss', style_meter.value()[0])
                # 因为x和y经过标准化处理(utils.normalize_batch),所以需要将它们还原
                vis.img('output', (y.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1))
                vis.img('input', (x.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1))

        # 保存visdom和模型
        vis.save([opt.env])
        t.save(transformer.state_dict(), 'checkpoints/%s_style.pth' % epoch)