예제 #1
0
 def _do_EVD(self):
     if not self.has_done_EVD():
         V, s, UT = svd0(self._R)
         M = UT.shape[1]
         d = s**2
         d = CovMat._clip(d)
         rk = (d > 0).sum()
         d = d[:rk]
         V = UT[:rk].T
         self._assign_EVD(M, rk, d, V)
예제 #2
0
def effective_N(YR, dyR, xN, g):
    """Effective ensemble size N.

    As measured by the finite-size EnKF-N
    """
    N, Ny = YR.shape
    N1 = N - 1

    V, s, UT = svd0(YR)
    du = UT @ dyR

    eN, cL = hyperprior_coeffs(s, N, xN, g)

    def pad_rk(arr):
        return pad0(arr, min(N, Ny))

    def dgn_rk(l1):
        return pad_rk((l1 * s)**2) + N1

    # Make dual cost function (in terms of l1)
    def J(l1):
        val = np.sum(du**2/dgn_rk(l1)) \
            + eN/l1**2 \
            + cL*np.log(l1**2)
        return val

    # Derivatives (not required with minimize_scalar):
    def Jp(l1):
        val = -2*l1   * np.sum(pad_rk(s**2) * du**2/dgn_rk(l1)**2) \
            + -2*eN/l1**3 \
            + 2*cL/l1
        return val

    def Jpp(l1):
        val = 8*l1**2 * np.sum(pad_rk(s**4) * du**2/dgn_rk(l1)**3) \
            + 6*eN/l1**4 \
            + -2*cL/l1**2
        return val

    # Find inflation factor (optimize)
    l1 = Newton_m(Jp, Jpp, 1.0)
    # l1 = fmin_bfgs(J, x0=[1], gtol=1e-4, disp=0)
    # l1 = minimize_scalar(J, bracket=(sqrt(prior_mode), 1e2), tol=1e-4).x

    za = N1 / l1**2
    return za
예제 #3
0
                def local_analysis(ii):
                    """Notation:
                     - ii: inds for the state batch defining the locality
                     - jj: inds for the associated obs"""

                    # Locate local obs
                    jj, tapering = obs_taperer(ii)
                    if len(jj) == 0:
                        return E[:, ii], N1  # no update
                    Y_jj = Y[:, jj]
                    dy_jj = dy[jj]

                    # Adaptive inflation
                    za = effective_N(Y_jj, dy_jj, xN,
                                     g) if infl == '-N' else N1

                    # Taper
                    Y_jj *= sqrt(tapering)
                    dy_jj *= sqrt(tapering)

                    # Compute ETKF update
                    if len(jj) < N:
                        # SVD version
                        V, sd, _ = svd0(Y_jj)
                        d = pad0(sd**2, N) + za
                        Pw = (V * d**(-1.0)) @ V.T
                        T = (V * d**(-0.5)) @ V.T * sqrt(za)
                    else:
                        # EVD version
                        d, V = sla.eigh(Y_jj @ Y_jj.T + za * eye(N))
                        T = V @ diag(d**(-0.5)) @ V.T * sqrt(za)
                        Pw = V @ diag(d**(-1.0)) @ V.T
                    AT = T @ A[:, ii]
                    dmu = dy_jj @ Y_jj.T @ Pw @ A[:, ii]
                    Eii = mu[ii] + dmu + AT
                    return Eii, za
