Exemple #1
0
    def __init__(self, lin_op, offset, diag=None, freq_diag=None,
                 freq_dims=None, implem=Impl['numpy'], **kwargs):
        self.K = CompGraph(lin_op)
        self.offset = offset
        self.diag = diag
        # TODO: freq diag is supposed to be True/False. What is going on below?
        self.freq_diag = freq_diag
        self.orig_freq_diag = freq_diag
        self.freq_dims = freq_dims
        self.orig_freq_dims = freq_dims
        # Get shape for frequency inversion var
        if self.freq_diag is not None:
            if len(self.K.orig_end.variables()) > 1:
                raise Exception("Diagonal frequency inversion supports only one var currently.")

            self.freq_shape = self.K.orig_end.variables()[0].shape
            self.freq_diag = np.reshape(self.freq_diag, self.freq_shape)
            if implem == Impl['halide'] and \
                    (len(self.freq_shape) == 2 or (len(self.freq_shape) == 2 and
                                                   self.freq_dims == 2)):
                # TODO: FIX REAL TO IMAG
                hsize = self.freq_shape if len(self.freq_shape) == 3 else (
                    self.freq_shape[0], self.freq_shape[1], 1)
                hsizehalide = ((hsize[0] + 1) / 2 + 1, hsize[1], hsize[2], 2)

                self.hsizehalide = hsizehalide
                self.ftmp_halide = np.zeros(hsizehalide, dtype=np.float32, order='F')
                self.ftmp_halide_out = np.zeros(hsize, dtype=np.float32, order='F')
                self.freq_diag = np.reshape(self.freq_diag[0:hsizehalide[0], ...],
                                            hsizehalide[0:3])

        super(least_squares, self).__init__(lin_op, implem=implem, **kwargs)
Exemple #2
0
    def test_combo(self):
        """Test subsampling followed by convolution.
        """
        # Forward.
        var = Variable((2, 3))
        kernel = np.array([[1, 2, 3]])  # 2x3
        fn = vstack([conv(kernel, subsample(var, (2, 1)))])
        fn = CompGraph(fn)
        x = np.arange(6) * 1.0
        x = np.reshape(x, (2, 3))
        out = np.zeros(fn.output_size)
        fn.forward(x.flatten(), out)
        y = np.zeros((1, 3))

        xsub = x[::2, ::1]
        y = ndimage.convolve(xsub, kernel, mode='wrap')

        self.assertItemsAlmostEqual(np.reshape(out, y.shape), y)

        # Adjoint.
        x = np.arange(3) * 1.0
        x = np.reshape(x, (1, 3))
        out = np.zeros(var.size)
        fn.adjoint(x.flatten(), out)

        y = ndimage.correlate(x, kernel, mode='wrap')
        y2 = np.zeros((2, 3))
        y2[::2, :] = y

        self.assertItemsAlmostEqual(np.reshape(out, y2.shape), y2)
        out = np.zeros(var.size)
        fn.adjoint(x.flatten(), out)
        self.assertItemsAlmostEqual(np.reshape(out, y2.shape), y2)
Exemple #3
0
    def __init__(self, lin_op, offset, diag=None, freq_diag=None,
                 freq_dims=None, implem=Impl['numpy'], **kwargs):
        self.K = CompGraph(lin_op)
        self.offset = offset
        self.diag = diag
        # TODO: freq diag is supposed to be True/False. What is going on below?
        self.freq_diag = freq_diag
        self.orig_freq_diag = freq_diag
        self.freq_dims = freq_dims
        self.orig_freq_dims = freq_dims
        # Get shape for frequency inversion var
        if self.freq_diag is not None:
            if len(self.K.orig_end.variables()) > 1:
                raise Exception("Diagonal frequency inversion supports only one var currently.")

            self.freq_shape = self.K.orig_end.variables()[0].shape
            self.freq_diag = np.reshape(self.freq_diag, self.freq_shape)
            if implem == Impl['halide'] and \
                    (len(self.freq_shape) == 2 or (len(self.freq_shape) == 2 and
                                                   self.freq_dims == 2)):
                print "hello"
                # TODO: FIX REAL TO IMAG
                hsize = self.freq_shape if len(self.freq_shape) == 3 else (
                    self.freq_shape[0], self.freq_shape[1], 1)
                hsizehalide = ((hsize[0] + 1) / 2 + 1, hsize[1], hsize[2], 2)

                self.hsizehalide = hsizehalide
                self.ftmp_halide = np.zeros(hsizehalide, dtype=np.float32, order='F')
                self.ftmp_halide_out = np.zeros(hsize, dtype=np.float32, order='F')
                self.freq_diag = np.reshape(self.freq_diag[0:hsizehalide[0], ...],
                                            hsizehalide[0:3])

        super(least_squares, self).__init__(lin_op, implem=implem, **kwargs)
Exemple #4
0
    def solve(self, solver=None, *args, **kwargs):
        if solver is None:
            solver = self.solver

        if len(self.omega_fns + self.psi_fns) == 0:
            prox_fns = self.prox_fns
        else:
            prox_fns = self.omega_fns + self.psi_fns
        # Absorb lin ops if desired.
        if self.absorb:
            prox_fns = absorb.absorb_all_lin_ops(prox_fns)

        # Merge prox fns.
        if self.merge:
            prox_fns = merge.merge_all(prox_fns)
        # Absorb offsets.
        prox_fns = [absorb.absorb_offset(fn) for fn in prox_fns]
        # TODO more analysis of what solver to use.
        # Short circuit with one function.
        if len(prox_fns) == 1 and type(prox_fns[0].lin_op) == Variable:
            fn = prox_fns[0]
            var = fn.lin_op
            var.value = fn.prox(0, np.zeros(fn.lin_op.shape))
            return fn.value
        elif solver in NAME_TO_SOLVER:
            module = NAME_TO_SOLVER[solver]
            if len(self.omega_fns + self.psi_fns) == 0:
                if self.try_split and len(prox_fns) > 1 and len(self.variables()) == 1:
                    psi_fns, omega_fns = module.partition(prox_fns,
                                                          self.try_diagonalize)
                else:
                    psi_fns = prox_fns
                    omega_fns = []
            # Scale the problem.
            if self.scale:
                K = CompGraph(vstack([fn.lin_op for fn in psi_fns]),
                              implem=self.implem)
                Knorm = est_CompGraph_norm(K, try_fast_norm=self.try_fast_norm)
                for idx, fn in enumerate(psi_fns):
                    psi_fns[idx] = fn.copy(fn.lin_op / Knorm,
                                           beta=fn.beta * np.sqrt(Knorm),
                                           implem=self.implem)
                for idx, fn in enumerate(omega_fns):
                    omega_fns[idx] = fn.copy(beta=fn.beta / np.sqrt(Knorm),
                                             implem=self.implem)
            opt_val = module.solve(psi_fns, omega_fns,
                                   lin_solver=self.lin_solver,
                                   try_diagonalize=self.try_diagonalize,
                                   try_fast_norm=self.try_fast_norm,
                                   scaled=self.scale,
                                   *args, **kwargs)
            # Unscale the variables.
            if self.scale:
                for var in self.variables():
                    var.value /= np.sqrt(Knorm)
            return opt_val
        else:
            raise Exception("Unknown solver.")
    def test_combo(self):
        """Test subsampling followed by convolution.
        """
        # Forward.
        var = Variable((2, 3))
        kernel = np.array([[1, 2, 3]])  # 2x3
        fn = vstack([conv(kernel, subsample(var, (2, 1)))])
        fn = CompGraph(fn)
        x = np.arange(6) * 1.0
        x = np.reshape(x, (2, 3))
        out = np.zeros(fn.output_size)
        fn.forward(x.flatten(), out)
        y = np.zeros((1, 3))

        xsub = x[::2, ::1]
        y = ndimage.convolve(xsub, kernel, mode='wrap')

        self.assertItemsAlmostEqual(np.reshape(out, y.shape), y)

        # Adjoint.
        x = np.arange(3) * 1.0
        x = np.reshape(x, (1, 3))
        out = np.zeros(var.size)
        fn.adjoint(x.flatten(), out)

        y = ndimage.correlate(x, kernel, mode='wrap')
        y2 = np.zeros((2, 3))
        y2[::2, :] = y

        self.assertItemsAlmostEqual(np.reshape(out, y2.shape), y2)
        out = np.zeros(var.size)
        fn.adjoint(x.flatten(), out)
        self.assertItemsAlmostEqual(np.reshape(out, y2.shape), y2)
Exemple #6
0
class least_squares(sum_squares):
    """The function ||K*x||_2^2.

       Here K is a computation graph (vector to vector lin op).
    """

    def __init__(self, lin_op, offset, diag=None, freq_diag=None,
                 freq_dims=None, implem=Impl['numpy'], **kwargs):
        self.K = CompGraph(lin_op)
        self.offset = offset
        self.diag = diag
        # TODO: freq diag is supposed to be True/False. What is going on below?
        self.freq_diag = freq_diag
        self.orig_freq_diag = freq_diag
        self.freq_dims = freq_dims
        self.orig_freq_dims = freq_dims
        # Get shape for frequency inversion var
        if self.freq_diag is not None:
            if len(self.K.orig_end.variables()) > 1:
                raise Exception("Diagonal frequency inversion supports only one var currently.")

            self.freq_shape = self.K.orig_end.variables()[0].shape
            self.freq_diag = np.reshape(self.freq_diag, self.freq_shape)
            if implem == Impl['halide'] and \
                    (len(self.freq_shape) == 2 or (len(self.freq_shape) == 2 and
                                                   self.freq_dims == 2)):
                print "hello"
                # TODO: FIX REAL TO IMAG
                hsize = self.freq_shape if len(self.freq_shape) == 3 else (
                    self.freq_shape[0], self.freq_shape[1], 1)
                hsizehalide = ((hsize[0] + 1) / 2 + 1, hsize[1], hsize[2], 2)

                self.hsizehalide = hsizehalide
                self.ftmp_halide = np.zeros(hsizehalide, dtype=np.float32, order='F')
                self.ftmp_halide_out = np.zeros(hsize, dtype=np.float32, order='F')
                self.freq_diag = np.reshape(self.freq_diag[0:hsizehalide[0], ...],
                                            hsizehalide[0:3])

        super(least_squares, self).__init__(lin_op, implem=implem, **kwargs)

    def get_data(self):
        """Returns info needed to reconstruct the object besides the args.

        Returns
        -------
        list
        """
        return [self.offset, self.diag, self.orig_freq_diag, self.orig_freq_dims]

    def _prox(self, rho, v, b=None, lin_solver="cg", *args, **kwargs):
        """x = argmin_x ||K*x - self.offset - b||_2^2 + (rho/2)||x-v||_2^2.
        """
        if b is None:
            offset = self.offset
        else:
            offset = self.offset + b
        return self.solve(offset, rho=rho, v=v, lin_solver=lin_solver, *args, **kwargs)

    def _eval(self, v):
        """Evaluate the function on v (ignoring parameters).
        """
        Kv = np.zeros(self.K.output_size)
        self.K.forward(v.ravel(), Kv)
        return super(least_squares, self)._eval(Kv - self.offset)

    def solve(self, b, rho=None, v=None, lin_solver="lsqr", *args, **kwargs):
        # KtK Operator is diagonal
        if self.diag is not None:

            Ktb = np.zeros(self.K.input_size)
            self.K.adjoint(b, Ktb)
            if rho is None:
                Ktb /= self.diag
            else:
                Ktb += (rho / 2.) * v
                Ktb /= (self.diag + rho / 2.)

            return Ktb

        # KtK operator is diagonal in frequency domain.
        elif self.freq_diag is not None:
            Ktb = np.zeros(self.K.input_size)
            self.K.adjoint(b, Ktb)

            # Frequency inversion
            if self.implementation == Impl['halide'] and \
                    (len(self.freq_shape) == 2 or
                     (len(self.freq_shape) == 2 and self.freq_dims == 2)):

                Halide('fft2_r2c.cpp').fft2_r2c(np.asfortranarray(np.reshape(
                    Ktb.astype(np.float32), self.freq_shape)), 0, 0, self.ftmp_halide)

                Ktb = 1j * self.ftmp_halide[..., 1]
                Ktb += self.ftmp_halide[..., 0]

                if rho is None:
                    Ktb /= self.freq_diag
                else:
                    Halide('fft2_r2c.cpp').fft2_r2c(np.asfortranarray(np.reshape(
                        v.astype(np.float32), self.freq_shape)), 0, 0, self.ftmp_halide)

                    vhat = self.ftmp_halide[..., 0] + 1j * self.ftmp_halide[..., 1]
                    Ktb *= 1.0 / rho
                    Ktb += vhat
                    Ktb /= (1.0 / rho * self.freq_diag + 1.0)

                # Do inverse tranform
                Ktb = np.asfortranarray(np.stack((Ktb.real, Ktb.imag), axis=-1))
                Halide('ifft2_c2r.cpp').ifft2_c2r(Ktb, self.ftmp_halide_out)

                return self.ftmp_halide_out.ravel()

            else:

                # General frequency inversion
                Ktb = fftd(np.reshape(Ktb, self.freq_shape), self.freq_dims)

                if rho is None:
                    Ktb /= self.freq_diag
                else:
                    Ktb *= 2.0 / rho
                    Ktb += fftd(np.reshape(v, self.freq_shape), self.freq_dims)
                    Ktb /= (2.0 / rho * self.freq_diag + 1.0)

                return (ifftd(Ktb, self.freq_dims).real).ravel()

        elif lin_solver == "lsqr":
            return self.solve_lsqr(b, rho, v, *args, **kwargs)
        elif lin_solver == "cg":
            return self.solve_cg(b, rho, v, *args, **kwargs)
        else:
            raise Exception("Unknown least squares solver.")

    def solve_lsqr(self, b, rho=None, v=None, x_init=None, options=None):
        """Solve ||K*x - b||^2_2 + (rho/2)||x-v||_2^2.
        """

        # Add additional linear terms for the rho terms
        sizev = 0
        if rho is not None:
            vf = v.flatten() * np.sqrt(rho / 2.0)
            sizeb = self.K.input_size
            sizev = np.prod(v.shape)
            b = np.hstack((b, vf))

        input_data = np.zeros(self.K.input_size)
        output_data = np.zeros(self.K.output_size + sizev)

        def matvec(x, output_data):
            if rho is None:
                # Traverse compgraph
                self.K.forward(x, output_data)
            else:
                # Compgraph and additional terms
                self.K.forward(x, output_data[0:0 + sizeb])
                np.copyto(output_data[sizeb:sizeb + sizev], x * np.sqrt(rho / 2.0))

            return output_data

        def rmatvec(y, input_data):
            if rho is None:
                self.K.adjoint(y, input_data)
            else:
                self.K.adjoint(y[0:0 + sizeb], input_data)
                input_data += y[sizeb:sizeb + sizev] * np.sqrt(rho / 2.0)

            return input_data

        # Define linear operator
        def matvecComp(x): return matvec(x, output_data)

        def rmatvecComp(y): return rmatvec(y, input_data)

        K = LinearOperator((self.K.output_size + sizev, self.K.input_size),
                           matvecComp, rmatvecComp)

        # Options
        if options is None:
            # Default options
            return lsqr(K, b)[0]
        else:
            if not isinstance(options, lsqr_options):
                raise Exception("Invalid LSQR options.")
            return lsqr(K, b, atol=options.atol, btol=options.btol,
                        show=options.show, iter_lim=options.iter_lim)[0]

    def solve_cg(self, b, rho=None, v=None, x_init=None, options=None):
        """Solve ||K*x - b||^2_2 + (rho/2)||x-v||_2^2.
        """
        output_data = np.zeros(self.K.output_size)

        def KtK(x, r):
            self.K.forward(x, output_data)
            self.K.adjoint(output_data, r)
            if rho is not None:
                r += rho * x
            return r

        # Compute Ktb
        Ktb = np.zeros(self.K.input_size)
        self.K.adjoint(b, Ktb)
        if rho is not None:
            Ktb += rho * v

        # Options
        if options is None:
            # Default options
            options = cg_options()
        elif not isinstance(options, cg_options):
            raise Exception("Invalid CG options.")

        return cg(KtK, Ktb, options.tol, options.num_iters,
                  options.verbose, x_init, self.implementation)
    def test_op_overloading(self):
        """Test operator overloading.
        """
        # Multiplying by a scalar.

        # Forward.
        var = Variable((2, 5))
        W = np.arange(10)
        W = np.reshape(W, (2, 5))
        fn = -2 * mul_elemwise(W, var)
        fn = CompGraph(fn)
        x = W.copy()
        out = np.zeros(x.shape)
        fn.forward(x.flatten(), out)
        self.assertItemsAlmostEqual(out, -2 * W * W)

        # Adjoint.
        x = W.copy()
        out = np.zeros(x.shape).flatten()
        fn.adjoint(x, out)
        self.assertItemsAlmostEqual(out, -2 * W * W)

        # Forward.
        var = Variable((2, 5))
        W = np.arange(10)
        W = np.reshape(W, (2, 5))
        fn = mul_elemwise(W, var) * 0.5
        fn = CompGraph(fn)
        x = W.copy()
        out = np.zeros(x.shape)
        fn.forward(x.flatten(), out)
        self.assertItemsAlmostEqual(out, W * W / 2.)

        # Adjoint.
        x = W.copy()
        out = np.zeros(x.shape).flatten()
        fn.adjoint(x, out)
        self.assertItemsAlmostEqual(out, W * W / 2.)

        # Dividing by a scalar.
        # Forward.
        var = Variable((2, 5))
        W = np.arange(10)
        W = np.reshape(W, (2, 5))
        fn = mul_elemwise(W, var) / 2
        fn = CompGraph(fn)
        x = W.copy()
        out = np.zeros(x.shape)
        fn.forward(x, out)
        self.assertItemsAlmostEqual(out, W * W / 2.)

        # Adding lin ops.
        # Forward.
        x = Variable((2, 5))
        W = np.arange(10)
        W = np.reshape(W, (2, 5))
        fn = mul_elemwise(W, x)
        fn = fn + x + x
        self.assertEquals(len(fn.input_nodes), 3)
        fn = CompGraph(fn)
        x = W.copy()
        out = np.zeros(fn.shape)
        fn.forward(x, out)
        self.assertItemsAlmostEqual(out, W * W + 2 * W)

        # Adjoint.
        x = W.copy()
        out = np.zeros(x.shape).flatten()
        fn.adjoint(x, out)
        self.assertItemsAlmostEqual(out, W * W + 2 * W)

        # Adding in a constant.
        # CompGraph should ignore the constant.
        x = Variable((2, 5))
        W = np.arange(10)
        W = np.reshape(W, (2, 5))
        fn = mul_elemwise(W, x)
        fn = fn + x + W
        self.assertEquals(len(fn.input_nodes), 3)
        fn = CompGraph(fn)
        x = W.copy()
        out = np.zeros(fn.shape)
        fn.forward(x, out)
        self.assertItemsAlmostEqual(out, W * W + W)

       # Subtracting lin ops.
        # Forward.
        x = Variable((2, 5))
        W = np.arange(10)
        W = np.reshape(W, (2, 5))
        fn = -mul_elemwise(W, x)
        fn = x + x - fn
        self.assertEquals(len(fn.input_nodes), 3)
        fn = CompGraph(fn)
        x = W.copy()
        out = np.zeros(fn.shape)
        fn.forward(x, out)
        self.assertItemsAlmostEqual(out, W * W + 2 * W)

        # Adjoint.
        x = W.copy()
        out = np.zeros(x.shape).flatten()
        fn.adjoint(x, out)
        self.assertItemsAlmostEqual(out, W * W + 2 * W)
