Example #1
0
 def make_MOC_file(self):
     # # ca 20 sec per file
     # # 50 min for ctrl
     # # 26 min for rcp
     if selfrun in ['ctrl', 'rcp']:
         DXU  = xr_DXU(self.domain)            # [m]
         DZU  = xr_DZ(self.domain, grid='U')   # [m]
         # MASK = Atlantic_mask('ocn')   # Atlantic
         # MASK = boolean_mask('ocn', 2) # Pacific
         MASK = boolean_mask(self.domain, 0)   # Global OCean
         for i, (y,m,s) in enumerate(IterateOutputCESM(domain=self.domain,
                                                       run=self.run,
                                                       tavg='yrly',
                                                       name='UVEL_VVEL')):
             # ca. 20 sec per year
             print(i, y, s)
             ds = xr.open_dataset(s, decode_times=False)
             MOC = calculate_MOC(ds=ds, DXU=DXU, DZU=DZU, MASK=MASK)
             if i==0:
                 MOC_out = MOC.copy()
             else:
                 MOC_out = xr.concat([MOC_out, MOC], dim='time')
         #     if y==202: break
         MOC_out.to_netcdf(f'{path_results}/MOC/GMOC_{self.run}.nc')
     
     elif self.run in ['lpd', 'lpi']:
         ds = xr.open_mfdataset()
     
     return
Example #2
0
def geometry_file(domain):
    """ returns xr.Dataset with geometry fields """
    assert domain in ['ocn', 'ocn_low']
    fn = f'{path_prace}/Mov/ds_geo_{domain}.nc'
    if os.path.exists(fn):
        ds_geo = xr.open_dataset(fn)
    else:
        print(f'creating geometry file:  {fn}')
        if domain == 'ocn': fe = file_ex_ocn_ctrl
        elif domain == 'ocn_low': fe = file_ex_ocn_lpd
        geometrics = [
            'TAREA', 'UAREA', 'dz', 'DXT', 'DXU', 'HTN', 'HUS', 'DYT', 'DYU',
            'HTE', 'HUW', 'REGION_MASK'
        ]
        ds_geo = xr.open_dataset(fe, decode_times=False)[geometrics].drop(
            ['TLONG', 'TLAT', 'ULONG', 'ULAT']).squeeze()
        DZT = xr_DZ(domain)
        DZU = xr_DZ(domain, grid='U')
        DZT.name, DZU.name = 'DZT', 'DZU'
        ds_geo = xr.merge([ds_geo, DZT, DZU])
        ds_geo.to_netcdf(fn)
    return ds_geo
Example #3
0
def create_dz_mean(domain):
    """ average depth [m] per level of """
    assert domain in ['ocn', 'ocn_low', 'ocn_rect']
    
    (z, lat, lon) = dll_dims_names(domain)
    
    fn = f'{path_results}/geometry/dz_mean_{domain}.nc'
    if os.path.exists(fn):
        dz_mean = xr.open_dataarray(fn, decode_times=False)
    else:
        DZT     = xr_DZ(domain)
        dz_mean = DZT.where(DZT>0).mean(dim=(lat, lon))
        dz_mean.to_netcdf(fn)
        
    return dz_mean
Example #4
0
def calculate_BSF(ds):
    """ barotropic stream function [m^3/s] 
    
    input:
    ds  .. xr Dataset
    """
    assert 'UVEL' in ds

    DYU = xr_DYU('ocn')
    # remocing unnecessary TLAT/TLONG because they are not always exactly equal
    # (due to rounding errors) and hence lead to Error messages
    DYU = DYU.drop(['TLAT', 'TLONG'])

    DZU = xr_DZ('ocn', grid='U')

    UVEL = ds.UVEL / 1e2  # [cm] to [m]
    if 'TLAT' in UVEL.coords: UVEL = UVEL.drop('TLAT')
    if 'TLONG' in UVEL.coords: UVEL = UVEL.drop('TLONG')

    BSF = (((UVEL * DZU).sum(dim='z_t')) * DYU)
    for j in np.arange(1, 2400):
        BSF[j, :] += BSF[j - 1, :]

    return BSF
