Пример #1
0
class ObjectiveImageDenoising:
    """
    Class to do image denoising
    """

    def __init__(self, mesh, trueImage, parameters=[]):
        """
        Inputs:
            mesh = Fenics mesh
            trueImage = object from class Image
            parameters = dict
        """
        # Mesh
        self.mesh = mesh
        self.V = dl.FunctionSpace(self.mesh, "Lagrange", 1)
        self.xx = self.V.dofmap().tabulate_all_coordinates(self.mesh)
        self.dimV = self.V.dim()
        self.test, self.trial = dl.TestFunction(self.V), dl.TrialFunction(self.V)
        self.f_true = dl.interpolate(trueImage, self.V)
        self.g, self.dg, self.gtmp = dl.Function(self.V), dl.Function(self.V), dl.Function(self.V)
        self.Grad = dl.Function(self.V)
        self.Gradnorm0 = None
        # mass matrix
        self.Mweak = dl.inner(self.test, self.trial) * dl.dx
        self.M = dl.assemble(self.Mweak)
        self.solverM = dl.LUSolver("petsc")
        self.solverM.parameters["symmetric"] = True
        self.solverM.parameters["reuse_factorization"] = True
        self.solverM.set_operator(self.M)
        # identity matrix
        self.I = dl.assemble(self.Mweak)
        self.I.zero()
        self.I.set_diagonal(dl.interpolate(dl.Constant(1), self.V).vector())
        # self.targetnorm = np.sqrt((self.M*self.f_true.vector()).inner(self.f_true.vector()))
        self.targetnorm = np.sqrt((self.f_true.vector()).inner(self.f_true.vector()))
        # line search parameters
        self.parameters = {"alpha0": 1.0, "rho": 0.5, "c": 5e-5, "max_backtrack": 12}
        # regularization
        self.parameters.update({"eps": 1e-4, "k": 1.0, "regularization": "TV", "mode": "primaldual"})
        self.parameters.update(parameters)
        self.define_regularization()
        self.regparam = 1.0
        # plots:
        filename, ext = os.path.splitext(sys.argv[0])
        if os.path.isdir(filename + "/"):
            shutil.rmtree(filename + "/")
        self.myplot = PlotFenics(filename)

    def generatedata(self, noisepercent):
        """ compute data and add noisepercent (%) of noise """
        sigma = noisepercent * np.linalg.norm(self.f_true.vector().array()) / np.sqrt(self.dimV)
        print "sigma_noise = ", sigma
        np.random.seed(11)  # TODO: tmp
        eta = sigma * np.random.randn(self.dimV)
        self.dn = dl.Function(self.V)
        setfct(self.dn, eta)
        self.dn.vector().axpy(1.0, self.f_true.vector())
        print "min(true)={}, max(true)={}".format(
            np.amin(self.f_true.vector().array()), np.amax(self.f_true.vector().array())
        )
        print "min(noisy)={}, max(noisy)={}".format(
            np.amin(self.dn.vector().array()), np.amax(self.dn.vector().array())
        )

    def define_regularization(self, parameters=None):
        if not parameters == None:
            self.parameters.update(parameters)
        regularization = self.parameters["regularization"]
        if regularization == "tikhonov":
            gamma = self.parameters["gamma"]
            beta = self.parameters["beta"]
            self.Reg = LaplacianPrior({"gamma": gamma, "beta": beta, "Vm": self.V})
            self.inexact = False
        elif regularization == "TV":
            eps = self.parameters["eps"]
            k = self.parameters["k"]
            mode = self.parameters["mode"]
            if mode == "primaldual":
                self.Reg = self.Reg = TVPD({"eps": eps, "k": k, "Vm": self.V, "GNhessian": False})
            elif mode == "full":
                self.Reg = TV({"eps": eps, "k": k, "Vm": self.V, "GNhessian": False})
            else:
                self.Reg = TV({"eps": eps, "k": k, "Vm": self.V, "GNhessian": True})
            self.inexact = False

    ### COST and DERIVATIVES
    def computecost(self, f=None):
        """ Compute cost functional at f """
        if f == None:
            f = self.g
        df = f.vector() - self.dn.vector()
        # self.misfit = 0.5 * (self.M*df).inner(df)
        self.misfit = 0.5 * df.inner(df)
        self.reg = self.Reg.cost(f)
        self.cost = self.misfit + self.regparam * self.reg
        return self.cost

    def gradient(self, f=None):
        """ Compute M.g (discrete gradient) at a given point f """
        if f == None:
            f = self.g
        df = f.vector() - self.dn.vector()
        # self.MGk = self.M*df
        self.MGk = df
        self.MGr = self.Reg.grad(f)
        self.MG = self.MGk + self.MGr * self.regparam
        self.solverM.solve(self.Grad.vector(), self.MG)
        self.Gradnorm = np.sqrt((self.MG).inner(self.Grad.vector()))
        if self.Gradnorm0 == None:
            self.Gradnorm0 = self.Gradnorm

    def Hessian(self, f=None):
        """ Assemble Hessian at f """
        if f == None:
            f = self.g
        regularization = self.parameters["regularization"]
        if regularization == "TV":
            self.Reg.assemble_hessian(f)
            self.Hess = self.I + self.Reg.H * self.regparam
            # self.Hess = self.M + self.Reg.H*self.regparam
        elif regularization == "tikhonov":
            self.Hess = self.M + self.Reg.Minvprior * self.regparam

    ### SOLVER
    def searchdirection(self):
        """ Compute search direction """
        self.gradient()
        self.Hessian()
        solver = dl.PETScKrylovSolver("cg", "petsc_amg")
        solver.parameters["nonzero_initial_guess"] = False
        # Inexact CG:
        if self.inexact:
            self.cgtol = min(0.5, np.sqrt(self.Gradnorm / self.Gradnorm0))
        else:
            self.cgtol = 1e-8
        solver.parameters["relative_tolerance"] = self.cgtol
        solver.set_operator(self.Hess)
        self.cgiter = solver.solve(self.dg.vector(), -1.0 * self.MG)
        if (self.MG).inner(self.dg.vector()) > 0.0:
            print "*** WARNING: NOT a descent direction"

    def linesearch(self):
        """ Perform inexact backtracking line search """
        regularization = self.parameters["regularization"]
        # compute new direction for dual variables
        if regularization == "TV" and self.Reg.isPD():
            self.Reg.compute_dw(self.dg)
        # line search for primal variable
        self.alpha = self.parameters["alpha0"]
        rho = self.parameters["rho"]
        c = self.parameters["c"]
        self.computecost()
        costref = self.cost
        cdJdf = ((self.MG).inner(self.dg.vector())) * c
        self.LS = False
        for ii in xrange(self.parameters["max_backtrack"]):
            setfct(self.gtmp, self.g.vector() + self.dg.vector() * self.alpha)
            if self.computecost(self.gtmp) < costref + self.alpha * cdJdf:
                self.g.vector().axpy(self.alpha, self.dg.vector())
                self.LS = True
                break
            else:
                self.alpha *= rho
        # update dual variable
        if regularization == "TV" and self.Reg.isPD():
            self.Reg.update_w(self.alpha)

    def solve(self, plot=False):
        """ Solve image denoising pb """
        regularization = self.parameters["regularization"]
        print "\t{:12s} {:12s} {:12s} {:12s} {:12s} {:12s} {:12s}\t{:12s} {:12s}".format(
            "a_reg", "cost", "misfit", "reg", "||G||", "a_LS", "medmisfit", "tol_cg", "n_cg"
        )
        #
        if regularization == "tikhonov":
            # pb is linear with tikhonov regularization
            self.searchdirection()
            self.g.vector().axpy(1.0, self.dg.vector())
            self.computecost()
            self.alpha = 1.0
            self.printout()
        else:
            self.computecost()
            cost = self.cost
            # initial printout
            df = self.f_true.vector() - self.g.vector()
            self.medmisfit = np.sqrt(df.inner(df))
            # self.medmisfit = np.sqrt((self.M*df).inner(df))
            self.relmedmisfit = self.medmisfit / self.targetnorm
            print ("{:12.1e} {:12.4e} {:12.4e} {:12.4e} {:12s} {:12s} {:12.2e}" + " ({:.3f})").format(
                self.regparam, self.cost, self.misfit, self.reg, "", "", self.medmisfit ** 2, self.relmedmisfit
            )
            # iterate
            for ii in xrange(1000):
                self.searchdirection()
                self.linesearch()
                print ii + 1,
                self.printout()
                # Check termination conditions:
                if not self.LS:
                    print "Line search failed"
                    break
                if self.Gradnorm < min(1e-12, 1e-10 * self.Gradnorm0):
                    print "gradient sufficiently reduced -- optimization converged"
                    break
                elif np.abs(cost - self.cost) / cost < 1e-12:
                    print "cost functional stagnates -- optimization converged"
                    break
                cost = self.cost

    ### OUTPUT
    def printout(self):
        """ Print results """
        df = self.f_true.vector() - self.g.vector()
        self.medmisfit = np.sqrt((self.M * df).inner(df))
        self.relmedmisfit = self.medmisfit / self.targetnorm
        print ("{:12.1e} {:12.4e} {:12.4e} {:12.4e} {:12.4e} {:12.2e} {:12.2e}" + " ({:.3f}) {:12.2e} {:6d}").format(
            self.regparam,
            self.cost,
            self.misfit,
            self.reg,
            self.Gradnorm,
            self.alpha,
            self.medmisfit ** 2,
            self.relmedmisfit,
            self.cgtol,
            self.cgiter,
        )

    def plot(self, index=0, add=""):
        """ Plot target (w/ noise 0, or w/o noise 1) or current iterate (2) """
        if index == 0:
            self.myplot.set_varname("target" + add)
            self.myplot.plot_vtk(self.f_true)
        elif index == 1:
            self.myplot.set_varname("data" + add)
            self.myplot.plot_vtk(self.dn)
        elif index == 2:
            self.myplot.set_varname("solution" + add)
            self.myplot.plot_vtk(self.g)

    ### TESTS
    def test_gradient(self, f=None, n=5):
        """ test gradient with FD approx around point f """
        if f == None:
            f = self.f_true
        pm = [1.0, -1.0]
        eps = 1e-5
        self.gradient(f)
        for nn in xrange(1, n + 1):
            expr = dl.Expression("sin(n*pi*x[0]/200)*sin(n*pi*x[1]/100)", n=nn)
            df = dl.interpolate(expr, self.V)
            MGdf = self.MG.inner(df.vector())
            cost = []
            for sign in pm:
                setfct(self.g, f)
                self.g.vector().axpy(sign * eps, df.vector())
                cost.append(self.computecost(self.g))
            MGFD = (cost[0] - cost[1]) / (2 * eps)
            print "n={}:\tMGFD={:.5e}, MGdf={:.5e}, error={:.2e}".format(
                nn, MGFD, MGdf, np.abs(MGdf - MGFD) / np.abs(MGdf)
            )

    def test_hessian(self, f=None, n=5):
        """ test Hessian with FD approx around point f """
        if f == None:
            f = self.f_true
        pm = [1.0, -1.0]
        eps = 1e-5
        self.Hessian(f)
        for nn in xrange(1, n + 1):
            expr = dl.Expression("sin(n*pi*x[0]/200)*sin(n*pi*x[1]/100)", n=nn)
            df = dl.interpolate(expr, self.V)
            Hdf = (self.Hess * df.vector()).array()
            MG = []
            for sign in pm:
                setfct(self.g, f)
                self.g.vector().axpy(sign * eps, df.vector())
                self.gradient(self.g)
                MG.append(self.MG.array())
            HFD = (MG[0] - MG[1]) / (2 * eps)
            print "n={}:\tHFD={:.5e}, Hdf={:.5e}, error={:.2e}".format(
                nn, np.linalg.norm(HFD), np.linalg.norm(Hdf), np.linalg.norm(Hdf - HFD) / np.linalg.norm(Hdf)
            )