Exemple #8
0
def solve(psi_fns, omega_fns, tau=None, sigma=None, theta=None,
          max_iters=1000, eps_abs=1e-3, eps_rel=1e-3, x0=None,
          lin_solver="cg", lin_solver_options=None,
          try_diagonalize=True, try_fast_norm=False, scaled=True,
          metric=None, convlog=None, verbose=0):
    # Can only have one omega function.
    assert len(omega_fns) <= 1
    prox_fns = psi_fns + omega_fns
    stacked_ops = vstack([fn.lin_op for fn in psi_fns])
    K = CompGraph(stacked_ops)
    v = np.zeros(K.input_size)
    # Select optimal parameters if wanted
    if tau is None or sigma is None or theta is None:
        tau, sigma, theta = est_params_pc(K, tau, sigma, verbose, scaled, try_fast_norm)

    # Initialize
    x = np.zeros(K.input_size)
    y = np.zeros(K.output_size)
    xbar = np.zeros(K.input_size)
    u = np.zeros(K.output_size)
    z = np.zeros(K.output_size)

    if x0 is not None:
        x[:] = np.reshape(x0, K.input_size)
        K.forward(x, y)
        xbar[:] = x

    # Buffers.
    Kxbar = np.zeros(K.output_size)
    Kx = np.zeros(K.output_size)
    KTy = np.zeros(K.input_size)
    KTu = np.zeros(K.input_size)
    s = np.zeros(K.input_size)

    prev_x = x.copy()
    prev_Kx = Kx.copy()
    prev_z = z.copy()
    prev_u = u.copy()

    # Log for prox ops.
    prox_log = TimingsLog(prox_fns)
    # Time iterations.
    iter_timing = TimingsEntry("PC iteration")

    # Convergence log for initial iterate
    if convlog is not None:
        K.update_vars(x)
        objval = sum([fn.value for fn in prox_fns])
        convlog.record_objective(objval)
        convlog.record_timing(0.0)

    for i in range(max_iters):
        iter_timing.tic()
        if convlog is not None:
            convlog.tic()

        # Keep track of previous iterates
        np.copyto(prev_x, x)
        np.copyto(prev_z, z)
        np.copyto(prev_u, u)
        np.copyto(prev_Kx, Kx)

        # Compute z
        K.forward(xbar, Kxbar)
        z = y + sigma * Kxbar

        # Update y.
        offset = 0
        for fn in psi_fns:
            slc = slice(offset, offset + fn.lin_op.size, None)
            z_slc = np.reshape(z[slc], fn.lin_op.shape)

            # Moreau identity: apply and time prox.
            prox_log[fn].tic()
            y[slc] = (z_slc - sigma * fn.prox(sigma, z_slc / sigma, i)).flatten()
            prox_log[fn].toc()
            offset += fn.lin_op.size
        y[offset:] = 0

        # Update x
        K.adjoint(y, KTy)
        x -= tau * KTy

        if len(omega_fns) > 0:
            xtmp = np.reshape(x, omega_fns[0].lin_op.shape)
            x[:] = omega_fns[0].prox(1.0 / tau, xtmp, x_init=prev_x,
                                     lin_solver=lin_solver, options=lin_solver_options).flatten()

        # Update xbar
        np.copyto(xbar, x)
        xbar += theta * (x - prev_x)

        # Convergence log
        if convlog is not None:
            convlog.toc()
            K.update_vars(x)
            objval = sum([fn.value for fn in prox_fns])
            convlog.record_objective(objval)

        """ Old convergence check
        #Very basic convergence check.
        r_x = np.linalg.norm(x - prev_x)
        r_xbar = np.linalg.norm(xbar - prev_xbar)
        r_ybar = np.linalg.norm(y - prev_y)
        error = r_x + r_xbar + r_ybar
        """

        # Residual based convergence check
        K.forward(x, Kx)
        u = 1.0 / sigma * y + theta * (Kx - prev_Kx)
        z = prev_u + prev_Kx - 1.0 / sigma * y

        # Iteration order is different than
        # lin-admm (--> start checking at iteration 1)
        if i > 0:

            # Check convergence
            r = prev_Kx - z
            K.adjoint(sigma * (z - prev_z), s)
            eps_pri = np.sqrt(K.output_size) * eps_abs + eps_rel * \
                max([np.linalg.norm(prev_Kx), np.linalg.norm(z)])

            K.adjoint(u, KTu)
            eps_dual = np.sqrt(K.input_size) * eps_abs + eps_rel * np.linalg.norm(KTu) / sigma

            # Progress
            if verbose > 0:
                # Evaluate objective only if required (expensive !)
                objstr = ''
                if verbose == 2:
                    K.update_vars(x)
                    objstr = ", obj_val = %02.03e" % sum([fn.value for fn in prox_fns])

                """ Old convergence check
                #Evaluate metric potentially
                metstr = '' if metric is None else ", {}".format( metric.message(x.copy()) )
                print "iter [%04d]:" \
                      "||x - x_prev||_2 = %02.02e " \
                      "||xbar - xbar_prev||_2 = %02.02e " \
                      "||y - y_prev||_2 = %02.02e " \
                      "SUM = %02.02e (eps=%02.03e)%s%s" \
                        % (i, r_x, r_xbar, r_ybar, error, eps, objstr, metstr)
                """

                # Evaluate metric potentially
                metstr = '' if metric is None else ", {}".format(metric.message(v))
                print(
                    "iter %d: ||r||_2 = %.3f, eps_pri = %.3f, ||s||_2 = %.3f, eps_dual = %.3f%s%s"
                    % (i, np.linalg.norm(r), eps_pri, np.linalg.norm(s), eps_dual, objstr, metstr)
                )

            iter_timing.toc()
            if np.linalg.norm(r) <= eps_pri and np.linalg.norm(s) <= eps_dual:
                break

        else:
            iter_timing.toc()

        """ Old convergence check
        if error <= eps:
            break
        """

    # Print out timings info.
    if verbose > 0:
        print iter_timing
        print "prox funcs:"
        print prox_log
        print "K forward ops:"
        print K.forward_log
        print "K adjoint ops:"
        print K.adjoint_log

    # Assign values to variables.
    K.update_vars(x)

    # Return optimal value.
    return sum([fn.value for fn in prox_fns])