예제 #4
0
    def assimilate(self, HMM, xx, yy):
        # Unpack
        Dyn, Obs, chrono, X0, stats = \
            HMM.Dyn, HMM.Obs, HMM.t, HMM.X0, self.stats
        R, N, N1 = HMM.Obs.noise.C, self.N, self.N - 1

        # Init
        E = X0.sample(N)
        stats.assess(0, E=E)

        # Loop
        for k, kObs, t, dt in progbar(chrono.ticker):
            # Forecast
            E = Dyn(E, t - dt, dt)
            E = add_noise(E, dt, Dyn.noise, self.fnoise_treatm)

            # Analysis
            if kObs is not None:
                stats.assess(k, kObs, 'f', E=E)
                Eo = Obs(E, t)
                y = yy[kObs]

                mu = np.mean(E, 0)
                A = E - mu

                xo = np.mean(Eo, 0)
                Y = Eo - xo
                dy = y - xo

                V, s, UT = svd0(Y @ R.sym_sqrt_inv.T)
                du = UT @ (dy @ R.sym_sqrt_inv.T)

                def dgn_N(l1):
                    return pad0((l1 * s)**2, N) + N1

                # Adjust hyper-prior
                # xN_ = noise_level(self.xN,stats,chrono,N1,kObs,A,
                #                   locals().get('A_old',None))
                eN, cL = hyperprior_coeffs(s, N, self.xN, self.g)

                if self.dual:
                    # Make dual cost function (in terms of l1)
                    def pad_rk(arr):
                        return pad0(arr, min(N, Obs.M))

                    def dgn_rk(l1):
                        return pad_rk((l1 * s)**2) + N1

                    def J(l1):
                        val = np.sum(du**2/dgn_rk(l1)) \
                            + eN/l1**2 \
                            + cL*np.log(l1**2)
                        return val

                    # Derivatives (not required with minimize_scalar):
                    def Jp(l1):
                        val = -2*l1 * np.sum(pad_rk(s**2) * du**2/dgn_rk(l1)**2) \
                            + -2*eN/l1**3 + 2*cL/l1
                        return val

                    def Jpp(l1):
                        val = 8*l1**2 * np.sum(pad_rk(s**4) * du**2/dgn_rk(l1)**3) \
                            + 6*eN/l1**4 + -2*cL/l1**2
                        return val

                    # Find inflation factor (optimize)
                    l1 = Newton_m(Jp, Jpp, 1.0)
                    # l1 = fmin_bfgs(J, x0=[1], gtol=1e-4, disp=0)
                    # l1 = minimize_scalar(J, bracket=(sqrt(prior_mode), 1e2),
                    #                      tol=1e-4).x

                else:
                    # Primal form, in a fully linearized version.
                    def za(w):
                        return zeta_a(eN, cL, w)

                    def J(w):                        return \
                  .5*np.sum(((dy-w@Y)@R.sym_sqrt_inv.T)**2) + \
                  .5*N1*cL*np.log(eN + w@w)

                    # Derivatives (not required with fmin_bfgs):
                    def Jp(w):
                        return -Y @ R.inv @ (dy - w @ Y) + w * za(w)

                    # Jpp   = lambda w:  [email protected]@Y.T + \
                    #     za(w)*(eye(N) - 2*np.outer(w,w)/(eN + w@w))
                    # Approx: no radial-angular cross-deriv:
                    # Jpp   = lambda w:  [email protected]@Y.T + za(w)*eye(N)

                    def nvrs(w):
                        # inverse of Jpp-approx
                        return (V * (pad0(s**2, N) + za(w))**-1.0) @ V.T

                    # Find w (optimize)
                    wa = Newton_m(Jp, nvrs, zeros(N), is_inverted=True)
                    # wa   = Newton_m(Jp,Jpp ,zeros(N))
                    # wa   = fmin_bfgs(J,zeros(N),Jp,disp=0)
                    l1 = sqrt(N1 / za(wa))

                # Uncomment to revert to ETKF
                # l1 = 1.0

                # Explicitly inflate prior
                # => formulae look different from `bib.bocquet2015expanding`.
                A *= l1
                Y *= l1

                # Compute sqrt update
                Pw = (V * dgn_N(l1)**(-1.0)) @ V.T
                w = dy @ R.inv @ Y.T @ Pw
                # For the anomalies:
                if not self.Hess:
                    # Regular ETKF (i.e. sym sqrt) update (with inflation)
                    T = (V * dgn_N(l1)**(-0.5)) @ V.T * sqrt(N1)
                    # = ([email protected]@Y.T/N1 + eye(N))**(-0.5)
                else:
                    # Also include angular-radial co-dependence.
                    # Note: denominator not squared coz
                    # unlike `bib.bocquet2015expanding` we have inflated Y.
                    Hw = Y @ R.inv @ Y.T / N1 + eye(N) - 2 * np.outer(
                        w, w) / (eN + w @ w)
                    T = funm_psd(
                        Hw, lambda x: x**-.5)  # is there a sqrtm Woodbury?

                E = mu + w @ A + T @ A
                E = post_process(E, self.infl, self.rot)

                stats.infl[kObs] = l1
                stats.trHK[kObs] = ((
                    (l1 * s)**2 + N1)**(-1.0) * s**2).sum() / HMM.Ny

            stats.assess(k, kObs, E=E)
