def stylize(args):
    if args.model.endswith(".onnx"):
        return stylize_onnx_caffe2(args)

    content_image = utils.tensor_load_rgbimage(args.content_image,
                                               scale=args.content_scale)
    content_image = content_image.unsqueeze(0)

    if args.cuda:
        content_image = content_image.cuda()
    content_image = Variable(utils.preprocess_batch(content_image),
                             requires_grad=False)
    style_model = TransformerNet()
    state_dict = torch.load(args.model)

    #    removed_modules = ['in2']
    in_names = [
        "in1.scale", "in1.shift", "in2.scale", "in2.shift", "in3.scale",
        "in3.shift", "res1.in1.scale", "res1.in1.shift", "res1.in2.scale",
        "res1.in2.shift", "res2.in1.scale", "res2.in1.shift", "res2.in2.scale",
        "res2.in2.shift", "res3.in1.scale", "res3.in1.shift", "res3.in2.scale",
        "res3.in2.shift", "res4.in1.scale", "res4.in1.shift", "res4.in2.scale",
        "res4.in2.shift", "res5.in1.scale", "res5.in1.shift", "res5.in2.scale",
        "res5.in2.shift", "in4.scale", "in4.shift", "in5.scale", "in5.shift"
    ]

    #   kl = list(state_dict.keys())
    #   for k in kl:

    for k in in_names:
        state_dict[k.replace("scale",
                             "weight").replace("shift",
                                               "bias")] = state_dict.pop(k)

    style_model.load_state_dict(state_dict)

    if args.cuda:
        style_model.cuda()

    if args.half:
        style_model.half()
        content_image = content_image.half()

    if args.export_onnx:
        assert args.export_onnx.endswith(
            ".onnx"), "Export model file should end with .onnx"
        output = torch.onnx._export(style_model, content_image,
                                    args.export_onnx)
    else:
        output = style_model(content_image)

    if args.half:
        output = output.float()

    utils.tensor_save_bgrimage(output.data[0], args.output_image, args.cuda)
Example #2
0
def multi_style(path,
                width=320,
                device=device,
                cycle_length=np.inf,
                half_precision=False,
                rotate=0):
    model_iter = itertools.cycle(os.listdir(path))
    model_file = next(model_iter)
    print(f'Using {model_file} ')
    model_path = os.path.join(path, model_file)
    model = TransformerNet()
    model.load_state_dict(read_state_dict(model_path))
    model.to(device)
    if half_precision:
        model.half()
    vs = VideoStream(src=0).start()
    time.sleep(2.0)
    timer = Timer()
    cycle_begin = time.time()
    while (True):
        frame = vs.read()
        if frame is None:
            frame = np.random.randint(0,
                                      255, (int(width / 1.5), width, 3),
                                      dtype=np.uint8)
        frame = cv2.flip(frame, 1)
        frame = resize(frame, width=width)

        # Style the frame
        img = style_frame(frame, model, device, half_precision)
        img = cv2.resize(img[:, :, ::-1], (640, 480))

        # rotate
        if rotate > 0:
            img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
        elif rotate < 0:
            img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
        # print(img.shape)
        cv2.imshow("Output", img)
        timer()
        key = cv2.waitKey(1) & 0xFF
        if key == ord("n") or (time.time() - cycle_begin) > cycle_length:
            model_file = next(model_iter)
            print(f'Using {model_file} ')
            model_path = os.path.join(path, model_file)
            model.load_state_dict(read_state_dict(model_path))
            model.to(device)
            cycle_begin = time.time()
        elif key == ord("q"):
            break
Example #3
0
def multi_style(path,
                width=320,
                device=device,
                cycle_length=np.inf,
                half_precision=False,
                rotate=0,
                camera=0,
                cutoff=0):
    if path.is_file():
        model_iter = itertools.cycle([os.path.basename(path)])
        path = os.path.dirname(path)
    else:
        model_iter = itertools.cycle(os.listdir(path))
    model_file = next(model_iter)
    print(f'Using {model_file} ')
    model_path = os.path.join(path, model_file)
    model = TransformerNet()
    model.load_state_dict(read_state_dict(model_path))
    model.to(device)
    if half_precision:
        model.half()

    # attempts to load jetcam for Jetson Nano, if fails uses normal camera.
    if rotate != 0:
        width = int(width / .75)

    height = int(width * .75)
    if camera < 0:
        #from jetcam.csi_camera import CSICamera
        #vs = CSICamera(width=width, height=int(width/1.5), capture_width=1080, capture_height=720, capture_fps=15)
        #vs.read()
        print('Using CSI camera')
        vs = cv2.VideoCapture(
            gstreamer_pipeline(capture_width=width,
                               capture_height=height,
                               display_width=width,
                               display_height=height), cv2.CAP_GSTREAMER)
        time.sleep(2.0)
        img = vs.read()
        assert img[1] is not None

    else:
        print('Using USB camera')
        vs = VideoStream(src=camera, resolution=(width, height)).start()
        time.sleep(2.0)
    if rotate != 0:
        width = int(width * .75)

    timer = Timer()
    cycle_begin = time.time()
    while (True):
        frame = vs.read()
        if frame is None:
            frame = np.random.randint(0,
                                      255, (int(width / 1.5), width, 3),
                                      dtype=np.uint8)

        if type(frame) is type(()):
            frame = frame[1]

        frame = cv2.flip(frame, 1)
        frame = resize(frame, width=width)
        # Style the frame
        img = style_frame(frame, model, device, half_precision).numpy()
        img = np.clip(img, 0, 255)
        img = img.astype(np.uint8)

        img = img.transpose(1, 2, 0)
        img = img[:, :, ::-1]
        # rotate
        if rotate != 0:
            h, w, _ = img.shape
            margin = int(w - h * h / w) // 2
            img = img[:, margin:-margin, :]

        if cutoff > 0:
            margin = int(cutoff * img.shape[1]) // 2
            img = img[:, margin:-margin, :]
        elif cutoff < 0:
            margin = int(-cutoff * img.shape[0]) // 2
            img = img[margin:-margin, :, :]
        # print(img.shape)
        cv2.imshow("Output", img)
        timer()
        key = cv2.waitKey(1) & 0xFF
        if key == ord("n") or (time.time() - cycle_begin) > cycle_length:
            model_file = next(model_iter)
            print(f'Using {model_file} ')
            model_path = os.path.join(path, model_file)
            model.load_state_dict(read_state_dict(model_path))
            model.to(device)
            cycle_begin = time.time()
        elif key == ord("q"):
            break