def run(image, mask, ambient_intensity, light_intensity, light_source_height,
        gamma_correction, stroke_density_clipping, light_color_red,
        light_color_green, light_color_blue, enabling_multiple_channel_effects,
        x, y):
    # Some pre-processing to resize images and remove input JPEG artifacts.
    raw_image = min_resize(image, 512)
    print("Preprocess srcnn...")
    memory_use()
    raw_image = run_srcnn(raw_image)
    print("Completed.")
    memory_use()
    raw_image = min_resize(raw_image, 512)
    raw_image = raw_image.astype(np.float32)
    unmasked_image = raw_image.copy()

    if mask is not None:
        alpha = np.mean(d_resize(mask, raw_image.shape).astype(np.float32) /
                        255.0,
                        axis=2,
                        keepdims=True)
        raw_image = unmasked_image * alpha

    # Compute the convex-hull-like palette.
    h, w, c = raw_image.shape
    flattened_raw_image = raw_image.reshape((h * w, c))
    raw_image_center = np.mean(flattened_raw_image, axis=0)
    hull = ConvexHull(flattened_raw_image)

    # Estimate the stroke density map.
    print("Preprocess trimesh...")
    memory_use()
    intersector = trimesh.Trimesh(faces=hull.simplices,
                                  vertices=hull.points).ray
    start = np.tile(raw_image_center[None, :], [h * w, 1])
    direction = flattened_raw_image - start
    print("Completed.")
    memory_use()

    print('Begin ray intersecting ...')
    index_tri, index_ray, locations = intersector.intersects_id(
        start, direction, return_locations=True, multiple_hits=True)
    print('Intersecting finished.')
    intersections = np.zeros(shape=(h * w, c), dtype=np.float32)
    intersection_count = np.zeros(shape=(h * w, 1), dtype=np.float32)
    CI = index_ray.shape[0]
    for c in range(CI):
        i = index_ray[c]
        intersection_count[i] += 1
        intersections[i] += locations[c]
    intersections = (intersections + 1e-10) / (intersection_count + 1e-10)
    intersections = intersections.reshape((h, w, 3))
    intersection_count = intersection_count.reshape((h, w))
    intersections[intersection_count < 1] = raw_image[intersection_count < 1]
    intersection_distance = np.sqrt(
        np.sum(np.square(intersections - raw_image_center[None, None, :]),
               axis=2,
               keepdims=True))
    pixel_distance = np.sqrt(
        np.sum(np.square(raw_image - raw_image_center[None, None, :]),
               axis=2,
               keepdims=True))
    stroke_density = (
        (1.0 - np.abs(1.0 - pixel_distance / intersection_distance)) *
        stroke_density_clipping).clip(0, 1) * 255

    # A trick to improve the quality of the stroke density map.
    # It uses guided filter to remove some possible artifacts.
    # You can remove these codes if you like sharper effects.
    guided_filter = createGuidedFilter(
        pixel_distance.clip(0, 255).astype(np.uint8), 1, 0.01)
    for _ in range(4):
        stroke_density = guided_filter.filter(stroke_density)

    # Visualize the estimated stroke density.
    cv2.imwrite('stroke_density.png',
                stroke_density.clip(0, 255).astype(np.uint8))

    # Then generate the lighting effects
    raw_image = unmasked_image.copy()
    lighting_effect = np.stack([
        generate_lighting_effects(stroke_density, raw_image[:, :, 0]),
        generate_lighting_effects(stroke_density, raw_image[:, :, 1]),
        generate_lighting_effects(stroke_density, raw_image[:, :, 2])
    ],
                               axis=2)

    gx = -float(x % w) / float(w) * 2.0 + 1.0
    gy = -float(y % h) / float(h) * 2.0 + 1.0

    light_source_color = np.array(
        [light_color_blue, light_color_green, light_color_red])
    light_source_location = np.array([[[light_source_height, gy, gx]]],
                                     dtype=np.float32)
    light_source_direction = light_source_location / np.sqrt(
        np.sum(np.square(light_source_location)))
    final_effect = np.sum(lighting_effect * light_source_direction,
                          axis=3).clip(0, 1)
    if not enabling_multiple_channel_effects:
        final_effect = np.mean(final_effect, axis=2, keepdims=True)
    rendered_image = (ambient_intensity + final_effect *
                      light_intensity) * light_source_color * raw_image
    rendered_image = ((rendered_image / 255.0)**gamma_correction) * 255.0
    cv2.imwrite("target.png", rendered_image)
    print('Completed.')
