Exemplo n.º 1
0
    def ftimeincradj(self, tt):
        """ Compute rhs for incremental adjoint at time tt """
        try:
            indexf = int(np.where(isequal(self.PDE.times, tt, 1e-14))[0])
            indexa = int(np.where(isequal(self.PDE.times[::-1], tt, 1e-14))[0])
        except:
            print 'Error in ftimeincradj at time {}'.format(tt)
            print np.min(np.abs(self.PDE.times-tt))
            sys.exit(0)

        # B* B phat
#        assert isequal(tt, self.solincrfwd[indexf][1], 1e-16)
        setfct(self.phat, self.solincrfwd[indexf][0])
        self.qhat.vector().zero()
        self.qhat.vector().axpy(1.0, self.obsop.incradj(self.phat, tt))

        if not self.GN:
            # bhat: bhat*grad(ptilde).grad(v)
#            assert isequal(tt, self.soladji[indexa][1], 1e-16)
            setfct(self.q, self.soladji[indexa][0])
            self.qhat.vector().axpy(1.0, self.C*self.q.vector())

            # ahat: ahat*ptilde*q'':
            setfct(self.q, self.solppadji[indexa][0])
            self.qhat.vector().axpy(1.0, self.get_incra(self.q.vector()))

            # ABC:
            if self.PDE.parameters['abc']:
                setfct(self.phat, self.solpadji[indexa][0])
                self.qhat.vector().axpy(-0.5, self.Dp*self.phat.vector())

        return -1.0*self.qhat.vector()
Exemplo n.º 2
0
 def mult(self, lamhat, y):
     """
     mult(self, lamhat, y): return y = Hessian * lamhat
     inputs:
         y, lamhat = Function(V).vector()
     """
     self.regularization.assemble_hessian(lamhat)
     setfct(self.lamhat, lamhat)
     self.C = assemble(self.wkformrhsincr)
     if self.PDE.abc:    self.Dp = assemble(self.wkformDprime)
     # solve for phat
     self.PDE.set_fwd()
     self.PDE.ftime = self.ftimeincrfwd
     self.solincrfwd,_ = self.PDE.solve()
     # solve for vhat
     self.PDE.set_adj()
     self.PDE.ftime = self.ftimeincradj
     self.solincradj,_ = self.PDE.solve()
     # Compute Hessian*lamhat
     y.zero()
     index = 0
     if self.PDE.abc:
         self.vD.vector().zero(); self.vhatD.vector().zero(); 
         self.p1D.vector().zero(); self.p2D.vector().zero();
         self.p1hatD.vector().zero(); self.p2hatD.vector().zero();
     for fwd, adj, incrfwd, incradj, fact in \
     zip(self.solfwd, reversed(self.soladj), \
     self.solincrfwd, reversed(self.solincradj), self.factors):
         ttf, tta, ttf2 = incrfwd[1], incradj[1], fwd[1]
         assert isequal(ttf, tta, 1e-16), 'tfwd={}, tadj={}, reldiff={}'.\
         format(ttf, tta, abs(ttf-tta)/ttf)
         assert isequal(ttf, ttf2, 1e-16), 'tfwd={}, tadj={}, reldiff={}'.\
         format(ttf, ttf2, abs(ttf-ttf2)/ttf)
         setfct(self.p, fwd[0])
         setfct(self.v, adj[0])
         setfct(self.phat, incrfwd[0])
         setfct(self.vhat, incradj[0])
         y.axpy(fact, assemble(self.wkformhess))
         if self.PDE.abc:
             if index%2 == 0:
                 self.p2D.vector().axpy(1.0, self.p.vector())
                 self.p2hatD.vector().axpy(1.0, self.phat.vector())
                 setfct(self.dp, self.p2D)
                 setfct(self.dph, self.p2hatD)
                 y.axpy(1.0*0.5*self.invDt, assemble(self.wkformhessD))
                 setfct(self.p2D, -1.0*self.p.vector())
                 setfct(self.p2hatD, -1.0*self.phat.vector())
             else:
                 self.p1D.vector().axpy(1.0, self.p.vector())
                 self.p1hatD.vector().axpy(1.0, self.phat.vector())
                 setfct(self.dp, self.p1D)
                 setfct(self.dph, self.p1hatD)
                 y.axpy(1.0*0.5*self.invDt, assemble(self.wkformhessD))
                 setfct(self.p1D, -1.0*self.p.vector())
                 setfct(self.p1hatD, -1.0*self.phat.vector())
             setfct(self.vD, self.v)
             setfct(self.vhatD, self.vhat)
         index += 1
     # add regularization term
     y.axpy(self.alpha_reg, self.regularization.hessian(lamhat))
