Beispiel #1
0
def paint_until_object_detected(img,
                                actor_fn,
                                renderer_fn,
                                true_class,
                                classifier,
                                div=1,
                                max_big_strokes=750,
                                max_small_strokes=750,
                                img_width=128,
                                white_canvas=True,
                                built_in_features=False):
    global Decoder, width, divide, canvas_cnt, origin_shape
    divide = div

    from sketchy.classifier import SketchyClassifier

    width = img_width

    max_big_step = int(max_big_strokes / 5)
    max_small_step = int(max_small_strokes / 5 / divide**2)
    max_step = max_big_step + max_small_step

    Decoder = FCN()
    Decoder.load_state_dict(torch.load(renderer_fn))

    if built_in_features:
        actor = ResNet(10, 18, 65)  # action_bundle = 5, 65 = 5 * 13
        actor.load_state_dict(torch.load(actor_fn))
    else:
        actor = ResNet(9, 18, 65)  # action_bundle = 5, 65 = 5 * 13
        actor.load_state_dict(torch.load(actor_fn))
    actor = actor.to(device).eval()
    Decoder = Decoder.to(device).eval()

    imgid = 0
    canvas_cnt = divide * divide
    T = torch.ones([1, 1, width, width], dtype=torch.float32).to(device)
    img = cv2.imread(img, cv2.IMREAD_COLOR)
    origin_shape = (img.shape[1], img.shape[0])

    coord = torch.zeros([1, 2, width, width])
    for i in range(width):
        for j in range(width):
            coord[0, 0, i, j] = i / (width - 1.)
            coord[0, 1, i, j] = j / (width - 1.)
    coord = coord.to(device)  # Coordconv

    canvas = torch.zeros([1, 3, width, width]).to(device)
    if white_canvas:
        canvas = torch.ones([1, 3, width, width]).to(device)

    patch_img = cv2.resize(img, (width * divide, width * divide))
    patch_img = large2small(patch_img)
    patch_img = np.transpose(patch_img, (0, 3, 1, 2))
    patch_img = torch.tensor(patch_img).to(device).float() / 255.

    img = cv2.resize(img, (width, width))
    mask = None
    if built_in_features:
        mask = get_l2_mask(
            torch.unsqueeze(
                torch.tensor(np.transpose(img.astype('float32'),
                                          (2, 0, 1))), 0) / 255)[:, 0, :, :]
        mask = mask.unsqueeze(0)
    img = img.reshape(1, width, width, 3)
    img = np.transpose(img, (0, 3, 1, 2))
    img = torch.tensor(img).to(device).float() / 255.

    with torch.no_grad():
        for i in range(max_big_step):
            stepnum = T * i / max_step
            if built_in_features:
                state = torch.cat(
                    [canvas, img, mask.float(), stepnum, coord], 1)
            else:
                state = torch.cat([canvas, img, stepnum, coord], 1)
            actions = actor(state)
            canvas, res = decode(actions, canvas, discrete_colors=False)
            for j in range(5):
                imgid += 1
                c = res_to_img(res[j], imgid)[..., ::-1]
                if ((imgid % 5) == 0) and (imgid > 10):
                    c_norm = cv2.resize(c,
                                        (classifier.width, classifier.width))
                    c_norm = classifier.normalize(c_norm)
                    pred_class, confidence = classifier.classify(
                        c_norm.unsqueeze(0).to(device))
                    if pred_class[0] == true_class and confidence[0] > 0.013:
                        return imgid, c
        if divide != 1:
            canvas = canvas[0].detach().cpu().numpy()
            canvas = np.transpose(canvas, (1, 2, 0))
            canvas = cv2.resize(canvas, (width * divide, width * divide))
            canvas = large2small(canvas)
            canvas = np.transpose(canvas, (0, 3, 1, 2))
            canvas = torch.tensor(canvas).to(device).float()
            coord = coord.expand(canvas_cnt, 2, width, width)

            T = T.expand(canvas_cnt, 1, width, width)
            for i in range(max_small_step):
                stepnum = T * i / max_step
                if built_in_features:
                    state = torch.cat(
                        [canvas, patch_img, mask, stepnum, coord], 1)
                else:
                    state = torch.cat([canvas, patch_img, stepnum, coord], 1)
                actions = actor(state)
                canvas, res = decode(actions, canvas, discrete_colors=False)
                for j in range(5):
                    imgid += divide**2
                    c = res_to_img(res[j], imgid, divide=True)[..., ::-1]
                    c_norm = cv2.resize(c,
                                        (classifier.width, classifier.width))
                    c_norm = classifier.normalize(c_norm)
                    pred_class, confidence = classifier.classify(
                        c_norm.unsqueeze(0).to(device))
                    if pred_class[0] == true_class and confidence[0] > 0.013:
                        return imgid, c
    return None, c
