def add_deriv_four(x, uVar, K, n):
    """ Power of a term when x-derivatives are involved;
    Need to multiply by nu outside
    x: (1D) array along which the fft is performed
    uVar: the quantity to multiply and derive
    K: the (2D) array of wave numbers
    n: the order of derivative (hdegs)
    """
    FFF = tf.obfft(x, uVar, 2)
    k = tf.k_of_x(x)
    dnudxn = np.real(tf.obifft(k, FFF * (1j * K)**n, 2))
    return uVar * dnudxn
dtKEtot = integrate_vol(x[p2:p1+1], z[p3:], dtKE[:, p3:, p2:p1+1])
del dtKE

dEdt = dtKEtot  # +dtPEtot

tic = print_elpsd('dEdt', tic)
# %%

#
#
# Compute dissipation terms
#
#

k = tf.k_of_x(x)
K, M = np.meshgrid(k, z)
dffs = {'nuh2': nuh, 'nuh4': -nuh2, 'nuz2': nuz, 'nuz4': -nuz2}
dissip = 0.*Var['u']
counter = 1
tic_tmp = time.time()
print('Start Dissipation')

for ii in ['u', 'v', 'w']:
    """ First compute and add each term. Once done, integrate the result.
    Looks like more loops than necessary, but I find it easier to read. """
    for jj in ['2', '4']:
        dzNm = 'd' + jj + ii + 'dz' + jj
        dissip += dffs['nuz'+jj] * add_deriv_load(openderivs, Var[ii], dzNm)
        dissip += dffs['nuh'+jj] * add_deriv_four(x, Var[ii], K, int(jj))
        counter = print_int_elpsd('Dissipation', counter, 6, tic)
        dset[fld] = file['tasks'][fld]
        # if fld == 'PV':
        #     clim[fld] = 2e-9  # np.abs(np.amin(dset[fld]))
        #     # print(clim[fld])
        # else:
        #     clim[fld] = np.amax(np.abs(dset[fld][:]))  # global extremum
​
    _, X, Z = retrieve_3d(file['tasks'][tasks[0]], 0, start)  # 0 is time axis
    X = np.delete(X, (0), axis=0)
    X = np.delete(X, (0), axis=1)
    Z = np.delete(Z, (0), axis=0)
    Z = np.delete(Z, (0), axis=1)
    # c_obj = plt.contour(X, Z, buoy, 10)
    # blevs = c_obj.levels  # that way I will always plot the same
​
    k = tf.k_of_x(X[0, :])
    K, _ = np.meshgrid(k, Z[:, 0])
​
    return clim, dset, blevs, K, X, X/1000, Z
​
​
def print_one_task(dadata, buoy, Xp, Z, lim, blevs, task, dpi, iteration,
                   output, title, im):
    fg, ax = plt.subplots(1, 1, figsize=(8, 4))
    cbax = ax.pcolormesh(Xp[:, im:], Z[:, im:], dadata[:, im:],
                         cmap='RdBu_r', vmin=-lim, vmax=lim)
    ax.contour(Xp[:, im:], Z[:, im:], buoy[:, im:],
               blevs, colors='k', linewidths=0.5, linestyles='-')
​
    ax.set_ylabel('$z$ (m)')
    ax.set_xlabel('$x$ (km)')