Example #5
0
    def all_transports(self, run, quantity):
        """ computes heat or salt fluxes """
        assert run in ['ctrl', 'lpd']

        assert quantity in ['SALT', 'OHC']

        if quantity == 'OHC':
            VN, UE = 'VNT', 'UET'
            conversion = rho_sw * cp_sw
            qstr = 'heat'
            unit_out = 'W'
        elif quantity == 'SALT':
            VN, UE = 'VNS', 'UES'
            conversion = rho_sw * 1e-3
            qstr = 'salt'
            unit_out = 'kg/s'

        if run == 'ctrl':
            domain = 'ocn'
            all_transports_list = []

        elif run == 'lpd':
            domain = 'ocn_low'
            mf_fn = f'{path_prace}/{run}/ocn_yrly_{VN}_{UE}_*.nc'
            kwargs = {
                'concat_dim': 'time',
                'decode_times': False,
                'drop_variables': ['TLONG', 'TLAT', 'ULONG', 'ULAT'],
                'parallel': True
            }
            ds = xr.open_mfdataset(mf_fn, **kwargs)

        DZ = xr_DZ(domain=domain)
        adv = self.all_advection_cells(domain=domain)
        AREA = xr_AREA(domain=domain)
        dims = [dim for dim in dll_dims_names(domain=domain)]

        for i, pair in enumerate(tqdm(neighbours)):
            name = f'{qstr}_flux_{regions_dict[pair[0]]}_to_{regions_dict[pair[1]]}'
            #             if i>2:  continue
            adv_E = adv[
                f'adv_E_{regions_dict[pair[0]]}_to_{regions_dict[pair[1]]}']
            adv_N = adv[
                f'adv_N_{regions_dict[pair[0]]}_to_{regions_dict[pair[1]]}']
            MASK = ((abs(adv_E) + abs(adv_N)) /
                    (abs(adv_E) + abs(adv_N))).copy()
            adv_E = adv_E.where(MASK == 1, drop=True)
            adv_N = adv_N.where(MASK == 1, drop=True)
            DZ_ = DZ.where(MASK == 1, drop=True)
            AREA_ = AREA.where(MASK == 1, drop=True)
            if run == 'ctrl':
                for j, (y,m,f) in tqdm(enumerate(IterateOutputCESM(domain='ocn', run='ctrl',\
                                                                   tavg='yrly', name=f'{VN}_{UE}'))):
                    #                     if j>1: continue
                    ds = xr.open_dataset(f,
                                         decode_times=False).where(MASK == 1,
                                                                   drop=True)
                    transport = ((adv_E * ds[UE] + adv_N * ds[VN]) * AREA_ *
                                 DZ_).sum(dim=dims) * conversion
                    transport.name = name
                    transport.attrs['units'] = unit_out
                    if j == 0: transport_t = transport
                    else:
                        transport_t = xr.concat([transport_t, transport],
                                                dim='time')

                all_transports_list.append(transport_t)

            elif run == 'lpd':
                ds_ = ds.where(MASK == 1, drop=True)
                transport = ((adv_E * ds_[UE] + adv_N * ds_[VN]) * AREA_ *
                             DZ_).sum(dim=dims) * conversion

                transport.name = name
                transport.attrs['units'] = unit_out
                if i == 0: all_transports = transport
                else: all_transports = xr.merge([all_transports, transport])

        if run == 'ctrl': all_transports = xr.merge(all_transports_list)

        all_transports.to_netcdf(
            f'{path_prace}/{quantity}/{quantity}_fluxes_{run}.nc')
        return all_transports
Example #6
0
    def generate_OHC_files(self, run, year=None, pwqd=False):
        """ non-detrended OHC files for full length of simulations
        
        One file contains integrals (all global and by basin):
        
        x,y,z .. scalars        
        x,y   .. vertical profiles 
        x     .. "zonal" integrals 
        
        A separate file each for 4 different depth levels
        z     .. 2D maps, global only, but for different vertical levels
 
        # (ocn:      takes about 45 seconds per year: 70 yrs approx 55 mins)
        (ocn:      takes about 14 min per year)
        (ocn_rect: takes about  3 seconds per year: 70 yrs approx 3 mins)
        """
        
        def t2da(da, t):
            """adds time dimension to xr DataArray, then sets time value to t"""
            da = da.expand_dims('time')
            da = da.assign_coords(time=[t])
            return da

        def t2ds(da, name, t):
            """ 
            adds time dimension to xr DataArray, then sets time value to t,
            and then returns as array in xr dataset
            """
            da = t2da(da, t)
            ds = da.to_dataset(name=name)
            return ds
        start = datetime.datetime.now()
        def tss():  # time since start
            return datetime.datetime.now()-start
        print(f'{start}  start OHC calculation: run={run}')
        assert run in ['ctrl', 'rcp', 'lpd', 'lpi']

        if run=='rcp':
            domain = 'ocn'
        elif run=='ctrl':
            domain = 'ocn_rect'
        elif run in ['lpd', 'lpi']:
            domain = 'ocn_low'
            
        (z, lat, lon) = dll_dims_names(domain)

        # geometry
        DZT  = xr_DZ(domain)#.chunk(chunks={z:1})
        AREA = xr_AREA(domain)
        HTN  = xr_HTN(domain)
        LATS = xr_LATS(domain)
        
        def round_tlatlon(das):
            """ rounds TLAT and TLONG to 2 decimals
            some files' coordinates differ in their last digit
            rounding them avoids problems in concatonating
            """
            das['TLAT']   = das['TLAT'].round(decimals=2)
            das['TLONG']  = das['TLONG'].round(decimals=2)
            return das
        if domain=='ocn':
            round_tlatlon(HTN)
            round_tlatlon(LATS)

        MASK = boolean_mask(domain, mask_nr=0)
        DZT  = DZT.where(MASK)#.chunk(chunks={z:1})
        # with chunking for ctrl_rect: 21 sec per iteration, 15 sec without
