def assimilate(self, HMM, xx, yy): E = zeros((HMM.tseq.K + 1, self.N, HMM.Dyn.M)) Ef = E.copy() E[0] = HMM.X0.sample(self.N) # Forward pass for k, ko, t, dt in progbar(HMM.tseq.ticker): E[k] = HMM.Dyn(E[k - 1], t - dt, dt) E[k] = add_noise(E[k], dt, HMM.Dyn.noise, self.fnoise_treatm) Ef[k] = E[k] if ko is not None: self.stats.assess(k, ko, 'f', E=E[k]) Eo = HMM.Obs(E[k], t) y = yy[ko] E[k] = EnKF_analysis(E[k], Eo, HMM.Obs.noise, y, self.upd_a, self.stats, ko) E[k] = post_process(E[k], self.infl, self.rot) self.stats.assess(k, ko, 'a', E=E[k]) # Backward pass for k in progbar(range(HMM.tseq.K)[::-1]): A = center(E[k])[0] Af = center(Ef[k + 1])[0] J = tinv(Af) @ A J *= self.DeCorr E[k] += (E[k + 1] - Ef[k + 1]) @ J for k, ko, _, _ in progbar(HMM.tseq.ticker, desc='Assessing'): self.stats.assess(k, ko, 'u', E=E[k]) if ko is not None: self.stats.assess(k, ko, 's', E=E[k])
def sqrt_core(): T = np.nan # cause error if used Qa12 = np.nan # cause error if used A2 = A.copy() # Instead of using (the implicitly nonlocal) A, # which changes A outside as well. NB: This is a bug in Datum! if N <= Nx: Ainv = tinv(A2.T) Qa12 = Ainv@Q12 T = funm_psd(eye(N) + dt*(N-1)*([email protected]), sqrt) A2 = T@A2 else: # "Left-multiplying" form P = A2.T @ A2 / (N-1) L = funm_psd(eye(Nx) + dt*mrdiv(Q, P), sqrt) A2 = A2 @ L.T E = mu + A2 return E, T, Qa12
def assimilate(self, HMM, xx, yy): Dyn, Obs, chrono, X0, stats = \ HMM.Dyn, HMM.Obs, HMM.t, HMM.X0, self.stats E = zeros((chrono.K + 1, self.N, Dyn.M)) Ef = E.copy() E[0] = X0.sample(self.N) # Forward pass for k, kObs, t, dt in progbar(chrono.ticker): E[k] = Dyn(E[k - 1], t - dt, dt) E[k] = add_noise(E[k], dt, Dyn.noise, self.fnoise_treatm) Ef[k] = E[k] if kObs is not None: stats.assess(k, kObs, 'f', E=E[k]) Eo = Obs(E[k], t) y = yy[kObs] E[k] = EnKF_analysis(E[k], Eo, Obs.noise, y, self.upd_a, stats, kObs) E[k] = post_process(E[k], self.infl, self.rot) stats.assess(k, kObs, 'a', E=E[k]) # Backward pass for k in progbar(range(chrono.K)[::-1]): A = center(E[k])[0] Af = center(Ef[k + 1])[0] J = tinv(Af) @ A J *= self.cntr E[k] += (E[k + 1] - Ef[k + 1]) @ J for k, kObs, _, _ in progbar(chrono.ticker, desc='Assessing'): stats.assess(k, kObs, 'u', E=E[k]) if kObs is not None: stats.assess(k, kObs, 's', E=E[k])
def assimilate(self, HMM, xx, yy): Dyn, Obs, chrono, X0, stats, N = \ HMM.Dyn, HMM.Obs, HMM.t, HMM.X0, self.stats, self.N R, KObs, N1 = HMM.Obs.noise.C, HMM.t.KObs, N - 1 Rm12 = R.sym_sqrt_inv assert Dyn.noise.C == 0, ( "Q>0 not yet supported." " See Sakov et al 2017: 'An iEnKF with mod. error'") if self.bundle: EPS = 1e-4 # Sakov/Boc use T=EPS*eye(N), with EPS=1e-4, but I ... else: EPS = 1.0 # ... prefer using T=EPS*T, yielding a conditional cloud shape # Initial ensemble E = X0.sample(N) # Loop over DA windows (DAW). for kObs in progbar(np.arange(-1, KObs + self.Lag + 1)): kLag = kObs - self.Lag DAW = range(max(0, kLag + 1), min(kObs, KObs) + 1) # Assimilation (if ∃ "not-fully-assimlated" obs). if 0 <= kObs <= KObs: # Init iterations. X0, x0 = center(E) # Decompose ensemble. w = np.zeros(N) # Control vector for the mean state. T = np.eye(N) # Anomalies transform matrix. Tinv = np.eye(N) # Explicit Tinv [instead of tinv(T)] allows for merging MDA code # with iEnKS/EnRML code, and flop savings in 'Sqrt' case. for iteration in np.arange(self.nIter): # Reconstruct smoothed ensemble. E = x0 + (w + EPS * T) @ X0 # Forecast. for kCycle in DAW: for k, t, dt in chrono.cycle(kCycle): # noqa E = Dyn(E, t - dt, dt) # Observe. Eo = Obs(E, t) # Undo the bundle scaling of ensemble. if EPS != 1.0: E = inflate_ens(E, 1 / EPS) Eo = inflate_ens(Eo, 1 / EPS) # Assess forecast stats; store {Xf, T_old} for analysis assessment. if iteration == 0: stats.assess(k, kObs, 'f', E=E) Xf, xf = center(E) T_old = T # Prepare analysis. y = yy[kObs] # Get current obs. Y, xo = center(Eo) # Get obs {anomalies, mean}. dy = (y - xo) @ Rm12.T # Transform obs space. Y = Y @ Rm12.T # Transform obs space. Y0 = Tinv @ Y # "De-condition" the obs anomalies. V, s, UT = svd0(Y0) # Decompose Y0. # Set "cov normlzt fctr" za ("effective ensemble size") # => pre_infl^2 = (N-1)/za. if self.xN is None: za = N1 else: za = zeta_a(*hyperprior_coeffs(s, N, self.xN), w) if self.MDA: # inflation (factor: nIter) of the ObsErrCov. za *= self.nIter # Post. cov (approx) of w, # estimated at current iteration, raised to power. def Cowp(expo): return (V * (pad0(s**2, N) + za)**-expo) @ V.T Cow1 = Cowp(1.0) if self.MDA: # View update as annealing (progressive assimilation). Cow1 = Cow1 @ T # apply previous update dw = dy @ Y.T @ Cow1 if 'PertObs' in self.upd_a: # == "ES-MDA". By Emerick/Reynolds D = mean0(np.random.randn(*Y.shape)) * np.sqrt( self.nIter) T -= (Y + D) @ Y.T @ Cow1 elif 'Sqrt' in self.upd_a: # == "ETKF-ish". By Raanes T = Cowp(0.5) * np.sqrt(za) @ T elif 'Order1' in self.upd_a: # == "DEnKF-ish". By Emerick T -= 0.5 * Y @ Y.T @ Cow1 # Tinv = eye(N) [as initialized] coz MDA does not de-condition. else: # View update as Gauss-Newton optimzt. of log-posterior. grad = Y0 @ dy - w * za # Cost function gradient dw = grad @ Cow1 # Gauss-Newton step # ETKF-ish". By Bocquet/Sakov. if 'Sqrt' in self.upd_a: # Sqrt-transforms T = Cowp(0.5) * np.sqrt(N1) Tinv = Cowp(-.5) / np.sqrt(N1) # Tinv saves time [vs tinv(T)] when Nx<N # "EnRML". By Oliver/Chen/Raanes/Evensen/Stordal. elif 'PertObs' in self.upd_a: D = mean0(np.random.randn(*Y.shape)) \ if iteration == 0 else D gradT = -(Y + D) @ Y0.T + N1 * (np.eye(N) - T) T = T + gradT @ Cow1 # Tinv= tinv(T, threshold=N1) # unstable Tinv = sla.inv(T + 1) # the +1 is for stability. # "DEnKF-ish". By Raanes. elif 'Order1' in self.upd_a: # Included for completeness; does not make much sense. gradT = -0.5 * Y @ Y0.T + N1 * (np.eye(N) - T) T = T + gradT @ Cow1 Tinv = tinv(T, threshold=N1) w += dw if dw @ dw < self.wtol * N: break # Assess (analysis) stats. # The final_increment is a linearization to # (i) avoid re-running the model and # (ii) reproduce EnKF in case nIter==1. final_increment = (dw + T - T_old) @ Xf # See docs/snippets/iEnKS_Ea.jpg. stats.assess(k, kObs, 'a', E=E + final_increment) stats.iters[kObs] = iteration + 1 if self.xN: stats.infl[kObs] = np.sqrt(N1 / za) # Final (smoothed) estimate of E at [kLag]. E = x0 + (w + T) @ X0 E = post_process(E, self.infl, self.rot) # Slide/shift DAW by propagating smoothed ('s') ensemble from [kLag]. if -1 <= kLag < KObs: if kLag >= 0: stats.assess(chrono.kkObs[kLag], kLag, 's', E=E) for k, t, dt in chrono.cycle(kLag + 1): stats.assess(k - 1, None, 'u', E=E) E = Dyn(E, t - dt, dt) stats.assess(k, KObs, 'us', E=E)
def iEnKS_update(upd_a, E, DAW, HMM, stats, EPS, y, time, Rm12, xN, MDA, threshold): """Perform the iEnKS update. This implementation includes several flavours and forms, specified by `upd_a` (See `iEnKS`) """ # distribute variable k, kObs, t = time nIter, wtol = threshold N, Nx = E.shape # Init iterations. N1 = N-1 X0, x0 = center(E) # Decompose ensemble. w = np.zeros(N) # Control vector for the mean state. T = np.eye(N) # Anomalies transform matrix. Tinv = np.eye(N) # Explicit Tinv [instead of tinv(T)] allows for merging MDA code # with iEnKS/EnRML code, and flop savings in 'Sqrt' case. for iteration in np.arange(nIter): # Reconstruct smoothed ensemble. E = x0 + (w + EPS*T)@X0 # Forecast. for kCycle in DAW: for k, t, dt in HMM.t.cycle(kCycle): # noqa E = HMM.Dyn(E, t-dt, dt) # Observe. Eo = HMM.Obs(E, t) # Undo the bundle scaling of ensemble. if EPS != 1.0: E = inflate_ens(E, 1/EPS) Eo = inflate_ens(Eo, 1/EPS) # Assess forecast stats; store {Xf, T_old} for analysis assessment. if iteration == 0: stats.assess(k, kObs, 'f', E=E) Xf, xf = center(E) T_old = T # Prepare analysis. Y, xo = center(Eo) # Get obs {anomalies, mean}. dy = (y - xo) @ Rm12.T # Transform obs space. Y = Y @ Rm12.T # Transform obs space. Y0 = Tinv @ Y # "De-condition" the obs anomalies. V, s, UT = svd0(Y0) # Decompose Y0. # Set "cov normlzt fctr" za ("effective ensemble size") # => pre_infl^2 = (N-1)/za. if xN is None: za = N1 else: za = zeta_a(*hyperprior_coeffs(s, N, xN), w) if MDA: # inflation (factor: nIter) of the ObsErrCov. za *= nIter # Post. cov (approx) of w, # estimated at current iteration, raised to power. def Cowp(expo): return (V * (pad0(s**2, N) + za)**-expo) @ V.T Cow1 = Cowp(1.0) if MDA: # View update as annealing (progressive assimilation). Cow1 = Cow1 @ T # apply previous update dw = dy @ Y.T @ Cow1 if 'PertObs' in upd_a: # == "ES-MDA". By Emerick/Reynolds D = mean0(np.random.randn(*Y.shape)) * np.sqrt(nIter) T -= (Y + D) @ Y.T @ Cow1 elif 'Sqrt' in upd_a: # == "ETKF-ish". By Raanes T = Cowp(0.5) * np.sqrt(za) @ T elif 'Order1' in upd_a: # == "DEnKF-ish". By Emerick T -= 0.5 * Y @ Y.T @ Cow1 # Tinv = eye(N) [as initialized] coz MDA does not de-condition. else: # View update as Gauss-Newton optimzt. of log-posterior. grad = Y0@dy - w*za # Cost function gradient dw = grad@Cow1 # Gauss-Newton step # ETKF-ish". By Bocquet/Sakov. if 'Sqrt' in upd_a: # Sqrt-transforms T = Cowp(0.5) * np.sqrt(N1) Tinv = Cowp(-.5) / np.sqrt(N1) # Tinv saves time [vs tinv(T)] when Nx<N # "EnRML". By Oliver/Chen/Raanes/Evensen/Stordal. elif 'PertObs' in upd_a: D = mean0(np.random.randn(*Y.shape)) \ if iteration == 0 else D gradT = -(Y+D)@Y0.T + N1*(np.eye(N) - T) T = T + gradT@Cow1 # Tinv= tinv(T, threshold=N1) # unstable Tinv = sla.inv(T+1) # the +1 is for stability. # "DEnKF-ish". By Raanes. elif 'Order1' in upd_a: # Included for completeness; does not make much sense. gradT = -0.5*[email protected] + N1*(np.eye(N) - T) T = T + gradT@Cow1 Tinv = tinv(T, threshold=N1) w += dw if dw@dw < wtol*N: break # Assess (analysis) stats. # The final_increment is a linearization to # (i) avoid re-running the model and # (ii) reproduce EnKF in case nIter==1. final_increment = (dw+T-T_old)@Xf # See docs/snippets/iEnKS_Ea.jpg. stats.assess(k, kObs, 'a', E=E+final_increment) stats.iters[kObs] = iteration+1 if xN: stats.infl[kObs] = np.sqrt(N1/za) # Final (smoothed) estimate of E at [kLag]. E = x0 + (w+T)@X0 return E
def assimilate(self, HMM, xx, yy): N, xN, Nx = self.N, self.xN, HMM.Dyn.M Rm12, Ri = HMM.Obs.noise.C.sym_sqrt_inv, HMM.Obs.noise.C.inv E = HMM.X0.sample(N) w = 1/N*np.ones(N) DD = None self.stats.assess(0, E=E, w=w) for k, ko, t, dt in progbar(HMM.tseq.ticker): E = HMM.Dyn(E, t-dt, dt) if HMM.Dyn.noise.C != 0: E += np.sqrt(dt)*(rnd.randn(N, Nx)@HMM.Dyn.noise.C.Right) if ko is not None: self.stats.assess(k, ko, 'f', E=E, w=w) y = yy[ko] Eo = HMM.Obs(E, t) wD = w.copy() # Importance weighting innovs = (y - Eo) @ Rm12.T w = reweight(w, innovs=innovs) # Resampling if trigger_resampling(w, self.NER, [self.stats, E, k, ko]): # Weighted covariance factors Aw = raw_C12(E, wD) Yw = raw_C12(Eo, wD) # EnKF-without-pertubations update if N > Nx: C = Yw.T @ Yw + HMM.Obs.noise.C.full KG = mrdiv(Aw.T@Yw, C) cntrs = E + (y-Eo)@KG.T Pa = Aw.T@Aw - [email protected]@Aw P_cholU = funm_psd(Pa, np.sqrt) if DD is None or not self.re_use: DD = rnd.randn(N*xN, Nx) chi2 = np.sum(DD**2, axis=1) * Nx/N log_q = -0.5 * chi2 else: V, sig, UT = svd0(Yw @ Rm12.T) dgn = pad0(sig**2, N) + 1 Pw = (V * dgn**(-1.0)) @ V.T cntrs = E + (y-Eo)@[email protected]@Pw@Aw P_cholU = (V*dgn**(-0.5)).T @ Aw # Generate N·xN random numbers from NormDist(0,1), # and compute log(q(x)) if DD is None or not self.re_use: rnk = min(Nx, N-1) DD = rnd.randn(N*xN, N) chi2 = np.sum(DD**2, axis=1) * rnk/N log_q = -0.5 * chi2 # NB: the DoF_linalg/DoF_stoch correction # is only correct "on average". # It is inexact "in proportion" to [email protected], # where V,s,UT = tsvd(Aw). # Anyways, we're computing the tsvd of Aw below, # so might as well compute q(x) instead of q(xi). # Duplicate ED = cntrs.repeat(xN, 0) wD = wD.repeat(xN) / xN # Sample q AD = DD@P_cholU ED = ED + AD # log(prior_kernel(x)) s = self.Qs*auto_bandw(N, Nx) innovs_pf = AD @ tinv(s*Aw) # NB: Correct: innovs_pf = (ED-E_orig) @ tinv(s*Aw) # But it seems to make no difference on well-tuned performance ! log_pf = -0.5 * np.sum(innovs_pf**2, axis=1) # log(likelihood(x)) innovs = (y - HMM.Obs(ED, t)) @ Rm12.T log_L = -0.5 * np.sum(innovs**2, axis=1) # Update weights log_tot = log_L + log_pf - log_q wD = reweight(wD, logL=log_tot) # Resample and reduce wroot = 1.0 while wroot < self.wroot_max: idx, w = resample(wD, self.resampl, wroot=wroot, N=N) dups = sum(mask_unique_of_sorted(idx)) if dups == 0: E = ED[idx] break else: wroot += 0.1 self.stats.assess(k, ko, 'u', E=E, w=w)