def MEBTerms(s, setup, plst):  # Mechanichal Energy Budget Terms

    Var = {}  # this dictionary will be the recipient of the loaded variables

    tic = time.time()
    tic_strt = time.time()

    # these dictionaries will be the recipients of the file IDs
    fID_XZ = {}
    fID_d = {}
    fID_td = {}

    # locate the processor containing the bottom of the box
    npb = plst[2] // (s.nz // s.np)

    ntint = s.nites // s.sk_it + 1 - plst[3]
    nxint = plst[0] - plst[1] + 1
    if s.np > 1:
        nzint = (s.np - npb) * (s.nz // s.np) + 1
    else:
        nzint = s.nnz - plst[2]

    print(nzint)

    ts_shp = (ntint, )  # shape of each time series
    PHIp = np.zeros(ts_shp)
    PHIm = np.zeros(ts_shp)
    PHIz = np.zeros(ts_shp)
    AdPEtot = np.zeros(ts_shp)
    GSPtot = np.zeros(ts_shp)
    LSPtot = np.zeros(ts_shp)
    LCtot = np.zeros(ts_shp)
    dtEtot = np.zeros(ts_shp)
    LDztot = np.zeros(ts_shp)
    HDztot = np.zeros(ts_shp)
    LDhtot = np.zeros(ts_shp)
    HDhtot = np.zeros(ts_shp)

    to_int_shp = (ntint, nzint, nxint)
    phi_xpls = np.zeros((ntint, nzint))
    phi_xmns = np.zeros((ntint, nzint))
    phi_zmns = np.zeros((ntint, nxint))
    AdPE = np.zeros(to_int_shp)
    GSP = np.zeros(to_int_shp)
    LSP = np.zeros(to_int_shp)
    LC = np.zeros(to_int_shp)
    dtE = np.zeros(to_int_shp)
    LDz = np.zeros(to_int_shp)
    HDz = np.zeros(to_int_shp)
    LDh = np.zeros((ntint, nzint, s.nx))
    HDh = np.zeros((ntint, nzint, s.nx))

    k = tf.k_of_x(s.x)

    # open and read for each file
    for pr in range(npb, s.np):

        print(' ')
        print('---------------- Proc #{0:02d}/{1:02d} ----------------'.format(
            pr - npb + 1, s.np - npb))
        print(' ')

        prs = '{0:02d}'.format(pr)  # string version of pr, 2 digits

        if s.np > 1:
            kb = pr * s.nz // s.np  # bottom z index of one proc: proc# * pts/proc
            kt = (pr + 1) * s.nz // s.np  # top z index of one proc
            kbi = (pr - npb) * s.nz // s.np  # bottom z index for integrands
            kti = (pr - npb + 1) * s.nz // s.np  # top z index for integrands
            if pr == s.np - 1:
                kt += 1
                kti += 1
        else:  # then these numbers have to be the CV bounds
            kb = plst[2]
            kt = s.nnz
            kbi = 0  # bottom z index of one proc for integrands
            kti = nzint  # top z index of one proc for integrands

        if s.np == 1:  # string version of pr, as suffix of file name
            npstr = ''
        else:
            npstr = '_0' + prs

        # bottom boundary
        if s.np > 1:  # bottom of a CV is the bottom of a processor domain
            kbott = 0
        else:  # bottom of CV is the actual bottom of CV as prescribed
            kbott = plst[2]

        # %% Loading variables ------------------------------------------------
        print('Start loading, proc #' + prs)

        fID_XZ[prs] = nc.Dataset(setup + '/2D/XZ' + npstr + '.nc', 'r')
        fID_d[prs] = nc.Dataset(setup + '/2D/derivs' + npstr + '.nc', 'r')
        fID_td[prs] = nc.Dataset(setup + '/2D/t_derivs' + npstr + '.nc', 'r')

        # print(setup + '/2D/XZ' + npstr + '.nc')
        # print(fID_d[prs].variables.keys())

        for vv in ['u', 'v', 'w', 's1']:
            Var[vv] = (fID_XZ[prs].variables[vv + 'Var'][plst[3]:,
                                                         kbott:, :].copy())
            tic = print_elpsd('loading ' + vv + ', proc #' + prs, tic)

        bVar = -s.g / s.rho_0 * (Var['s1'][:, :, :] - s.RLoc[kb:kt, :])
        bN2 = bVar / s.N2Loc[kb:kt, :]

        if pr == npb:
            t = fID_XZ[prs].variables['tVar'][plst[3]:].copy()

        # tic = print_elpsd('Loading Vars', tic)

        # %% ------------------------------------------------------------------
        # Flux calculations

        if pr == npb:  # works no matter how many procs there are
            pVar_zmns = (fID_d[prs].variables['pVar'][plst[3]:, kbott,
                                                      plst[1]:plst[0] +
                                                      1].copy())
            phi_zmns = -(pVar_zmns + 0.5 * bVar[:, 0, plst[1]:plst[0] + 1] *
                         bN2[:, 0, plst[1]:plst[0] + 1]
                         ) * Var['w'][:, 0, plst[1]:plst[0] + 1]
            # the - sign above is because of orientation of outward normal

        # right boundary
        pVar_xpls = (fID_d[prs].variables['pVar'][plst[3]:, kbott:,
                                                  plst[0]].copy())
        phi_xpls[:, kbi:kti] = (pVar_xpls + 0.5 * bVar[:, :, plst[0]] *
                                bN2[:, :, plst[0]]) * Var['u'][:, :, plst[0]]
        # u, v, etc. were only loaded from plst[3]

        # left boundary
        pVar_xmns = (fID_d[prs].variables['pVar'][plst[3]:, kbott:,
                                                  plst[1]].copy())
        phi_xmns[:, kbi:kti] = -(pVar_xmns + 0.5 * bVar[:, :, plst[1]] *
                                 bN2[:, :, plst[1]]) * Var['u'][:, :, plst[1]]
        # the - sign above is because of orientation of outward normal

        # Zero flux occurs through top surface

        tic = print_elpsd('Fluxes', tic)

        # %% ------------------------------------------------------------------
        # Advection of PE

        # cf. Notability sheet
        dN2dx = s.S2Loc[kb:kt, plst[1]:plst[0] + 1] / (s.dltH * s.Lz)
        dN2dz = s.d2Rdz2[kb:kt, plst[1]:plst[0] + 1] * (-s.g / s.rho_0)
        AdPE[:, kbi:kti, :] = (0.5 * bN2[:, :, plst[1]:plst[0] + 1]**2 *
                               (Var['u'][:, :, plst[1]:plst[0] + 1] * dN2dx +
                                Var['w'][:, :, plst[1]:plst[0] + 1] * dN2dz))

        tic = print_elpsd('PE advection', tic)

        # %% ------------------------------------------------------------------
        # Geostrophic Shear Production (GSP)

        # GSP = 0.*Var['u']
        GSP[:, kbi:kti, :] = (Var['v'][:, :, plst[1]:plst[0] + 1] *
                              Var['w'][:, :, plst[1]:plst[0] + 1] *
                              s.S2Loc[kb:kt, plst[1]:plst[0] + 1] / s.f)

        tic = print_elpsd('GSP', tic)

        # %% ------------------------------------------------------------------
        # Lateral Conversion (LC)

        LC[:, kbi:kti, :] = (Var['u'][:, :, plst[1]:plst[0] + 1] *
                             bN2[:, :, plst[1]:plst[0] + 1] *
                             s.S2Loc[kb:kt, plst[1]:plst[0] + 1])

        tic = print_elpsd('LC', tic)

        # %% ------------------------------------------------------------------
        # Lateral Shear Production (LSP)

        LSP[:, kbi:kti, :] = (Var['v'][:, :, plst[1]:plst[0] + 1] *
                              Var['u'][:, :, plst[1]:plst[0] + 1] * s.f *
                              s.RoGLoc[kb:kt, plst[1]:plst[0] + 1])

        tic = print_elpsd('LSP', tic)

        # %% ------------------------------------------------------------------
        # Compute dissipation terms and dEdt

        K, M = np.meshgrid(k, s.z[kb:kt])

        counter = 1
        print('Proc #' + prs + ', start dissipation')

        for ii in ['u', 'v', 'w']:
            # dzNm = 'd2'+ii+'dz2'

            dotProd = Var[ii][:, :, plst[1]:plst[0] + 1]
            to_four_deriv = Var[ii]

            dtE[:, kbi:kti, :] += add_deriv_load(fID_td[prs], dotProd,
                                                 'd' + ii + 'dt', kbott, plst)

            LDz[:,
                kbi:kti, :] += add_deriv_load(fID_d[prs], dotProd,
                                              'd2' + ii + 'dz2', kbott, plst)
            LDh[:, kbi:kti, :] += add_deriv_four(s.x, to_four_deriv, K, 2)
            counter = print_int_elpsd('Dissipation', counter, 8, tic)

            HDz[:,
                kbi:kti, :] -= add_deriv_load(fID_d[prs], dotProd,
                                              'd4' + ii + 'dz4', kbott, plst)
            HDh[:, kbi:kti, :] -= add_deriv_four(s.x, to_four_deriv, K, 4)
            counter = print_int_elpsd('Dissipation', counter, 8, tic)

        dtE[:, kbi:kti, :] += (add_deriv_load(
            fID_td[prs], -s.g * bN2[:, :, plst[1]:plst[0] + 1] / s.rho_0,
            'ds1dt', kbott, plst))

        LDz[:, kbi:kti, :] += (add_deriv_load(
            fID_d[prs], -s.g * bN2[:, :, plst[1]:plst[0] + 1] / s.rho_0,
            'd2s1dz2', kbott, plst))
        LDh[:,
            kbi:kti, :] += add_deriv_four(s.x, bVar, K, 2) / s.N2Loc[kb:kt, :]
        counter = print_int_elpsd('Dissipation', counter, 8, tic)

        HDz[:, kbi:kti, :] -= (add_deriv_load(
            fID_d[prs], -s.g * bN2[:, :, plst[1]:plst[0] + 1] / s.rho_0,
            'd4s1dz4', kbott, plst))
        HDh[:,
            kbi:kti, :] -= add_deriv_four(s.x, bVar, K, 4) / s.N2Loc[kb:kt, :]
        counter = print_int_elpsd('Dissipation', counter, 8, tic)

        # dissipation of the thermal wind
        # d2TWdz2 = s.S2Loc / (s.f*s.dltH*s.Lz)
        # LDh += - s.g / s.rho_0 * s.d2Rdx2[kb:kt, :] * bN2[:, :, :]
        #               # + s.d2TWdx2[kb:kt, :] * Var['v'][:, :, :])
        # LDz += - s.g / s.rho_0 * s.d2Rdz2[kb:kt, :] * bN2[:, :, :]
        # + d2TWdz2[kb:kt, :] * Var['v'][:, :, :]

        tic = print_elpsd('Dissipation and dEdt', tic)

        # %% ------------------------------------------------------------------
        fID_XZ[prs].close()
        fID_d[prs].close()
        fID_td[prs].close()

    # %%
    # %% Integrations
    # integrate and add the three paths of energy flux (taking CW path)
    PHIm = np.trapz(phi_xmns, dx=s.dz, axis=1)
    PHIp = np.trapz(phi_xpls, dx=s.dz, axis=1)
    PHIz = np.trapz(phi_zmns, dx=s.dx, axis=1)

    # volume integrations
    AdPEtot = int_vol(s.dx, s.dz, AdPE)
    GSPtot = int_vol(s.dx, s.dz, GSP)
    LCtot = int_vol(s.dx, s.dz, LC)
    LSPtot = int_vol(s.dx, s.dz, LSP)
    dtEtot = int_vol(s.dx, s.dz, dtE)
    LDztot = int_vol(s.dx, s.dz, s.nuz * LDz)
    HDztot = int_vol(s.dx, s.dz, s.nuz2 * HDz)
    LDhtot = int_vol(s.dx, s.dz, s.nuh * LDh[:, :, plst[1]:plst[0] + 1])
    HDhtot = int_vol(s.dx, s.dz, s.nuh2 * HDh[:, :, plst[1]:plst[0] + 1])

    tic = print_elpsd('Integrations', tic)

    # %%
    # Final step: place everything in a dictionary and save it
    os.system('rm {0}/MEBterms.npz'.format(setup))
    np.savez('{0}/MEBterms.npz'.format(setup),
             t=t,
             pUpls=PHIp,
             pUmns=PHIm,
             pWmns=PHIz,
             AdPE=AdPEtot,
             GSP=GSPtot,
             LSP=LSPtot,
             LC=LCtot,
             dEdt=dtEtot,
             LDz=LDztot,
             HDz=HDztot,
             LDh=LDhtot,
             HDh=HDhtot)

    dic_of_terms = {
        't': t,
        'pUpls': PHIp,
        'pUmns': PHIm,
        'pWmns': PHIz,
        'AdPE': AdPEtot,
        'GSP': GSPtot,
        'LSP': LSPtot,
        'LC': LCtot,
        'dEdt': dtEtot,
        'LDz': LDztot,
        'HDz': HDztot,
        'LDh': LDhtot,
        'HDh': HDhtot
    }

    # %%

    time_tot = time.time() - tic_strt
    mins = int(time_tot / 60)
    secs = int(time_tot - 60 * mins)

    print(' ')
    print('          ***********')
    print(' ')
    print('All done! Elapsed time = {0:3d}:{0:02d}'.format(mins, secs))
    print(' ')
    print('          ***********')
    print(' ')

    # %%
    return dic_of_terms
def KEBTerms(s, setup, plst):

    Var = {}
    print('Start loading.')

    tic = time.time()

    if s.np == 1:
        fID_XZ = nc.Dataset(setup + '/2D/XZ.nc', 'r')
        fID_d = nc.Dataset(setup + '/2D/derivs.nc', 'r')
        fID_td = nc.Dataset(setup + '/2D/t_derivs.nc', 'r')

        t = fID_XZ.variables['tVar'][:].copy()
        for vv in ['u', 'v', 'w', 's1']:
            Var[vv] = fID_XZ.variables[vv + 'Var'][:, :, :].copy()
            tic = print_elpsd('loading of ' + vv, tic)
    else:
        fID_XZ = {}
        fID_d = {}
        fID_td = {}
        for ii in range(s.np):
            fID_XZ[str(ii)] = nc.Dataset(
                '{0}/2D/XZ_{1:03d}.nc'.format(setup, ii), 'r')
            fID_d[str(ii)] = nc.Dataset(
                '{0}/2D/derivs_{1:03d}.nc'.format(setup, ii), 'r')
            fID_td[str(ii)] = nc.Dataset(
                '{0}/2D/t_derivs_{1:03d}.nc'.format(setup, ii), 'r')

        t = fID_XZ['0'].variables['tVar'][:].copy()
        for vv in ['u', 'v', 'w', 's1']:
            utmp = fID_XZ['0'].variables[vv + 'Var'][:, :, :].copy()
            for jj in range(1, s.np):
                to_cat = fID_XZ[str(jj)].variables[vv + 'Var'][:, :, :].copy()
                utmp = np.concatenate((utmp, to_cat), axis=1)
            Var[vv] = utmp

    # tic = print_elpsd('Loading Vars', tic)

    # %%
    # Conversion to PE calculation

    print('Shape of RLoc = {}'.format(s.RLoc.shape))

    conv = s.g * (Var['s1'][:, :, :] - s.RLoc[:, :]) * Var['w'] / s.rho_0

    # conv = - Var['b']*Var['w']

    Ctot = integrate_vol(s.x[plst[1]:plst[0] + 1], s.z[plst[2]:],
                         conv[:, plst[2]:, plst[1]:plst[0] + 1])

    tic = print_elpsd('Conversion term', tic)

    # %%
    # Flux calculations

    if s.np > 1:
        raise NameError("We don't have a procedure for the flux with np>1")

    pVar_xplus = fID_d.variables['pVar'][:, :, plst[0]].copy()
    phi_xplus = pVar_xplus * Var['u'][:, :, plst[0]]  # [kg/s^3]or[W/m^2]

    pVar_xminus = fID_d.variables['pVar'][:, :, plst[1]].copy()
    phi_xminus = -pVar_xminus * Var['u'][:, :, plst[1]]

    pVar_zminus = fID_d.variables['pVar'][:, plst[2], :].copy()
    phi_zminus = -pVar_zminus * Var['w'][:, plst[2], :]

    # Zero flux occurs through top surface

    # integrate and add the three paths of energy flux (taking CW path)
    PHI1 = np.trapz(phi_xplus[:, plst[2]:], s.z[plst[2]:], axis=1)
    PHI2 = np.trapz(phi_xminus[:, plst[2]:], s.z[plst[2]:], axis=1)
    PHI3 = np.trapz(phi_zminus[:, plst[1]:plst[0] + 1],
                    s.x[plst[1]:plst[0] + 1],
                    axis=1)

    tic = print_elpsd('Fluxes', tic)

    # %%
    # Geostrophic Shear Production (GSP)

    GSP = 0. * Var['u']
    GSP[:, :, :] = s.S2Loc[:, :] * Var['v'][:, :, :] * Var[
        'w'][:, :, :] / s.f  # +
    #                         Var['u'][:, :, :]*b[:, :, :]/N2[:, :])
    # NG just removed the contribution from PE

    GSPtot = integrate_vol(s.x[plst[1]:plst[0] + 1], s.z[plst[2]:],
                           GSP[:, plst[2]:, plst[1]:plst[0] + 1])

    tic = print_elpsd('GSP', tic)

    # %%
    # Lateral Shear Production (LSP)

    LSP = 0. * Var['u']
    LSP[:, :, :] = s.f * s.RoGLoc[:, :] * Var['v'][:, :, :] * Var['u'][:, :, :]

    LSPtot = integrate_vol(s.x[plst[1]:plst[0] + 1], s.z[plst[2]:],
                           LSP[:, plst[2]:, plst[1]:plst[0] + 1])

    tic = print_elpsd('LSP', tic)

    # %% NG removed PE
    # Change in KE and NOT PE with time

    # Calculating dKE/dt at every grid point within CV grid
    dtKE = 0. * Var['u']
    counter = 1

    for ii in ['u', 'v', 'w']:
        dtKE += add_deriv_load(fID_td, Var[ii], 'd{}dt'.format(ii))
        counter = print_int_elpsd('dKEdt', counter, 3, tic)

    dtKEtot = integrate_vol(s.x[plst[1]:plst[0] + 1], s.z[plst[2]:],
                            dtKE[:, plst[2]:, plst[1]:plst[0] + 1])

    dEdt = dtKEtot  # +dtPEtot

    tic = print_elpsd('dEdt', tic)
    # %%
    # Compute dissipation terms

    k = tf.k_of_x(s.x)
    K, M = np.meshgrid(k, s.z)
    # dffs = {'nuh2': s.nuh, 'nuh4': -s.nuh2, 'nuz2': s.nuz, 'nuz4': -s.nuz2}

    ldiss = 0. * Var['u']  # Laplacian dissipation
    hdiss = 0. * Var['u']  # hyperviscous dissipation
    counter = 1
    print('Start Dissipation')

    for ii in ['u', 'v', 'w']:
        # dzNm = 'd2'+ii+'dz2'
        ldiss += s.nuz * add_deriv_load(fID_d, Var[ii], 'd2' + ii + 'dz2')
        counter = print_int_elpsd('Dissipation', counter, 12, tic)
        ldiss += s.nuh * add_deriv_four(s.x, Var[ii], K, 2)
        counter = print_int_elpsd('Dissipation', counter, 12, tic)
        hdiss += -s.nuz2 * add_deriv_load(fID_d, Var[ii], 'd4' + ii + 'dz4')
        counter = print_int_elpsd('Dissipation', counter, 12, tic)
        hdiss += -s.nuh2 * add_deriv_four(s.x, Var[ii], K, 4)
        counter = print_int_elpsd('Dissipation', counter, 12, tic)

    # dissipation of the thermal wind
    # d2ThWdz2 = s.S2Loc / (s.f*s.dltH*s.Lz)
    # dissip = dissip[:, :, :]  # + (nuz*d2ThWdz2[:, :])*Var['v'][:, :, :]

    LDtot = integrate_vol(s.x[plst[1]:plst[0] + 1], s.z[plst[2]:],
                          ldiss[:, plst[2]:, plst[1]:plst[0] + 1])
    HDtot = integrate_vol(s.x[plst[1]:plst[0] + 1], s.z[plst[2]:],
                          hdiss[:, plst[2]:, plst[1]:plst[0] + 1])

    tic = print_elpsd('Dissipation', tic)

    # %%
    # Final step: place everything in a dictionary and save it
    os.system('rm {0}/KEBterms.npz'.format(setup))
    np.savez('{0}/KEBterms.npz'.format(setup),
             t=t,
             pUpls=PHI1,
             pUmns=PHI2,
             pWmns=PHI3,
             GSP=GSPtot,
             LSP=LSPtot,
             dKEdt=dEdt,
             LDiss=LDtot,
             HDiss=HDtot,
             toPE=Ctot)

    dic_of_terms = {
        't': t,
        'pUpls': PHI1,
        'pUmns': PHI2,
        'pWmns': PHI3,
        'GSP': GSPtot,
        'LSP': LSPtot,
        'dKEdt': dEdt,
        'LDiss': LDtot,
        'HDiss': HDtot,
        'toPE': Ctot
    }

    return dic_of_terms

    # %%

    #
    #
    # Full accounting of Energy Budget -> checking for LHS=RHS and all energy
    # is conserved
    #
    #

    #    Ebudget = -PHI - GSPtot - LSPtot - dEdt + LDtot + HDtot - Ctot
    #
    #    plt.figure()
    #    plt.plot(t, -PHI1, 'k', label='x+ Flux')
    #    plt.plot(t, -PHI2, 'k--', label='x- Flux')
    #    plt.plot(t, -PHI3, 'k+-', label='z- Flux')
    #    plt.plot(t, -GSPtot, 'm', label='GSP')
    #    plt.plot(t, -LSPtot, 'm--', label='LSP')
    #    plt.plot(t, -dEdt, 'b', label='dEdt')
    #    plt.plot(t, Dtot, 'g', label='Dissipation')
    #    plt.plot(t, -Ctot, 'c', label='-Conversion')
    #    plt.plot(t, Ebudget, 'r--', label='Cumulative Budget', lw=2)
    #
    #    plt.xlabel("Time [$s$]")
    #    plt.ylabel("Integrated Energy Per Unit Density [$m^4/s^3$]")
    #    plt.title("Full Accounting of Energy Budget")
    #    plt.grid()
    #    plt.legend(loc='upper left')
    #    plt.show()

    if s.np == 1:
        fID_XZ.close()
        fID_d.close()
        fID_td.close()
    else:
        for ii in range(s.np):
            fID_XZ[str(ii)].close()
            fID_d[str(ii)].close()
            fID_td[str(ii)].close()

    time_tot = time.time() - tic_strt
    mins = int(time_tot / 60)
    secs = int(time_tot - 60 * mins)

    print(' ')
    print('          ******')
    print(' ')
    print('All done! Elapsed time = {0:3d}:{0:02d}'.format(mins, secs))
    print(' ')
    print('          ******')
    print(' ')
Exemple #6
0
def read_one_set(root, fl):
    print(root)
    ProbParPath = '{0}codes_etc/input/problem_params'.format(root)

    if not os.path.exists(ProbParPath):
        print('    => does not exist')
        return

    # %% ----------------------------------------------------------------------
    # problem_params is relatively stable in terms of what is where.
    # Therefore, it is possible to assume that the line numbers won't change

    fid = open(ProbParPath)
    lines = fid.readlines()
    fid.close()

    d2pkl = {
        'nx': int(lines[1][:15]),
        'ny': int(lines[2][:15]),
        'nz': int(lines[3][:15]),
        'nL': int(lines[4][:15]),
        'nsp': int(lines[5][:15]),
        'dt': float(replace_d_by_e(lines[6][:15])),
        't_end': float(replace_d_by_e(lines[8][:15])),
        'BCs': lines[9][:15].strip(),
        'Lx': float(replace_d_by_e(lines[10][:15])),
        'Ly': float(replace_d_by_e(lines[11][:15])),
        'Lz': float(replace_d_by_e(lines[12][:15])),
        'g': float(replace_d_by_e(lines[16][:15])),
        'f': float(replace_d_by_e(lines[17][:15])),
        'rho_0': float(replace_d_by_e(lines[18][:15])),
        'nuh': float(replace_d_by_e(lines[19][:15])),
        'nuz': float(replace_d_by_e(lines[20][:15])),
        'kappah': float(replace_d_by_e(lines[21][:15])),
        'kappaz': float(replace_d_by_e(lines[22][:15])),
        'hdeg': int(lines[25][:15]),
        'nuh2': float(replace_d_by_e(lines[30][:15])),
        'nuz2': float(replace_d_by_e(lines[31][:15])),
        'kappah2': float(replace_d_by_e(lines[32][:15])),
        'kappaz2': float(replace_d_by_e(lines[33][:15])),
        'hdeg2': int(lines[36][:15]),
        'nlflag': int(lines[37][:15])
    }

    # %% ----------------------------------------------------------------------
    # Contrary to problem_params, the user_params module in user_module.f90
    # changes often. Better use a more flexible approach.

    UsrModPath = '{0}codes_etc/input/user_module.f90'.format(root)
    fid = open(UsrModPath)
    # lines = fid.readlines()

    # um = {} # user_module dictionary
    for line in fid:
        els = line.split()
        if line.count(':: '):
            idx_eq = els.index('=')
            var = els[idx_eq - 1]
            val = els[idx_eq + 1]
            if els[0].count('real(kind=8)'):
                d2pkl[var] = float(replace_d_by_e(val))
            elif els[0].count('integer'):
                d2pkl[var] = int(val)
            elif els[0].count('logical'):
                if val.lower().count('true'):
                    d2pkl[var] = True
                else:
                    d2pkl[var] = False
        elif line.count('end module user_params'):
            break

    fid.close()

    # %% ----------------------------------------------------------------------
    # Reading io_params

    fid = open('{0}codes_etc/input/io_params'.format(root))
    lines = fid.readlines()
    fid.close()

    idx = 0
    ln_cnt = 0
    line = lines[0]
    while line[:line.index('!')].strip() != fl:
        ln_cnt += 1
        line = lines[ln_cnt]

    idx = ln_cnt
    d2pkl.update({
        'RecType': lines[idx + 1][:lines[idx + 1].index('!')].strip(),
        'sk_it': int(lines[idx + 2][:lines[idx + 2].index('!')])
    })

    if d2pkl['RecType'] == 'new':
        d2pkl['iteID'] = '_0000000000'
    elif d2pkl['RecType'] == 'append':
        d2pkl['iteID'] = ''
    else:
        raise NameError('No RecType? io_params not read.')

    # Number of procs
    ListFl = os.listdir('{0}2D/'.format(root))
    ListFl_tmp = os.listdir('{0}2D/'.format(root))
    # print(ListFl_tmp)
    for item in ListFl:
        # print(item)
        # print(item[:len(fl)])
        # print(fl)
        if item[:len(fl)] != fl:
            ListFl_tmp.remove(item)
            # print(ListFl_tmp)

    d2pkl['np'] = len(
        [nm for nm in ListFl_tmp if nm.count('{0}'.format(d2pkl['iteID']))])

    # print('Number of processors = {}'.format(d2pkl['np']))

    if d2pkl['np'] == 1:
        d2pkl['npID'] = ''
    elif d2pkl['np'] > 1:
        d2pkl['npID'] = '_000'

    # Number of iterations
    d2pkl['nites'] = int(np.floor(d2pkl['t_end'] / d2pkl['dt'])) + 1

    # %% ----------------------------------------------------------------------
    # Coordinates

    rbf = nc.netcdf_file(
        '{0}2D/{1}{2}{3}.nc'.format(root, fl, d2pkl['iteID'], d2pkl['npID']),
        'r')
    d2pkl.update({
        'x': rbf.variables['xVar'][:].copy(),
        'z': rbf.variables['zVar'][:].copy(),
        't': rbf.variables['tVar'][:].copy(),
        'V': d2pkl['Lx'] * d2pkl['Lz'],
        'dx': d2pkl['Lx'] / d2pkl['nx'],
        'dy': d2pkl['Ly'] / d2pkl['ny'],
        'dz': d2pkl['Lz'] / d2pkl['nz']
    })
    if d2pkl['ny'] > 1:
        d2pkl['y'] = rbf.variables['yVar'][:].copy()
        d2pkl['V'] *= d2pkl['Ly']
    rbf.close()

    d2pkl['xs'] = d2pkl['x'] - d2pkl['Lx'] * 0.5

    # vertical coordinate
    if d2pkl['np'] > 1:
        for ii in range(d2pkl['np'])[1:]:
            pID = '_{0:03d}'.format(ii)
            rbf = nc.netcdf_file(
                '{0}2D/{1}{2}.nc'.format(root, 'rhobar_0000000000', pID), 'r')
            d2pkl['z'] = np.concatenate(
                (d2pkl['z'], rbf.variables['zVar'][:].copy()))
            rbf.close()

    d2pkl['zs'] = d2pkl['z'] - d2pkl['Lz']

    d2pkl['nnz'] = len(d2pkl['z'])

    # grids
    d2pkl['xz'], d2pkl['zx'] = np.meshgrid(d2pkl['x'], d2pkl['z'])
    if d2pkl['ny'] > 1:
        d2pkl['xy'], d2pkl['yx'] = np.meshgrid(d2pkl['x'], d2pkl['y'])
        d2pkl['yz'], d2pkl['zy'] = np.meshgrid(d2pkl['y'], d2pkl['z'])

    d2pkl['zxs'] = d2pkl['zx'] - d2pkl['Lz']
    d2pkl['xzs'] = d2pkl['xz'] - d2pkl['xctr0']
    d2pkl['xz0'] = d2pkl['xz'] - d2pkl['Lx'] + d2pkl['xctr0']

    # print('Lz = {}'.format(d2pkl['Lz']))

    # time coordinate
    if d2pkl['RecType'] == 'new':
        for ii in range(d2pkl['nites'])[1:]:
            pID = '_{0:010d}'.format(ii)
            rbf = nc.netcdf_file(
                '{0}2D/{1}{2}{3}.nc'.format(root, fl, pID, d2pkl['npID']), 'r')
            d2pkl['t'] = np.concatenate(
                (d2pkl['t'], rbf.variables['tVar'][:].copy()))
            rbf.close()

    # spectral coordinates
    d2pkl['k'] = tf.k_of_x(d2pkl['x'])
    d2pkl['dk'] = d2pkl['k'][1] - d2pkl['k'][0]

    if d2pkl['BCs'] == 'zslip':
        d2pkl['m'] = tf.k_of_x(d2pkl['z'][:-1])
    elif d2pkl['BCs'] == 'zperiodic':
        d2pkl['m'] = tf.k_of_x(d2pkl['z'])
    d2pkl['dm'] = d2pkl['m'][1] - d2pkl['m'][0]

    d2pkl['km'], d2pkl['mk'] = np.meshgrid(d2pkl['k'], d2pkl['m'])

    if d2pkl['ny'] > 1:
        d2pkl['l'] = tf.k_of_x(d2pkl['y'])
        d2pkl['dl'] = d2pkl['l'][1] - d2pkl['l'][0]

        d2pkl['kl'], d2pkl['lk'] = np.meshgrid(d2pkl['k'], d2pkl['l'])
        d2pkl['lm'], d2pkl['ml'] = np.meshgrid(d2pkl['l'], d2pkl['m'])

    # Misc:
    Omega = 2. * np.pi / 86400
    Rearth = 6375.e3
    lat = np.arcsin(0.5 * d2pkl['f'] / Omega)
    if d2pkl['beta_y_n']:
        d2pkl['beta'] = -2. * Omega * np.cos(lat) / Rearth
    else:
        d2pkl['beta'] = 0.
    d2pkl['f_local'] = d2pkl['f'] + d2pkl['beta'] * d2pkl['xzs']
    d2pkl['f0'] = d2pkl['f'] + d2pkl['beta'] * (d2pkl['Lx'] * 0.5 -
                                                d2pkl['xctr0'])

    # %% ----------------------------------------------------------------------
    # Geostrophic flow and stratification
    # Not ready for 3D

    if d2pkl['dltH'] > 0. and d2pkl['FrG'] > 0.:
        alpha1 = d2pkl['dltH'] * d2pkl['Lz'] * (
            1. - np.exp(-d2pkl['totdepth'] / d2pkl['dltH'] / d2pkl['Lz']))
        alpha2 = np.exp(d2pkl['zxs'] / d2pkl['Lz'] / d2pkl['dltH'])
        alpha3 = np.exp(d2pkl['zxs'] / d2pkl['dltH'] /
                        d2pkl['Lz']) / d2pkl['dltH'] / d2pkl['Lz']
        alpha4 = (np.exp(d2pkl['zxs'] / d2pkl['dltH'] / d2pkl['Lz']) -
                  np.exp(-d2pkl['totdepth'] / d2pkl['dltH'] /
                         d2pkl['Lz'])) * d2pkl['dltH'] * d2pkl['Lz']
    elif d2pkl['dltH'] <= 0. and d2pkl['FrG'] > 0.:
        alpha1 = d2pkl['totdepth']
        alpha2 = 1.
        alpha3 = 0.
        alpha4 = d2pkl['zxs'] + d2pkl['totdepth']
    else:
        alpha1 = 0.
        alpha2 = 0.
        alpha3 = 0.
        alpha4 = 0.

    Bhat = (8 * 3.**(-1.5) * alpha1 * d2pkl['FrG']**2 * d2pkl['up_N2'] /
            d2pkl['RoG'])
    if d2pkl['FrG'] > 0.:
        dltS2 = 0.5 * Bhat / (abs(d2pkl['f0']) * np.sqrt(d2pkl['up_N2']) *
                              d2pkl['FrG'])
    else:
        dltS2 = 1.e8

    # Main front:
    Gamma0 = 0.5 * (1. - np.tanh(d2pkl['xz0'] / dltS2))
    dGamma0dx = -(1. + np.tanh(d2pkl['xz0'] / dltS2)) * Gamma0 / dltS2
    d2Gamma0dx2 = 2. * np.tanh(d2pkl['xz0'] / dltS2) * Gamma0 * (
        1. + np.tanh(d2pkl['xz0'] / dltS2)) / dltS2**2

    # print('max(x) = {}'.format(d2pkl['xz'].max()))
    # print('max(x0) = {}'.format(d2pkl['xz0'].max()))

    # Secondary front ensuring x-periodicity of N^2
    Gamma1 = -1. / np.cosh(d2pkl['xz'] / d2pkl['chi1'])**2
    dGamma1dx = -2. * np.tanh(
        d2pkl['xz'] / d2pkl['chi1']) * Gamma1 / d2pkl['chi1']
    d2Gamma1dx2 = 2. * Gamma1 * (
        2. - 3. / np.cosh(d2pkl['xz'] / d2pkl['chi1'])**2) / d2pkl['chi1']**2

    # Total front-induced buoyancy perturbation shape
    Gamma = Gamma0 + Gamma1
    dGammadx = dGamma0dx + dGamma1dx
    d2Gammadx2 = d2Gamma0dx2 + d2Gamma1dx2

    # Non-frontal stratification
    if d2pkl['z0_Gill'] < 0.:
        N21 = d2pkl['up_N2']
        BN1 = d2pkl['up_N2'] * d2pkl['zxs']
        dN21dz = 0.
    else:
        N21 = d2pkl['up_N2'] * (1 - d2pkl['zxs'] / d2pkl['z0_Gill'])**(-2)
        BN1 = d2pkl['up_N2'] * d2pkl['zxs'] * (
            1 - d2pkl['zxs'] / d2pkl['z0_Gill'])**(-1)
        dN21dz = d2pkl['up_N2'] * 2. * (1 - d2pkl['zxs'] / d2pkl['z0_Gill'])**(
            -3) / d2pkl['z0_Gill']

    # Geostrophic Flow dictionary
    d2pkl.update({
        'RLoc': (BN1 + Bhat * Gamma * alpha2) * (-d2pkl['rho_0'] / d2pkl['g']),
        'N2Loc':
        N21 + Bhat * Gamma * alpha3,
        'S2Loc':
        Bhat * dGammadx * alpha2,
        'TWLoc':
        Bhat * dGammadx * alpha4 / d2pkl['f_local'],
        'RoGLoc':
        Bhat * (d2pkl['f_local'] * d2Gammadx2 - d2pkl['beta'] * dGammadx) *
        alpha4 / d2pkl['f_local']**3,
        'd2Rdx2':
        Bhat * d2Gammadx2 * alpha2 * (-d2pkl['rho_0'] / d2pkl['g']),
        'd2Rdz2':
        (Bhat * Gamma * alpha3 / (d2pkl['dltH'] * d2pkl['Lz']) + dN21dz) *
        (-d2pkl['rho_0'] / d2pkl['g'])
    })

    #    plt.figure()
    #    plt.pcolormesh(d2pkl['xz'], d2pkl['zx'], d2pkl['RLoc'])
    #    plt.colorbar()
    #    plt.show()

    # %% ----------------------------------------------------------------------
    # We can pickle that!

    os.system('rm {0}spec_params.pkl'.format(root))
    pkl_file = open('{0}spec_params.pkl'.format(root), 'wb')
    pickle.dump(d2pkl, pkl_file, -1)
    pkl_file.close()
def read_one_set(root, fl):
    print(root)
    ProbParPath = '{0}codes_etc/input/problem_params'.format(root)

    if not os.path.exists(ProbParPath):
        print('    => does not exist')
        return

    # %% ----------------------------------------------------------------------
    # problem_params is relatively stable in terms of what is where.
    # Therefore, it is possible to assume that the line numbers won't change

    fid = open(ProbParPath)
    lines = fid.readlines()
    fid.close()

    d2pkl = {
        'nx': int(lines[1][:15]),
        'ny': int(lines[2][:15]),
        'nz': int(lines[3][:15]),
        'nL': int(lines[4][:15]),
        'nsp': int(lines[5][:15]),
        'dt': float(replace_d_by_e(lines[6][:15])),
        't_end': float(replace_d_by_e(lines[8][:15])),
        'BCs': lines[9][:15].strip(),
        'Lx': float(replace_d_by_e(lines[10][:15])),
        'Ly': float(replace_d_by_e(lines[11][:15])),
        'Lz': float(replace_d_by_e(lines[12][:15])),
        'g': float(replace_d_by_e(lines[16][:15])),
        'f': float(replace_d_by_e(lines[17][:15])),
        'rho_0': float(replace_d_by_e(lines[18][:15])),
        'nuh': float(replace_d_by_e(lines[19][:15])),
        'nuz': float(replace_d_by_e(lines[20][:15])),
        'kappah': float(replace_d_by_e(lines[21][:15])),
        'kappaz': float(replace_d_by_e(lines[22][:15])),
        'hdeg': int(lines[25][:15]),
        'nuh2': float(replace_d_by_e(lines[30][:15])),
        'nuz2': float(replace_d_by_e(lines[31][:15])),
        'kappah2': float(replace_d_by_e(lines[32][:15])),
        'kappaz2': float(replace_d_by_e(lines[33][:15])),
        'hdeg2': int(lines[36][:15]),
        'nlflag': int(lines[37][:15])}

    # %% ----------------------------------------------------------------------
    # Contrary to problem_params, the user_params module in user_module.f90
    # changes often. Better use a more flexible approach.

    UsrModPath = '{0}codes_etc/input/user_module.f90'.format(root)
    fid = open(UsrModPath)
    # lines = fid.readlines()

    # um = {} # user_module dictionary
    for line in fid:
        els = line.split()
        if line.count(':: '):
            idx_eq = els.index('=')
            var = els[idx_eq-1]
            val = els[idx_eq+1]
            if els[0].count('real(kind=8)'):
                d2pkl[var] = float(replace_d_by_e(val))
            elif els[0].count('integer'):
                d2pkl[var] = int(val)
            elif els[0].count('logical'):
                if val.lower().count('true'):
                    d2pkl[var] = True
                else:
                    d2pkl[var] = False
        elif line.count('end module user_params'):
            break

    fid.close()

    # %% ----------------------------------------------------------------------
    # Reading io_params

    fid = open('{0}codes_etc/input/io_params'.format(root))
    lines = fid.readlines()
    fid.close()

    idx = 0
    ln_cnt = 0
    line = lines[0]
    while line[:line.index('!')].strip() != fl:
        ln_cnt += 1
        line = lines[ln_cnt]

    idx = ln_cnt
    d2pkl.update({
        'RecType': lines[idx+1][:lines[idx+1].index('!')].strip(),
        'sk_it': int(lines[idx+2][:lines[idx+2].index('!')])})

    if d2pkl['RecType'] == 'new':
        d2pkl['iteID'] = '_0000000000'
    elif d2pkl['RecType'] == 'append':
        d2pkl['iteID'] = ''
    else:
        raise NameError('No RecType? io_params not read.')

    # Number of procs
    ListFl = os.listdir('{0}2D/'.format(root))
    ListFl_tmp = os.listdir('{0}2D/'.format(root))
    # print(ListFl_tmp)
    for item in ListFl:
        # print(item)
        # print(item[:len(fl)])
        # print(fl)
        if item[:len(fl)] != fl:
            ListFl_tmp.remove(item)
            # print(ListFl_tmp)

    d2pkl['np'] = len(
        [nm for nm in ListFl_tmp if nm.count('{0}'.format(d2pkl['iteID']))])

    # print('Number of processors = {}'.format(d2pkl['np']))

    # print("d2pkl['np'] = {}".format(d2pkl['np']))
    if d2pkl['np'] == 1:
        d2pkl['npID'] = ''
    elif d2pkl['np'] > 1:
        d2pkl['npID'] = '_000'

    # Number of iterations
    d2pkl['nites'] = int(np.floor(d2pkl['t_end']/d2pkl['dt'])) + 1

    # %% ----------------------------------------------------------------------
    # Coordinates

    rbf = nc.netcdf_file('{0}2D/{1}{2}{3}.nc'.format(root, fl, d2pkl['iteID'],
                         d2pkl['npID']), 'r')
    d2pkl.update({
        'x': rbf.variables['xVar'][:].copy(),
        'z': rbf.variables['zVar'][:].copy(),
        't': rbf.variables['tVar'][:].copy(),
        'V': d2pkl['Lx']*d2pkl['Lz'],
        'dx': d2pkl['Lx']/d2pkl['nx'],
        'dy': d2pkl['Ly']/d2pkl['ny'],
        'dz': d2pkl['Lz']/d2pkl['nz']})
    if d2pkl['ny'] > 1:
        d2pkl['y'] = rbf.variables['yVar'][:].copy()
        d2pkl['V'] *= d2pkl['Ly']
    rbf.close()

    d2pkl['xs'] = d2pkl['x'] - d2pkl['Lx']*0.5

    # vertical coordinate
    if d2pkl['np'] > 1:
        for ii in range(d2pkl['np'])[1:]:
            pID = '_{0:03d}'.format(ii)
            rbf = nc.netcdf_file('{0}2D/{1}{2}.nc'.format(
                root, 'rhobar_0000000000', pID), 'r')
            d2pkl['z'] = np.concatenate(
                (d2pkl['z'], rbf.variables['zVar'][:].copy()))
            rbf.close()

    d2pkl['zs'] = d2pkl['z'] - d2pkl['Lz']

    d2pkl['nnz'] = len(d2pkl['z'])

    # grids
    d2pkl['xz'], d2pkl['zx'] = np.meshgrid(d2pkl['x'], d2pkl['z'])
    if d2pkl['ny'] > 1:
        d2pkl['xy'], d2pkl['yx'] = np.meshgrid(d2pkl['x'], d2pkl['y'])
        d2pkl['yz'], d2pkl['zy'] = np.meshgrid(d2pkl['y'], d2pkl['z'])

    d2pkl['zxs'] = d2pkl['zx'] - d2pkl['Lz']
    d2pkl['xzs'] = d2pkl['xz'] - d2pkl['Lx']
    d2pkl['xz0'] = d2pkl['xz'] - d2pkl['Lx'] + d2pkl['xctr0']

    # print('Lz = {}'.format(d2pkl['Lz']))

    # time coordinate
    if d2pkl['RecType'] == 'new':
        for ii in range(d2pkl['nites'])[1:]:
            pID = '_{0:010d}'.format(ii)
            rbf = nc.netcdf_file(
                '{0}2D/{1}{2}{3}.nc'.format(root, fl, pID, d2pkl['npID']), 'r')
            d2pkl['t'] = np.concatenate(
                (d2pkl['t'], rbf.variables['tVar'][:].copy()))
            rbf.close()

    # spectral coordinates
    d2pkl['k'] = tf.k_of_x(d2pkl['x'])
    d2pkl['dk'] = d2pkl['k'][1] - d2pkl['k'][0]

    if d2pkl['BCs'] == 'zslip':
        d2pkl['m'] = tf.k_of_x(d2pkl['z'][:-1])
    elif d2pkl['BCs'] == 'zperiodic':
        d2pkl['m'] = tf.k_of_x(d2pkl['z'])
    d2pkl['dm'] = d2pkl['m'][1] - d2pkl['m'][0]

    d2pkl['km'], d2pkl['mk'] = np.meshgrid(d2pkl['k'], d2pkl['m'])

    if d2pkl['ny'] > 1:
        d2pkl['l'] = tf.k_of_x(d2pkl['y'])
        d2pkl['dl'] = d2pkl['l'][1] - d2pkl['l'][0]

        d2pkl['kl'], d2pkl['lk'] = np.meshgrid(d2pkl['k'], d2pkl['l'])
        d2pkl['lm'], d2pkl['ml'] = np.meshgrid(d2pkl['l'], d2pkl['m'])

    # Misc:
    Omega = 2.*np.pi/86400
    Rearth = 6375.e3
    lat = np.arcsin(0.5*d2pkl['f']/Omega)
    if d2pkl['beta_y_n']:
        d2pkl['beta'] = -2.*Omega*np.cos(lat)/Rearth
    else:
        d2pkl['beta'] = 0.
    d2pkl['f_local'] = d2pkl['f'] + d2pkl['beta']*d2pkl['xzs']
    d2pkl['f0'] = d2pkl['f'] + d2pkl['beta']*(d2pkl['Lx']*0.5 -
                                              d2pkl['xctr0'])

    # %% ----------------------------------------------------------------------
    # Geostrophic flow and stratification

    # Useful values for the jet
    a, dummy = gamma_exp3(0.)

    # Vertical Stratification:
    N210 = (d2pkl['Phi0']*d2pkl['f'])**2

    if d2pkl['z0_Gill'] > 0.:
        N21 = N210 / (1. - d2pkl['zxs'] / d2pkl['z0_Gill']) ** 2
        BN1 = N210 * d2pkl['zxs'] / (1. - d2pkl['zxs'] / d2pkl['z0_Gill'])
        dN21dz = (N210 * 2 * (1. - d2pkl['zxs']/d2pkl['[z0_Gill'])**(-3) /
                  d2pkl['z0_Gill'])
    else:
        N21 = N210 * np.ones(d2pkl['zxs'].shape)
        BN1 = N210 * d2pkl['zxs']
        dN21dz = 0. * N21

    dltHm = d2pkl['dltH'] * d2pkl['Lz']

    alpha = {
        '1': (1. - np.exp(-d2pkl['totdepth']/dltHm)) * dltHm,
        '2': np.exp(d2pkl['zxs'] / dltHm),
        '3': np.exp(d2pkl['zxs'] / dltHm) / dltHm,
        '4': dltHm * (np.exp(d2pkl['zxs'] / dltHm) -
                      np.exp(-d2pkl['totdepth']/dltHm))}

    # ###### Total front-induced buoyancy perturbation shape
    RiG0 = d2pkl['FrG']**-1
    BuG0 = d2pkl['RoG']*(RiG0*d2pkl['RoG']*(a[1]/a[2])**2 - a[0]/a[2])

    chi0 = d2pkl['Phi0']/np.sqrt(BuG0) * dltHm
    Bhat = (d2pkl['RoG']*d2pkl['Phi0']**2/(a[2]*BuG0) * dltHm * d2pkl['f']**2)

    dummy, dG0 = gamma_exp3(d2pkl['xzs']/chi0)
    dummy, dG1 = gamma_exp3(d2pkl['xz']/d2pkl['chi1'])

    dG = {
        '0': dG0['0']*dG1['0'],
        '1': dG0['1']*dG1['0']/chi0 + dG0['0']*dG1['1']/d2pkl['chi1'],
        '2': (dG0['2']*dG1['0']/chi0**2 + dG0['0']*dG1['2']/d2pkl['chi1']**2 +
              2*dG0['1']*dG1['2']/(chi0*d2pkl['chi1'])),
        '3': (dG0['3']*dG1['0']/chi0**3 + dG0['0']*dG1['3']/d2pkl['chi1']**3 +
              3*dG0['1']*dG1['2']/(chi0*d2pkl['chi1']**2) +
              3*dG0['2']*dG1['1']/(chi0**2*d2pkl['chi1']))}

    # Geostrophic Flow dictionary
    d2pkl.update({
        'RLoc': (BN1 + Bhat*dG['0']*alpha['2']) * (-d2pkl['rho_0']/d2pkl['g']),
        'N2Loc': N21 + Bhat*dG['0']*alpha['3'],
        'S2Loc': Bhat*dG['1']*alpha['2'],
        'TWLoc': Bhat*dG['1']*alpha['4'] / d2pkl['f_local'],
        # 'd2TWdx2': Bhat*dG['3']*alpha['4'] / d2pkl['f_local'],
        'RoGLoc': Bhat*(d2pkl['f_local']*dG['2'] - d2pkl['beta']*dG['1']
                        )*alpha['4'] / d2pkl['f_local']**3,
        'd2Rdx2': Bhat*dG['2']*alpha['2'] * (-d2pkl['rho_0']/d2pkl['g']),
        'd2Rdz2': (Bhat*dG['0']*alpha['3']/dltHm + dN21dz
                   ) * (-d2pkl['rho_0']/d2pkl['g'])})

    #    plt.figure()
    #    plt.pcolormesh(d2pkl['xz'], d2pkl['zx'], d2pkl['RLoc'])
    #    plt.colorbar()
    #    plt.show()

    # %% ----------------------------------------------------------------------
    # We can pickle that!

    if os.path.isfile('{0}spec_params.pkl'.format(root)):
        os.system('rm {0}spec_params.pkl'.format(root))
    pkl_file = open('{0}spec_params.pkl'.format(root), 'wb')
    pickle.dump(d2pkl, pkl_file, -1)
    pkl_file.close()