예제 #5
0
def EnKF_analysis(E, Eo, hnoise, y, upd_a, stats, kObs):
    """The EnKF analysis update, in many flavours and forms.

    The update is specified via 'upd_a'.

    Main references: `bib.sakov2008deterministic`,
    `bib.sakov2008implications`, `bib.hoteit2015mitigating`
    """
    R = hnoise.C  # Obs noise cov
    N, Nx = E.shape  # Dimensionality
    N1 = N - 1  # Ens size - 1

    mu = np.mean(E, 0)  # Ens mean
    A = E - mu  # Ens anomalies

    xo = np.mean(Eo, 0)  # Obs ens mean
    Y = Eo - xo  # Obs ens anomalies
    dy = y - xo  # Mean "innovation"

    if 'PertObs' in upd_a:
        # Uses classic, perturbed observations (Burgers'98)
        C = Y.T @ Y + R.full * N1
        D = mean0(hnoise.sample(N))
        YC = mrdiv(Y, C)
        KG = A.T @ YC
        HK = Y.T @ YC
        dE = (KG @ (y - D - Eo).T).T
        E = E + dE

    elif 'Sqrt' in upd_a:
        # Uses a symmetric square root (ETKF)
        # to deterministically transform the ensemble.

        # The various versions below differ only numerically.
        # EVD is default, but for large N use SVD version.
        if upd_a == 'Sqrt' and N > Nx:
            upd_a = 'Sqrt svd'

        if 'explicit' in upd_a:
            # Not recommended due to numerical costs and instability.
            # Implementation using inv (in ens space)
            Pw = sla.inv(Y @ R.inv @ Y.T + N1 * eye(N))
            T = sla.sqrtm(Pw) * sqrt(N1)
            HK = R.inv @ Y.T @ Pw @ Y
            # KG = R.inv @ Y.T @ Pw @ A
        elif 'svd' in upd_a:
            # Implementation using svd of Y R^{-1/2}.
            V, s, _ = svd0(Y @ R.sym_sqrt_inv.T)
            d = pad0(s**2, N) + N1
            Pw = (V * d**(-1.0)) @ V.T
            T = (V * d**(-0.5)) @ V.T * sqrt(N1)
            # docs/snippets/trHK.jpg
            trHK = np.sum((s**2 + N1)**(-1.0) * s**2)
        elif 'sS' in upd_a:
            # Same as 'svd', but with slightly different notation
            # (sometimes used by Sakov) using the normalization sqrt(N1).
            S = Y @ R.sym_sqrt_inv.T / sqrt(N1)
            V, s, _ = svd0(S)
            d = pad0(s**2, N) + 1
            Pw = (V * d**(-1.0)) @ V.T / N1  # = G/(N1)
            T = (V * d**(-0.5)) @ V.T
            # docs/snippets/trHK.jpg
            trHK = np.sum((s**2 + 1)**(-1.0) * s**2)
        else:  # 'eig' in upd_a:
            # Implementation using eig. val. decomp.
            d, V = sla.eigh(Y @ R.inv @ Y.T + N1 * eye(N))
            T = V @ diag(d**(-0.5)) @ V.T * sqrt(N1)
            Pw = V @ diag(d**(-1.0)) @ V.T
            HK = R.inv @ Y.T @ (V @ diag(d**(-1)) @ V.T) @ Y
        w = dy @ R.inv @ Y.T @ Pw
        E = mu + w @ A + T @ A

    elif 'Serial' in upd_a:
        # Observations assimilated one-at-a-time:
        inds = serial_inds(upd_a, y, R, A)
        #  Requires de-correlation:
        dy = dy @ R.sym_sqrt_inv.T
        Y = Y @ R.sym_sqrt_inv.T
        # Enhancement in the nonlinear case:
        # re-compute Y each scalar obs assim.
        # But: little benefit, model costly (?),
        # updates cannot be accumulated on S and T.

        if any(x in upd_a for x in ['Stoch', 'ESOPS', 'Var1']):
            # More details: Misc/Serial_ESOPS.py.
            for i, j in enumerate(inds):

                # Perturbation creation
                if 'ESOPS' in upd_a:
                    # "2nd-O exact perturbation sampling"
                    if i == 0:
                        # Init -- increase nullspace by 1
                        V, s, UT = svd0(A)
                        s[N - 2:] = 0
                        A = svdi(V, s, UT)
                        v = V[:, N - 2]
                    else:
                        # Orthogonalize v wrt. the new A
                        #
                        # v = Zj - Yj (from paper) requires Y==HX.
                        # Instead: mult` should be c*ones(Nx) so we can
                        # project v into ker(A) such that v@A is null.
                        mult = (v @ A) / (Yj @ A)  # noqa
                        v = v - mult[0] * Yj  # noqa
                        v /= sqrt(v @ v)
                    Zj = v * sqrt(N1)  # Standardized perturbation along v
                    Zj *= np.sign(rand() - 0.5)  # Random sign
                else:
                    # The usual stochastic perturbations.
                    Zj = mean0(randn(N))  # Un-coloured noise
                    if 'Var1' in upd_a:
                        Zj *= sqrt(N / (Zj @ Zj))

                # Select j-th obs
                Yj = Y[:, j]  # [j] obs anomalies
                dyj = dy[j]  # [j] innov mean
                DYj = Zj - Yj  # [j] innov anomalies
                DYj = DYj[:, None]  # Make 2d vertical

                # Kalman gain computation
                C = Yj @ Yj + N1  # Total obs cov
                KGx = Yj @ A / C  # KG to update state
                KGy = Yj @ Y / C  # KG to update obs

                # Updates
                A += DYj * KGx
                mu += dyj * KGx
                Y += DYj * KGy
                dy -= dyj * KGy
            E = mu + A
        else:
            # "Potter scheme", "EnSRF"
            # - EAKF's two-stage "update-regress" form yields
            #   the same *ensemble* as this.
            # - The form below may be derived as "serial ETKF",
            #   but does not yield the same
            #   ensemble as 'Sqrt' (which processes obs as a batch)
            #   -- only the same mean/cov.
            T = eye(N)
            for j in inds:
                Yj = Y[:, j]
                C = Yj @ Yj + N1
                Tj = np.outer(Yj, Yj / (C + sqrt(N1 * C)))
                T -= Tj @ T
                Y -= Tj @ Y
            w = dy @ Y.T @ T / N1
            E = mu + w @ A + T @ A

    elif 'DEnKF' == upd_a:
        # Uses "Deterministic EnKF" (sakov'08)
        C = Y.T @ Y + R.full * N1
        YC = mrdiv(Y, C)
        KG = A.T @ YC
        HK = Y.T @ YC
        E = E + KG @ dy - 0.5 * (KG @ Y.T).T

    else:
        raise KeyError("No analysis update method found: '" + upd_a + "'.")

    # Diagnostic: relative influence of observations
    if 'trHK' in locals():
        stats.trHK[kObs] = trHK / hnoise.M
    elif 'HK' in locals():
        stats.trHK[kObs] = HK.trace() / hnoise.M

    return E