Exemplo n.º 3
0
    def solve(self):
        """ General solver method """
        if self.timestepper == 'backward':
            def iterate(tt):  self.iteration_backward(tt)
        elif self.timestepper == 'centered':
            def iterate(tt):  self.iteration_centered(tt)
        else:
            print "Time stepper not implemented"
            sys.exit(1)

        if self.verbose:    print 'Compute solution'
        solout = [] # Store computed solution
        # u0:
        tt = self.get_tt(0)
        if self.verbose:    print 'Compute solution -- time {}'.format(tt)
        setfct(self.u0, self.u0init)
        solout.append([self.u0.vector().array(), tt])
        # Compute u1:
        if not self.u1init == None: self.u1 = self.u1init
        else:
            assert(not self.utinit == None)
            setfct(self.rhs, self.ftime(tt))
            self.rhs.vector().axpy(-self.fwdadj, self.D*self.utinit.vector())
            self.rhs.vector().axpy(-1.0, self.K*self.u0.vector())
            if not self.bc == None: self.bc.apply(self.rhs.vector())
            self.solverM.solve(self.sol.vector(), self.rhs.vector())
            setfct(self.u1, self.u0)
            self.u1.vector().axpy(self.fwdadj*self.Dt, self.utinit.vector())
            self.u1.vector().axpy(0.5*self.Dt**2, self.sol.vector())
        tt = self.get_tt(1)
        if self.verbose:    print 'Compute solution -- time {}'.format(tt)
        solout.append([self.u1.vector().array(), tt])
        # Iteration
        for nn in xrange(2, self.Nt+1):
            iterate(tt)
            # Advance to next time step
            setfct(self.u0, self.u1)
            setfct(self.u1, self.u2)
            tt = self.get_tt(nn)
            if self.verbose:    
                print 'Compute solution -- time {}, rhs {}'.\
                format(tt, np.max(np.abs(self.ftime(tt))))
            solout.append([self.u1.vector().array(),tt])
        if self.fwdadj > 0.0:   
            assert isequal(tt, self.Tf, 1e-16), \
            'tt={}, Tf={}, reldiff={}'.format(tt, self.Tf, abs(tt-self.Tf)/self.Tf)
        else:
            assert isequal(tt, self.t0, 1e-16), \
            'tt={}, t0={}, reldiff={}'.format(tt, self.t0, abs(tt-self.t0))
        return solout, self.computeerror()
Exemplo n.º 4
0
    def ftimeincrfwd(self, tt):
        """ Compute rhs for incremental forward at time tt """
        try:
            index = int(np.where(isequal(self.PDE.times, tt, 1e-14))[0])
        except:
            print 'Error in ftimeincrfwd at time {}'.format(tt)
            print np.min(np.abs(self.PDE.times-tt))
            sys.exit(0)

        # bhat: bhat*grad(p).grad(qtilde)
#        assert isequal(tt, self.solfwdi[index][1], 1e-16)
        setfct(self.p, self.solfwdi[index][0])
        self.q.vector().zero()
        self.q.vector().axpy(1.0, self.C*self.p.vector())

        # ahat: ahat*p''*qtilde:
        setfct(self.p, self.solppfwdi[index][0])
        self.q.vector().axpy(1.0, self.get_incra(self.p.vector()))

        # ABC:
        if self.PDE.parameters['abc']:
            setfct(self.phat, self.solpfwdi[index][0])
            self.q.vector().axpy(0.5, self.Dp*self.phat.vector())

        return -1.0*self.q.vector()