Exemple #9
0
    def solve(self,
              solver=None,
              test_adjoints=False,
              test_norm=False,
              show_graph=False,
              *args,
              **kwargs):
        if solver is None:
            solver = self.solver

        if len(self.omega_fns + self.psi_fns) == 0:
            prox_fns = self.prox_fns
        else:
            prox_fns = self.omega_fns + self.psi_fns
        # Absorb lin ops if desired.
        if self.absorb:
            prox_fns = absorb.absorb_all_lin_ops(prox_fns)

        # Merge prox fns.
        if self.merge:
            prox_fns = merge.merge_all(prox_fns)
        # Absorb offsets.
        prox_fns = [absorb.absorb_offset(fn) for fn in prox_fns]
        # TODO more analysis of what solver to use.

        if show_graph:
            print("Computational graph before optimizing:")
            graph_visualize(
                prox_fns,
                filename=show_graph if type(show_graph) is str else None)

        # Short circuit with one function.
        if len(prox_fns) == 1 and type(prox_fns[0].lin_op) == Variable:
            fn = prox_fns[0]
            var = fn.lin_op
            var.value = fn.prox(0, np.zeros(fn.lin_op.shape))
            return fn.value
        elif solver in NAME_TO_SOLVER:
            module = NAME_TO_SOLVER[solver]
            if len(self.omega_fns + self.psi_fns) == 0:
                if self.try_split and len(prox_fns) > 1 and len(
                        self.variables()) == 1:
                    psi_fns, omega_fns = module.partition(
                        prox_fns, self.try_diagonalize)
                else:
                    psi_fns = prox_fns
                    omega_fns = []
            else:
                psi_fns = self.psi_fns
                omega_fns = self.omega_fns
            if test_norm:
                L = CompGraph(vstack([fn.lin_op for fn in psi_fns]))
                from numpy.random import random

                output_mags = [NotImplemented]
                L.norm_bound(output_mags)
                if not NotImplemented in output_mags:
                    assert len(output_mags) == 1

                    x = random(L.input_size)
                    x = x / LA.norm(x)
                    y = np.zeros(L.output_size)
                    y = L.forward(x, y)
                    ny = LA.norm(y)
                    nL2 = est_CompGraph_norm(L, try_fast_norm=False)
                    if ny > output_mags[0]:
                        raise RuntimeError("wrong implementation of norm!")
                    print("%.3f <= ||K|| = %.3f (%.3f)" %
                          (ny, output_mags[0], nL2))

            # Scale the problem.
            if self.scale:
                K = CompGraph(vstack([fn.lin_op for fn in psi_fns]),
                              implem=self.implem)
                Knorm = est_CompGraph_norm(K, try_fast_norm=self.try_fast_norm)
                for idx, fn in enumerate(psi_fns):
                    psi_fns[idx] = fn.copy(fn.lin_op / Knorm,
                                           beta=fn.beta * np.sqrt(Knorm),
                                           implem=self.implem)
                for idx, fn in enumerate(omega_fns):
                    omega_fns[idx] = fn.copy(beta=fn.beta / np.sqrt(Knorm),
                                             implem=self.implem)
                for v in K.orig_end.variables():
                    if v.initval is not None:
                        v.initval *= np.sqrt(Knorm)
            if not test_adjoints in [False, None]:
                if test_adjoints is True:
                    test_adjoints = 1e-6
                # test adjoints
                L = CompGraph(vstack([fn.lin_op for fn in psi_fns]))
                from numpy.random import random

                x = random(L.input_size)
                yt = np.zeros(L.output_size)
                #print("x=", x)
                yt = L.forward(x, yt)
                #print("yt=", yt)
                #print("x=", x)
                y = random(L.output_size)
                #print("y=", y)
                xt = np.zeros(L.input_size)
                xt = L.adjoint(y, xt)
                #print("xt=", xt)
                #print("y=", y)
                r = np.abs(
                    np.dot(np.ravel(y), np.ravel(yt)) -
                    np.dot(np.ravel(x), np.ravel(xt)))
                #print( x.shape, y.shape, xt.shape, yt.shape)
                if r > test_adjoints:
                    #print("yt=", yt)
                    #print("y =", y)
                    #print("xt=", xt)
                    #print("x =", x)
                    raise RuntimeError("Unmatched adjoints: " + str(r))
                else:
                    print("Adjoint test passed.", r)

            if self.implem == Impl['pycuda']:
                kwargs['adapter'] = PyCudaAdapter()
            opt_val = module.solve(psi_fns,
                                   omega_fns,
                                   lin_solver=self.lin_solver,
                                   try_diagonalize=self.try_diagonalize,
                                   try_fast_norm=self.try_fast_norm,
                                   scaled=self.scale,
                                   *args,
                                   **kwargs)
            # Unscale the variables.
            if self.scale:
                for var in self.variables():
                    var.value /= np.sqrt(Knorm)
            return opt_val
        else:
            raise Exception("Unknown solver.")
Exemple #10
0
def solve(psi_fns, omega_fns, lmb=1.0, mu=None, quad_funcs=None,
          max_iters=1000, eps_abs=1e-3, eps_rel=1e-3,
          lin_solver="cg", lin_solver_options=None,
          try_diagonalize=True, try_fast_norm=True, scaled=False,
          metric=None, convlog=None, verbose=0):

    # Can only have one omega function.
    assert len(omega_fns) <= 1
    prox_fns = psi_fns + omega_fns
    stacked_ops = vstack([fn.lin_op for fn in psi_fns])
    K = CompGraph(stacked_ops)
    # Select optimal parameters if wanted
    if lmb is None or mu is None:
        lmb, mu = est_params_lin_admm(K, lmb, verbose, scaled, try_fast_norm)

    # Initialize everything to zero.
    v = np.zeros(K.input_size)
    z = np.zeros(K.output_size)
    u = np.zeros(K.output_size)

    # Buffers.
    Kv = np.zeros(K.output_size)
    KTu = np.zeros(K.input_size)
    s = np.zeros(K.input_size)

    Kvzu = np.zeros(K.output_size)
    v_prev = np.zeros(K.input_size)
    z_prev = np.zeros(K.output_size)

    # Log for prox ops.
    prox_log = TimingsLog(prox_fns)
    # Time iterations.
    iter_timing = TimingsEntry("LIN-ADMM iteration")
    # Convergence log for initial iterate
    if convlog is not None:
        K.update_vars(v)
        objval = sum([fn.value for fn in prox_fns])
        convlog.record_objective(objval)
        convlog.record_timing(0.0)

    for i in range(max_iters):
        iter_timing.tic()
        if convlog is not None:
            convlog.tic()

        v_prev[:] = v
        z_prev[:] = z

        # Update v
        K.forward(v, Kv)
        Kvzu[:] = Kv - z + u
        K.adjoint(Kvzu, v)
        v[:] = v_prev - (mu / lmb) * v

        if len(omega_fns) > 0:
            v[:] = omega_fns[0].prox(1.0 / mu, v, x_init=v_prev.copy(),
                                     lin_solver=lin_solver, options=lin_solver_options)

        # Update z.
        K.forward(v, Kv)
        Kv_u = Kv + u
        offset = 0
        for fn in psi_fns:
            slc = slice(offset, offset + fn.lin_op.size, None)
            Kv_u_slc = np.reshape(Kv_u[slc], fn.lin_op.shape)
            # Apply and time prox.
            prox_log[fn].tic()
            z[slc] = fn.prox(1.0 / lmb, Kv_u_slc, i).flatten()
            prox_log[fn].toc()
            offset += fn.lin_op.size

        # Update u.
        u += Kv - z
        K.adjoint(u, KTu)

        # Check convergence.
        r = Kv - z
        K.adjoint((1.0 / lmb) * (z - z_prev), s)
        eps_pri = np.sqrt(K.output_size) * eps_abs + eps_rel * \
            max([np.linalg.norm(Kv), np.linalg.norm(z)])
        eps_dual = np.sqrt(K.input_size) * eps_abs + eps_rel * np.linalg.norm(KTu) / (1.0 / lmb)

        # Convergence log
        if convlog is not None:
            convlog.toc()
            K.update_vars(v)
            objval = sum([fn.value for fn in prox_fns])
            convlog.record_objective(objval)

        # Show progess
        if verbose > 0:
            # Evaluate objective only if required (expensive !)
            objstr = ''
            if verbose == 2:
                K.update_vars(v)
                objstr = ", obj_val = %02.03e" % sum([fn.value for fn in prox_fns])

            # Evaluate metric potentially
            metstr = '' if metric is None else ", {}".format(metric.message(v))
            print "iter %d: ||r||_2 = %.3f, eps_pri = %.3f, ||s||_2 = %.3f, eps_dual = %.3f%s%s" % (
                i, np.linalg.norm(r), eps_pri, np.linalg.norm(s), eps_dual, objstr, metstr)

        iter_timing.toc()
        if np.linalg.norm(r) <= eps_pri and np.linalg.norm(s) <= eps_dual:
            break

    # Print out timings info.
    if verbose > 0:
        print iter_timing
        print "prox funcs:"
        print prox_log
        print "K forward ops:"
        print K.forward_log
        print "K adjoint ops:"
        print K.adjoint_log

    # Assign values to variables.
    K.update_vars(v)

    # Return optimal value.
    return sum([fn.value for fn in prox_fns])
Exemple #11
0
def solve(psi_fns,
          omega_fns,
          rho=1.0,
          max_iters=1000,
          eps_abs=1e-1,
          eps_rel=1e-3,
          x0=None,
          lin_solver="cg",
          lin_solver_options=None,
          try_diagonalize=True,
          try_fast_norm=False,
          scaled=True,
          metric=None,
          convlog=None,
          verbose=0):
    prox_fns = psi_fns + omega_fns
    stacked_ops = vstack([fn.lin_op for fn in psi_fns])
    K = CompGraph(stacked_ops)
    # Rescale so (rho/2)||x - b||^2_2
    rescaling = np.sqrt(2. / rho)
    quad_ops = []
    const_terms = []
    for fn in omega_fns:
        fn = fn.absorb_params()
        quad_ops.append(scale(rescaling * fn.beta, fn.lin_op))
        const_terms.append(fn.b.flatten() * rescaling)
    # Check for fast inverse.
    op_list = [func.lin_op for func in psi_fns] + quad_ops
    stacked_ops = vstack(op_list)

    # Get optimize inverse (tries spatial and frequency diagonalization)
    v_update = get_least_squares_inverse(op_list, None, try_diagonalize,
                                         verbose)

    # Initialize everything to zero.
    input_size = K.input_size
    output_size = K.output_size
    v = np.zeros(input_size)
    z = np.zeros(output_size)
    u = np.zeros(output_size)
    N_z = len(z[:])
    print(input_size)
    print(output_size)

    # Initialize
    if x0 is not None:
        v[:] = np.reshape(x0, input_size)
        K.forward(v, z)

    # Buffers.

    Kv = np.zeros(output_size)
    KTu = np.zeros(input_size)
    s = np.zeros(input_size)
    Kv_pre = Kv.copy()
    # Log for prox ops.
    prox_log = TimingsLog(prox_fns)
    # Time iterations.
    iter_timing = TimingsEntry("ADMM iteration")
    # Convergence log for initial iterate
    if convlog is not None:
        K.update_vars(v)
        objval = sum([func.value for func in prox_fns])
        convlog.record_objective(objval)
        convlog.record_timing(0.0)
    res_pre = 9e20
    res = 0

    curr_time = 0
    total_time = []
    Combine_res = []
    # ------------------------------------------------------------------------------------
    for i in range(max_iters):
        # iter_timing.tic()
        t1 = time.time()
        if convlog is not None:
            convlog.tic()
        K.forward(v, Kv)
        Kv_pre = Kv.copy()
        # z_prev = z.copy()
        # Update z.
        K.forward(v, Kv)
        Kv_u = Kv + u
        offset = 0
        for fn in psi_fns:
            tmp = np.hstack([z - u] + const_terms)
            v = v_update.solve(tmp,
                               x_init=v,
                               lin_solver=lin_solver,
                               options=lin_solver_options)
            K.forward(v, Kv)
            Kv_u = Kv + u
            slc = slice(offset, offset + fn.lin_op.size, None)
            Kv_u_slc = np.reshape(Kv_u[slc], fn.lin_op.shape)
            # Apply and time prox.
            z_pre = z.copy()
            prox_log[fn].tic()
            z[slc] = fn.prox(rho, Kv_u_slc, i).flatten()
            prox_log[fn].toc()
            offset += fn.lin_op.size
        # Update v.

        # Check convergence.
        r = Kv - z
        # Update u.
        u += r
        K.adjoint(u, KTu)

        # K.adjoint(rho * (z - z_prev), s)
        s = z - z_pre
        t2 = time.time()
        curr_time += t2 - t1

        res = np.linalg.norm(r)**2 + np.linalg.norm(s)**2

        # K.adjoint((z-z_prev),s)
        # eps_pri = np.sqrt(output_size) * eps_abs + eps_rel * \
        #   max([np.linalg.norm(Kv), np.linalg.norm(z)])
        # eps_dual = np.sqrt(input_size) * eps_abs + eps_rel * np.linalg.norm(KTu) / rho

        # Convergence log
        if convlog is not None:
            convlog.toc()
            K.update_vars(v)
            objval = sum([fn.value for fn in prox_fns])
            convlog.record_objective(objval)

        # Show progess
        if verbose > 0:
            # Evaluate objective only if required (expensive !)
            objstr = ''
            if verbose == 2:
                K.update_vars(v)
                objstr = ", obj_val = %02.03e" % sum(
                    [fn.value for fn in prox_fns])

            # Evaluate metric potentially
            metstr = '' if metric is None else ", {}".format(metric.message(v))
            # print("iter %d: ||r||_2 = %.3f, eps_pri = %.3f, ||s||_2 = %.3f, eps_dual = %.3f%s%s" % (
            #     i, np.linalg.norm(r), eps_pri, np.linalg.norm(s), eps_dual, objstr, metstr))
            print("iter %d: combine residual = %.8f" % (i, res))

        #curr_time = curr_time + iter_timing.toc()

        Combine_res.append(np.sqrt(rho * res / N_z))
        total_time.append(curr_time)
        # Exit if converged.
        if (res) < eps_abs:
            break

    # Print out timings info.
    if verbose > 0:
        print(iter_timing)
        print("prox funcs:")
        print(prox_log)
        print("K forward ops:")
        print(K.forward_log)
        print("K adjoint ops:")
        print(K.adjoint_log)

    # Assign values to variables.
    K.update_vars(v)
    # Return optimal value.
    # return sum([fn.value for fn in prox_fns])
    return total_time, Combine_res