예제 #6
0
    def assimilate(self, HMM, xx, yy):
        Dyn, Obs, chrono, X0, stats = \
            HMM.Dyn, HMM.Obs, HMM.t, HMM.X0, self.stats
        N, xN, Nx, Rm12, Ri = \
            self.N, self.xN, Dyn.M, Obs.noise.C.sym_sqrt_inv, Obs.noise.C.inv

        E = X0.sample(N)
        w = 1 / N * np.ones(N)

        DD = None

        stats.assess(0, E=E, w=w)

        for k, kObs, t, dt in progbar(chrono.ticker):
            E = Dyn(E, t - dt, dt)
            if Dyn.noise.C != 0:
                E += np.sqrt(dt) * (randn(N, Nx) @ Dyn.noise.C.Right)

            if kObs is not None:
                stats.assess(k, kObs, 'f', E=E, w=w)
                y = yy[kObs]
                Eo = Obs(E, t)
                wD = w.copy()

                # Importance weighting
                innovs = (y - Eo) @ Rm12.T
                w = reweight(w, innovs=innovs)

                # Resampling
                if trigger_resampling(w, self.NER, [stats, E, k, kObs]):
                    # Weighted covariance factors
                    Aw = raw_C12(E, wD)
                    Yw = raw_C12(Eo, wD)

                    # EnKF-without-pertubations update
                    if N > Nx:
                        C = Yw.T @ Yw + Obs.noise.C.full
                        KG = mrdiv(Aw.T @ Yw, C)
                        cntrs = E + (y - Eo) @ KG.T
                        Pa = Aw.T @ Aw - KG @ Yw.T @ Aw
                        P_cholU = funm_psd(Pa, np.sqrt)
                        if DD is None or not self.re_use:
                            DD = randn(N * xN, Nx)
                            chi2 = np.sum(DD**2, axis=1) * Nx / N
                            log_q = -0.5 * chi2
                    else:
                        V, sig, UT = svd0(Yw @ Rm12.T)
                        dgn = pad0(sig**2, N) + 1
                        Pw = (V * dgn**(-1.0)) @ V.T
                        cntrs = E + (y - Eo) @ Ri @ Yw.T @ Pw @ Aw
                        P_cholU = (V * dgn**(-0.5)).T @ Aw
                        # Generate N·xN random numbers from NormDist(0,1),
                        # and compute log(q(x))
                        if DD is None or not self.re_use:
                            rnk = min(Nx, N - 1)
                            DD = randn(N * xN, N)
                            chi2 = np.sum(DD**2, axis=1) * rnk / N
                            log_q = -0.5 * chi2
                        # NB: the DoF_linalg/DoF_stoch correction
                        # is only correct "on average".
                        # It is inexact "in proportion" to [email protected],
                        # where V,s,UT = tsvd(Aw).
                        # Anyways, we're computing the tsvd of Aw below,
                        # so might as well compute q(x) instead of q(xi).

                    # Duplicate
                    ED = cntrs.repeat(xN, 0)
                    wD = wD.repeat(xN) / xN

                    # Sample q
                    AD = DD @ P_cholU
                    ED = ED + AD

                    # log(prior_kernel(x))
                    s = self.Qs * auto_bandw(N, Nx)
                    innovs_pf = AD @ tinv(s * Aw)
                    # NB: Correct: innovs_pf = (ED-E_orig) @ tinv(s*Aw)
                    #     But it seems to make no difference on well-tuned performance !
                    log_pf = -0.5 * np.sum(innovs_pf**2, axis=1)

                    # log(likelihood(x))
                    innovs = (y - Obs(ED, t)) @ Rm12.T
                    log_L = -0.5 * np.sum(innovs**2, axis=1)

                    # Update weights
                    log_tot = log_L + log_pf - log_q
                    wD = reweight(wD, logL=log_tot)

                    # Resample and reduce
                    wroot = 1.0
                    while wroot < self.wroot_max:
                        idx, w = resample(wD, self.resampl, wroot=wroot, N=N)
                        dups = sum(mask_unique_of_sorted(idx))
                        if dups == 0:
                            E = ED[idx]
                            break
                        else:
                            wroot += 0.1
            stats.assess(k, kObs, 'u', E=E, w=w)