Exemplo n.º 5
0
 def ftimeincrfwd(self, tt):
     """ Compute rhs for incremental forward at time tt """
     try:
         index = int(np.where(isequal(self.PDE.times, tt, 1e-14))[0])
     except:
         print 'Error in ftimeincrfwd at time {}'.format(tt)
         print np.min(np.abs(self.PDE.times-tt))
         sys.exit(0)
     # lamhat * grad(p).grad(vtilde)
     assert isequal(tt, self.solfwd[index][1], 1e-16)
     setfct(self.p, self.solfwd[index][0])
     setfct(self.v, self.C*self.p.vector())
     # D'.dot(p)
     if self.PDE.abc and index > 0:
             setfct(self.p, \
             self.solfwd[index+1][0] - self.solfwd[index-1][0])
             self.v.vector().axpy(.5*self.invDt, self.Dp*self.p.vector())
     return -1.0*self.v.vector().array()
Exemplo n.º 6
0
 def ftimeadj(self, tt):
     """ Evaluate source term for adj eqn at time tt """
     try:
         index = int(np.where(isequal(self.times, tt, 1e-14))[0])
     except:
         print 'Error in ftimeadj at time {}'.format(tt)
         print np.min(np.abs(self.times-tt))
         sys.exit(0)
     dd = self.diff[:, index]
     self.PtwiseObs.BTdotvec(dd, self.outvec)
     if not self.bcadj == None:  self.bcadj.apply(self.outvec.vector())
     return -1.0*self.st(tt)*self.outvec.array()
Exemplo n.º 7
0
 def ftimeincradj(self, tt):
     """ Compute rhs for incremental adjoint at time tt """
     try:
         indexf = int(np.where(isequal(self.PDE.times, tt, 1e-14))[0])
         indexa = int(np.where(isequal(self.PDE.times[::-1], tt, 1e-14))[0])
     except:
         print 'Error in ftimeincradj at time {}'.format(tt)
         print np.min(np.abs(self.PDE.times-tt))
         sys.exit(0)
     # lamhat * grad(ptilde).grad(v)
     assert isequal(tt, self.soladj[indexa][1], 1e-16)
     setfct(self.v, self.soladj[indexa][0])
     setfct(self.vhat, self.C*self.v.vector())
     # B* B phat
     assert isequal(tt, self.solincrfwd[indexf][1], 1e-16)
     setfct(self.phat, self.solincrfwd[indexf][0])
     self.vhat.vector().axpy(1.0, self.obsop.incradj(self.phat, tt))
     # D'.dot(v)
     if self.PDE.abc and indexa > 0:
             setfct(self.v, \
             self.soladj[indexa-1][0] - self.soladj[indexa+1][0])
             self.vhat.vector().axpy(-.5*self.invDt, self.Dp*self.v.vector())
     return -1.0*self.vhat.vector().array()
Exemplo n.º 8
0
    def solveadj(self, grad=False):
        self.PDE.set_adj()
        self.obsop.assemble_rhsadj(self.Bp, self.dd, self.PDE.times, self.PDE.bc)
        self.PDE.ftime = self.obsop.ftimeadj
        self.soladj,_ = self.PDE.solve()
        if grad:
            self.MGv.zero()
            if self.PDE.abc:
                self.vD.vector().zero(); self.pD.vector().zero();
                self.p1D.vector().zero(); self.p2D.vector().zero();
            index = 0
            for fwd, adj, fact in \
            zip(self.solfwd, reversed(self.soladj), self.factors):
                ttf, tta = fwd[1], adj[1]
                assert isequal(ttf, tta, 1e-16), \
                'tfwd={}, tadj={}, reldiff={}'.format(ttf, tta, abs(ttf-tta)/ttf)
                setfct(self.p, fwd[0])
                setfct(self.v, adj[0])
                self.MGv.axpy(fact, assemble(self.wkformgrad))
#                self.MGv.axpy(fact, assemble(self.wkformgrad, \
#                form_compiler_parameters={'optimize':True,\
#                'representation':'quadrature'}))
                if self.PDE.abc:
                    if index%2 == 0:
                        self.p2D.vector().axpy(1.0, self.p.vector())
                        setfct(self.pD, self.p2D)
                        #self.MGv.axpy(1.0*0.5*self.invDt, assemble(self.wkformgradD))
                        self.MGv.axpy(fact*0.5*self.invDt, assemble(self.wkformgradD))
                        setfct(self.p2D, -1.0*self.p.vector())
                        setfct(self.vD, self.v)
                    else:
                        self.p1D.vector().axpy(1.0, self.p.vector())
                        setfct(self.pD, self.p1D)
                        #self.MGv.axpy(1.0*0.5*self.invDt, assemble(self.wkformgradD))
                        self.MGv.axpy(fact*0.5*self.invDt, assemble(self.wkformgradD))
                        setfct(self.p1D, -1.0*self.p.vector())
                        setfct(self.vD, self.v)
                index += 1
            self.MGv.axpy(self.alpha_reg, self.regularization.grad(self.PDE.lam))
            self.solverM.solve(self.Gradv, self.MGv)
