def ftimeincradj(self, tt): """ Compute rhs for incremental adjoint at time tt """ try: indexf = int(np.where(isequal(self.PDE.times, tt, 1e-14))[0]) indexa = int(np.where(isequal(self.PDE.times[::-1], tt, 1e-14))[0]) except: print 'Error in ftimeincradj at time {}'.format(tt) print np.min(np.abs(self.PDE.times-tt)) sys.exit(0) # B* B phat # assert isequal(tt, self.solincrfwd[indexf][1], 1e-16) setfct(self.phat, self.solincrfwd[indexf][0]) self.qhat.vector().zero() self.qhat.vector().axpy(1.0, self.obsop.incradj(self.phat, tt)) if not self.GN: # bhat: bhat*grad(ptilde).grad(v) # assert isequal(tt, self.soladji[indexa][1], 1e-16) setfct(self.q, self.soladji[indexa][0]) self.qhat.vector().axpy(1.0, self.C*self.q.vector()) # ahat: ahat*ptilde*q'': setfct(self.q, self.solppadji[indexa][0]) self.qhat.vector().axpy(1.0, self.get_incra(self.q.vector())) # ABC: if self.PDE.parameters['abc']: setfct(self.phat, self.solpadji[indexa][0]) self.qhat.vector().axpy(-0.5, self.Dp*self.phat.vector()) return -1.0*self.qhat.vector()
def mult(self, lamhat, y): """ mult(self, lamhat, y): return y = Hessian * lamhat inputs: y, lamhat = Function(V).vector() """ self.regularization.assemble_hessian(lamhat) setfct(self.lamhat, lamhat) self.C = assemble(self.wkformrhsincr) if self.PDE.abc: self.Dp = assemble(self.wkformDprime) # solve for phat self.PDE.set_fwd() self.PDE.ftime = self.ftimeincrfwd self.solincrfwd,_ = self.PDE.solve() # solve for vhat self.PDE.set_adj() self.PDE.ftime = self.ftimeincradj self.solincradj,_ = self.PDE.solve() # Compute Hessian*lamhat y.zero() index = 0 if self.PDE.abc: self.vD.vector().zero(); self.vhatD.vector().zero(); self.p1D.vector().zero(); self.p2D.vector().zero(); self.p1hatD.vector().zero(); self.p2hatD.vector().zero(); for fwd, adj, incrfwd, incradj, fact in \ zip(self.solfwd, reversed(self.soladj), \ self.solincrfwd, reversed(self.solincradj), self.factors): ttf, tta, ttf2 = incrfwd[1], incradj[1], fwd[1] assert isequal(ttf, tta, 1e-16), 'tfwd={}, tadj={}, reldiff={}'.\ format(ttf, tta, abs(ttf-tta)/ttf) assert isequal(ttf, ttf2, 1e-16), 'tfwd={}, tadj={}, reldiff={}'.\ format(ttf, ttf2, abs(ttf-ttf2)/ttf) setfct(self.p, fwd[0]) setfct(self.v, adj[0]) setfct(self.phat, incrfwd[0]) setfct(self.vhat, incradj[0]) y.axpy(fact, assemble(self.wkformhess)) if self.PDE.abc: if index%2 == 0: self.p2D.vector().axpy(1.0, self.p.vector()) self.p2hatD.vector().axpy(1.0, self.phat.vector()) setfct(self.dp, self.p2D) setfct(self.dph, self.p2hatD) y.axpy(1.0*0.5*self.invDt, assemble(self.wkformhessD)) setfct(self.p2D, -1.0*self.p.vector()) setfct(self.p2hatD, -1.0*self.phat.vector()) else: self.p1D.vector().axpy(1.0, self.p.vector()) self.p1hatD.vector().axpy(1.0, self.phat.vector()) setfct(self.dp, self.p1D) setfct(self.dph, self.p1hatD) y.axpy(1.0*0.5*self.invDt, assemble(self.wkformhessD)) setfct(self.p1D, -1.0*self.p.vector()) setfct(self.p1hatD, -1.0*self.phat.vector()) setfct(self.vD, self.v) setfct(self.vhatD, self.vhat) index += 1 # add regularization term y.axpy(self.alpha_reg, self.regularization.hessian(lamhat))
def solve(self): """ General solver method """ if self.timestepper == 'backward': def iterate(tt): self.iteration_backward(tt) elif self.timestepper == 'centered': def iterate(tt): self.iteration_centered(tt) else: print "Time stepper not implemented" sys.exit(1) if self.verbose: print 'Compute solution' solout = [] # Store computed solution # u0: tt = self.get_tt(0) if self.verbose: print 'Compute solution -- time {}'.format(tt) setfct(self.u0, self.u0init) solout.append([self.u0.vector().array(), tt]) # Compute u1: if not self.u1init == None: self.u1 = self.u1init else: assert(not self.utinit == None) setfct(self.rhs, self.ftime(tt)) self.rhs.vector().axpy(-self.fwdadj, self.D*self.utinit.vector()) self.rhs.vector().axpy(-1.0, self.K*self.u0.vector()) if not self.bc == None: self.bc.apply(self.rhs.vector()) self.solverM.solve(self.sol.vector(), self.rhs.vector()) setfct(self.u1, self.u0) self.u1.vector().axpy(self.fwdadj*self.Dt, self.utinit.vector()) self.u1.vector().axpy(0.5*self.Dt**2, self.sol.vector()) tt = self.get_tt(1) if self.verbose: print 'Compute solution -- time {}'.format(tt) solout.append([self.u1.vector().array(), tt]) # Iteration for nn in xrange(2, self.Nt+1): iterate(tt) # Advance to next time step setfct(self.u0, self.u1) setfct(self.u1, self.u2) tt = self.get_tt(nn) if self.verbose: print 'Compute solution -- time {}, rhs {}'.\ format(tt, np.max(np.abs(self.ftime(tt)))) solout.append([self.u1.vector().array(),tt]) if self.fwdadj > 0.0: assert isequal(tt, self.Tf, 1e-16), \ 'tt={}, Tf={}, reldiff={}'.format(tt, self.Tf, abs(tt-self.Tf)/self.Tf) else: assert isequal(tt, self.t0, 1e-16), \ 'tt={}, t0={}, reldiff={}'.format(tt, self.t0, abs(tt-self.t0)) return solout, self.computeerror()
def ftimeincrfwd(self, tt): """ Compute rhs for incremental forward at time tt """ try: index = int(np.where(isequal(self.PDE.times, tt, 1e-14))[0]) except: print 'Error in ftimeincrfwd at time {}'.format(tt) print np.min(np.abs(self.PDE.times-tt)) sys.exit(0) # bhat: bhat*grad(p).grad(qtilde) # assert isequal(tt, self.solfwdi[index][1], 1e-16) setfct(self.p, self.solfwdi[index][0]) self.q.vector().zero() self.q.vector().axpy(1.0, self.C*self.p.vector()) # ahat: ahat*p''*qtilde: setfct(self.p, self.solppfwdi[index][0]) self.q.vector().axpy(1.0, self.get_incra(self.p.vector())) # ABC: if self.PDE.parameters['abc']: setfct(self.phat, self.solpfwdi[index][0]) self.q.vector().axpy(0.5, self.Dp*self.phat.vector()) return -1.0*self.q.vector()
def ftimeincrfwd(self, tt): """ Compute rhs for incremental forward at time tt """ try: index = int(np.where(isequal(self.PDE.times, tt, 1e-14))[0]) except: print 'Error in ftimeincrfwd at time {}'.format(tt) print np.min(np.abs(self.PDE.times-tt)) sys.exit(0) # lamhat * grad(p).grad(vtilde) assert isequal(tt, self.solfwd[index][1], 1e-16) setfct(self.p, self.solfwd[index][0]) setfct(self.v, self.C*self.p.vector()) # D'.dot(p) if self.PDE.abc and index > 0: setfct(self.p, \ self.solfwd[index+1][0] - self.solfwd[index-1][0]) self.v.vector().axpy(.5*self.invDt, self.Dp*self.p.vector()) return -1.0*self.v.vector().array()
def ftimeadj(self, tt): """ Evaluate source term for adj eqn at time tt """ try: index = int(np.where(isequal(self.times, tt, 1e-14))[0]) except: print 'Error in ftimeadj at time {}'.format(tt) print np.min(np.abs(self.times-tt)) sys.exit(0) dd = self.diff[:, index] self.PtwiseObs.BTdotvec(dd, self.outvec) if not self.bcadj == None: self.bcadj.apply(self.outvec.vector()) return -1.0*self.st(tt)*self.outvec.array()
def ftimeincradj(self, tt): """ Compute rhs for incremental adjoint at time tt """ try: indexf = int(np.where(isequal(self.PDE.times, tt, 1e-14))[0]) indexa = int(np.where(isequal(self.PDE.times[::-1], tt, 1e-14))[0]) except: print 'Error in ftimeincradj at time {}'.format(tt) print np.min(np.abs(self.PDE.times-tt)) sys.exit(0) # lamhat * grad(ptilde).grad(v) assert isequal(tt, self.soladj[indexa][1], 1e-16) setfct(self.v, self.soladj[indexa][0]) setfct(self.vhat, self.C*self.v.vector()) # B* B phat assert isequal(tt, self.solincrfwd[indexf][1], 1e-16) setfct(self.phat, self.solincrfwd[indexf][0]) self.vhat.vector().axpy(1.0, self.obsop.incradj(self.phat, tt)) # D'.dot(v) if self.PDE.abc and indexa > 0: setfct(self.v, \ self.soladj[indexa-1][0] - self.soladj[indexa+1][0]) self.vhat.vector().axpy(-.5*self.invDt, self.Dp*self.v.vector()) return -1.0*self.vhat.vector().array()
def solveadj(self, grad=False): self.PDE.set_adj() self.obsop.assemble_rhsadj(self.Bp, self.dd, self.PDE.times, self.PDE.bc) self.PDE.ftime = self.obsop.ftimeadj self.soladj,_ = self.PDE.solve() if grad: self.MGv.zero() if self.PDE.abc: self.vD.vector().zero(); self.pD.vector().zero(); self.p1D.vector().zero(); self.p2D.vector().zero(); index = 0 for fwd, adj, fact in \ zip(self.solfwd, reversed(self.soladj), self.factors): ttf, tta = fwd[1], adj[1] assert isequal(ttf, tta, 1e-16), \ 'tfwd={}, tadj={}, reldiff={}'.format(ttf, tta, abs(ttf-tta)/ttf) setfct(self.p, fwd[0]) setfct(self.v, adj[0]) self.MGv.axpy(fact, assemble(self.wkformgrad)) # self.MGv.axpy(fact, assemble(self.wkformgrad, \ # form_compiler_parameters={'optimize':True,\ # 'representation':'quadrature'})) if self.PDE.abc: if index%2 == 0: self.p2D.vector().axpy(1.0, self.p.vector()) setfct(self.pD, self.p2D) #self.MGv.axpy(1.0*0.5*self.invDt, assemble(self.wkformgradD)) self.MGv.axpy(fact*0.5*self.invDt, assemble(self.wkformgradD)) setfct(self.p2D, -1.0*self.p.vector()) setfct(self.vD, self.v) else: self.p1D.vector().axpy(1.0, self.p.vector()) setfct(self.pD, self.p1D) #self.MGv.axpy(1.0*0.5*self.invDt, assemble(self.wkformgradD)) self.MGv.axpy(fact*0.5*self.invDt, assemble(self.wkformgradD)) setfct(self.p1D, -1.0*self.p.vector()) setfct(self.vD, self.v) index += 1 self.MGv.axpy(self.alpha_reg, self.regularization.grad(self.PDE.lam)) self.solverM.solve(self.Gradv, self.MGv)
def update(self, parameters_m): assert not self.timestepper == None, "You need to set a time stepping method" # Time options: if parameters_m.has_key('t0'): self.t0 = parameters_m['t0'] if parameters_m.has_key('tf'): self.tf = parameters_m['tf'] if parameters_m.has_key('Dt'): self.Dt = parameters_m['Dt'] if parameters_m.has_key('t0') or parameters_m.has_key('tf') or parameters_m.has_key('Dt'): self.Nt = int(round((self.tf-self.t0)/self.Dt)) self.Tf = self.t0 + self.Dt*self.Nt self.times = np.linspace(self.t0, self.Tf, self.Nt+1) assert isequal(self.times[1]-self.times[0], self.Dt, 1e-16), "Dt modified" self.Dt = self.times[1] - self.times[0] assert isequal(self.Tf, self.tf, 1e-2), "Final time differs by more than 1%" if not isequal(self.Tf, self.tf, 1e-12): print 'Final time modified from {} to {} ({}%)'.\ format(self.tf, self.Tf, abs(self.Tf-self.tf)/self.tf) # Initial conditions: if parameters_m.has_key('u0init'): self.u0init = parameters_m['u0init'] if parameters_m.has_key('utinit'): self.utinit = parameters_m['utinit'] if parameters_m.has_key('u1init'): self.u1init = parameters_m['u1init'] if parameters_m.has_key('um1init'): self.um1init = parameters_m['um1init'] # Medium parameters: setfct(self.lam, parameters_m['lambda']) if self.verbose: print 'lambda updated ' if self.elastic == True: setfct(self.mu, parameters_m['mu']) if self.verbose: print 'mu updated' if self.verbose: print 'assemble K', self.K = assemble(self.weak_k) if self.verbose: print ' -- K assembled' if parameters_m.has_key('rho'): setfct(self.rho, parameters_m['rho']) # Mass matrix: if self.verbose: print 'rho updated\nassemble M', Mfull = assemble(self.weak_m) if self.lump: self.solverM = LumpedMatrixSolverS(self.V) self.solverM.set_operator(Mfull, self.bc) self.M = self.solverM else: if mpisize == 1: self.solverM = LUSolver() self.solverM.parameters['reuse_factorization'] = True self.solverM.parameters['symmetric'] = True else: self.solverM = KrylovSolver('cg', 'amg') self.solverM.parameters['report'] = False self.M = Mfull if not self.bc == None: self.bc.apply(Mfull) self.solverM.set_operator(Mfull) if self.verbose: print ' -- M assembled' # Matrix D for abs BC if self.abc == True: if self.verbose: print 'assemble D', Mfull = assemble(self.weak_m) Dfull = assemble(self.weak_d) if self.lumpD: self.D = LumpedMatrixSolverS(self.V) self.D.set_operator(Dfull, None, False) if self.lump: self.solverMplD = LumpedMatrixSolverS(self.V) self.solverMplD.set_operators(Mfull, Dfull, .5*self.Dt, self.bc) self.MminD = LumpedMatrixSolverS(self.V) self.MminD.set_operators(Mfull, Dfull, -.5*self.Dt, self.bc) else: self.D = Dfull if self.verbose: print ' -- D assembled' else: self.D = 0.0