def test_preconditioners():
    atol = 1e-6
    index = 1
    scale = 0.2

    name = f"GT{index:02d}"
    # print(name)

    image_dir = "data"

    image = load_image(f"{image_dir}/input_training_lowres/{name}.png", "rgb",
                       scale, "bilinear")
    trimap = load_image(
        f"{image_dir}/trimap_training_lowres/Trimap1/{name}.png",
        "gray",
        scale,
        "nearest",
    )

    A, b = make_linear_system(cf_laplacian(image), trimap)

    preconditioners = [
        ("no", lambda A: None),
        ("jacobi", lambda A: jacobi(A)),
        ("icholt", lambda A: ichol(A, max_nnz=500000)),
        ("vcycle", lambda A: vcycle(A, trimap.shape)),
    ]

    expected_iterations = {
        "no": 532,
        "jacobi": 250,
        "icholt": 3,
        "vcycle": 88,
    }

    for preconditioner_name, preconditioner in preconditioners:
        callback = CounterCallback()

        M = preconditioner(A)

        x = cg(A, b, M=M, atol=atol, rtol=0, maxiter=10000, callback=callback)

        r = b - A.dot(x)

        norm_r = np.linalg.norm(r)

        assert norm_r <= atol

        n_expected = expected_iterations[preconditioner_name]

        if callback.n > n_expected:
            print(
                "WARNING: Unexpected number of iterations. Expected %d, but got %d"
                % (n_expected, callback.n))

        assert callback.n <= n_expected
Esempio n. 2
0
def main():
    print("loading images")

    size = (34, 22)
    size = (680, 440)
    image = np.array(
        Image.open("images/lemur.png").convert("RGB").resize(
            size, Image.BOX)) / 255.0
    trimap = np.array(
        Image.open("images/lemur_trimap.png").convert("L").resize(
            size, Image.NEAREST)) / 255.0

    is_fg = trimap == 1.0
    is_bg = trimap == 0.0
    is_known = is_fg | is_bg
    is_unknown = ~is_known

    b = 100.0 * is_fg.flatten()
    c = 100.0 * is_known.flatten()

    shape = trimap.shape
    h, w = shape

    L = cf_laplacian(image)
    C = scipy.sparse.diags(c)
    A = L + C

    M = vcycle(A, (h, w))

    x = cg(A, b, M=M, callback=ProgressCallback())

    print("\nbaseline:")
    print("iteration      69 - 2.690681e-03 (0.00269068076571873623)")

    alpha = np.clip(x, 0, 1).reshape(h, w)

    import matplotlib.pyplot as plt
    for i, img in enumerate([image, trimap, alpha]):
        plt.subplot(1, 3, 1 + i)
        plt.imshow(img, cmap='gray')
        plt.axis('off')
    plt.show()