#         for k in range(42):
#             DZT[k,:,:]  = DZT[k,:,:].where(MASK)
        AREA = AREA.where(MASK)
        HTN  = HTN.where(MASK)
        LATS = LATS.where(MASK)
#         print(f'{datetime.datetime.now()}  done with geometry')
        
        if pwqd:  name = 'TEMP_pwqd'
        else:     name = 'TEMP_PD'
            
#         print(run, domain, name)
        
        for y,m,file in IterateOutputCESM(domain=domain, run=run, tavg='yrly', name=name):
#             print(tss(), y)
            
#             break
            
            if year!=None:  # select specific year
                if year==y:
                    pass
                else:
                    continue
                    
            if pwqd:  file_out = f'{path_samoc}/OHC/OHC_integrals_{run}_{y:04d}_pwqd.nc'
            else:     file_out = f'{path_samoc}/OHC/OHC_integrals_{run}_{y:04d}.nc'
                

            if os.path.exists(file_out) and year is None:
#     #             should check here if all the fields exist
#                 print(f'{datetime.datetime.now()} {y} skipped as files exists already')
#             if y not in [250,251,252,253,254,255,273,274,275]:
                continue
            print(f'{tss()} {y}, {file}')

            t   = y*365  # time in days since year 0, for consistency with CESM date output
#             ds  = xr.open_dataset(file, decode_times=False, chunks={z:1}).TEMP
            ds  = xr.open_dataset(file, decode_times=False).TEMP
            print(f'{tss()} opened dataset')
            if domain=='ocn':
                ds = ds.drop(['ULONG', 'ULAT'])
                ds = round_tlatlon(ds)

#             if ds.PD[0,150,200].round(decimals=0)==0:
#                 ds['PD'] = ds['PD']*1000 + rho_sw
#             elif ds.PD[0,150,200].round(decimals=0)==1:
#                 ds['PD'] = ds['PD']*1000
#             else: 
#                 print('density [g/cm^3] is neither close to 0 or 1')

#             OHC = ds.TEMP*ds.PD*cp_sw
            OHC = ds*rho_sw*cp_sw
            ds.close()
            OHC = OHC.where(MASK)

            OHC_DZT = OHC*DZT
            print(f'{tss()}  {y} calculated OHC & OHC_DZT')
            
            # global, global levels, zonal, zonal levels integrals for different regions
            for mask_nr in tqdm([0,1,2,3,6,7,8,9,10]):