Beispiel #2
0
patch_img = np.transpose(patch_img, (0, 3, 1, 2))
patch_img = torch.tensor(patch_img).to(device).float() / 255.

img = cv2.resize(img, (width, width))
img = img.reshape(1, width, width, 3)
img = np.transpose(img, (0, 3, 1, 2))
img = torch.tensor(img).to(device).float() / 255.

os.system('mkdir output')

with torch.no_grad():
    if args.divide != 1:
        args.max_step = args.max_step // 2
    for i in range(args.max_step):
        stepnum = T * i / args.max_step
        actions = actor(torch.cat([canvas, img, stepnum, coord], 1))
        canvas, res = decode(actions, canvas)
        print('canvas step {}, L2Loss = {}'.format(i, ((canvas - img) ** 2).mean()))
        for j in range(5):
            save_img(res[j], args.imgid)
            args.imgid += 1
    if args.divide != 1:
        canvas = canvas[0].detach().cpu().numpy()
        canvas = np.transpose(canvas, (1, 2, 0))
        canvas = cv2.resize(canvas, (width * args.divide, width * args.divide))
        canvas = large2small(canvas)
        canvas = np.transpose(canvas, (0, 3, 1, 2))
        canvas = torch.tensor(canvas).to(device).float()
        coord = coord.expand(canvas_cnt, 2, width, width)
        T = T.expand(canvas_cnt, 1, width, width)
        for i in range(args.max_step):