Exemplo n.º 9
0
 def update(self, parameters_m):
     assert not self.timestepper == None, "You need to set a time stepping method"
     # Time options:
     if parameters_m.has_key('t0'):   self.t0 = parameters_m['t0'] 
     if parameters_m.has_key('tf'):   self.tf = parameters_m['tf'] 
     if parameters_m.has_key('Dt'):   self.Dt = parameters_m['Dt'] 
     if parameters_m.has_key('t0') or parameters_m.has_key('tf') or parameters_m.has_key('Dt'):
         self.Nt = int(round((self.tf-self.t0)/self.Dt))
         self.Tf = self.t0 + self.Dt*self.Nt
         self.times = np.linspace(self.t0, self.Tf, self.Nt+1)
         assert isequal(self.times[1]-self.times[0], self.Dt, 1e-16), "Dt modified"
         self.Dt = self.times[1] - self.times[0]
         assert isequal(self.Tf, self.tf, 1e-2), "Final time differs by more than 1%"
         if not isequal(self.Tf, self.tf, 1e-12):
             print 'Final time modified from {} to {} ({}%)'.\
             format(self.tf, self.Tf, abs(self.Tf-self.tf)/self.tf)
     # Initial conditions:
     if parameters_m.has_key('u0init'):   self.u0init = parameters_m['u0init']
     if parameters_m.has_key('utinit'):   self.utinit = parameters_m['utinit']
     if parameters_m.has_key('u1init'):   self.u1init = parameters_m['u1init']
     if parameters_m.has_key('um1init'):   self.um1init = parameters_m['um1init']
     # Medium parameters:
     setfct(self.lam, parameters_m['lambda'])
     if self.verbose: print 'lambda updated '
     if self.elastic == True:    
         setfct(self.mu, parameters_m['mu'])
         if self.verbose: print 'mu updated'
     if self.verbose: print 'assemble K',
     self.K = assemble(self.weak_k)
     if self.verbose: print ' -- K assembled'
     if parameters_m.has_key('rho'):
         setfct(self.rho, parameters_m['rho'])
         # Mass matrix:
         if self.verbose: print 'rho updated\nassemble M',
         Mfull = assemble(self.weak_m)
         if self.lump:
             self.solverM = LumpedMatrixSolverS(self.V)
             self.solverM.set_operator(Mfull, self.bc)
             self.M = self.solverM
         else:
             if mpisize == 1:
                 self.solverM = LUSolver()
                 self.solverM.parameters['reuse_factorization'] = True
                 self.solverM.parameters['symmetric'] = True
             else:
                 self.solverM = KrylovSolver('cg', 'amg')
                 self.solverM.parameters['report'] = False
             self.M = Mfull
             if not self.bc == None: self.bc.apply(Mfull)
             self.solverM.set_operator(Mfull)
         if self.verbose: print ' -- M assembled'
     # Matrix D for abs BC
     if self.abc == True:    
         if self.verbose:    print 'assemble D',
         Mfull = assemble(self.weak_m)
         Dfull = assemble(self.weak_d)
         if self.lumpD:
             self.D = LumpedMatrixSolverS(self.V)
             self.D.set_operator(Dfull, None, False)
             if self.lump:
                 self.solverMplD = LumpedMatrixSolverS(self.V)
                 self.solverMplD.set_operators(Mfull, Dfull, .5*self.Dt, self.bc)
                 self.MminD = LumpedMatrixSolverS(self.V)
                 self.MminD.set_operators(Mfull, Dfull, -.5*self.Dt, self.bc)
         else:
             self.D = Dfull
         if self.verbose:    print ' -- D assembled'
     else:   self.D = 0.0