예제 #7
0
    def assimilate(self, HMM, xx, yy):
        Dyn, Obs, chrono, X0, stats, N = \
            HMM.Dyn, HMM.Obs, HMM.t, HMM.X0, self.stats, self.N
        R, KObs, N1 = HMM.Obs.noise.C, HMM.t.KObs, N - 1
        Rm12 = R.sym_sqrt_inv

        assert Dyn.noise.C == 0, (
            "Q>0 not yet supported."
            " See Sakov et al 2017: 'An iEnKF with mod. error'")

        if self.bundle:
            EPS = 1e-4  # Sakov/Boc use T=EPS*eye(N), with EPS=1e-4, but I
        else:
            EPS = 1.0  # prefer using  T=EPS*T, yielding a conditional cloud shape

        # Initial ensemble
        E = X0.sample(N)

        # Loop over DA windows (DAW).
        for kObs in progbar(np.arange(-1, KObs + self.Lag + 1)):
            kLag = kObs - self.Lag
            DAW = range(max(0, kLag + 1), min(kObs, KObs) + 1)

            # Assimilation (if ∃ "not-fully-assimlated" obs).
            if 0 <= kObs <= KObs:

                # Init iterations.
                X0, x0 = center(E)  # Decompose ensemble.
                w = np.zeros(N)  # Control vector for the mean state.
                T = np.eye(N)  # Anomalies transform matrix.
                Tinv = np.eye(N)
                # Explicit Tinv [instead of tinv(T)] allows for merging MDA code
                # with iEnKS/EnRML code, and flop savings in 'Sqrt' case.

                for iteration in np.arange(self.nIter):
                    # Reconstruct smoothed ensemble.
                    E = x0 + (w + EPS * T) @ X0
                    # Forecast.
                    for kCycle in DAW:
                        for k, t, dt in chrono.cycle(kCycle):
                            E = Dyn(E, t - dt, dt)
                    # Observe.
                    Eo = Obs(E, t)

                    # Undo the bundle scaling of ensemble.
                    if EPS != 1.0:
                        E = inflate_ens(E, 1 / EPS)
                        Eo = inflate_ens(Eo, 1 / EPS)

                    # Assess forecast stats; store {Xf, T_old} for analysis assessment.
                    if iteration == 0:
                        stats.assess(k, kObs, 'f', E=E)
                        Xf, xf = center(E)
                    T_old = T

                    # Prepare analysis.
                    y = yy[kObs]  # Get current obs.
                    Y, xo = center(Eo)  # Get obs {anomalies, mean}.
                    dy = (y - xo) @ Rm12.T  # Transform obs space.
                    Y = Y @ Rm12.T  # Transform obs space.
                    Y0 = Tinv @ Y  # "De-condition" the obs anomalies.
                    V, s, UT = svd0(Y0)  # Decompose Y0.

                    # Set "cov normlzt fctr" za ("effective ensemble size")
                    # => pre_infl^2 = (N-1)/za.
                    if self.xN is None:
                        za = N1
                    else:
                        za = zeta_a(*hyperprior_coeffs(s, N, self.xN), w)
                    if self.MDA:
                        # inflation (factor: nIter) of the ObsErrCov.
                        za *= self.nIter

                    # Post. cov (approx) of w,
                    # estimated at current iteration, raised to power.
                    def Cowp(expo):
                        return (V * (pad0(s**2, N) + za)**-expo) @ V.T

                    Cow1 = Cowp(1.0)

                    if self.MDA:  # View update as annealing (progressive assimilation).
                        Cow1 = Cow1 @ T  # apply previous update
                        dw = dy @ Y.T @ Cow1
                        if 'PertObs' in self.upd_a:  # == "ES-MDA". By Emerick/Reynolds
                            D = mean0(randn(*Y.shape)) * np.sqrt(self.nIter)
                            T -= (Y + D) @ Y.T @ Cow1
                        elif 'Sqrt' in self.upd_a:  # == "ETKF-ish". By Raanes
                            T = Cowp(0.5) * np.sqrt(za) @ T
                        elif 'Order1' in self.upd_a:  # == "DEnKF-ish". By Emerick
                            T -= 0.5 * Y @ Y.T @ Cow1
                        # Tinv = eye(N) [as initialized] coz MDA does not de-condition.

                    else:  # View update as Gauss-Newton optimzt. of log-posterior.
                        grad = Y0 @ dy - w * za  # Cost function gradient
                        dw = grad @ Cow1  # Gauss-Newton step
                        # ETKF-ish". By Bocquet/Sakov.
                        if 'Sqrt' in self.upd_a:
                            # Sqrt-transforms
                            T = Cowp(0.5) * np.sqrt(N1)
                            Tinv = Cowp(-.5) / np.sqrt(N1)
                            # Tinv saves time [vs tinv(T)] when Nx<N
                        # "EnRML". By Oliver/Chen/Raanes/Evensen/Stordal.
                        elif 'PertObs' in self.upd_a:
                            D = mean0(randn(*Y.shape)) if iteration == 0 else D
                            gradT = -(Y + D) @ Y0.T + N1 * (np.eye(N) - T)
                            T = T + gradT @ Cow1
                            # Tinv= tinv(T, threshold=N1)  # unstable
                            Tinv = sla.inv(T + 1)  # the +1 is for stability.
                        # "DEnKF-ish". By Raanes.
                        elif 'Order1' in self.upd_a:
                            # Included for completeness; does not make much sense.
                            gradT = -0.5 * Y @ Y0.T + N1 * (np.eye(N) - T)
                            T = T + gradT @ Cow1
                            Tinv = tinv(T, threshold=N1)

                    w += dw
                    if dw @ dw < self.wtol * N:
                        break

                # Assess (analysis) stats.
                # The final_increment is a linearization to
                # (i) avoid re-running the model and
                # (ii) reproduce EnKF in case nIter==1.
                final_increment = (dw + T - T_old) @ Xf
                # See docs/snippets/iEnKS_Ea.jpg.
                stats.assess(k, kObs, 'a', E=E + final_increment)
                stats.iters[kObs] = iteration + 1
                if self.xN:
                    stats.infl[kObs] = np.sqrt(N1 / za)

                # Final (smoothed) estimate of E at [kLag].
                E = x0 + (w + T) @ X0
                E = post_process(E, self.infl, self.rot)

            # Slide/shift DAW by propagating smoothed ('s') ensemble from [kLag].
            if -1 <= kLag < KObs:
                if kLag >= 0:
                    stats.assess(chrono.kkObs[kLag], kLag, 's', E=E)
                for k, t, dt in chrono.cycle(kLag + 1):
                    stats.assess(k - 1, None, 'u', E=E)
                    E = Dyn(E, t - dt, dt)

        stats.assess(k, KObs, 'us', E=E)
