Example #1
0
    def test_equil(self):
        """Test equilibration.
        """
        from proximal.algorithms.equil import newton_equil
        np.random.seed(1)
        kernel = np.array([1, 1, 1]) / np.sqrt(3)
        kernel_mat = np.ones((3, 3)) / np.sqrt(3)
        x = px.Variable(3)
        wr = np.array([10, 5, 7])
        K = px.mul_elemwise(wr, x)
        K = px.conv(kernel, K)
        wl = np.array([100, 50, 3])
        K = px.mul_elemwise(wl, K)
        K = px.CompGraph(K)

        # Equilibrate
        gamma = 1e-1
        d, e = px.equil(K, 1000, gamma=gamma, M=5)
        tmp = d * wl * kernel_mat * wr * e
        u, v = np.log(d), np.log(e)
        obj_val = np.square(tmp).sum() / 2 - u.sum() - v.sum() + \
                  gamma * (np.linalg.norm(v) ** 2 + np.linalg.norm(u) ** 2)

        d, e = newton_equil(wl * kernel_mat * wr, gamma, 100)
        tmp = d * wl * kernel_mat * wr * e
        u, v = np.log(d), np.log(e)
        sltn_val = np.square(tmp).sum() / 2 - u.sum() - v.sum() + \
                   gamma * (np.linalg.norm(v) ** 2 + np.linalg.norm(u) ** 2)
        self.assertAlmostEqual((obj_val - sltn_val) / sltn_val, 0, places=3)
def solver(f, x0, metric, cnn_func, elemental):
    """
    Solves the demosaicking problem for the given input.

    :param f: Corrupted input image
    :type f: np.ndarray
    :param x0: Predemosaicked initialization image
    :type x0: np.ndarray
    :param metric: Preinitialized metric
    :type metric: proximal.utils.metrics
    :param cnn_func: Preinitialized deployment CNN
    :type cnn_func: function
    :param elemental: General experiment configuration parameters
    :type elemental: Dict

    :returns: Reconstructed output image
    :rtype: np.ndarray
    """
    # pylint:disable=no-value-for-parameter
    options = px.cg_options(tol=1e-4, num_iters=100, verbose=True)

    u = px.Variable(f.shape)
    A = bayer_mask(f.shape)
    A_u = px.mul_elemwise(A, u)

    alpha_sumsquare = elemental['alpha_data'] / 2.0
    data = px.sum_squares(A_u - f, alpha=alpha_sumsquare)

    prox_fns = data
    if elemental['alpha_tv'] > 0.0:
        prox_fns += px.norm1(elemental['alpha_tv'] * px.grad(u, dims=2))

    if elemental['alpha_cross'] > 0.0:
        grad_u = px.grad(u, dims=2)
        grad_x0 = px.grad(x0, dims=2).value
        x0_stacked = np.array([x0, x0]).reshape(x0.shape + (2, ))
        u_stacked = px.reshape(px.hstack([u, u]), x0.shape + (2, ))
        cross_1 = px.vstack([
            px.mul_elemwise(np.roll(x0_stacked, 1, 2), grad_u),
            px.mul_elemwise(np.roll(x0_stacked, 2, 2), grad_u)
        ])
        cross_2 = px.vstack([
            px.mul_elemwise(np.roll(grad_x0, 1, 2), u_stacked),
            px.mul_elemwise(np.roll(grad_x0, 2, 2), u_stacked)
        ])

        prox_fns += px.norm1(0.5 * elemental['alpha_cross'] *
                             (cross_1 - cross_2))

    prox_fns += init_denoising_prior(u,
                                     cnn_func,
                                     sigma=elemental['sigma'],
                                     sigma_scale=elemental['sigma_scale'])

    prob = init_problem(prox_fns)
    solve_problem(prob,
                  x0=x0,
                  metric=metric,
                  sigma=elemental['sigma'],
                  lin_solver_options=options)
    return np.clip(u.value, 0.0, 1.0)
Example #3
0
def bayerify_proximal(x, mask):
    #return proximal.sum([proximal.mul_elemwise(r, x), proximal.mul_elemwise(g, x), proximal.mul_elemwise(b, x)])
    return proximal.mul_elemwise(
        mask, x)  #+ proximal.mul_elemwise(g,x) + proximal.mul_elemwise(b,x)