Exemple #12
0
def solve(psi_fns,
          omega_fns,
          rho_0=1.0,
          rho_scale=math.sqrt(2.0) * 2.0,
          rho_max=2**8,
          max_iters=-1,
          max_inner_iters=100,
          x0=None,
          eps_rel=1e-3,
          eps_abs=1e-3,
          lin_solver="cg",
          lin_solver_options=None,
          try_diagonalize=True,
          scaled=False,
          try_fast_norm=False,
          metric=None,
          convlog=None,
          verbose=0):
    prox_fns = psi_fns + omega_fns
    stacked_ops = vstack([fn.lin_op for fn in psi_fns])
    K = CompGraph(stacked_ops)
    # Rescale so (1/2)||x - b||^2_2
    rescaling = np.sqrt(2.)
    quad_ops = []
    quad_weights = []
    const_terms = []
    for fn in omega_fns:
        fn = fn.absorb_params()
        quad_ops.append(scale(rescaling * fn.beta, fn.lin_op))
        quad_weights.append(rescaling * fn.beta)
        const_terms.append(fn.b.flatten() * rescaling)

    # Get optimize inverse (tries spatial and frequency diagonalization)
    op_list = [func.lin_op for func in psi_fns] + quad_ops
    stacked_ops = vstack(op_list)
    x_update = get_least_squares_inverse(op_list, None, try_diagonalize,
                                         verbose)

    # Initialize
    if x0 is not None:
        x = np.reshape(x0, K.input_size)
    else:
        x = np.zeros(K.input_size)

    Kx = np.zeros(K.output_size)
    w = Kx.copy()

    # Temporary iteration counts
    x_prev = x.copy()

    # Log for prox ops.
    prox_log = TimingsLog(prox_fns)
    # Time iterations.
    iter_timing = TimingsEntry("HQS iteration")
    inner_iter_timing = TimingsEntry("HQS inner iteration")
    # Convergence log for initial iterate
    if convlog is not None:
        K.update_vars(x)
        objval = sum([func.value for func in prox_fns])
        convlog.record_objective(objval)
        convlog.record_timing(0.0)

    # Rho scedule
    rho = rho_0
    i = 0
    while rho < rho_max and i < max_iters:
        iter_timing.tic()
        if convlog is not None:
            convlog.tic()

        # Update rho for quadratics
        for idx, op in enumerate(quad_ops):
            op.scalar = quad_weights[idx] / np.sqrt(rho)
        x_update = get_least_squares_inverse(op_list, CompGraph(stacked_ops),
                                             try_diagonalize, verbose)

        for ii in range(max_inner_iters):
            inner_iter_timing.tic()
            # Update Kx.
            K.forward(x, Kx)

            # Prox update to get w.
            offset = 0
            w_prev = w.copy()
            for fn in psi_fns:
                slc = slice(offset, offset + fn.lin_op.size, None)
                # Apply and time prox.
                prox_log[fn].tic()
                w[slc] = fn.prox(rho, np.reshape(Kx[slc], fn.lin_op.shape),
                                 ii).flatten()
                prox_log[fn].toc()
                offset += fn.lin_op.size

            # Update x.
            x_prev[:] = x
            tmp = np.hstack([w] +
                            [cterm / np.sqrt(rho) for cterm in const_terms])
            x = x_update.solve(tmp,
                               x_init=x,
                               lin_solver=lin_solver,
                               options=lin_solver_options)

            # Very basic convergence check.
            r_x = np.linalg.norm(x_prev - x)
            eps_x = eps_rel * np.prod(K.input_size)

            r_w = np.linalg.norm(w_prev - w)
            eps_w = eps_rel * np.prod(K.output_size)

            # Convergence log
            if convlog is not None:
                convlog.toc()
                K.update_vars(x)
                objval = sum([fn.value for fn in prox_fns])
                convlog.record_objective(objval)

            # Show progess
            if verbose > 0:
                # Evaluate objective only if required (expensive !)
                objstr = ''
                if verbose == 2:
                    K.update_vars(x)
                    objstr = ", obj_val = %02.03e" % sum(
                        [fn.value for fn in prox_fns])

                # Evaluate metric potentially
                metstr = '' if metric is None else ", {}".format(
                    metric.message(x))
                print("iter [%02d (rho=%2.1e) || %02d]:"
                      "||w - w_prev||_2 = %02.02e (eps=%02.03e)"
                      "||x - x_prev||_2 = %02.02e (eps=%02.03e)%s%s" %
                      (i, rho, ii, r_x, eps_x, r_w, eps_w, objstr, metstr))

            inner_iter_timing.toc()
            if r_x < eps_x and r_w < eps_w:
                break

        # Update rho
        rho = np.minimum(rho * rho_scale, rho_max)
        i += 1
        iter_timing.toc()

    # Print out timings info.
    if verbose > 0:
        print(iter_timing)
        print(inner_iter_timing)
        print("prox funcs:")
        print(prox_log)
        print("K forward ops:")
        print(K.forward_log)
        print("K adjoint ops:")
        print(K.adjoint_log)

    # Assign values to variables.
    K.update_vars(x)

    # Return optimal value.
    return sum([fn.value for fn in prox_fns])
Exemple #13
0
def solve(psi_fns,
          omega_fns,
          tau=None,
          sigma=None,
          theta=None,
          max_iters=1000,
          eps_abs=1e-3,
          eps_rel=1e-3,
          x0=None,
          lin_solver="cg",
          lin_solver_options=None,
          conv_check=100,
          try_diagonalize=True,
          try_fast_norm=False,
          scaled=True,
          implem=None,
          metric=None,
          convlog=None,
          verbose=0,
          callback=None,
          adapter=NumpyAdapter()):

    # Can only have one omega function.
    assert len(omega_fns) <= 1
    prox_fns = psi_fns + omega_fns
    stacked_ops = vstack([fn.lin_op for fn in psi_fns])
    K = CompGraph(stacked_ops, implem=implem)

    #graph_visualize(prox_fns)

    if adapter.implem() == 'numpy':
        K_forward = K.forward
        K_adjoint = K.adjoint
        prox_off_and_fac = lambda offset, factor, fn, *args, **kw: ne.evaluate(
            'x*a+b', {
                'x': fn.prox(*args, **kw),
                'a': factor,
                'b': offset
            })
        prox = lambda fn, *args, **kw: fn.prox(*args, **kw)
    elif adapter.implem() == 'pycuda':
        K_forward = K.forward_cuda
        K_adjoint = K.adjoint_cuda
        prox_off_and_fac = lambda offset, factor, fn, *args, **kw: fn.prox_cuda(
            *args, offset=offset, factor=factor, **kw)
        prox = lambda fn, *args, **kw: fn.prox_cuda(*args, **kw)
    else:
        raise RuntimeError("Implementation %s unknown" % adapter.implem())
    # Select optimal parameters if wanted
    if tau is None or sigma is None or theta is None:
        tau, sigma, theta = est_params_pc(K, tau, sigma, verbose, scaled,
                                          try_fast_norm)
    elif callable(tau) or callable(sigma) or callable(theta):
        if scaled:
            L = 1
        else:
            L = est_CompGraph_norm(K, try_fast_norm)

    # Initialize
    x = adapter.zeros(K.input_size)
    y = adapter.zeros(K.output_size)
    xbar = adapter.zeros(K.input_size)
    u = adapter.zeros(K.output_size)
    z = adapter.zeros(K.output_size)

    if x0 is not None:
        x[:] = adapter.reshape(adapter.from_np(x0), K.input_size)
    else:
        x[:] = adapter.from_np(K.x0())

    K_forward(x, y)
    xbar[:] = x

    # Buffers.
    Kxbar = adapter.zeros(K.output_size)
    Kx = adapter.zeros(K.output_size)
    KTy = adapter.zeros(K.input_size)
    KTu = adapter.zeros(K.input_size)
    s = adapter.zeros(K.input_size)

    prev_x = x.copy()
    prev_Kx = Kx.copy()
    prev_z = z.copy()
    prev_u = u.copy()

    # Log for prox ops.
    prox_log = TimingsLog(prox_fns)
    prox_log_tot = TimingsLog(prox_fns)
    # Time iterations.
    iter_timing = TimingsLog([
        "pc_iteration_tot", "copyprev", "calcz", "calcx", "omega_fn", "xbar",
        "conv_check"
    ])

    # Convergence log for initial iterate
    if convlog is not None:
        K.update_vars(adapter.to_np(x))
        objval = 0.0
        for f in prox_fns:
            evp = f.value
            #print(str(f), '->', f.value)
            objval += evp
        convlog.record_objective(objval)
        convlog.record_timing(0.0)

    for i in range(max_iters):
        iter_timing["pc_iteration_tot"].tic()
        if convlog is not None:
            convlog.tic()

        if callable(sigma):
            csigma = sigma(i, L)
        else:
            csigma = sigma
        if callable(tau):
            ctau = tau(i, L)
        else:
            ctau = tau
        if callable(theta):
            ctheta = theta(i, L)
        else:
            ctheta = theta

        csigma = adapter.scalar(csigma)
        ctau = adapter.scalar(ctau)
        ctheta = adapter.scalar(ctheta)

        # Keep track of previous iterates
        iter_timing["copyprev"].tic()
        adapter.copyto(prev_x, x)
        adapter.copyto(prev_z, z)
        adapter.copyto(prev_u, u)
        adapter.copyto(prev_Kx, Kx)
        iter_timing["copyprev"].toc()

        # Compute z
        iter_timing["calcz"].tic()
        K_forward(xbar, Kxbar)
        ne.evaluate('y + csigma * Kxbar', out=z)
        iter_timing["calcz"].toc()

        # Update y.
        offset = 0
        for fn in psi_fns:
            prox_log_tot[fn].tic()
            slc = slice(offset, offset + fn.lin_op.size, None)
            z_slc = adapter.reshape(z[slc], fn.lin_op.shape)
            # Moreau identity: apply and time prox.
            prox_log[fn].tic()
            y[slc] = adapter.flatten(
                prox_off_and_fac(z_slc, -csigma, fn, csigma, z_slc / csigma,
                                 i))
            prox_log[fn].toc()
            offset += fn.lin_op.size
            prox_log_tot[fn].toc()

        iter_timing["calcx"].tic()
        if offset < y.shape[0]:
            y[offset:] = 0
        # Update x
        K_adjoint(y, KTy)
        ne.evaluate('x - ctau * KTy', out=x)
        iter_timing["calcx"].toc()

        iter_timing["omega_fn"].tic()
        if len(omega_fns) > 0:
            fn = omega_fns[0]
            prox_log_tot[fn].tic()
            xtmp = adapter.reshape(x, fn.lin_op.shape)
            prox_log[fn].tic()
            if adapter.implem() == 'numpy':
                # ravel() avoids a redundant memcpy
                x[:] = prox(fn,
                            1.0 / ctau,
                            xtmp,
                            x_init=prev_x,
                            lin_solver=lin_solver,
                            options=lin_solver_options).ravel()
            else:
                x[:] = adapter.flatten(
                    prox(fn,
                         1.0 / ctau,
                         xtmp,
                         x_init=prev_x,
                         lin_solver=lin_solver,
                         options=lin_solver_options))

            prox_log[fn].toc()
            prox_log_tot[fn].toc()
        iter_timing["omega_fn"].toc()

        iter_timing["xbar"].tic()
        # Update xbar
        ne.evaluate('x + ctheta * (x - prev_x)', out=xbar)
        iter_timing["xbar"].toc()

        # Convergence log
        if convlog is not None:
            convlog.toc()
            K.update_vars(adapter.to_np(x))
            objval = list([fn.value for fn in prox_fns])
            objval = sum(objval)
            convlog.record_objective(objval)

        # Residual based convergence check
        if i % conv_check in [0, conv_check - 1]:
            iter_timing["conv_check"].tic()
            K_forward(x, Kx)
            ne.evaluate('y / csigma + ctheta * (Kx - prev_Kx)',
                        out=u,
                        casting='unsafe')
            ne.evaluate('prev_u + prev_Kx - y / csigma',
                        out=z,
                        casting='unsafe')
            iter_timing["conv_check"].toc()

        # Iteration order is different than
        # lin-admm (--> start checking at iteration 1)
        if i > 0 and i % conv_check == 0:

            # Check convergence
            r = ne.evaluate('prev_Kx - z')
            dz = ne.evaluate('csigma * (z - prev_z)')
            K_adjoint(dz, s)
            eps_pri = np.sqrt(K.output_size) * eps_abs + eps_rel * \
                max([np.linalg.norm(prev_Kx), np.linalg.norm(z)])

            K_adjoint(u, KTu)
            eps_dual = np.sqrt(
                K.input_size) * eps_abs + eps_rel * np.linalg.norm(
                    KTu) / csigma

            if not callback is None or verbose == 2:
                K.update_vars(adapter.to_np(x))
            if not callback is None:
                callback(adapter.to_np(x))

            # Progress
            if verbose > 0:
                # Evaluate objective only if required (expensive !)
                objstr = ''
                if verbose == 2:
                    ov = list([fn.value for fn in prox_fns])
                    objval = sum(ov)
                    objstr = ", obj_val = %02.03e [%s] " % (objval, ", ".join(
                        "%02.03e" % x for x in ov))

                # Evaluate metric potentially
                metstr = '' if metric is None else ", {}".format(
                    metric.message(v))
                print(
                    "iter %d: ||r||_2 = %.3f, eps_pri = %.3f, ||s||_2 = %.3f, eps_dual = %.3f%s%s"
                    % (i, np.linalg.norm(adapter.to_np(r)), eps_pri,
                       np.linalg.norm(
                           adapter.to_np(s)), eps_dual, objstr, metstr))

            iter_timing["pc_iteration_tot"].toc()
            if np.linalg.norm(adapter.to_np(r)) <= eps_pri and np.linalg.norm(
                    adapter.to_np(s)) <= eps_dual:
                break

        else:
            iter_timing["pc_iteration_tot"].toc()

    # Print out timings info.
    if verbose > 0:
        print(iter_timing)
        print("prox funcs total:")
        print(prox_log_tot)
        print("prox funcs inner:")
        print(prox_log)
        print("K forward ops:")
        print(K.forward_log)
        print("K adjoint ops:")
        print(K.adjoint_log)

    # Assign values to variables.
    K.update_vars(adapter.to_np(x))
    if not callback is None:
        callback(adapter.to_np(x))
    # Return optimal value.
    return sum([fn.value for fn in prox_fns])