Example #2
0
def refine_image(image, sketch, origin):
    verbose = False

    def cv_log(name, img):
        if verbose:
            print(name)
            cv2.imshow('cv_log', img.clip(0, 255).astype(np.uint8))
            cv2.imwrite('cv_log.png', img.clip(0, 255).astype(np.uint8))
            cv2.waitKey(0)

    print('Building Sparse Matrix ...')
    sketch = sketch.astype(np.float32)
    sparse_matrix = build_sketch_sparse(sketch, True)
    bright_matrix = build_sketch_sparse(sketch - cv2.GaussianBlur(sketch, (0, 0), 3.0), False)
    guided_matrix = createGuidedFilter(sketch.clip(0, 255).astype(np.uint8), 1, 0.01)
    HDRL, HDRM, HDRH = get_hdr(image)

    def go_guide(x):
        y = x + (x - cv2.GaussianBlur(x, (0, 0), 1)) * 2.0
        for _ in tqdm(range(4)):
            y = guided_matrix.filter(y)
        return y

    def go_refine_sparse(x):
        return session.run(tf_sparse_op_H, feed_dict={ipsp3: x, ipsp9: sparse_matrix})

    def go_refine_bright(x):
        return session.run(tf_sparse_op_L, feed_dict={ipsp3: x, ipsp9: bright_matrix})

    def go_flat(x):
        pia = 32
        y = x.clip(0, 255).astype(np.uint8)
        y = cv2.resize(y, (x.shape[1] // 2, x.shape[0] // 2), interpolation=cv2.INTER_AREA)
        y = np.pad(y, ((pia, pia), (pia, pia), (0, 0)), 'reflect')
        y = l0Smooth(y, None, 0.01)
        y = y[pia:-pia, pia:-pia, :]
        y = cv2.resize(y, (x.shape[1], x.shape[0]), interpolation=cv2.INTER_CUBIC)
        return y

    def go_hdr(x):
        xl, xm, xh = get_hdr(x)
        y = f2(xl, xm, xh, HDRL, HDRM, HDRH, x)
        return y.clip(0, 255)

    def go_blend(BGR, X, m):
        BGR = BGR.clip(0, 255).astype(np.uint8)
        X = X.clip(0, 255).astype(np.uint8)
        YUV = cv2.cvtColor(BGR, cv2.COLOR_BGR2YUV)
        s_l = YUV[:, :, 0].astype(np.float32)
        t_l = X.astype(np.float32)
        r_l = (s_l * t_l / 255.0) if m else np.minimum(s_l, t_l)
        YUV[:, :, 0] = r_l.clip(0, 255).astype(np.uint8)
        return cv2.cvtColor(YUV, cv2.COLOR_YUV2BGR)

    print('Getting Target ...')
    smoothed = d_resize(image, sketch.shape)
    print('Global Optimization ...')
    cv_log('smoothed', smoothed)
    sparse_smoothed = go_refine_sparse(smoothed)
    cv_log('smoothed', sparse_smoothed)
    smoothed = go_guide(sparse_smoothed)
    cv_log('smoothed', smoothed)
    smoothed = go_hdr(smoothed)
    cv_log('smoothed', smoothed)
    print('Decomposition Optimization ...')
    flat = sparse_smoothed.copy()
    cv_log('flat', flat)
    flat = go_refine_bright(flat)
    cv_log('flat', flat)
    flat = go_flat(flat)
    cv_log('flat', flat)
    flat = go_refine_sparse(flat)
    cv_log('flat', flat)
    flat = go_guide(flat)
    cv_log('flat', flat)
    flat = go_hdr(flat)
    cv_log('flat', flat)
    print('Blending Optimization ...')
    cv_log('origin', origin)
    blended_smoothed = go_blend(smoothed, origin, False)
    cv_log('blended_smoothed', blended_smoothed)
    blended_flat = go_blend(flat, origin, True)
    cv_log('blended_flat', blended_flat)
    print('Optimization finished.')
    return smoothed, flat, blended_smoothed, blended_flat
Example #3
0
def run(image, mask, ambient_intensity, light_intensity, light_source_height, gamma_correction, stroke_density_clipping, light_color_red, light_color_green, light_color_blue, enabling_multiple_channel_effects):


    wandb.init(entity='ayush-thakur', project='paintlight')
    
    # Some pre-processing to resize images and remove input JPEG artifacts.
    raw_image = min_resize(image, 512)
    raw_image = run_srcnn(raw_image)
    raw_image = min_resize(raw_image, 512)
    raw_image = raw_image.astype(np.float32)
    unmasked_image = raw_image.copy()

    if mask is not None:
        alpha = np.mean(d_resize(mask, raw_image.shape).astype(np.float32) / 255.0, axis=2, keepdims=True)
        raw_image = unmasked_image * alpha

    # Compute the convex-hull-like palette.
    h, w, c = raw_image.shape
    flattened_raw_image = raw_image.reshape((h * w, c))
    raw_image_center = np.mean(flattened_raw_image, axis=0)
    hull = ConvexHull(flattened_raw_image)

    # Estimate the stroke density map.
    intersector = trimesh.Trimesh(faces=hull.simplices, vertices=hull.points).ray
    start = np.tile(raw_image_center[None, :], [h * w, 1])
    direction = flattened_raw_image - start
    print('Begin ray intersecting ...')
    index_tri, index_ray, locations = intersector.intersects_id(start, direction, return_locations=True, multiple_hits=True)
    print('Intersecting finished.')
    intersections = np.zeros(shape=(h * w, c), dtype=np.float32)
    intersection_count = np.zeros(shape=(h * w, 1), dtype=np.float32)
    CI = index_ray.shape[0]
    for c in range(CI):
        i = index_ray[c]
        intersection_count[i] += 1
        intersections[i] += locations[c]
    intersections = (intersections + 1e-10) / (intersection_count + 1e-10)
    intersections = intersections.reshape((h, w, 3))
    intersection_count = intersection_count.reshape((h, w))
    intersections[intersection_count < 1] = raw_image[intersection_count < 1]
    intersection_distance = np.sqrt(np.sum(np.square(intersections - raw_image_center[None, None, :]), axis=2, keepdims=True))
    pixel_distance = np.sqrt(np.sum(np.square(raw_image - raw_image_center[None, None, :]), axis=2, keepdims=True))
    stroke_density = ((1.0 - np.abs(1.0 - pixel_distance / intersection_distance)) * stroke_density_clipping).clip(0, 1) * 255

    # A trick to improve the quality of the stroke density map.
    # It uses guided filter to remove some possible artifacts.
    # You can remove these codes if you like sharper effects.
    guided_filter = createGuidedFilter(pixel_distance.clip(0, 255).astype(np.uint8), 1, 0.01)
    for _ in range(4):
        stroke_density = guided_filter.filter(stroke_density)

    # Visualize the estimated stroke density.
    cv2.imwrite('stroke_density.png', stroke_density.clip(0, 255).astype(np.uint8))

    # Then generate the lighting effects
    raw_image = unmasked_image.copy()
    lighting_effect = np.stack([
        generate_lighting_effects(stroke_density, raw_image[:, :, 0]),
        generate_lighting_effects(stroke_density, raw_image[:, :, 1]),
        generate_lighting_effects(stroke_density, raw_image[:, :, 2])
    ], axis=2)

    light_source_color = np.array([light_color_blue, light_color_green, light_color_red])


    ## points in circle  
    def PointsInCircum(r,n=10):
        return [(math.cos(2*pi/n*x)*r,math.sin(2*pi/n*x)*r) for x in range(0,n+1)]


    ## Log images as gif
    def log_gif(ims, log_name):
        ims[0].save('light.gif', save_all=True, append_images=ims[1:], duration=40, loop=0)
        wandb.log({"{}".format(log_name): wandb.Video('light.gif', fps=4, format="gif")})


    ## Apply lightening effect
    def apply_light(gx, gy, log_name):
        light_source_location = np.array([[[light_source_height, gy, gx]]], dtype=np.float32)
        light_source_direction = light_source_location / np.sqrt(np.sum(np.square(light_source_location)))
        final_effect = np.sum(lighting_effect * light_source_direction, axis=3).clip(0, 1)
        if not enabling_multiple_channel_effects:
            final_effect = np.mean(final_effect, axis=2, keepdims=True)
        rendered_image = (ambient_intensity + final_effect * light_intensity) * light_source_color * raw_image
        rendered_image = ((rendered_image / 255.0) ** gamma_correction) * 255.0
        
        canvas = rendered_image.clip(0,255).astype(np.uint8)
        canvas = cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)
    
        print("[INFO] gx is: {} | gy is: {}".format(gx, gy))
        print("[INFO] Logging image to wandb")
        wandb.log({'{}'.format(log_name): [wandb.Image(canvas)]})

        return canvas


    ## Move across x-axis
    gx_samples_horizontal = np.arange(-0.99, 0.99, 0.1)
    gy_samples_horizontal = np.repeat(np.random.uniform(-0.35, 0.65, 1), len(gx_samples_horizontal))

    ## Move across y-axis
    gy_samples_vertical = np.arange(-0.99, 0.99, 0.1)
    gx_samples_vertical = np.repeat(np.random.uniform(-0.35, 0.65, 1), len(gy_samples_vertical))

    ## Move in circular motion
    circlepoints = PointsInCircum(r=0.7, n=20)

    ## Original Image and Stroke Density
    original_image = raw_image.copy().clip(0, 255).astype(np.uint8)
    original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)

    stroke_density_log = stroke_density.clip(0, 255).astype(np.uint8)
    stroke_density_log = cv2.cvtColor(stroke_density_log, cv2.COLOR_BGR2RGB)

    wandb.log({"original-image": [wandb.Image(original_image)]})
    wandb.log({"stroke-density": [wandb.Image(stroke_density_log)]})
    
    ## Apply light horizontally
    ims= []
    for gx, gy in zip(gx_samples_horizontal, gy_samples_horizontal):
        im = apply_light(gx, gy, 'swipe_across_horizontallyn_015')
        ims.append(Image.fromarray(im))
    log_gif(ims, 'swipe_across_horizontallyn_gif_015')
    
    ## Apply light vertically
    ims= []
    for gx, gy in zip(gx_samples_horizontal, gy_samples_horizontal):
        im = apply_light(gx, gy, 'swipe_across_verticallyn_015')
        ims.append(Image.fromarray(im))
    log_gif(ims, 'swipe_across_verticallyn_gif_015')

    ## Apply light in circular manner
    ims= []
    for gx, gy in circlepoints:
        im = apply_light(gy, gy, 'move_circlen_015')
        ims.append(Image.fromarray(im))
    log_gif(ims, 'move_circlen_gif_015')