Beispiel #3
0
def paint(actor_fn,
          renderer_fn,
          max_step=40,
          div=5,
          img_width=128,
          img='../image/vangogh.png',
          discrete_colors=True,
          n_colors=10,
          white_canvas=True,
          built_in_features=False,
          use_multiple_renderers=False):
    global Decoder, width, divide, canvas_cnt, origin_shape
    width = img_width
    divide = div

    if not use_multiple_renderers:
        Decoder = FCN()
        Decoder.load_state_dict(torch.load(renderer_fn))
        Decoder = Decoder.to(device).eval()
    else:
        global decoders, decoder_cutoff
        from DRL.ddpg import decoders, decoder_cutoff

    if built_in_features:
        actor = ResNet(10, 18, 65)  # action_bundle = 5, 65 = 5 * 13
        actor.load_state_dict(torch.load(actor_fn))
    else:
        actor = ResNet(9, 18, 65)  # action_bundle = 5, 65 = 5 * 13
        actor.load_state_dict(torch.load(actor_fn))
    actor = actor.to(device).eval()

    # Get the allowed colors if it's supposed to be discrete
    if discrete_colors:
        color_cluster(img, n_colors)

    imgid = 0
    canvas_cnt = divide * divide
    T = torch.ones([1, 1, width, width], dtype=torch.float32).to(device)
    img = cv2.imread(img, cv2.IMREAD_COLOR)
    origin_shape = (img.shape[1], img.shape[0])

    coord = torch.zeros([1, 2, width, width])
    for i in range(width):
        for j in range(width):
            coord[0, 0, i, j] = i / (width - 1.)
            coord[0, 1, i, j] = j / (width - 1.)
    coord = coord.to(device)  # Coordconv

    canvas = torch.zeros([1, 3, width, width]).to(device)
    if white_canvas:
        canvas = torch.ones([1, 3, width, width]).to(device)
    canvas_discrete = canvas.detach().clone()

    patch_img = cv2.resize(img, (width * divide, width * divide))
    patch_img = large2small(patch_img)
    patch_img = np.transpose(patch_img, (0, 3, 1, 2))
    patch_img = torch.tensor(patch_img).to(device).float() / 255.

    img = cv2.resize(img, (width, width))
    mask = None
    if built_in_features:
        mask = get_l2_mask(
            torch.unsqueeze(
                torch.tensor(np.transpose(img.astype('float32'),
                                          (2, 0, 1))), 0) / 255)[:, 0, :, :]
        mask = mask.unsqueeze(0)
    img = img.reshape(1, width, width, 3)
    img = np.transpose(img, (0, 3, 1, 2))
    img = torch.tensor(img).to(device).float() / 255.

    actions_whole = None
    actions_divided = None
    all_canvases = []

    with torch.no_grad():
        if divide != 1:
            max_step = max_step // 2
        for i in range(max_step):
            stepnum = T * i / max_step

            if built_in_features:
                state = torch.cat(
                    [canvas, img, mask.float(), stepnum, coord], 1)
            else:
                state = torch.cat([canvas, img, stepnum, coord], 1)
            actions = actor(state)

            # Use the non discrete canvas for acting, but save the discrete canvas if painting with finite colors
            if use_multiple_renderers:
                canvas_discrete, res_discrete = decode_multiple_renderers(
                    actions,
                    canvas_discrete,
                    i,
                    discrete_colors=discrete_colors)
            else:
                canvas_discrete, res_discrete = decode(
                    actions, canvas_discrete, discrete_colors=discrete_colors)

            if use_multiple_renderers:
                canvas, res = decode_multiple_renderers(actions,
                                                        canvas,
                                                        i,
                                                        discrete_colors=False)
            else:
                canvas, res = decode(actions, canvas, discrete_colors=False)

            if actions_whole is None:
                actions_whole = actions
            else:
                actions_whole = torch.cat([actions_whole, actions], 1)
            for j in range(5):
                # save_img(res[j], imgid)
                # plot_canvas(res[j], imgid)
                all_canvases.append(
                    res_to_img(res_discrete[j], imgid)[..., ::-1])
                imgid += 1
        if divide != 1:
            canvas = canvas[0].detach().cpu().numpy()
            canvas = np.transpose(canvas, (1, 2, 0))
            canvas = cv2.resize(canvas, (width * divide, width * divide))
            canvas = large2small(canvas)
            canvas = np.transpose(canvas, (0, 3, 1, 2))
            canvas = torch.tensor(canvas).to(device).float()
            coord = coord.expand(canvas_cnt, 2, width, width)

            canvas_discrete = canvas_discrete[0].detach().cpu().numpy()
            canvas_discrete = np.transpose(canvas_discrete, (1, 2, 0))
            canvas_discrete = cv2.resize(canvas_discrete,
                                         (width * divide, width * divide))
            canvas_discrete = large2small(canvas_discrete)
            canvas_discrete = np.transpose(canvas_discrete, (0, 3, 1, 2))
            canvas_discrete = torch.tensor(canvas_discrete).to(device).float()

            T = T.expand(canvas_cnt, 1, width, width)
            for i in range(max_step):
                stepnum = T * i / max_step
                if built_in_features:
                    state = torch.cat(
                        [canvas, patch_img, mask, stepnum, coord], 1)
                else:
                    state = torch.cat([canvas, patch_img, stepnum, coord], 1)
                actions = actor(state)
                canvas_discrete, res_discrete = decode(
                    actions, canvas_discrete, discrete_colors=discrete_colors)
                canvas, res = decode(actions, canvas, discrete_colors=False)
                if actions_divided is None:
                    actions_divided = actions
                else:
                    actions_divided = torch.cat([actions_divided, actions], 1)

                for j in range(5):
                    # save_img(res[j], imgid, True)
                    # plot_canvas(res[j], imgid, True)
                    all_canvases.append(
                        res_to_img(res_discrete[j], imgid, True)[..., ::-1])
                    imgid += 1

        final_result = res_to_img(res_discrete[-1], imgid, True)[..., ::-1]
    return actions_whole, actions_divided, all_canvases, final_result