Exemple #14
0
class least_squares(sum_squares):
    """The function ||K*x||_2^2.

       Here K is a computation graph (vector to vector lin op).
    """
    def __init__(self,
                 lin_op,
                 offset,
                 diag=None,
                 freq_diag=None,
                 freq_dims=None,
                 implem=Impl['numpy'],
                 **kwargs):
        self.K = CompGraph(lin_op)
        self.offset = offset
        self.diag = diag
        # TODO: freq diag is supposed to be True/False. What is going on below?
        self.freq_diag = freq_diag
        self.orig_freq_diag = freq_diag
        self.freq_dims = freq_dims
        self.orig_freq_dims = freq_dims
        # Get shape for frequency inversion var
        if self.freq_diag is not None:
            if len(self.K.orig_end.variables()) > 1:
                raise Exception(
                    "Diagonal frequency inversion supports only one var currently."
                )

            self.freq_shape = self.K.orig_end.variables()[0].shape
            self.freq_diag = np.reshape(self.freq_diag, self.freq_shape)
            if implem == Impl['halide'] and \
                    (len(self.freq_shape) == 2 or (len(self.freq_shape) == 2 and
                                                   self.freq_dims == 2)):
                # TODO: FIX REAL TO IMAG
                hsize = self.freq_shape if len(
                    self.freq_shape) == 3 else (self.freq_shape[0],
                                                self.freq_shape[1], 1)
                hsizehalide = ((hsize[0] + 1) / 2 + 1, hsize[1], hsize[2], 2)

                self.hsizehalide = hsizehalide
                self.ftmp_halide = np.zeros(hsizehalide,
                                            dtype=np.float32,
                                            order='F')
                self.ftmp_halide_out = np.zeros(hsize,
                                                dtype=np.float32,
                                                order='F')
                self.freq_diag = np.reshape(
                    self.freq_diag[0:hsizehalide[0], ...], hsizehalide[0:3])

        super(least_squares, self).__init__(lin_op, implem=implem, **kwargs)

    def get_data(self):
        """Returns info needed to reconstruct the object besides the args.

        Returns
        -------
        list
        """
        return [
            self.offset, self.diag, self.orig_freq_diag, self.orig_freq_dims
        ]

    def _prox(self, rho, v, b=None, lin_solver="cg", *args, **kwargs):
        """x = argmin_x ||K*x - self.offset - b||_2^2 + (rho/2)||x-v||_2^2.
        """
        if b is None:
            offset = self.offset
        else:
            offset = self.offset + b
        return self.solve(offset,
                          rho=rho,
                          v=v,
                          lin_solver=lin_solver,
                          *args,
                          **kwargs)

    def _eval(self, v):
        """Evaluate the function on v (ignoring parameters).
        """
        Kv = np.zeros(self.K.output_size)
        self.K.forward(v.ravel(), Kv)
        return super(least_squares, self)._eval(Kv - self.offset)

    def solve(self, b, rho=None, v=None, lin_solver="lsqr", *args, **kwargs):
        # KtK Operator is diagonal
        if self.diag is not None:

            Ktb = np.zeros(self.K.input_size)
            self.K.adjoint(b, Ktb)
            if rho is None:
                Ktb /= self.diag
            else:
                Ktb += (rho / 2.) * v
                Ktb /= (self.diag + rho / 2.)

            return Ktb

        # KtK operator is diagonal in frequency domain.
        elif self.freq_diag is not None:
            Ktb = np.zeros(self.K.input_size)
            self.K.adjoint(b, Ktb)

            # Frequency inversion
            if self.implementation == Impl['halide'] and \
                    (len(self.freq_shape) == 2 or
                     (len(self.freq_shape) == 2 and self.freq_dims == 2)):

                Halide('fft2_r2c.cpp').fft2_r2c(
                    np.asfortranarray(
                        np.reshape(Ktb.astype(np.float32), self.freq_shape)),
                    0, 0, self.ftmp_halide)

                Ktb = 1j * self.ftmp_halide[..., 1]
                Ktb += self.ftmp_halide[..., 0]

                if rho is None:
                    Ktb /= self.freq_diag
                else:
                    Halide('fft2_r2c.cpp').fft2_r2c(
                        np.asfortranarray(
                            np.reshape(v.astype(np.float32), self.freq_shape)),
                        0, 0, self.ftmp_halide)

                    vhat = self.ftmp_halide[...,
                                            0] + 1j * self.ftmp_halide[..., 1]
                    Ktb *= 1.0 / rho
                    Ktb += vhat
                    Ktb /= (1.0 / rho * self.freq_diag + 1.0)

                # Do inverse tranform
                Ktb = np.asfortranarray(np.stack((Ktb.real, Ktb.imag),
                                                 axis=-1))
                Halide('ifft2_c2r.cpp').ifft2_c2r(Ktb, self.ftmp_halide_out)

                return self.ftmp_halide_out.ravel()

            else:

                # General frequency inversion
                Ktb = fftd(np.reshape(Ktb, self.freq_shape), self.freq_dims)

                if rho is None:
                    Ktb /= self.freq_diag
                else:
                    Ktb *= 2.0 / rho
                    Ktb += fftd(np.reshape(v, self.freq_shape), self.freq_dims)
                    Ktb /= (2.0 / rho * self.freq_diag + 1.0)

                return (ifftd(Ktb, self.freq_dims).real).ravel()

        elif lin_solver == "lsqr":
            return self.solve_lsqr(b, rho, v, *args, **kwargs)
        elif lin_solver == "cg":
            return self.solve_cg(b, rho, v, *args, **kwargs)
        else:
            raise Exception("Unknown least squares solver.")

    def solve_lsqr(self, b, rho=None, v=None, x_init=None, options=None):
        """Solve ||K*x - b||^2_2 + (rho/2)||x-v||_2^2.
        """

        # Add additional linear terms for the rho terms
        sizev = 0
        if rho is not None:
            vf = v.flatten() * np.sqrt(rho / 2.0)
            sizeb = self.K.input_size
            sizev = np.prod(v.shape)
            b = np.hstack((b, vf))

        input_data = np.zeros(self.K.input_size)
        output_data = np.zeros(self.K.output_size + sizev)

        def matvec(x, output_data):
            if rho is None:
                # Traverse compgraph
                self.K.forward(x, output_data)
            else:
                # Compgraph and additional terms
                self.K.forward(x, output_data[0:0 + sizeb])
                np.copyto(output_data[sizeb:sizeb + sizev],
                          x * np.sqrt(rho / 2.0))

            return output_data

        def rmatvec(y, input_data):
            if rho is None:
                self.K.adjoint(y, input_data)
            else:
                self.K.adjoint(y[0:0 + sizeb], input_data)
                input_data += y[sizeb:sizeb + sizev] * np.sqrt(rho / 2.0)

            return input_data

        # Define linear operator
        def matvecComp(x):
            return matvec(x, output_data)

        def rmatvecComp(y):
            return rmatvec(y, input_data)

        K = LinearOperator((self.K.output_size + sizev, self.K.input_size),
                           matvecComp, rmatvecComp)

        # Options
        if options is None:
            # Default options
            return lsqr(K, b)[0]
        else:
            if not isinstance(options, lsqr_options):
                raise Exception("Invalid LSQR options.")
            return lsqr(K,
                        b,
                        atol=options.atol,
                        btol=options.btol,
                        show=options.show,
                        iter_lim=options.iter_lim)[0]

    def solve_cg(self, b, rho=None, v=None, x_init=None, options=None):
        """Solve ||K*x - b||^2_2 + (rho/2)||x-v||_2^2.
        """
        output_data = np.zeros(self.K.output_size)

        def KtK(x, r):
            self.K.forward(x, output_data)
            self.K.adjoint(output_data, r)
            if rho is not None:
                r += rho * x
            return r

        # Compute Ktb
        Ktb = np.zeros(self.K.input_size)
        self.K.adjoint(b, Ktb)
        if rho is not None:
            Ktb += rho * v

        # Options
        if options is None:
            # Default options
            options = cg_options()
        elif not isinstance(options, cg_options):
            raise Exception("Invalid CG options.")

        return cg(KtK, Ktb, options.tol, options.num_iters, options.verbose,
                  x_init, self.implementation)
def solve(psi_fns, omega_fns, tau=None, sigma=None, theta=None,
          max_iters=1000, eps_abs=1e-3, eps_rel=1e-3, x0=None,
          lin_solver="cg", lin_solver_options=None, conv_check=100,
          try_diagonalize=True, try_fast_norm=False, scaled=True,
          metric=None, convlog=None, verbose=0, callback=None, adapter = NumpyAdapter()):

    # Can only have one omega function.
    assert len(omega_fns) <= 1
    prox_fns = psi_fns + omega_fns
    stacked_ops = vstack([fn.lin_op for fn in psi_fns])
    K = CompGraph(stacked_ops)

    #graph_visualize(prox_fns)

    if adapter.implem() == 'numpy':
        K_forward = K.forward
        K_adjoint = K.adjoint
        prox_off_and_fac = lambda offset, factor, fn, *args, **kw: offset + factor*fn.prox(*args, **kw)
        prox = lambda fn, *args, **kw: fn.prox(*args, **kw)
    elif adapter.implem() == 'pycuda':
        K_forward = K.forward_cuda
        K_adjoint = K.adjoint_cuda
        prox_off_and_fac = lambda offset, factor, fn, *args, **kw: fn.prox_cuda(*args, offset=offset, factor=factor, **kw)
        prox = lambda fn, *args, **kw: fn.prox_cuda(*args, **kw)
    else:
        raise RuntimeError("Implementation %s unknown" % adapter.implem())
    # Select optimal parameters if wanted
    if tau is None or sigma is None or theta is None:
        tau, sigma, theta = est_params_pc(K, tau, sigma, verbose, scaled, try_fast_norm)
    elif callable(tau) or callable(sigma) or callable(theta):
        if scaled:
            L = 1
        else:
            L = est_CompGraph_norm(K, try_fast_norm)

    # Initialize
    x = adapter.zeros(K.input_size)
    y = adapter.zeros(K.output_size)
    xbar = adapter.zeros(K.input_size)
    u = adapter.zeros(K.output_size)
    z = adapter.zeros(K.output_size)

    if x0 is not None:
        x[:] = adapter.reshape(adapter.from_np(x0), K.input_size)
    else:
        x[:] = adapter.from_np(K.x0())

    K_forward(x, y)
    xbar[:] = x

    # Buffers.
    Kxbar = adapter.zeros(K.output_size)
    Kx = adapter.zeros(K.output_size)
    KTy = adapter.zeros(K.input_size)
    KTu = adapter.zeros(K.input_size)
    s = adapter.zeros(K.input_size)

    prev_x = x.copy()
    prev_Kx = Kx.copy()
    prev_z = z.copy()
    prev_u = u.copy()

    # Log for prox ops.
    prox_log = TimingsLog(prox_fns)
    prox_log_tot = TimingsLog(prox_fns)
    # Time iterations.
    iter_timing = TimingsLog(["pc_iteration_tot",
                              "copyprev",
                              "calcz",
                              "calcx",
                              "omega_fn",
                              "xbar",
                              "conv_check"])

    # Convergence log for initial iterate
    if convlog is not None:
        K.update_vars(adapter.to_np(x))
        objval = 0.0
        for f in prox_fns:
            evp = f.value
            #print(str(f), '->', f.value)
            objval += evp
        convlog.record_objective(objval)
        convlog.record_timing(0.0)

    for i in range(max_iters):
        iter_timing["pc_iteration_tot"].tic()
        if convlog is not None:
            convlog.tic()

        if callable(sigma):
            csigma = sigma(i, L)
        else:
            csigma = sigma
        if callable(tau):
            ctau = tau(i, L)
        else:
            ctau = tau
        if callable(theta):
            ctheta = theta(i, L)
        else:
            ctheta = theta

        csigma = adapter.scalar(csigma)
        ctau = adapter.scalar(ctau)
        ctheta = adapter.scalar(ctheta)

        # Keep track of previous iterates
        iter_timing["copyprev"].tic()
        adapter.copyto(prev_x, x)
        adapter.copyto(prev_z, z)
        adapter.copyto(prev_u, u)
        adapter.copyto(prev_Kx, Kx)
        iter_timing["copyprev"].toc()

        # Compute z
        iter_timing["calcz"].tic()
        K_forward(xbar, Kxbar)
        z = y + csigma * Kxbar
        iter_timing["calcz"].toc()

        # Update y.
        offset = 0
        for fn in psi_fns:
            prox_log_tot[fn].tic()
            slc = slice(offset, offset + fn.lin_op.size, None)
            z_slc = adapter.reshape(z[slc], fn.lin_op.shape)
            # Moreau identity: apply and time prox.
            prox_log[fn].tic()
            y[slc] = adapter.flatten( prox_off_and_fac(z_slc, -csigma, fn, csigma, z_slc / csigma, i) )
            prox_log[fn].toc()
            offset += fn.lin_op.size
            prox_log_tot[fn].toc()

        iter_timing["calcx"].tic()
        if offset < y.shape[0]:
            y[offset:] = 0
        # Update x
        K_adjoint(y, KTy)
        x -= ctau * KTy
        iter_timing["calcx"].toc()

        iter_timing["omega_fn"].tic()
        if len(omega_fns) > 0:
            fn = omega_fns[0]
            prox_log_tot[fn].tic()
            xtmp = adapter.reshape(x, fn.lin_op.shape)
            prox_log[fn].tic()
            x[:] = adapter.flatten( prox(fn, adapter.scalar(1.0) / ctau, xtmp, x_init=prev_x,
                                     lin_solver=lin_solver, options=lin_solver_options) )
            prox_log[fn].toc()
            prox_log_tot[fn].toc()
        iter_timing["omega_fn"].toc()

        iter_timing["xbar"].tic()
        # Update xbar
        adapter.copyto(xbar, x)
        xbar += ctheta * (x - prev_x)
        iter_timing["xbar"].toc()

        # Convergence log
        if convlog is not None:
            convlog.toc()
            K.update_vars(adapter.to_np(x))
            objval = list([fn.value for fn in prox_fns])
            objval = sum(objval)
            convlog.record_objective(objval)

        """ Old convergence check
        #Very basic convergence check.
        r_x = np.linalg.norm(x - prev_x)
        r_xbar = np.linalg.norm(xbar - prev_xbar)
        r_ybar = np.linalg.norm(y - prev_y)
        error = r_x + r_xbar + r_ybar
        """

        # Residual based convergence check
        if i % conv_check in [0, conv_check-1]:
            iter_timing["conv_check"].tic()
            K_forward(x, Kx)
            u = adapter.scalar(1.0) / csigma * y + ctheta * (Kx - prev_Kx)
            z = prev_u + prev_Kx - adapter.scalar(1.0) / csigma * y
            iter_timing["conv_check"].toc()

        # Iteration order is different than
        # lin-admm (--> start checking at iteration 1)
        if i > 0 and i % conv_check == 0:

            # Check convergence
            r = prev_Kx - z
            K_adjoint(csigma * (z - prev_z), s)
            eps_pri = np.sqrt(K.output_size) * eps_abs + eps_rel * \
                max([np.linalg.norm(adapter.to_np(prev_Kx)), np.linalg.norm(adapter.to_np(z))])

            K_adjoint(u, KTu)
            eps_dual = np.sqrt(K.input_size) * eps_abs + eps_rel * np.linalg.norm(adapter.to_np(KTu)) / csigma

            if not callback is None or verbose == 2:
                K.update_vars(adapter.to_np(x))
            if not callback is None:
                callback(adapter.to_np(x))

            # Progress
            if verbose > 0:
                # Evaluate objective only if required (expensive !)
                objstr = ''
                if verbose == 2:
                    ov = list([fn.value for fn in prox_fns])
                    objval = sum(ov)
                    objstr = ", obj_val = %02.03e [%s] " % (objval, ", ".join("%02.03e" % x for x in ov))

                """ Old convergence check
                #Evaluate metric potentially
                metstr = '' if metric is None else ", {}".format( metric.message(x.copy()) )
                print "iter [%04d]:" \
                      "||x - x_prev||_2 = %02.02e " \
                      "||xbar - xbar_prev||_2 = %02.02e " \
                      "||y - y_prev||_2 = %02.02e " \
                      "SUM = %02.02e (eps=%02.03e)%s%s" \
                        % (i, r_x, r_xbar, r_ybar, error, eps, objstr, metstr)
                """

                # Evaluate metric potentially
                metstr = '' if metric is None else ", {}".format(metric.message(v))
                print(
                    "iter %d: ||r||_2 = %.3f, eps_pri = %.3f, ||s||_2 = %.3f, eps_dual = %.3f%s%s"
                    % (i, np.linalg.norm(adapter.to_np(r)), eps_pri, np.linalg.norm(adapter.to_np(s)), eps_dual, objstr, metstr)
                )

            iter_timing["pc_iteration_tot"].toc()
            if np.linalg.norm(adapter.to_np(r)) <= eps_pri and np.linalg.norm(adapter.to_np(s)) <= eps_dual:
                break

        else:
            iter_timing["pc_iteration_tot"].toc()

        """ Old convergence check
        if error <= eps:
            break
        """

    # Print out timings info.
    if verbose > 0:
        print(iter_timing)
        print("prox funcs total:")
        print(prox_log_tot)
        print("prox funcs inner:")
        print(prox_log)
        print("K forward ops:")
        print(K.forward_log)
        print("K adjoint ops:")
        print(K.adjoint_log)

    # Assign values to variables.
    K.update_vars(adapter.to_np(x))
    if not callback is None:
        callback(adapter.to_np(x))
    # Return optimal value.
    return sum([fn.value for fn in prox_fns])