예제 #8
0
    def assimilate(self, HMM, xx, yy):
        Dyn, Obs, chrono, X0, stats = HMM.Dyn, HMM.Obs, HMM.t, HMM.X0, self.stats
        R, KObs = HMM.Obs.noise.C, HMM.t.KObs
        Rm12 = R.sym_sqrt_inv
        Nx = Dyn.M

        # Set background covariance. Note that it is static (compare to iEnKS).
        if self.B in (None, 'clim'):
            # Use climatological cov, ...
            B = np.cov(xx.T)  # ... estimated from truth
        elif self.B == 'eye':
            B = np.eye(Nx)
        else:
            B = self.B
        B *= self.xB
        B12 = CovMat(B).sym_sqrt

        # Init
        x = X0.mu
        stats.assess(0, mu=x, Cov=B)

        # Loop over DA windows (DAW).
        for kObs in progbar(np.arange(-1, KObs + self.Lag + 1)):
            kLag = kObs - self.Lag
            DAW = range(max(0, kLag + 1), min(kObs, KObs) + 1)

            # Assimilation (if ∃ "not-fully-assimlated" obs).
            if 0 <= kObs <= KObs:

                # Init iterations.
                w = np.zeros(Nx)  # Control vector for the mean state.
                x0 = x.copy()  # Increment reference.

                for iteration in np.arange(self.nIter):
                    # Reconstruct smoothed state.
                    x = x0 + B12 @ w
                    X = B12  # Aggregate composite TLMs onto B12
                    # Forecast.
                    for kCycle in DAW:
                        for k, t, dt in chrono.cycle(kCycle):
                            X = Dyn.linear(x, t - dt, dt) @ X
                            x = Dyn(x, t - dt, dt)

                    # Assess forecast stats
                    if iteration == 0:
                        stats.assess(k, kObs, 'f', mu=x, Cov=X @ X.T)

                    # Observe.
                    Y = Obs.linear(x, t) @ X
                    xo = Obs(x, t)

                    # Analysis prep.
                    y = yy[kObs]  # Get current obs.
                    dy = Rm12 @ (y - xo)  # Transform obs space.
                    Y = Rm12 @ Y  # Transform obs space.
                    V, s, UT = svd0(Y.T)  # Decomp for lin-alg update comps.

                    # Post. cov (approx) of w,
                    # estimated at current iteration, raised to power.
                    Cow1 = (V * (pad0(s**2, Nx) + 1)**-1.0) @ V.T

                    # Compute analysis update.
                    grad = Y.T @ dy - w  # Cost function gradient
                    dw = Cow1 @ grad  # Gauss-Newton step
                    w += dw  # Step

                    if dw @ dw < self.wtol * Nx:
                        break

                # Assess (analysis) stats.
                final_increment = X @ dw
                stats.assess(k,
                             kObs,
                             'a',
                             mu=x + final_increment,
                             Cov=X @ Cow1 @ X.T)
                stats.iters[kObs] = iteration + 1

                # Final (smoothed) estimate at [kLag].
                x = x0 + B12 @ w
                X = B12

            # Slide/shift DAW by propagating smoothed ('s') state from [kLag].
            if -1 <= kLag < KObs:
                if kLag >= 0:
                    stats.assess(chrono.kkObs[kLag],
                                 kLag,
                                 's',
                                 mu=x,
                                 Cov=X @ Cow1 @ X.T)
                for k, t, dt in chrono.cycle(kLag + 1):
                    stats.assess(k - 1, None, 'u', mu=x, Cov=Y @ Y.T)
                    X = Dyn.linear(x, t - dt, dt) @ X
                    x = Dyn(x, t - dt, dt)

        stats.assess(k, KObs, 'us', mu=x, Cov=X @ Cow1 @ X.T)