#             for mask_nr in [0,1,2,3,6,7,8,9,10]:
                name = regions_dict[mask_nr]
                da = OHC.where(boolean_mask(domain, mask_nr=mask_nr))
                
                da_g = (da*AREA*DZT).sum(dim=[z, lat, lon])
                da_g.attrs['units'] = '[J]'
                ds_g  = t2ds(da_g , f'OHC_{name}', t)

                da_gl = (da*AREA).sum(dim=[lat, lon])
                da_gl.attrs['units'] = '[J m^-1]'
                ds_gl = t2ds(da_gl, f'OHC_levels_{name}', t)

                if domain=='ocn':  da_z  = xr_int_zonal(da=da, HTN=HTN, LATS=LATS, AREA=AREA, DZ=DZT)
                else:  da_z = (da*HTN*DZT).sum(dim=[z, lon])
                da_z.attrs['units'] = '[J m^-1]'
                ds_z = t2ds(da_z , f'OHC_zonal_{name}', t)
                
                if domain=='ocn':  da_zl = xr_int_zonal_level(da=da, HTN=HTN, LATS=LATS, AREA=AREA, DZ=DZT)
                else:  da_zl = (da*HTN).sum(dim=[lon])
                da_zl.attrs['units'] = '[J m^-2]'
                ds_zl = t2ds(da_zl, f'OHC_zonal_levels_{name}', t)
                if mask_nr==0:   ds_new = xr.merge([ds_g, ds_gl, ds_z, ds_zl])
                else:            ds_new = xr.merge([ds_new, ds_g, ds_gl, ds_z, ds_zl])
                    
            print(f'{tss()}  done with horizontal calculations')
            
            # vertical integrals
            # full depth
            da_v  = OHC_DZT.sum(dim=z)                         #   0-6000 m
            da_v.attrs = {'depths':f'{OHC_DZT[z][0]-OHC_DZT[z][-1]}',
                          'units':'[J m^-2]'}
            
            if domain in ['ocn', 'ocn_rect']:  zsel = [[0,9], [0,20], [20,26]]
            elif domain=='ocn_low':            zsel = [[0,9], [0,36], [36,45]]
            
            #   0- 100 m
            da_va = OHC_DZT.isel({z:slice(zsel[0][0], zsel[0][1])}).sum(dim=z)  
            da_va.attrs = {'depths':f'{OHC_DZT[z][zsel[0][0]].values:.0f}-{OHC_DZT[z][zsel[0][1]].values:.0f}',
                           'units':'[J m^-2]'}
            
            #   0- 700 m
            da_vb = OHC_DZT.isel({z:slice(zsel[1][0],zsel[1][1])}).sum(dim=z)  
            da_vb.attrs = {'depths':f'{OHC_DZT[z][zsel[1][0]].values:.0f}-{OHC_DZT[z][zsel[1][1]].values:.0f}',
                           'units':'[J m^-2]'}
            
            # 700-2000 m
            da_vc = OHC_DZT.isel({z:slice(zsel[2][0],zsel[2][1])}).sum(dim=z)  
            da_vc.attrs = {'depths':f'{OHC_DZT[z][zsel[2][0]].values:.0f}-{OHC_DZT[z][zsel[2][1]].values:.0f}',
                           'units':'[J m^-2]'}
            
            ds_v  = t2ds(da_v , 'OHC_vertical_0_6000m'  , t)
            ds_va = t2ds(da_va, 'OHC_vertical_0_100m'   , t)
            ds_vb = t2ds(da_vb, 'OHC_vertical_0_700m'   , t)
            ds_vc = t2ds(da_vc, 'OHC_vertical_700_2000m', t)

            ds_new = xr.merge([ds_new, ds_v, ds_va, ds_vb, ds_vc])

            print(f'{tss()}  done making datasets')
#             print(f'output: {file_out}\n')
            ds_new.to_netcdf(path=file_out, mode='w')
            ds_new.close()

#             if y in [2002, 102, 156, 1602]:  break  # for testing only

        # combining yearly files
        
        print(f'{datetime.datetime.now()}  done\n')
        
#         if run=='ctrl':  print('year 205 is wrong and should be averaged by executing `fix_ctrl_year_205()`')
        return
Example #7
0
    """ Weddell Gyre transport [m^3/s]
    
    input:
    BSF .. xr DataArray
    """
    WGT = BSF.sel(WG_center)
    return WGT


if __name__ == '__main__':

    run = sys.argv[1]
    # ys, ye = int(sys.argv[2]), int(sys.argv[3])
    if run == 'ctrl': yy = np.arange(200, 230)
    elif run == 'lpd': yy = np.arange(500, 530)
    elif run in ['rcp', 'lr1']: yy = np.arange(2000, 2101)

    if run in ['ctrl', 'rcp']:
        DZU = xr_DZ('ocn', grid='U')
        DYU = xr.open_dataset(file_ex_ocn_ctrl, decode_times=False).DYU
    elif run in ['lpd', 'lr1']:
        DZU = xr_DZ('ocn_low', grid='U')
        DYU = xr.open_dataset(file_ex_ocn_lpd, decode_times=False).DYU
    for i, y in enumerate(yy):
        UVEL = xr.open_dataset(
            f'{path_prace}/{run}/ocn_yrly_UVEL_VVEL_{y:04d}.nc',
            decode_times=False).UVEL

        psi_list.append((DYU * (UVEL * DZU).sum('z_t') / 1e10).cumsum('nlat'))
    psi = xr.concat(psi_list, concat_dim='time')
    psi.to_netcdf(f'{path_prace}/BSF/BSF_{run}.nc')