Exemple #16
0
def solve(psi_fns, omega_fns, rho=1.0,
          max_iters=1000, eps_abs=1e-3, eps_rel=1e-3, x0=None,
          lin_solver="cg", lin_solver_options=None,
          try_diagonalize=True, try_fast_norm=False,
          scaled=True, conv_check=100,
          metric=None, convlog=None, verbose=0):
    prox_fns = psi_fns + omega_fns
    stacked_ops = vstack([fn.lin_op for fn in psi_fns])
    K = CompGraph(stacked_ops)
    # Rescale so (rho/2)||x - b||^2_2
    rescaling = np.sqrt(2. / rho)
    quad_ops = []
    const_terms = []
    for fn in omega_fns:
        fn = fn.absorb_params()
        quad_ops.append(scale(rescaling * fn.beta, fn.lin_op))
        const_terms.append(fn.b.flatten() * rescaling)
    # Check for fast inverse.
    op_list = [func.lin_op for func in psi_fns] + quad_ops
    stacked_ops = vstack(op_list)

    # Get optimize inverse (tries spatial and frequency diagonalization)
    v_update = get_least_squares_inverse(op_list, None, try_diagonalize, verbose)

    # Initialize everything to zero.
    v = np.zeros(K.input_size)
    z = np.zeros(K.output_size)
    u = np.zeros(K.output_size)

    # Initialize
    if x0 is not None:
        v[:] = np.reshape(x0, K.input_size)
        K.forward(v, z)

    # Buffers.
    Kv = np.zeros(K.output_size)
    KTu = np.zeros(K.input_size)
    s = np.zeros(K.input_size)

    # Log for prox ops.
    prox_log = TimingsLog(prox_fns)
    # Time iterations.
    iter_timing = TimingsEntry("ADMM iteration")
    # Convergence log for initial iterate
    if convlog is not None:
        K.update_vars(v)
        objval = sum([func.value for func in prox_fns])
        convlog.record_objective(objval)
        convlog.record_timing(0.0)

    for i in range(max_iters):
        iter_timing.tic()
        if convlog is not None:
            convlog.tic()

        z_prev = z.copy()

        # Update v.
        tmp = np.hstack([z - u] + const_terms)
        v = v_update.solve(tmp, x_init=v, lin_solver=lin_solver, options=lin_solver_options)

        # Update z.
        K.forward(v, Kv)
        Kv_u = Kv + u
        offset = 0
        for fn in psi_fns:
            slc = slice(offset, offset + fn.lin_op.size, None)
            Kv_u_slc = np.reshape(Kv_u[slc], fn.lin_op.shape)
            # Apply and time prox.
            prox_log[fn].tic()
            z[slc] = fn.prox(rho, Kv_u_slc, i).flatten()
            prox_log[fn].toc()
            offset += fn.lin_op.size
        # Update u.
        u += Kv - z

        # Check convergence.
        if i % conv_check == 0:
            r = Kv - z
            K.adjoint(u, KTu)
            K.adjoint(rho * (z - z_prev), s)
            eps_pri = np.sqrt(K.output_size) * eps_abs + eps_rel * \
                max([np.linalg.norm(Kv), np.linalg.norm(z)])
            eps_dual = np.sqrt(K.input_size) * eps_abs + eps_rel * np.linalg.norm(KTu) * rho

        # Convergence log
        if convlog is not None:
            convlog.toc()
            K.update_vars(v)
            objval = sum([fn.value for fn in prox_fns])
            convlog.record_objective(objval)

        # Show progess
        if verbose > 0 and i % conv_check == 0:
            # Evaluate objective only if required (expensive !)
            objstr = ''
            if verbose == 2:
                K.update_vars(v)
                objstr = ", obj_val = %02.03e" % sum([fn.value for fn in prox_fns])

            # Evaluate metric potentially
            metstr = '' if metric is None else ", {}".format(metric.message(v))
            print("iter %d: ||r||_2 = %.3f, eps_pri = %.3f, ||s||_2 = %.3f, eps_dual = %.3f%s%s" % (
                i, np.linalg.norm(r), eps_pri, np.linalg.norm(s), eps_dual, objstr, metstr))

        iter_timing.toc()
        # Exit if converged.
        if np.linalg.norm(r) <= eps_pri and np.linalg.norm(s) <= eps_dual:
            break

    # Print out timings info.
    if verbose > 0:
        print(iter_timing)
        print("prox funcs:")
        print(prox_log)
        print("K forward ops:")
        print(K.forward_log)
        print("K adjoint ops:")
        print(K.adjoint_log)

    # Assign values to variables.
    K.update_vars(v)
    # Return optimal value.
    return sum([fn.value for fn in prox_fns])
Exemple #17
0
    def solve(self, solver=None, test_adjoints = False, test_norm = False, show_graph = False, *args, **kwargs):
        if solver is None:
            solver = self.solver

        if len(self.omega_fns + self.psi_fns) == 0:
            prox_fns = self.prox_fns
        else:
            prox_fns = self.omega_fns + self.psi_fns
        # Absorb lin ops if desired.
        if self.absorb:
            prox_fns = absorb.absorb_all_lin_ops(prox_fns)

        # Merge prox fns.
        if self.merge:
            prox_fns = merge.merge_all(prox_fns)
        # Absorb offsets.
        prox_fns = [absorb.absorb_offset(fn) for fn in prox_fns]
        # TODO more analysis of what solver to use.
        
        if show_graph:
            print("Computational graph before optimizing:")
            graph_visualize(prox_fns, filename = show_graph if type(show_graph) is str else None)
        
        # Short circuit with one function.
        if len(prox_fns) == 1 and type(prox_fns[0].lin_op) == Variable:
            fn = prox_fns[0]
            var = fn.lin_op
            var.value = fn.prox(0, np.zeros(fn.lin_op.shape))
            return fn.value
        elif solver in NAME_TO_SOLVER:
            module = NAME_TO_SOLVER[solver]
            if len(self.omega_fns + self.psi_fns) == 0:
                if self.try_split and len(prox_fns) > 1 and len(self.variables()) == 1:
                    psi_fns, omega_fns = module.partition(prox_fns,
                                                          self.try_diagonalize)
                else:
                    psi_fns = prox_fns
                    omega_fns = []
            else:
                psi_fns = self.psi_fns
                omega_fns = self.omega_fns
            if test_norm:
                L = CompGraph(vstack([fn.lin_op for fn in psi_fns]))
                from numpy.random import random

                output_mags = [NotImplemented]
                L.norm_bound(output_mags)
                if not NotImplemented in output_mags:
                    assert len(output_mags) == 1
                
                    x = random(L.input_size)
                    x = x / LA.norm(x)
                    y = np.zeros(L.output_size)
                    y = L.forward(x, y)
                    ny = LA.norm(y)
                    nL2 = est_CompGraph_norm(L, try_fast_norm=False)
                    if ny > output_mags[0]:
                        raise RuntimeError("wrong implementation of norm!")
                    print("%.3f <= ||K|| = %.3f (%.3f)" % (ny, output_mags[0], nL2))
                
            # Scale the problem.
            if self.scale:
                K = CompGraph(vstack([fn.lin_op for fn in psi_fns]),
                              implem=self.implem)
                Knorm = est_CompGraph_norm(K, try_fast_norm=self.try_fast_norm)
                for idx, fn in enumerate(psi_fns):
                    psi_fns[idx] = fn.copy(fn.lin_op / Knorm,
                                           beta=fn.beta * np.sqrt(Knorm),
                                           implem=self.implem)
                for idx, fn in enumerate(omega_fns):
                    omega_fns[idx] = fn.copy(beta=fn.beta / np.sqrt(Knorm),
                                             implem=self.implem)
                for v in K.orig_end.variables():
                    if v.initval is not None:
                        v.initval *= np.sqrt(Knorm)
            if not test_adjoints in [False, None]:
                if test_adjoints is True:
                    test_adjoints = 1e-6
                # test adjoints
                L = CompGraph(vstack([fn.lin_op for fn in psi_fns]))
                from numpy.random import random
                
                x = random(L.input_size)
                yt = np.zeros(L.output_size)
                #print("x=", x)
                yt = L.forward(x, yt)
                #print("yt=", yt)
                #print("x=", x)
                y = random(L.output_size)
                #print("y=", y)
                xt = np.zeros(L.input_size)
                xt = L.adjoint(y, xt)
                #print("xt=", xt)
                #print("y=", y)
                r = np.abs( np.dot(np.ravel(y), np.ravel(yt)) - np.dot(np.ravel(x), np.ravel(xt)) )
                #print( x.shape, y.shape, xt.shape, yt.shape)
                if r > test_adjoints:
                    #print("yt=", yt)
                    #print("y =", y)
                    #print("xt=", xt)
                    #print("x =", x)
                    raise RuntimeError("Unmatched adjoints: " + str(r))
                else:
                    print("Adjoint test passed.", r)
                                    
            if self.implem == Impl['pycuda']:
                kwargs['adapter'] = PyCudaAdapter()
            opt_val = module.solve(psi_fns, omega_fns,
                                   lin_solver=self.lin_solver,
                                   try_diagonalize=self.try_diagonalize,
                                   try_fast_norm=self.try_fast_norm,
                                   scaled=self.scale,
                                   *args, **kwargs)
            # Unscale the variables.
            if self.scale:
                for var in self.variables():
                    var.value /= np.sqrt(Knorm)
            return opt_val
        else:
            raise Exception("Unknown solver.")