def run(image, mask, ambient_intensity, light_intensity, light_source_height,
        stroke_density_clipping, light_color_red, light_color_green,
        light_color_blue):
    raw_image = min_resize(image, 512)
    raw_image = raw_image.astype(np.float32)
    unmasked_image = raw_image.copy()

    if mask is not None:
        alpha = np.mean(d_resize(mask, raw_image.shape).astype(np.float32) /
                        255.0,
                        axis=2,
                        keepdims=True)
        raw_image = unmasked_image * alpha

    h, w, c = raw_image.shape
    flattened_raw_image = raw_image.reshape(h * w, c)
    raw_image_center = np.mean(flattened_raw_image, axis=0)
    hull = ConvexHull(flattened_raw_image)

    intersector = trimesh.Trimesh(faces=hull.simplices,
                                  vertices=hull.points).ray
    start = np.tile(raw_image_center[:], [h * w, 1])
    direction = flattened_raw_image - start
    print('Begin ray intersection ...')
    index_tri, index_ray, locations = intersector.intersects_id(
        start, direction, return_locations=True, multiple_hits=True)
    print('Intersecting finished.')
    intersections = np.zeros((h * w, c), dtype=np.float32)
    intersection_count = np.zeros((h * w, 1), dtype=np.float32)
    CI = index_ray.shape[0]
    for c in range(CI):
        i = index_ray[c]
        intersection_count[i] += 1
        intersections[i] += locations[c]
    intersections = (intersections + 1e-10) / (intersection_count + 1e-10)
    intersections = intersections.reshape((h, w, 3))
    intersection_count = intersection_count.reshape((h, w))
    intersections[intersection_count < 1] = raw_image[intersection_count < 1]
    intersection_distance = np.sqrt(
        np.sum(np.square(intersections - raw_image_center[None, None, :]),
               axis=2,
               keepdims=True))
    pixel_distance = np.sqrt(
        np.sum(np.square(raw_image - raw_image_center[None, None, :]),
               axis=2,
               keepdims=True))
    stroke_density = (np.abs(pixel_distance / intersection_distance) *
                      stroke_density_clipping).clip(0, 1) * 255

    guided_filter = createGuidedFilter(
        pixel_distance.clip(0, 255).astype(np.uint8), 1, 0.01)
    for _ in range(4):
        stroke_density = guided_filter.filter(stroke_density)

    cv2.imwrite('stroke_density.png',
                stroke_density.clip(0, 255).astype(np.uint8))

    raw_image = unmasked_image.copy()
    lighting_effect = np.stack([
        generate_lighting_effects(stroke_density, raw_image[:, :, 0]),
        generate_lighting_effects(stroke_density, raw_image[:, :, 1]),
        generate_lighting_effects(stroke_density, raw_image[:, :, 2])
    ],
                               axis=2)

    def update_mouse(event, x, y, flags, param):
        global gx
        global gy
        gy = -float(x % w) / float(w) * 2.0 + 1.0
        gx = -float(y % h) / float(h) * 2.0 + 1.0
        return

    light_source_color = np.array(
        [light_color_blue, light_color_green, light_color_red])

    global gx
    global gy

    while True:
        light_source_location = np.array([[[light_source_height, gy, gx]]],
                                         dtype=np.float32)
        light_source_direction = light_source_location / np.sqrt(
            np.sum(np.square(light_source_location)))
        final_effect = np.sum(lighting_effect * light_source_direction,
                              axis=3).clip(0, 1)
        rendered_image = (ambient_intensity + final_effect *
                          light_intensity) * light_source_color * raw_image
        canvas = np.concatenate([raw_image, rendered_image],
                                axis=1).clip(0, 255).astype(np.uint8)
        # cv2.namedWindow('Result', cv2.WINDOW_NORMAL)
        cv2.imshow('Result', canvas)
        cv2.setMouseCallback('Result', update_mouse)
        cv2.waitKey(10)