Example #8
0
    def make_TEMP_SALT_transport_section(self, section_dict, years=None):
        """ zonal or meridional sections along grid lines
        {'nlat':j_34, 'nlon':slice(i_SA,i_CGH)}
        """
        print(self.run)
        assert self.domain in ['ocn', 'ocn_low'], 'only implemented for ocn and ocn_low grids'
        if years is None:  years = np.arange(2000,2101)
        dss = {}
        DZT = xr_DZ(domain=self.domain)
        DZT.name = 'DZT'
        if self.run in ['ctrl', 'rcp']:   fn = file_ex_ocn_ctrl
        elif self.run in ['lpd', 'lr1']:  fn = file_ex_ocn_lpd
        geometry = xr.open_dataset(fn, decode_times=False).drop('time')[['DYT', 'DYU', 'DXT', 'DXU', 'z_t', 'dz']]

        for k, y in enumerate(years):
#             if k==2:  break
            print(y)
            fn_UVEL_VVEL = f'{path_prace}/{self.run}/ocn_yrly_UVEL_VVEL_{y:04d}.nc'
            fn_SALT_VNS_VNE = f'{path_prace}/{self.run}/ocn_yrly_SALT_VNS_UES_{y:04d}.nc'
            fn_TEMP = f'{path_prace}/{self.run}/ocn_yrly_TEMP_{y:04d}.nc'
            fn_VNT_UET = f'{path_prace}/{self.run}/ocn_yrly_VNT_UET_{y:04d}.nc'
            
            UVEL_VVEL = xr.open_dataset(fn_UVEL_VVEL, decode_times=False).drop('time')
            SALT_VNS_UES = xr.open_dataset(fn_SALT_VNS_VNE, decode_times=False).drop('time')
            TEMP = xr.open_dataset(fn_TEMP, decode_times=False).drop('time')
            VNT_UET = xr.open_dataset(fn_VNT_UET, decode_times=False).drop('time')
            
            ds = xr.merge([UVEL_VVEL, SALT_VNS_UES, TEMP, VNT_UET, geometry, DZT],compat='override')
            ds = ds.assign_coords(time=365*y+31).expand_dims('time')  # 31.1. of each year
            
            for key in section_dict.keys():
                fn =  f'{path_prace}/{self.run}/section_{key}_{self.run}_{y:04d}.nc'
                if y>years[0] and os.path.exists(fn):  continue
                else:
                    (i,j) = section_dict[key]
                    ds_ = ds

                    if type(i)==int and type(j)==tuple: 
                        sel_dict = {'nlat':slice(j[0],j[1]), 'nlon':i}
                        if k==0:
                            lon1 = f'{ds_.TLONG.sel(nlon=i,nlat=j[0]).values:6.1f}'
                            lon2 = f'{ds_.TLONG.sel(nlon=i,nlat=j[1]).values:6.1f}'
                            lat1 = f'{ds_.TLAT.sel(nlon=i,nlat=j[0]).values:6.1f}'
                            lat2 = f'{ds_.TLAT.sel(nlon=i,nlat=j[1]).values:6.1f}'
                            print(f'{key:10} merid section: {str(sel_dict):50},{lon1}/{lon2}E,{lat1}N-{lat2}N')
                        list1 = ['UVEL', 'SALT', 'TEMP', 'UES', 'UET', 'DZT', 'DYT', 'DYU', 'z_t', 'dz']
                    elif type(i)==tuple and type(j)==int:
                        if i[1]<i[0]:
                            ds_ = shift_ocn_low(ds_)
                            if i[0]>160:
                                i = (i[0]-320, i[1])

                        sel_dict = {'nlat':j, 'nlon':slice(i[0],i[1])}
                        if k==0:
                            lon1 = f'{ds_.TLONG.sel(nlon=i[0],nlat=j).values:6.1f}'
                            lon2 = f'{ds_.TLONG.sel(nlon=i[1], nlat=j).values:6.1f}'
                            lat1 = f'{ds_.TLAT.sel(nlon=i[0],nlat=j).values:6.1f}'
                            lat2 = f'{ds_.TLAT.sel(nlon=i[1],nlat=j).values:6.1f}'
                            print(f'{key:10} zonal section: {str(sel_dict):50},{lon1}E-{lon2}E,{lat1}/{lat2}N')
                        list1 = ['VVEL', 'SALT', 'TEMP', 'VNS', 'VNT', 'DZT', 'DXT', 'DXU', 'z_t', 'dz']
                    else: raise ValueError('one of i/j needs to be length 2 tuple of ints and the other an int')
                    ds_ = ds_[list1].sel(sel_dict)

                    if k==0:
                        TLAT, TLONG = ds_.TLAT, ds_.TLONG
                    else:
                        ds_['TLAT'], ds_['TLONG'] = TLAT, TLONG 
                    ds_.to_netcdf(fn)
        return