Exemple #18
0
def solve(psi_fns,
          omega_fns,
          rho=1.0,
          max_iters=1000,
          eps_abs=1e-10,
          eps_rel=1e-3,
          x0=None,
          lin_solver="cg",
          lin_solver_options=None,
          try_diagonalize=True,
          try_fast_norm=False,
          scaled=True,
          metric=None,
          convlog=None,
          verbose=0):
    # C=np.array([[1,0],[0,0]]);
    # b=np.array([2,0]);
    # print(np.linalg.lstsq(C,b,rcond=None)[0])
    prox_fns = psi_fns + omega_fns
    stacked_ops = vstack([fn.lin_op for fn in psi_fns])
    K = CompGraph(stacked_ops)
    # Rescale so (rho/2)||x - b||^2_2
    rescaling = np.sqrt(2. / rho)
    quad_ops = []
    const_terms = []
    for fn in omega_fns:
        fn = fn.absorb_params()
        quad_ops.append(scale(rescaling * fn.beta, fn.lin_op))
        const_terms.append(fn.b.flatten() * rescaling)
    # Check for fast inverse.
    op_list = [func.lin_op for func in psi_fns] + quad_ops
    stacked_ops = vstack(op_list)

    # Get optimize inverse (tries spatial and frequency diagonalization)
    v_update = get_least_squares_inverse(op_list, None, try_diagonalize,
                                         verbose)

    # Initialize everything to zero.
    input_size = K.input_size
    output_size = K.output_size
    v = np.zeros(input_size)
    z = np.zeros(output_size)
    u = np.zeros(output_size)

    print(output_size)

    # Initialize
    if x0 is not None:
        v[:] = np.reshape(x0, input_size)
        K.forward(v, z)

    # Buffers.
    v0 = v.copy()
    z0 = z.copy()
    u0 = u.copy()
    N_z = len(z[:])
    Kv = np.zeros(output_size)
    KTu = np.zeros(input_size)
    s = np.zeros(input_size)
    Kv_pre = Kv.copy()
    # Log for prox ops.
    prox_log = TimingsLog(prox_fns)
    # Time iterations.
    iter_timing = TimingsEntry("ADMM iteration")
    # Convergence log for initial iterate
    if convlog is not None:
        K.update_vars(v)
        objval = sum([func.value for func in prox_fns])
        convlog.record_objective(objval)
        convlog.record_timing(0.0)

    # --------------------------------------------------------------------------------------------------
    print("Anderson Acceleration:")
    for andersonmk in range(6, 7):
        v = v0.copy()
        u = u0.copy()
        v_d = v.copy()
        u_d = u.copy()
        res_pre = 9e20
        total_energy = []
        total_time = []
        Combine_res = []
        reset = False
        sca_z = 1
        size = v.flatten().shape[0]
        total_size = (u.flatten()).shape[0] + size
        print(size)
        sign = 0
        curr_time = 0
        AA_compute_time = 0
        acc1 = Anderson(
            np.concatenate((v.flatten(), sca_z * u.flatten()), axis=0),
            total_size, andersonmk)
        for i in range(max_iters):
            t1 = time.time()
            if convlog is not None:
                convlog.tic()

            K.forward(v, Kv)
            # Update z.
            Kv_pre = Kv.copy()
            K.forward(v, Kv)
            Kv_u = Kv + u
            offset = 0
            for fn in psi_fns:
                tmp = np.hstack([z - u] + const_terms)
                v = v_update.solve(tmp,
                                   x_init=v,
                                   lin_solver=lin_solver,
                                   options=lin_solver_options)
                K.forward(v, Kv)
                Kv_u = Kv + u
                slc = slice(offset, offset + fn.lin_op.size, None)
                Kv_u_slc = np.reshape(Kv_u[slc], fn.lin_op.shape)
                # Apply and time prox.
                z_pre = z.copy()
                prox_log[fn].tic()
                z[slc] = fn.prox(rho, Kv_u_slc, i).flatten()
                prox_log[fn].toc()
                offset += fn.lin_op.size
            # Update u.
            r = Kv - z
            u += r
            K.adjoint(u, KTu)
            # print(np.linalg.norm(u))

            # Check convergence.

            # K.adjoint(rho * (z - z_prev), s)
            s = z - z_pre
            res = np.linalg.norm(r)**2 + np.linalg.norm(s)**2
            # K.adjoint((z-z_prev),s)
            # eps_pri = np.sqrt(output_size) * eps_abs + eps_rel * \
            #           max([np.linalg.norm(Kv), np.linalg.norm(z)])
            # eps_dual = np.sqrt(input_size) * eps_abs + eps_rel * np.linalg.norm(KTu) / rho

            t3 = time.time()
            if res < res_pre or reset == True:
                v_d = v.copy()
                u_d = u.copy()
                res_pre = res
                reset = False
                tt = acc1.compute(
                    np.concatenate((v.flatten(), sca_z * u.flatten()), axis=0))
                v = tt[0:size].reshape(v.shape)
                u = tt[size:].reshape(u.shape) / sca_z
            else:
                sign = sign + 1
                v = v_d.copy()
                u = u_d.copy()
                reset = True
                acc1.reset(
                    np.concatenate((v.flatten(), sca_z * u.flatten()), axis=0))
            t4 = time.time()
            AA_compute_time += t4 - t3

            t2 = time.time()
            curr_time += t2 - t1

            # Convergence log
            if convlog is not None:
                convlog.toc()
                K.update_vars(v)
                objval = sum([fn.value for fn in prox_fns])
                convlog.record_objective(objval)

            # Show progess
            if verbose > 0:
                # Evaluate objective only if required (expensive !)
                objstr = ''
                if verbose == 2:
                    K.update_vars(v)
                    objstr = ", obj_val = %02.03e" % sum(
                        [fn.value for fn in prox_fns])

                # Evaluate metric potentially
                metstr = '' if metric is None else ", {}".format(
                    metric.message(v))
                # print("iter %d: ||r||_2 = %.3f, eps_pri = %.3f, ||s||_2 = %.3f, eps_dual = %.3f%s%s" % (
                #     i, np.linalg.norm(r), eps_pri, np.linalg.norm(s), eps_dual, objstr, metstr))
                print("iter %d: combine residual = %.8f" % (i, res))

            Combine_res.append(np.sqrt(rho * res_pre / N_z))
            total_time.append(curr_time)
            # Exit if converged.
            if (res) < eps_abs:
                break
        print("current time: %.6f, AA compute: %.6f, sign: %d" %
              (curr_time, AA_compute_time, sign))

        hm_src_path = 'residual-' + str(andersonmk) + '.txt'
        iter_num = []
        iter_num.append(len(total_time))
        iter_num.append(len(Combine_res))
        with open(hm_src_path, 'w') as f:
            for i in range(0, min(iter_num)):
                f.write('%f\t%.20f\n' % (total_time[i], Combine_res[i]))
        f.close()
    print("Anderson Acceleration with Douglas-Rachford splitting:")
    for andersonmk in range(6, 7):
        v = v0.copy()
        u = u0.copy()
        K.forward(v, Kv)
        v_d = v.copy()
        u_d = u.copy()
        d_s = z0.copy()
        d_u = d_s.copy()
        d_s_d = d_s.copy()
        d_v = d_s.copy()
        d_unew = d_u.copy()
        res_pre = 9e20
        r_com = 0
        r_com_pre = r_com
        total_energy = []
        total_time = []
        Combine_res = []
        reset = False
        size = v.flatten().shape[0]
        sign = 0
        curr_time = 0
        acc1 = Anderson(d_s.flatten(), size, andersonmk)
        for i in range(max_iters):
            t1 = time.time()
            if convlog is not None:
                convlog.tic()
            # K.forward(v, Kv)
            # Update v.
            Kv_u = d_s.copy()
            offset = 0
            for fn in psi_fns:
                slc = slice(offset, offset + fn.lin_op.size, None)
                Kv_u_slc = np.reshape(Kv_u[slc], fn.lin_op.shape)
                # Apply and time prox.
                prox_log[fn].tic()
                z[slc] = fn.prox(rho, Kv_u_slc, i).flatten()
                prox_log[fn].toc()
                offset += fn.lin_op.size

            d_u = z.copy()
            temp = 2 * d_u - d_s
            tmp = np.hstack([temp] + const_terms)
            v = v_update.solve(tmp,
                               x_init=v,
                               lin_solver=lin_solver,
                               options=lin_solver_options)
            K.forward(v, d_v)
            # z_prev = z.copy()
            # Update z.
            # Update d_s
            r = d_v - d_u
            d_s += r
            res = np.linalg.norm(r)**2
            t2 = time.time()
            curr_time += t2 - t1
            # print(np.linalg.norm(u))
            Kv_u = d_s.copy()
            offset = 0
            for fn in psi_fns:
                slc = slice(offset, offset + fn.lin_op.size, None)
                Kv_u_slc = np.reshape(Kv_u[slc], fn.lin_op.shape)
                # Apply and time prox.
                prox_log[fn].tic()
                z[slc] = fn.prox(rho, Kv_u_slc, i).flatten()
                prox_log[fn].toc()
                offset += fn.lin_op.size
            d_unew = z.copy()
            # Check convergence.
            # K.adjoint(rho * (z - z_prev),

            r_com = np.linalg.norm(d_unew - d_v)**2 + np.linalg.norm(d_unew -
                                                                     d_u)**2
            # K.adjoint((z-z_prev),s)
            # eps_pri = np.sqrt(output_size) * eps_abs + eps_rel * \
            #           max([np.linalg.norm(Kv), np.linalg.norm(z)])
            # eps_dual = np.sqrt(input_size) * eps_abs + eps_rel * np.linalg.norm(KTu) / rho

            # Convergence log
            if convlog is not None:
                convlog.toc()
                K.update_vars(v)
                objval = sum([fn.value for fn in prox_fns])
                convlog.record_objective(objval)
            t1 = time.time()
            if res < res_pre or reset == True:
                d_s_d = d_s.copy()
                res_pre = res
                r_com_pre = r_com
                reset = False
                tt = acc1.compute(d_s.flatten())
                d_s = tt.reshape(d_s.shape)
            else:
                sign = sign + 1
                d_s = d_s_d.copy()
                reset = True
                acc1.reset(d_s.flatten())
            # Show progess
            if verbose > 0:
                # Evaluate objective only if required (expensive !)
                objstr = ''
                if verbose == 2:
                    K.update_vars(v)
                    objstr = ", obj_val = %02.03e" % sum(
                        [fn.value for fn in prox_fns])

                # Evaluate metric potentially
                metstr = '' if metric is None else ", {}".format(
                    metric.message(v))
                # print("iter %d: ||r||_2 = %.3f, eps_pri = %.3f, ||s||_2 = %.3f, eps_dual = %.3f%s%s" % (
                #     i, np.linalg.norm(r), eps_pri, np.linalg.norm(s), eps_dual, objstr, metstr))
                print("iter %d: combine residual = %.8f" % (i, r_com))
            t2 = time.time()
            curr_time += t2 - t1
            Combine_res.append(np.sqrt(rho * r_com_pre / N_z))
            total_time.append(curr_time)
            # Exit if converged.
            # if (res) < eps_abs:
            #     break
        hm_src_path = 'dr-' + str(andersonmk) + '.txt'
        iter_num = []
        iter_num.append(len(total_time))
        iter_num.append(len(Combine_res))
        with open(hm_src_path, 'w') as f:
            for i in range(0, min(iter_num)):
                f.write('%f\t%.20f\n' % (total_time[i], Combine_res[i]))
        f.close()
    # Print out timings info.
    if verbose > 0:
        print(iter_timing)
        print("prox funcs:")
        print(prox_log)
        print("K forward ops:")
        print(K.forward_log)
        print("K adjoint ops:")
        print(K.adjoint_log)

    # Assign values to variables.
    K.update_vars(v)
    # Return optimal value.
    # return sum([fn.value for fn in prox_fns])
    return total_time, Combine_res
Exemple #19
0
    def test_op_overloading(self):
        """Test operator overloading.
        """
        # Multiplying by a scalar.

        # Forward.
        var = Variable((2, 5))
        W = np.arange(10)
        W = np.reshape(W, (2, 5))
        fn = -2 * mul_elemwise(W, var)
        fn = CompGraph(fn)
        x = W.copy()
        out = np.zeros(x.shape)
        fn.forward(x.flatten(), out)
        self.assertItemsAlmostEqual(out, -2 * W * W)

        # Adjoint.
        x = W.copy()
        out = np.zeros(x.shape).flatten()
        fn.adjoint(x, out)
        self.assertItemsAlmostEqual(out, -2 * W * W)

        # Forward.
        var = Variable((2, 5))
        W = np.arange(10)
        W = np.reshape(W, (2, 5))
        fn = mul_elemwise(W, var) * 0.5
        fn = CompGraph(fn)
        x = W.copy()
        out = np.zeros(x.shape)
        fn.forward(x.flatten(), out)
        self.assertItemsAlmostEqual(out, W * W / 2.)

        # Adjoint.
        x = W.copy()
        out = np.zeros(x.shape).flatten()
        fn.adjoint(x, out)
        self.assertItemsAlmostEqual(out, W * W / 2.)

        # Dividing by a scalar.
        # Forward.
        var = Variable((2, 5))
        W = np.arange(10)
        W = np.reshape(W, (2, 5))
        fn = mul_elemwise(W, var) / 2
        fn = CompGraph(fn)
        x = W.copy()
        out = np.zeros(x.shape)
        fn.forward(x, out)
        self.assertItemsAlmostEqual(out, W * W / 2.)

        # Adding lin ops.
        # Forward.
        x = Variable((2, 5))
        W = np.arange(10)
        W = np.reshape(W, (2, 5))
        fn = mul_elemwise(W, x)
        fn = fn + x + x
        self.assertEqual(len(fn.input_nodes), 3)
        fn = CompGraph(fn)
        x = W.copy()
        out = np.zeros(fn.shape)
        fn.forward(x, out)
        self.assertItemsAlmostEqual(out, W * W + 2 * W)

        # Adjoint.
        x = W.copy()
        out = np.zeros(x.shape).flatten()
        fn.adjoint(x, out)
        self.assertItemsAlmostEqual(out, W * W + 2 * W)

        # Adding in a constant.
        # CompGraph should ignore the constant.
        x = Variable((2, 5))
        W = np.arange(10)
        W = np.reshape(W, (2, 5))
        fn = mul_elemwise(W, x)
        fn = fn + x + W
        self.assertEqual(len(fn.input_nodes), 3)
        fn = CompGraph(fn)
        x = W.copy()
        out = np.zeros(fn.shape)
        fn.forward(x, out)
        self.assertItemsAlmostEqual(out, W * W + W)

        # Subtracting lin ops.
        # Forward.
        x = Variable((2, 5))
        W = np.arange(10)
        W = np.reshape(W, (2, 5))
        fn = -mul_elemwise(W, x)
        fn = x + x - fn
        self.assertEqual(len(fn.input_nodes), 3)
        fn = CompGraph(fn)
        x = W.copy()
        out = np.zeros(fn.shape)
        fn.forward(x, out)
        self.assertItemsAlmostEqual(out, W * W + 2 * W)

        # Adjoint.
        x = W.copy()
        out = np.zeros(x.shape).flatten()
        fn.adjoint(x, out)
        self.assertItemsAlmostEqual(out, W * W + 2 * W)
Exemple #20
0
def solve(psi_fns,
          omega_fns,
          tau=None,
          sigma=None,
          theta=None,
          max_iters=1000,
          eps_abs=1e-3,
          eps_rel=1e-3,
          x0=None,
          lin_solver="cg",
          lin_solver_options=None,
          conv_check=100,
          try_diagonalize=True,
          try_fast_norm=False,
          scaled=True,
          metric=None,
          convlog=None,
          verbose=0):
    # Can only have one omega function.
    assert len(omega_fns) <= 1
    prox_fns = psi_fns + omega_fns
    stacked_ops = vstack([fn.lin_op for fn in psi_fns])
    K = CompGraph(stacked_ops)
    v = np.zeros(K.input_size)
    # Select optimal parameters if wanted
    if tau is None or sigma is None or theta is None:
        tau, sigma, theta = est_params_pc(K, tau, sigma, verbose, scaled,
                                          try_fast_norm)

    # Initialize
    x = np.zeros(K.input_size)
    y = np.zeros(K.output_size)
    xbar = np.zeros(K.input_size)
    u = np.zeros(K.output_size)
    z = np.zeros(K.output_size)

    if x0 is not None:
        x[:] = np.reshape(x0, K.input_size)
        K.forward(x, y)
        xbar[:] = x

    # Buffers.
    Kxbar = np.zeros(K.output_size)
    Kx = np.zeros(K.output_size)
    KTy = np.zeros(K.input_size)
    KTu = np.zeros(K.input_size)
    s = np.zeros(K.input_size)

    prev_x = x.copy()
    prev_Kx = Kx.copy()
    prev_z = z.copy()
    prev_u = u.copy()

    # Log for prox ops.
    prox_log = TimingsLog(prox_fns)
    # Time iterations.
    iter_timing = TimingsEntry("PC iteration")

    # Convergence log for initial iterate
    if convlog is not None:
        K.update_vars(x)
        objval = sum([fn.value for fn in prox_fns])
        convlog.record_objective(objval)
        convlog.record_timing(0.0)

    for i in range(max_iters):
        iter_timing.tic()
        if convlog is not None:
            convlog.tic()

        # Keep track of previous iterates
        np.copyto(prev_x, x)
        np.copyto(prev_z, z)
        np.copyto(prev_u, u)
        np.copyto(prev_Kx, Kx)

        # Compute z
        K.forward(xbar, Kxbar)
        z = y + sigma * Kxbar

        # Update y.
        offset = 0
        for fn in psi_fns:
            slc = slice(offset, offset + fn.lin_op.size, None)
            z_slc = np.reshape(z[slc], fn.lin_op.shape)

            # Moreau identity: apply and time prox.
            prox_log[fn].tic()
            y[slc] = (z_slc -
                      sigma * fn.prox(sigma, z_slc / sigma, i)).flatten()
            prox_log[fn].toc()
            offset += fn.lin_op.size
        y[offset:] = 0

        # Update x
        K.adjoint(y, KTy)
        x -= tau * KTy

        if len(omega_fns) > 0:
            xtmp = np.reshape(x, omega_fns[0].lin_op.shape)
            x[:] = omega_fns[0].prox(1.0 / tau,
                                     xtmp,
                                     x_init=prev_x,
                                     lin_solver=lin_solver,
                                     options=lin_solver_options).flatten()

        # Update xbar
        np.copyto(xbar, x)
        xbar += theta * (x - prev_x)

        # Convergence log
        if convlog is not None:
            convlog.toc()
            K.update_vars(x)
            objval = sum([fn.value for fn in prox_fns])
            convlog.record_objective(objval)
        """ Old convergence check
        #Very basic convergence check.
        r_x = np.linalg.norm(x - prev_x)
        r_xbar = np.linalg.norm(xbar - prev_xbar)
        r_ybar = np.linalg.norm(y - prev_y)
        error = r_x + r_xbar + r_ybar
        """

        # Residual based convergence check
        K.forward(x, Kx)
        u = 1.0 / sigma * y + theta * (Kx - prev_Kx)
        z = prev_u + prev_Kx - 1.0 / sigma * y

        # Iteration order is different than
        # lin-admm (--> start checking at iteration 1)
        if i > 0 and i % conv_check == 0:

            # Check convergence
            r = prev_Kx - z
            K.adjoint(sigma * (z - prev_z), s)
            eps_pri = np.sqrt(K.output_size) * eps_abs + eps_rel * \
                max([np.linalg.norm(prev_Kx), np.linalg.norm(z)])

            K.adjoint(u, KTu)
            eps_dual = np.sqrt(
                K.input_size) * eps_abs + eps_rel * np.linalg.norm(KTu) / sigma

            # Progress
            if verbose > 0:
                # Evaluate objective only if required (expensive !)
                objstr = ''
                if verbose == 2:
                    K.update_vars(x)
                    objstr = ", obj_val = %02.03e" % sum(
                        [fn.value for fn in prox_fns])
                """ Old convergence check
                #Evaluate metric potentially
                metstr = '' if metric is None else ", {}".format( metric.message(x.copy()) )
                print "iter [%04d]:" \
                      "||x - x_prev||_2 = %02.02e " \
                      "||xbar - xbar_prev||_2 = %02.02e " \
                      "||y - y_prev||_2 = %02.02e " \
                      "SUM = %02.02e (eps=%02.03e)%s%s" \
                        % (i, r_x, r_xbar, r_ybar, error, eps, objstr, metstr)
                """

                # Evaluate metric potentially
                metstr = '' if metric is None else ", {}".format(
                    metric.message(v))
                print(
                    "iter %d: ||r||_2 = %.3f, eps_pri = %.3f, ||s||_2 = %.3f, eps_dual = %.3f%s%s"
                    % (i, np.linalg.norm(r), eps_pri, np.linalg.norm(s),
                       eps_dual, objstr, metstr))

            iter_timing.toc()
            if np.linalg.norm(r) <= eps_pri and np.linalg.norm(s) <= eps_dual:
                break

        else:
            iter_timing.toc()
        """ Old convergence check
        if error <= eps:
            break
        """

    # Print out timings info.
    if verbose > 0:
        print(iter_timing)
        print("prox funcs:")
        print(prox_log)
        print("K forward ops:")
        print(K.forward_log)
        print("K adjoint ops:")
        print(K.adjoint_log)

    # Assign values to variables.
    K.update_vars(x)

    # Return optimal value.
    return sum([fn.value for fn in prox_fns])
Exemple #21
0
def solve(psi_fns,
          omega_fns,
          lmb=1.0,
          mu=None,
          quad_funcs=None,
          max_iters=1000,
          eps_abs=1e-3,
          eps_rel=1e-3,
          lin_solver="cg",
          lin_solver_options=None,
          try_diagonalize=True,
          try_fast_norm=True,
          scaled=False,
          metric=None,
          convlog=None,
          verbose=0):
    # Can only have one omega function.
    assert len(omega_fns) <= 1
    prox_fns = psi_fns + omega_fns
    stacked_ops = vstack([fn.lin_op for fn in psi_fns])
    K = CompGraph(stacked_ops)
    # Select optimal parameters if wanted
    if lmb is None or mu is None:
        lmb, mu = est_params_lin_admm(K, lmb, verbose, scaled, try_fast_norm)

    # Initialize everything to zero.
    v = np.zeros(K.input_size)
    z = np.zeros(K.output_size)
    u = np.zeros(K.output_size)

    # Buffers.
    Kv = np.zeros(K.output_size)
    KTu = np.zeros(K.input_size)
    s = np.zeros(K.input_size)

    Kvzu = np.zeros(K.output_size)
    v_prev = np.zeros(K.input_size)
    z_prev = np.zeros(K.output_size)

    # Log for prox ops.
    prox_log = TimingsLog(prox_fns)
    # Time iterations.
    iter_timing = TimingsEntry("LIN-ADMM iteration")
    # Convergence log for initial iterate
    if convlog is not None:
        K.update_vars(v)
        objval = sum([fn.value for fn in prox_fns])
        convlog.record_objective(objval)
        convlog.record_timing(0.0)

    for i in range(max_iters):
        iter_timing.tic()
        if convlog is not None:
            convlog.tic()

        v_prev[:] = v
        z_prev[:] = z

        # Update v
        K.forward(v, Kv)
        Kvzu[:] = Kv - z + u
        K.adjoint(Kvzu, v)
        v[:] = v_prev - (mu / lmb) * v

        if len(omega_fns) > 0:
            v[:] = omega_fns[0].prox(1.0 / mu,
                                     v,
                                     x_init=v_prev.copy(),
                                     lin_solver=lin_solver,
                                     options=lin_solver_options)

        # Update z.
        K.forward(v, Kv)
        Kv_u = Kv + u
        offset = 0
        for fn in psi_fns:
            slc = slice(offset, offset + fn.lin_op.size, None)
            Kv_u_slc = np.reshape(Kv_u[slc], fn.lin_op.shape)
            # Apply and time prox.
            prox_log[fn].tic()
            z[slc] = fn.prox(1.0 / lmb, Kv_u_slc, i).flatten()
            prox_log[fn].toc()
            offset += fn.lin_op.size

        # Update u.
        u += Kv - z
        K.adjoint(u, KTu)

        # Check convergence.
        r = Kv - z
        K.adjoint((1.0 / lmb) * (z - z_prev), s)
        eps_pri = np.sqrt(K.output_size) * eps_abs + eps_rel * \
                  max([np.linalg.norm(Kv), np.linalg.norm(z)])
        eps_dual = np.sqrt(K.input_size) * eps_abs + eps_rel * np.linalg.norm(
            KTu) / (1.0 / lmb)

        # Convergence log
        if convlog is not None:
            convlog.toc()
            K.update_vars(v)
            objval = sum([fn.value for fn in prox_fns])
            convlog.record_objective(objval)

        # Show progess
        if verbose > 0:
            # Evaluate objective only if required (expensive !)
            objstr = ''
            if verbose == 2:
                K.update_vars(v)
                objstr = ", obj_val = %02.03e" % sum(
                    [fn.value for fn in prox_fns])

            # Evaluate metric potentially
            metstr = '' if metric is None else ", {}".format(metric.message(v))
            print(
                "iter %d: ||r||_2 = %.3f, eps_pri = %.3f, ||s||_2 = %.3f, eps_dual = %.3f%s%s"
                % (i, np.linalg.norm(r), eps_pri, np.linalg.norm(s), eps_dual,
                   objstr, metstr))

        iter_timing.toc()
        if np.linalg.norm(r) <= eps_pri and np.linalg.norm(s) <= eps_dual:
            break

    # Print out timings info.
    if verbose > 0:
        print(iter_timing)
        print("prox funcs:")
        print(prox_log)
        print("K forward ops:")
        print(K.forward_log)
        print("K adjoint ops:")
        print(K.adjoint_log)

    # Assign values to variables.
    K.update_vars(v)

    # Return optimal value.
    return sum([fn.value for fn in prox_fns])
def solve(psi_fns, omega_fns,
          rho_0=1.0, rho_scale=math.sqrt(2.0) * 2.0, rho_max=2**8,
          max_iters=-1, max_inner_iters=100, x0=None,
          eps_rel=1e-3, eps_abs=1e-3,
          lin_solver="cg", lin_solver_options=None,
          try_diagonalize=True, scaled=False, try_fast_norm=False,
          metric=None, convlog=None, verbose=0):
    prox_fns = psi_fns + omega_fns
    stacked_ops = vstack([fn.lin_op for fn in psi_fns])
    K = CompGraph(stacked_ops)
    # Rescale so (1/2)||x - b||^2_2
    rescaling = np.sqrt(2.)
    quad_ops = []
    quad_weights = []
    const_terms = []
    for fn in omega_fns:
        fn = fn.absorb_params()
        quad_ops.append(scale(rescaling * fn.beta, fn.lin_op))
        quad_weights.append(rescaling * fn.beta)
        const_terms.append(fn.b.flatten() * rescaling)

    # Get optimize inverse (tries spatial and frequency diagonalization)
    op_list = [func.lin_op for func in psi_fns] + quad_ops
    stacked_ops = vstack(op_list)
    x_update = get_least_squares_inverse(op_list, None,
                                         try_diagonalize, verbose)

    # Initialize
    if x0 is not None:
        x = np.reshape(x0, K.input_size)
    else:
        x = np.zeros(K.input_size)

    Kx = np.zeros(K.output_size)
    w = Kx.copy()

    # Temporary iteration counts
    x_prev = x.copy()

    # Log for prox ops.
    prox_log = TimingsLog(prox_fns)
    # Time iterations.
    iter_timing = TimingsEntry("HQS iteration")
    inner_iter_timing = TimingsEntry("HQS inner iteration")
    # Convergence log for initial iterate
    if convlog is not None:
        K.update_vars(x)
        objval = sum([func.value for func in prox_fns])
        convlog.record_objective(objval)
        convlog.record_timing(0.0)

    # Rho scedule
    rho = rho_0
    i = 0
    while rho < rho_max and i < max_iters:
        iter_timing.tic()
        if convlog is not None:
            convlog.tic()

        # Update rho for quadratics
        for idx, op in enumerate(quad_ops):
            op.scalar = quad_weights[idx] / np.sqrt(rho)
        x_update = get_least_squares_inverse(op_list, CompGraph(stacked_ops),
                                             try_diagonalize, verbose)

        for ii in range(max_inner_iters):
            inner_iter_timing.tic()
            # Update Kx.
            K.forward(x, Kx)

            # Prox update to get w.
            offset = 0
            w_prev = w.copy()
            for fn in psi_fns:
                slc = slice(offset, offset + fn.lin_op.size, None)
                # Apply and time prox.
                prox_log[fn].tic()
                w[slc] = fn.prox(rho, np.reshape(Kx[slc], fn.lin_op.shape), ii).flatten()
                prox_log[fn].toc()
                offset += fn.lin_op.size

            # Update x.
            x_prev[:] = x
            tmp = np.hstack([w] + [cterm / np.sqrt(rho) for cterm in const_terms])
            x = x_update.solve(tmp, x_init=x, lin_solver=lin_solver, options=lin_solver_options)

            # Very basic convergence check.
            r_x = np.linalg.norm(x_prev - x)
            eps_x = eps_rel * np.prod(K.input_size)

            r_w = np.linalg.norm(w_prev - w)
            eps_w = eps_rel * np.prod(K.output_size)

            # Convergence log
            if convlog is not None:
                convlog.toc()
                K.update_vars(x)
                objval = sum([fn.value for fn in prox_fns])
                convlog.record_objective(objval)

            # Show progess
            if verbose > 0:
                # Evaluate objective only if required (expensive !)
                objstr = ''
                if verbose == 2:
                    K.update_vars(x)
                    objstr = ", obj_val = %02.03e" % sum([fn.value for fn in prox_fns])

                # Evaluate metric potentially
                metstr = '' if metric is None else ", {}".format(metric.message(x))
                print("iter [%02d (rho=%2.1e) || %02d]:"
                      "||w - w_prev||_2 = %02.02e (eps=%02.03e)"
                      "||x - x_prev||_2 = %02.02e (eps=%02.03e)%s%s"
                      % (i, rho, ii, r_x, eps_x, r_w, eps_w, objstr, metstr))

            inner_iter_timing.toc()
            if r_x < eps_x and r_w < eps_w:
                break

        # Update rho
        rho = np.minimum(rho * rho_scale, rho_max)
        i += 1
        iter_timing.toc()

    # Print out timings info.
    if verbose > 0:
        print(iter_timing)
        print(inner_iter_timing)
        print("prox funcs:")
        print(prox_log)
        print("K forward ops:")
        print(K.forward_log)
        print("K adjoint ops:")
        print(K.adjoint_log)

    # Assign values to variables.
    K.update_vars(x)

    # Return optimal value.
    return sum([fn.value for fn in prox_fns])