Example #9
0
File: OHC.py Project: AJueling/CESM
def OHC_parallel(run, mask_nr=0):
    """ ocean heat content calculation """
    print('***************************************************')
    print('* should be run with a dask scheduler             *')
    print('`from dask.distributed import Client, LocalCluster`')
    print('`cluster = LocalCluster(n_workers=2)`')
    print('`client = Client(cluster)`')
    print('***************************************************')

    print(
        f'\n{datetime.datetime.now()}  start OHC calculation: run={run} mask_nr={mask_nr}'
    )
    assert run in ['ctrl', 'rcp', 'lpd', 'lpi']
    assert type(mask_nr) == int
    assert mask_nr >= 0 and mask_nr < 13

    #     file_out = f'{path_samoc}/OHC/OHC_test.nc'
    file_out = f'{path_samoc}/OHC/OHC_integrals_{regions_dict[mask_nr]}_{run}.nc'

    if run in ['ctrl', 'rcp']:
        domain = 'ocn'
    elif run in ['lpd', 'lpi']:
        domain = 'ocn_low'

    MASK = boolean_mask(domain, mask_nr)

    # geometry
    DZT = xr_DZ(domain)
    AREA = xr_AREA(domain)
    HTN = xr_HTN(domain)
    LATS = xr_LATS(domain)
    print(f'{datetime.datetime.now()}  done with geometry')

    # multi-file
    file_list = ncfile_list(domain='ocn', run=run, tavg='yrly', name='TEMP_PD')
    OHC = xr.open_mfdataset(paths=file_list,
                            concat_dim='time',
                            decode_times=False,
                            compat='minimal',
                            parallel=True).drop(['ULAT', 'ULONG'
                                                 ]).TEMP * cp_sw * rho_sw
    if mask_nr != 0:
        OHC = OHC.where(MASK)
    print(f'{datetime.datetime.now()}  done loading data')

    for ds in [OHC, HTN, LATS]:
        round_tlatlon(ds)
    OHC_DZT = OHC * DZT
    print(f'{datetime.datetime.now()}  done OHC_DZT')

    # xr DataArrays
    da_g = xr_int_global(da=OHC, AREA=AREA, DZ=DZT)
    da_gl = xr_int_global_level(da=OHC, AREA=AREA, DZ=DZT)
    da_v = OHC_DZT.sum(dim='z_t')  #xr_int_vertical(da=OHC, DZ=DZT)
    da_va = OHC_DZT.isel(z_t=slice(0, 9)).sum(dim='z_t')  # above 100 m
    da_vb = OHC_DZT.isel(z_t=slice(9, 42)).sum(dim='z_t')  # below 100 m
    da_z = xr_int_zonal(da=OHC, HTN=HTN, LATS=LATS, AREA=AREA, DZ=DZT)
    da_zl = xr_int_zonal_level(da=OHC, HTN=HTN, LATS=LATS, AREA=AREA, DZ=DZT)
    print(f'{datetime.datetime.now()}  done calculations')

    # xr Datasets
    ds_g = da_g.to_dataset(name='OHC_global')
    ds_gl = da_gl.to_dataset(name='OHC_global_levels')
    ds_v = da_v.to_dataset(name='OHC_vertical')
    ds_va = da_va.to_dataset(name='OHC_vertical_above_100m')
    ds_vb = da_vb.to_dataset(name='OHC_vertical_below_100m')
    ds_z = da_z.to_dataset(name='OHC_zonal')
    ds_zl = da_zl.to_dataset(name='OHC_zonal_levels')
    print(f'{datetime.datetime.now()}  done dataset')

    print(f'output: {file_out}')

    ds_new = xr.merge([ds_g, ds_gl, ds_z, ds_zl, ds_v, ds_va, ds_vb])
    ds_new.to_netcdf(path=file_out, mode='w')
    #     ds_new.close()
    print(f'{datetime.datetime.now()}  done\n')

    return ds_new