예제 #1
0
def main():
    # Get options
    args = options()

    nw = args.number_of_workers
    if not os.path.isdir('PNG/MOC'):
        print('Creating a directory to place figures (PNG/MOC)... \n')
        os.system('mkdir -p PNG/MOC')
    if not os.path.isdir('ncfiles'):
        print('Creating a directory to place figures (ncfiles)... \n')
        os.system('mkdir ncfiles')

    # Read in the yaml file
    diag_config_yml = yaml.load(open(args.diag_config_yml_path, 'r'),
                                Loader=yaml.Loader)

    # Create the case instance
    dcase = DiagsCase(diag_config_yml['Case'])
    args.case_name = dcase.casename
    args.savefigs = True
    args.outdir = 'PNG/MOC/'
    RUNDIR = dcase.get_value('RUNDIR')
    print('Run directory is:', RUNDIR)
    print('Casename is:', dcase.casename)
    print('Number of workers to be used:', nw)

    # set avg dates
    avg = diag_config_yml['Avg']
    if not args.start_date: args.start_date = avg['start_date']
    if not args.end_date: args.end_date = avg['end_date']

    # read grid info
    grd = MOM6grid(RUNDIR + '/' + dcase.casename + '.mom6.static.nc')
    depth = grd.depth_ocean
    # remote Nan's, otherwise genBasinMasks won't work
    depth[np.isnan(depth)] = 0.0
    basin_code = m6toolbox.genBasinMasks(grd.geolon, grd.geolat, depth)

    parallel, cluster, client = m6toolbox.request_workers(nw)

    print('Reading {} dataset...'.format(args.file_name))
    startTime = datetime.now()

    # load data
    def preprocess(ds):
        variables = ['vmo', 'vhml', 'vhGM']
        for v in variables:
            if v not in ds.variables:
                ds[v] = xr.zeros_like(ds.vo)
        return ds[variables]

    if parallel:
        ds = xr.open_mfdataset(
            RUNDIR + '/' + dcase.casename + args.file_name,
            parallel=True,
            combine="nested",  # concatenate in order of files
            concat_dim="time",  # concatenate along time
            preprocess=preprocess,
        ).chunk({"time": 12})

    else:
        ds = xr.open_mfdataset(RUNDIR+'/'+dcase.casename+args.file_name, data_vars='minimal', \
                               coords='minimal', compat='override', preprocess=preprocess)

    print('Time elasped: ', datetime.now() - startTime)

    # compute yearly means first since this will be used in the time series
    print('Computing yearly means...')
    startTime = datetime.now()
    ds_yr = ds.resample(time="1Y", closed='left').mean('time')
    print('Time elasped: ', datetime.now() - startTime)

    print('Selecting data between {} and {}...'.format(args.start_date,
                                                       args.end_date))
    startTime = datetime.now()
    ds_sel = ds_yr.sel(time=slice(args.start_date, args.end_date))
    print('Time elasped: ', datetime.now() - startTime)

    print('Computing time mean...')
    startTime = datetime.now()
    ds_mean = ds_sel.mean('time').compute()
    print('Time elasped: ', datetime.now() - startTime)

    # create a ndarray subclass
    class C(np.ndarray):
        pass

    varName = 'vmo'
    conversion_factor = 1.e-9
    tmp = np.ma.masked_invalid(ds_mean[varName].values)
    tmp = tmp[:].filled(0.)
    VHmod = tmp.view(C)
    VHmod.units = ds[varName].units
    Zmod = m6toolbox.get_z(ds, depth, varName)  # same here

    if args.case_name != '': case_name = args.case_name
    else: case_name = ''

    # Global MOC
    m6plot.setFigureSize([16, 9], 576, debug=False)
    axis = plt.gca()
    cmap = plt.get_cmap('dunnePM')
    zg = Zmod.min(axis=-1)
    psiPlot = MOCpsi(VHmod) * conversion_factor
    psiPlot = 0.5 * (psiPlot[0:-1, :] + psiPlot[1::, :])
    yyg = grd.geolat_c[:, :].max(axis=-1) + 0 * zg
    ci = m6plot.pmCI(0., 40., 5.)
    plotPsi(
        yyg, zg, psiPlot, ci, 'Global MOC [Sv],' + 'averaged between ' +
        args.start_date + ' and ' + args.end_date)
    plt.xlabel(r'Latitude [$\degree$N]')
    plt.suptitle(case_name)
    plt.gca().invert_yaxis()
    findExtrema(yyg, zg, psiPlot, max_lat=-30.)
    findExtrema(yyg, zg, psiPlot, min_lat=25., min_depth=250.)
    findExtrema(yyg, zg, psiPlot, min_depth=2000., mult=-1.)
    objOut = args.outdir + str(case_name) + '_MOC_global.png'
    plt.savefig(objOut)

    if 'zl' in ds:
        zl = ds.zl.values
    elif 'z_l' in ds:
        zl = ds.z_l.values
    else:
        raise ValueError("Dataset does not have vertical coordinate zl or z_l")

    # create dataset to store results
    moc = xr.Dataset(data_vars={
        'moc': (('zl', 'yq'), psiPlot),
        'amoc': (('zl', 'yq'), np.zeros((psiPlot.shape))),
        'moc_FFM': (('zl', 'yq'), np.zeros((psiPlot.shape))),
        'moc_GM': (('zl', 'yq'), np.zeros((psiPlot.shape))),
        'amoc_45': (('time'), np.zeros((ds_yr.time.shape))),
        'moc_GM_ACC': (('time'), np.zeros((ds_yr.time.shape))),
        'amoc_26': (('time'), np.zeros((ds_yr.time.shape)))
    },
                     coords={
                         'zl': zl,
                         'yq': ds.yq,
                         'time': ds_yr.time
                     })
    attrs = {
        'description': 'MOC time-mean sections and time-series',
        'units': 'Sv',
        'start_date': avg['start_date'],
        'end_date': avg['end_date'],
        'casename': dcase.casename
    }
    m6toolbox.add_global_attrs(moc, attrs)

    # Atlantic MOC
    m6plot.setFigureSize([16, 9], 576, debug=False)
    cmap = plt.get_cmap('dunnePM')
    m = 0 * basin_code
    m[(basin_code == 2) | (basin_code == 4) | (basin_code == 6) |
      (basin_code == 7) | (basin_code == 8)] = 1
    ci = m6plot.pmCI(0., 22., 2.)
    z = (m * Zmod).min(axis=-1)
    psiPlot = MOCpsi(VHmod,
                     vmsk=m * np.roll(m, -1, axis=-2)) * conversion_factor
    psiPlot = 0.5 * (psiPlot[0:-1, :] + psiPlot[1::, :])
    yy = grd.geolat_c[:, :].max(axis=-1) + 0 * z
    plotPsi(
        yy, z, psiPlot, ci, 'Atlantic MOC [Sv],' + 'averaged between ' +
        args.start_date + ' and ' + args.end_date)
    plt.xlabel(r'Latitude [$\degree$N]')
    plt.suptitle(case_name)
    plt.gca().invert_yaxis()
    findExtrema(yy, z, psiPlot, min_lat=26.5, max_lat=27.,
                min_depth=250.)  # RAPID
    findExtrema(yy, z, psiPlot, max_lat=-33.)
    findExtrema(yy, z, psiPlot)
    findExtrema(yy, z, psiPlot, min_lat=5.)
    objOut = args.outdir + str(case_name) + '_MOC_Atlantic.png'
    plt.savefig(objOut, format='png')
    moc['amoc'].data = psiPlot

    print('Plotting AMOC profile at 26N...')
    rapid_vertical = xr.open_dataset(
        '/glade/work/gmarques/cesm/datasets/RAPID/moc_vertical.nc')
    fig, ax = plt.subplots(nrows=1, ncols=1)
    ax.plot(rapid_vertical.stream_function_mar.mean('time'),
            rapid_vertical.depth,
            'k',
            label='RAPID')
    ax.plot(moc['amoc'].sel(yq=26, method='nearest'), zl, label=case_name)
    ax.legend()
    plt.gca().invert_yaxis()
    plt.grid()
    ax.set_xlabel('AMOC @ 26N [Sv]')
    ax.set_ylabel('Depth [m]')
    objOut = args.outdir + str(case_name) + '_MOC_profile_26N.png'
    plt.savefig(objOut, format='png')

    print('Computing time series...')
    startTime = datetime.now()
    # time-series
    dtime = ds_yr.time
    amoc_26 = np.zeros(len(dtime))
    amoc_45 = np.zeros(len(dtime))
    moc_GM_ACC = np.zeros(len(dtime))
    if args.debug: startTime = datetime.now()
    # loop in time
    for t in range(len(dtime)):
        tmp = np.ma.masked_invalid(ds_yr[varName][t, :].values)
        tmp = tmp[:].filled(0.)
        # m is still Atlantic ocean
        psi = MOCpsi(tmp, vmsk=m * np.roll(m, -1, axis=-2)) * conversion_factor
        psi = 0.5 * (psi[0:-1, :] + psi[1::, :])
        amoc_26[t] = findExtrema(yy,
                                 z,
                                 psi,
                                 min_lat=26.,
                                 max_lat=27.,
                                 plot=False,
                                 min_depth=250.)
        amoc_45[t] = findExtrema(yy,
                                 z,
                                 psi,
                                 min_lat=44.,
                                 max_lat=46.,
                                 plot=False,
                                 min_depth=250.)
        tmp_GM = np.ma.masked_invalid(ds_yr['vhGM'][t, :].values)
        tmp_GM = tmp_GM[:].filled(0.)
        psiGM = MOCpsi(tmp_GM) * conversion_factor
        psiGM = 0.5 * (psiGM[0:-1, :] + psiGM[1::, :])
        moc_GM_ACC[t] = findExtrema(yyg,
                                    zg,
                                    psiGM,
                                    min_lat=-65.,
                                    max_lat=-30,
                                    mult=-1.,
                                    plot=False)
    print('Time elasped: ', datetime.now() - startTime)

    # add dataarays to the moc dataset
    moc['amoc_26'].data = amoc_26
    moc['amoc_45'].data = amoc_45
    moc['moc_GM_ACC'].data = moc_GM_ACC

    if parallel:
        print('Releasing workers ...')
        client.close()
        cluster.close()

    print('Plotting...')
    # load AMOC time series data (5th) cycle used in Danabasoglu et al., doi:10.1016/j.ocemod.2015.11.007
    path = '/glade/p/cesm/omwg/amoc/COREII_AMOC_papers/papers/COREII.variability/data.original/'
    amoc_core_26 = xr.open_dataset(path + 'AMOCts.cyc5.26p5.nc')
    # load AMOC from POP JRA-55
    amoc_pop_26 = xr.open_dataset(
        '/glade/u/home/bryan/MOM6-modeloutputanalysis/'
        'AMOC_series_26n.g210.GIAF_JRA.v13.gx1v7.01.nc')
    # load RAPID time series
    rapid = xr.open_dataset(
        '/glade/work/gmarques/cesm/datasets/RAPID/moc_transports.nc').resample(
            time="1Y", closed='left', keep_attrs=True).mean('time',
                                                            keep_attrs=True)
    # plot
    fig = plt.figure(figsize=(12, 6))
    plt.plot(np.arange(len(moc.time)) + 1958.5,
             moc['amoc_26'].values,
             color='k',
             label=case_name,
             lw=2)
    # core data
    core_mean = amoc_core_26['MOC'].mean(axis=0).data
    core_std = amoc_core_26['MOC'].std(axis=0).data
    plt.plot(amoc_core_26.time,
             core_mean,
             'k',
             label='CORE II (group mean)',
             color='#1B2ACC',
             lw=1)
    plt.fill_between(amoc_core_26.time,
                     core_mean - core_std,
                     core_mean + core_std,
                     alpha=0.25,
                     edgecolor='#1B2ACC',
                     facecolor='#089FFF')
    # pop data
    plt.plot(np.arange(len(amoc_pop_26.time)) + 1958.5,
             amoc_pop_26.AMOC_26n.values,
             color='r',
             label='POP',
             lw=1)
    # rapid
    plt.plot(np.arange(len(rapid.time)) + 2004.5,
             rapid.moc_mar_hc10.values,
             color='green',
             label='RAPID',
             lw=1)

    plt.title('AMOC @ 26 $^o$ N', fontsize=16)
    plt.ylim(5, 20)
    plt.xlim(1948, 1958.5 + len(moc.time))
    plt.xlabel('Time [years]', fontsize=16)
    plt.ylabel('Sv', fontsize=16)
    plt.legend(fontsize=13, ncol=2)
    objOut = args.outdir + str(case_name) + '_MOC_26N_time_series.png'
    plt.savefig(objOut, format='png')

    amoc_core_45 = xr.open_dataset(path + 'AMOCts.cyc5.45.nc')
    amoc_pop_45 = xr.open_dataset(
        '/glade/u/home/bryan/MOM6-modeloutputanalysis/'
        'AMOC_series_45n.g210.GIAF_JRA.v13.gx1v7.01.nc')
    # plot
    fig = plt.figure(figsize=(12, 6))
    plt.plot(np.arange(len(moc.time)) + 1958.5,
             moc['amoc_45'],
             color='k',
             label=case_name,
             lw=2)
    # core data
    core_mean = amoc_core_45['MOC'].mean(axis=0).data
    core_std = amoc_core_45['MOC'].std(axis=0).data
    plt.plot(amoc_core_45.time,
             core_mean,
             'k',
             label='CORE II (group mean)',
             color='#1B2ACC',
             lw=2)
    plt.fill_between(amoc_core_45.time,
                     core_mean - core_std,
                     core_mean + core_std,
                     alpha=0.25,
                     edgecolor='#1B2ACC',
                     facecolor='#089FFF')
    # pop data
    plt.plot(np.arange(len(amoc_pop_45.time)) + 1958.5,
             amoc_pop_45.AMOC_45n.values,
             color='r',
             label='POP',
             lw=1)

    plt.title('AMOC @ 45 $^o$ N', fontsize=16)
    plt.ylim(5, 20)
    plt.xlim(1948, 1958 + len(moc.time))
    plt.xlabel('Time [years]', fontsize=16)
    plt.ylabel('Sv', fontsize=16)
    plt.legend(fontsize=14)
    objOut = args.outdir + str(case_name) + '_MOC_45N_time_series.png'
    plt.savefig(objOut, format='png')

    # Submesoscale-induced Global MOC
    class C(np.ndarray):
        pass

    varName = 'vhml'
    conversion_factor = 1.e-9
    tmp = np.ma.masked_invalid(ds_mean[varName].values)
    tmp = tmp[:].filled(0.)
    VHml = tmp.view(C)
    VHml.units = ds[varName].units
    Zmod = m6toolbox.get_z(ds, depth, varName)  # same here
    m6plot.setFigureSize([16, 9], 576, debug=False)
    axis = plt.gca()
    cmap = plt.get_cmap('dunnePM')
    z = Zmod.min(axis=-1)
    psiPlot = MOCpsi(VHml) * conversion_factor
    psiPlot = 0.5 * (psiPlot[0:-1, :] + psiPlot[1::, :])
    yy = grd.geolat_c[:, :].max(axis=-1) + 0 * z
    ci = m6plot.pmCI(0., 20., 2.)
    plotPsi(yy,
            z,
            psiPlot,
            ci,
            'Global FFH MOC [Sv],' + 'averaged between ' + args.start_date +
            ' and ' + args.end_date,
            zval=[0., -400., -1000.])
    plt.xlabel(r'Latitude [$\degree$N]')
    plt.suptitle(case_name)
    plt.gca().invert_yaxis()
    objOut = args.outdir + str(case_name) + '_FFH_MOC_global.png'
    plt.savefig(objOut)
    moc['moc_FFM'].data = psiPlot

    # GM-induced Global MOC
    class C(np.ndarray):
        pass

    varName = 'vhGM'
    conversion_factor = 1.e-9
    tmp = np.ma.masked_invalid(ds_mean[varName].values)
    tmp = tmp[:].filled(0.)
    VHGM = tmp.view(C)
    VHGM.units = ds[varName].units
    Zmod = m6toolbox.get_z(ds, depth, varName)  # same here
    m6plot.setFigureSize([16, 9], 576, debug=False)
    axis = plt.gca()
    cmap = plt.get_cmap('dunnePM')
    z = Zmod.min(axis=-1)
    psiPlot = MOCpsi(VHGM) * conversion_factor
    psiPlot = 0.5 * (psiPlot[0:-1, :] + psiPlot[1::, :])
    yy = grd.geolat_c[:, :].max(axis=-1) + 0 * z
    ci = m6plot.pmCI(0., 20., 2.)
    plotPsi(
        yy, z, psiPlot, ci, 'Global GM MOC [Sv],' + 'averaged between ' +
        args.start_date + ' and ' + args.end_date)
    plt.xlabel(r'Latitude [$\degree$N]')
    plt.suptitle(case_name)
    plt.gca().invert_yaxis()
    findExtrema(yy, z, psiPlot, min_lat=-65., max_lat=-30, mult=-1.)
    objOut = args.outdir + str(case_name) + '_GM_MOC_global.png'
    plt.savefig(objOut)
    moc['moc_GM'].data = psiPlot

    print('Saving netCDF files...')
    moc.to_netcdf('ncfiles/' + str(case_name) + '_MOC.nc')
    return
예제 #2
0
def main(stream=False):
  # Get options
  args = options()
  nw = args.number_of_workers
  if not os.path.isdir('PNG/HT'):
    print('Creating a directory to place figures (PNG/HT)... \n')
    os.system('mkdir -p PNG/HT')
  if not os.path.isdir('ncfiles'):
    print('Creating a directory to place figures (ncfiles)... \n')
    os.system('mkdir ncfiles')

  # Read in the yaml file
  diag_config_yml = yaml.load(open(args.diag_config_yml_path,'r'), Loader=yaml.Loader)

  # Create the case instance
  dcase = DiagsCase(diag_config_yml['Case'])
  args.case_name = dcase.casename
  args.savefigs = True; args.outdir = 'PNG/HT'
  RUNDIR = dcase.get_value('RUNDIR')
  print('Run directory is:', RUNDIR)
  print('Casename is:', dcase.casename)
  print('Variables to be processed:', args.variables)
  print('Number of workers to be used:', nw)

  # set avg dates
  avg = diag_config_yml['Avg']
  if not args.start_date : args.start_date = avg['start_date']
  if not args.end_date : args.end_date = avg['end_date']

  # read grid info
  grd = MOM6grid(RUNDIR+'/'+dcase.casename+'.mom6.static.nc')
  depth = grd.depth_ocean
  # remote Nan's, otherwise genBasinMasks won't work
  depth[np.isnan(depth)] = 0.0
  basin_code = m6toolbox.genBasinMasks(grd.geolon, grd.geolat, depth)
  parallel, cluster, client = m6toolbox.request_workers(nw)
  print('Reading dataset...')
  startTime = datetime.now()
  variables = args.variables

  def preprocess(ds):
    ''' Compute montly averages and return the dataset with variables'''
    for var in variables:
      print('Processing {}'.format(var))
      if var not in ds.variables:
        print('WARNING: ds does not have variable {}. Creating dataarray with zeros'.format(var))
        jm, im = grd.geolat.shape
        tm = len(ds.time)
        da = xr.DataArray(np.zeros((tm, jm, im)), dims=['time','yq','xh'], \
             coords={'yq' : grd.yq, 'xh' : grd.xh, 'time' : ds.time}).rename(var)
        ds = xr.merge([ds, da])
    #return ds[variables].resample(time="1Y", closed='left', \
    #       keep_attrs=True).mean(dim='time', keep_attrs=True)
    return ds[variables]

  if parallel:
    ds = xr.open_mfdataset(RUNDIR+'/'+dcase.casename+'.mom6.hm_*.nc', \
         parallel=True, data_vars='minimal', chunks={'time': 12},\
         coords='minimal', compat='override', preprocess=preprocess)
  else:
    ds = xr.open_mfdataset(RUNDIR+'/'+dcase.casename+'.mom6.hm_*.nc', \
         data_vars='minimal', coords='minimal', compat='override', \
         preprocess=preprocess)
  print('Time elasped: ', datetime.now() - startTime)

  print('Selecting data between {} and {}...'.format(args.start_date, args.end_date))
  startTime = datetime.now()
  ds_sel = ds.sel(time=slice(args.start_date, args.end_date))
  print('Time elasped: ', datetime.now() - startTime)

  print('Computing yearly means...')
  startTime = datetime.now()
  ds_sel = ds_sel.resample(time="1Y", closed='left',keep_attrs=True).mean('time',keep_attrs=True)
  print('Time elasped: ', datetime.now() - startTime)

  print('Computing time mean...')
  startTime = datetime.now()
  ds_sel = ds_sel.mean('time').load()
  print('Time elasped: ', datetime.now() - startTime)

  if parallel:
    print('Releasing workers...')
    client.close(); cluster.close()

  varName = 'T_ady_2d'
  print('Saving netCDF files...')
  attrs = {'description': 'Time-mean poleward heat transport by components ', 'units': ds[varName].units,
       'start_date': args.start_date, 'end_date': args.end_date, 'casename': dcase.casename}
  m6toolbox.add_global_attrs(ds_sel,attrs)

  ds_sel.to_netcdf('ncfiles/'+dcase.casename+'_heat_transport.nc')
  # create a ndarray subclass
  class C(np.ndarray): pass

  if varName in ds.variables:
    tmp = np.ma.masked_invalid(ds_sel[varName].values)
    tmp = tmp[:].filled(0.)
    advective = tmp.view(C)
    advective.units = ds[varName].units
  else:
    raise Exception('Could not find "T_ady_2d" in file "%s"'%(args.infile+args.monthly))

  varName = 'T_diffy_2d'
  if varName in ds.variables:
    tmp = np.ma.masked_invalid(ds_sel[varName].values)
    tmp = tmp[:].filled(0.)
    diffusive = tmp.view(C)
    diffusive.units = ds[varName].units
  else:
    diffusive = None
    warnings.warn('Diffusive temperature term not found. This will result in an underestimation of the heat transport.')

  varName = 'T_lbd_diffy_2d'
  if varName in ds.variables:
    tmp = np.ma.masked_invalid(ds_sel[varName].values)
    tmp = tmp[:].filled(0.)
    lbd = tmp.view(C)
    #lbd.units = ds[varName].units
  else:
    lbd = None
    warnings.warn('Lateral boundary mixing term not found. This will result in an underestimation of the heat transport.')

  plt_heat_transport_model_vs_obs(advective, diffusive, lbd, basin_code, grd, args)
  return
예제 #3
0
def horizontal_mean_diff_rms(grd, dcase, basins, args):
    '''
   Compute horizontal mean difference and rms: model versus observations.

   Parameters
  ----------

  grd : OrderedDict
    Dictionary with statistics computed using function myStats_da

  dcase : case object
    Object created using mom6_tools.DiagsCase.

  basins : DataArray
   Basins mask to apply. Returns horizontal mean RMSE for each basin provided.
   Basins must be generated by genBasinMasks.

  args : object
    Object with command line options.

  Returns
  -------
    Plots horizontal mean difference and rms for different basins.

  '''

    RUNDIR = dcase.get_value('RUNDIR')
    area = grd.area_t.where(grd.wet > 0)
    if args.debug: print('RUNDIR:', RUNDIR)
    parallel, cluster, client = request_workers(args.number_of_workers)

    def preprocess(ds):
        if 'thetao' not in ds.variables:
            ds["thetao"] = xr.zeros_like(ds.h)
        if 'so' not in ds.variables:
            ds["so"] = xr.zeros_like(ds.h)

        return ds

    # read dataset
    startTime = datetime.now()
    print('Reading dataset...')
    ds = xr.open_mfdataset(
        RUNDIR + '/' + dcase.casename + '.mom6.h_*.nc',
        parallel=True,
        combine="nested",  # concatenate in order of files
        concat_dim="time",  # concatenate along time
        preprocess=preprocess,
    ).chunk({"time": 12})

    if args.debug:
        print(ds)

    print('Time elasped: ', datetime.now() - startTime)

    print('Selecting data between {} and {}...'.format(args.start_date,
                                                       args.end_date))
    ds = ds.sel(time=slice(args.start_date, args.end_date))

    # Compute climatologies
    thetao_model = ds.thetao.resample(time="1Y", closed='left', keep_attrs=True).mean(dim='time', \
                                      keep_attrs=True)

    salt_model = ds.so.resample(time="1Y", closed='left', keep_attrs=True).mean(dim='time', \
                                 keep_attrs=True)

    # TODO: improve how obs are selected
    if args.obs == 'PHC2':
        # load PHC2 data
        obs_path = '/glade/p/cesm/omwg/obs_data/phc/'
        obs_temp = xr.open_dataset(obs_path +
                                   'PHC2_TEMP_tx0.66v1_34lev_ann_avg.nc',
                                   decode_times=False)
        obs_salt = xr.open_dataset(obs_path +
                                   'PHC2_SALT_tx0.66v1_34lev_ann_avg.nc',
                                   decode_times=False)
        # get theta and salt and rename coordinates to be the same as the model's
        thetao_obs = obs_temp.TEMP.rename({
            'X': 'xh',
            'Y': 'yh',
            'depth': 'z_l'
        })
        salt_obs = obs_salt.SALT.rename({
            'X': 'xh',
            'Y': 'yh',
            'depth': 'z_l'
        })
    elif args.obs == 'WOA18':
        # load WOA18 data
        obs_path = '/glade/u/home/gmarques/Notebooks/CESM_MOM6/WOA18_remapping/'
        obs_temp = xr.open_dataset(obs_path +
                                   'WOA18_TEMP_tx0.66v1_34lev_ann_avg.nc',
                                   decode_times=False)
        obs_salt = xr.open_dataset(obs_path +
                                   'WOA18_SALT_tx0.66v1_34lev_ann_avg.nc',
                                   decode_times=False)
        # get theta and salt and rename coordinates to be the same as the model's
        thetao_obs = obs_temp.theta0.rename({'depth': 'z_l'})
        salt_obs = obs_salt.s_an.rename({'depth': 'z_l'})

    else:
        raise ValueError("The obs selected is not available.")

    # set coordinates to the same as the model's
    thetao_obs['xh'] = thetao_model.xh
    thetao_obs['yh'] = thetao_model.yh
    salt_obs['xh'] = salt_model.xh
    salt_obs['yh'] = salt_model.yh

    # compute difference
    temp_diff = thetao_model - thetao_obs
    salt_diff = salt_model - salt_obs

    # construct a 3D area with land values masked
    area3d = np.repeat(area.values[np.newaxis, :, :],
                       len(temp_diff.z_l),
                       axis=0)
    mask3d = xr.DataArray(area3d,
                          dims=(temp_diff.dims[1:4]),
                          coords={
                              temp_diff.dims[1]: temp_diff.z_l,
                              temp_diff.dims[2]: temp_diff.yh,
                              temp_diff.dims[3]: temp_diff.xh
                          })
    area3d_masked = mask3d.where(temp_diff[0, :] == temp_diff[0, :])

    # Horizontal Mean difference (model - obs)
    print('\n Computing Horizontal Mean difference for temperature...')
    startTime = datetime.now()
    temp_bias = HorizontalMeanDiff_da(temp_diff,
                                      weights=area3d_masked,
                                      basins=basins,
                                      debug=args.debug).rename('temp_bias')
    print('Time elasped: ', datetime.now() - startTime)
    print('\n Computing Horizontal Mean difference for salt...')
    startTime = datetime.now()
    salt_bias = HorizontalMeanDiff_da(salt_diff,
                                      weights=area3d_masked,
                                      basins=basins,
                                      debug=args.debug).rename('salt_bias')
    print('Time elasped: ', datetime.now() - startTime)

    # Horizontal Mean rms (model - obs)
    print('\n Computing Horizontal Mean rms for temperature...')
    startTime = datetime.now()
    temp_rms = HorizontalMeanRmse_da(temp_diff,
                                     weights=area3d_masked,
                                     basins=basins,
                                     debug=args.debug).rename('temp_rms')
    print('Time elasped: ', datetime.now() - startTime)
    print('\n Computing Horizontal Mean rms for salt...')
    salt_rms = HorizontalMeanRmse_da(salt_diff,
                                     weights=area3d_masked,
                                     basins=basins,
                                     debug=args.debug).rename('salt_rms')
    print('Time elasped: ', datetime.now() - startTime)

    if parallel:
        print('Releasing workers...')
        client.close()
        cluster.close()

    print('Saving netCDF files...')
    attrs = {
        'start_date': args.start_date,
        'end_date': args.end_date,
        'casename': dcase.casename,
        'obs': args.obs,
        'module': os.path.basename(__file__)
    }
    add_global_attrs(temp_bias, attrs)
    temp_bias.to_netcdf('ncfiles/' + str(dcase.casename) + '_temp_bias.nc')
    add_global_attrs(salt_bias, attrs)
    salt_bias.to_netcdf('ncfiles/' + str(dcase.casename) + '_salt_bias.nc')
    add_global_attrs(temp_rms, attrs)
    temp_rms.to_netcdf('ncfiles/' + str(dcase.casename) + '_temp_rms.nc')
    add_global_attrs(salt_rms, attrs)
    salt_rms.to_netcdf('ncfiles/' + str(dcase.casename) + '_salt_rms.nc')

    # temperature
    for reg in temp_bias.region:
        print('Generating temperature plots for:', str(reg.values))
        # remove Nan's
        temp_diff_reg = temp_bias.sel(region=reg).dropna('z_l')
        temp_rms_reg = temp_rms.sel(region=reg).dropna('z_l')
        if temp_diff_reg.z_l.max() <= 1000.0:
            splitscale = None
        else:
            splitscale = [0., -1000., -temp_diff_reg.z_l.max()]

        savefig_diff = 'PNG/Horizontal_mean_biases/' + str(
            dcase.casename) + '_' + str(reg.values) + '_temp_diff.png'
        savefig_rms = 'PNG/Horizontal_mean_biases/' + str(
            dcase.casename) + '_' + str(reg.values) + '_temp_rms.png'

        ztplot(temp_diff_reg.values,
               temp_diff_reg.time.values,
               temp_diff_reg.z_l.values * -1,
               ignore=np.nan,
               splitscale=splitscale,
               suptitle=dcase._casename,
               contour=True,
               title=str(reg.values) +
               ', Potential Temperature [C], diff (model - obs)',
               extend='both',
               colormap='dunnePM',
               autocenter=True,
               tunits='Year',
               show=False,
               clim=(-3, 3),
               save=savefig_diff,
               interactive=True)

        ztplot(temp_rms_reg.values,
               temp_rms_reg.time.values,
               temp_rms_reg.z_l.values * -1,
               ignore=np.nan,
               splitscale=splitscale,
               suptitle=dcase._casename,
               contour=True,
               title=str(reg.values) +
               ', Potential Temperature [C], rms (model - obs)',
               extend='both',
               colormap='dunnePM',
               autocenter=False,
               tunits='Year',
               show=False,
               clim=(0, 6),
               save=savefig_rms,
               interactive=True)

        plt.close('all')
    # salinity
    for reg in salt_bias.region:
        print('Generating salinity plots for ', str(reg.values))
        # remove Nan's
        salt_diff_reg = salt_bias.sel(region=reg).dropna('z_l')
        salt_rms_reg = salt_rms.sel(region=reg).dropna('z_l')
        if salt_diff_reg.z_l.max() <= 1000.0:
            splitscale = None
        else:
            splitscale = [0., -1000., -salt_diff_reg.z_l.max()]

        savefig_diff = 'PNG/Horizontal_mean_biases/' + str(
            dcase.casename) + '_' + str(reg.values) + '_salt_diff.png'
        savefig_rms = 'PNG/Horizontal_mean_biases/' + str(
            dcase.casename) + '_' + str(reg.values) + '_salt_rms.png'

        ztplot(salt_diff_reg.values,
               salt_diff_reg.time.values,
               salt_diff_reg.z_l.values * -1,
               ignore=np.nan,
               splitscale=splitscale,
               suptitle=dcase._casename,
               contour=True,
               title=str(reg.values) + ', Salinity [psu], diff (model - obs)',
               extend='both',
               colormap='dunnePM',
               autocenter=True,
               tunits='Year',
               show=False,
               clim=(-1.5, 1.5),
               save=savefig_diff,
               interactive=True)

        ztplot(salt_rms_reg.values,
               salt_rms_reg.time.values,
               salt_rms_reg.z_l.values * -1,
               ignore=np.nan,
               splitscale=splitscale,
               suptitle=dcase._casename,
               contour=True,
               title=str(reg.values) + ', Salinity [psu], rms (model - obs)',
               extend='both',
               colormap='dunnePM',
               autocenter=False,
               tunits='Year',
               show=False,
               clim=(0, 3),
               save=savefig_rms,
               interactive=True)

        plt.close('all')
    return
예제 #4
0
def driver(args):
    nw = args.number_of_workers
    if not os.path.isdir('PNG/TS_levels'):
        print('Creating a directory to place figures (PNG)... \n')
        os.system('mkdir -p PNG/TS_levels')
    if not os.path.isdir('ncfiles'):
        print('Creating a directory to place netCDF files (ncfiles)... \n')
        os.system('mkdir ncfiles')

    # Read in the yaml file
    diag_config_yml = yaml.load(open(args.diag_config_yml_path, 'r'),
                                Loader=yaml.Loader)

    # Create the case instance
    dcase = DiagsCase(diag_config_yml['Case'])
    RUNDIR = dcase.get_value('RUNDIR')
    args.casename = dcase.casename
    print('Run directory is:', RUNDIR)
    print('Casename is:', args.casename)
    print('Number of workers: ', nw)

    # set avg dates
    avg = diag_config_yml['Avg']
    if not args.start_date: args.start_date = avg['start_date']
    if not args.end_date: args.end_date = avg['end_date']

    # read grid info
    grd = MOM6grid(RUNDIR + '/' + args.casename + '.mom6.static.nc')
    grd_xr = MOM6grid(RUNDIR + '/' + args.casename + '.mom6.static.nc',
                      xrformat=True)

    # create masks
    depth = grd.depth_ocean
    # remote Nan's, otherwise genBasinMasks won't work
    depth[np.isnan(depth)] = 0.0
    basin_code = genBasinMasks(grd.geolon, grd.geolat, depth, xda=True)

    # TODO: improve how obs are selected
    if args.obs == 'PHC2':
        # load PHC2 data
        obs_path = '/glade/p/cesm/omwg/obs_data/phc/'
        obs_temp = xr.open_mfdataset(obs_path +
                                     'PHC2_TEMP_tx0.66v1_34lev_ann_avg.nc',
                                     decode_coords=False,
                                     decode_times=False)
        obs_salt = xr.open_mfdataset(obs_path +
                                     'PHC2_SALT_tx0.66v1_34lev_ann_avg.nc',
                                     decode_coords=False,
                                     decode_times=False)
    elif args.obs == 'WOA18':
        # load WOA18 data
        obs_path = '/glade/u/home/gmarques/Notebooks/CESM_MOM6/WOA18_remapping/'
        obs_temp = xr.open_dataset(
            obs_path + 'WOA18_TEMP_tx0.66v1_34lev_ann_avg.nc',
            decode_times=False).rename({'theta0': 'TEMP'})
        obs_salt = xr.open_dataset(obs_path +
                                   'WOA18_SALT_tx0.66v1_34lev_ann_avg.nc',
                                   decode_times=False).rename({'s_an': 'SALT'})
    else:
        raise ValueError("The obs selected is not available.")

    parallel, cluster, client = request_workers(nw)

    print('Reading surface dataset...')
    startTime = datetime.now()
    variables = ['thetao', 'so', 'time', 'time_bnds']

    def preprocess(ds):
        ''' Compute montly averages and return the dataset with variables'''
        return ds[variables]  #.resample(time="1Y", closed='left', \
        #keep_attrs=True).mean(dim='time', keep_attrs=True)

    if parallel:
        ds = xr.open_mfdataset(RUNDIR+'/'+dcase.casename+'.mom6.h_*.nc', \
             parallel=True, data_vars='minimal', \
             coords='minimal', compat='override', preprocess=preprocess)
    else:
        ds = xr.open_mfdataset(RUNDIR+'/'+dcase.casename+'.mom6.h_*.nc', \
             data_vars='minimal', coords='minimal', compat='override', preprocess=preprocess)

    print('Time elasped: ', datetime.now() - startTime)

    print('Selecting data between {} and {}...'.format(args.start_date,
                                                       args.end_date))
    startTime = datetime.now()
    ds = ds.sel(time=slice(args.start_date, args.end_date))
    print('Time elasped: ', datetime.now() - startTime)

    print('\n Computing yearly means...')
    startTime = datetime.now()
    ds = ds.resample(time="1Y", closed='left',
                     keep_attrs=True).mean('time', keep_attrs=True)
    print('Time elasped: ', datetime.now() - startTime)

    print('Time averaging...')
    startTime = datetime.now()
    temp = np.ma.masked_invalid(ds.thetao.mean('time').values)
    salt = np.ma.masked_invalid(ds.so.mean('time').values)
    print('Time elasped: ', datetime.now() - startTime)

    print('Computing stats for different basins...')
    startTime = datetime.now()
    # construct a 3D area with land values masked
    area = np.ma.masked_where(grd.wet == 0, grd.area_t)
    tmp = np.repeat(area[np.newaxis, :, :], len(obs_temp.depth), axis=0)
    area_mom3D = xr.DataArray(tmp,
                              dims=('depth', 'yh', 'xh'),
                              coords={
                                  'depth': obs_temp.depth.values,
                                  'yh': grd.yh,
                                  'xh': grd.xh
                              })
    for k in range(len(area_mom3D.depth)):
        area_mom3D[k, :] = grd_xr.area_t.where(
            grd_xr.depth_ocean >= area_mom3D.depth[k])

    # temp
    thetao_mean = ds.thetao.mean('time')
    temp_diff = thetao_mean.rename({
        'z_l': 'depth'
    }).rename('TEMP') - obs_temp['TEMP']
    temp_stats = myStats_da(temp_diff, area_mom3D,
                            basins=basin_code).rename('thetao_bias_stats')
    # salt
    so_mean = ds.so.mean('time')
    salt_diff = so_mean.rename({
        'z_l': 'depth'
    }).rename('SALT') - obs_salt['SALT']
    salt_stats = myStats_da(salt_diff, area_mom3D,
                            basins=basin_code).rename('so_bias_stats')

    # plots
    depth = temp_stats.depth.values
    basin = temp_stats.basin.values
    interfaces = np.zeros(len(depth) + 1)
    for k in range(1, len(depth) + 1):
        interfaces[k] = interfaces[k -
                                   1] + (2 *
                                         (depth[k - 1] - interfaces[k - 1]))

    reg = np.arange(len(temp_stats.basin.values) + 1)
    figname = 'PNG/TS_levels/' + str(dcase.casename) + '_'

    temp_label = r'Potential temperature [$^o$C]'
    salt_label = 'Salinity [psu]'
    # minimum
    score_plot2(basin,
                interfaces,
                temp_stats[:, 0, :],
                nbins=30,
                cmap=plt.cm.viridis,
                cmin=temp_stats[:, 0, :].min().values,
                units=temp_label,
                fname=figname + 'thetao_bias_min.png',
                title='Minimun temperature difference (model-{})'.format(
                    args.obs))
    score_plot2(basin,
                interfaces,
                salt_stats[:, 0, :],
                nbins=30,
                cmap=plt.cm.viridis,
                cmin=salt_stats[:, 0, :].min().values,
                units=salt_label,
                fname=figname + 'so_bias_min.png',
                title='Minimun salinity difference (model-{})'.format(
                    args.obs))

    # maximum
    score_plot2(basin,
                interfaces,
                temp_stats[:, 1, :],
                nbins=30,
                cmap=plt.cm.viridis,
                cmin=temp_stats[:, 1, :].min().values,
                units=temp_label,
                fname=figname + 'thetao_bias_max.png',
                title='Maximum temperature difference (model-{})'.format(
                    args.obs))
    score_plot2(basin,
                interfaces,
                salt_stats[:, 1, :],
                nbins=30,
                cmap=plt.cm.viridis,
                cmin=salt_stats[:, 1, :].min().values,
                units=salt_label,
                fname=figname + 'so_bias_max.png',
                title='Maximum salinity difference (model-{})'.format(
                    args.obs))

    # mean
    score_plot2(basin,
                interfaces,
                temp_stats[:, 2, :],
                nbins=30,
                cmap=plt.cm.seismic,
                units=temp_label,
                fname=figname + 'thetao_bias_mean.png',
                title='Mean temperature difference (model-{})'.format(
                    args.obs))
    score_plot2(basin,
                interfaces,
                salt_stats[:, 2, :],
                nbins=30,
                cmap=plt.cm.seismic,
                units=salt_label,
                fname=figname + 'so_bias_mean.png',
                title='Mean salinity difference (model-{})'.format(args.obs))

    # std
    score_plot2(basin,
                interfaces,
                temp_stats[:, 3, :],
                nbins=30,
                cmap=plt.cm.viridis,
                cmin=1.0E-15,
                units=temp_label,
                fname=figname + 'thetao_bias_std.png',
                title='Std temperature difference (model-{})'.format(args.obs))
    score_plot2(basin,
                interfaces,
                salt_stats[:, 3, :],
                nbins=30,
                cmap=plt.cm.viridis,
                cmin=1.0E-15,
                units=salt_label,
                fname=figname + 'so_bias_std.png',
                title='Std salinity difference (model-{})'.format(args.obs))
    # rms
    score_plot2(basin,
                interfaces,
                temp_stats[:, 4, :],
                nbins=30,
                cmap=plt.cm.viridis,
                cmin=1.0E-15,
                units=temp_label,
                fname=figname + 'thetao_bias_rms.png',
                title='Rms temperature difference (model-{})'.format(args.obs))
    score_plot2(basin,
                interfaces,
                salt_stats[:, 4, :],
                nbins=30,
                cmap=plt.cm.viridis,
                cmin=1.0E-15,
                units=salt_label,
                fname=figname + 'so_bias_rms.png',
                title='Rms salinity difference (model-{})'.format(args.obs))
    print('Time elasped: ', datetime.now() - startTime)

    print('Saving netCDF files...')
    startTime = datetime.now()
    attrs = {
        'description': 'model - obs at depth levels',
        'start_date': args.start_date,
        'end_date': args.end_date,
        'casename': dcase.casename,
        'obs': args.obs,
        'module': os.path.basename(__file__)
    }
    # create dataset to store results
    add_global_attrs(temp_stats, attrs)
    temp_stats.to_netcdf('ncfiles/' + str(args.casename) +
                         '_thetao_bias_ann_mean_stats.nc')
    add_global_attrs(salt_stats, attrs)
    salt_stats.to_netcdf('ncfiles/' + str(args.casename) +
                         '_so_bias_ann_mean_stats.nc')

    thetao = xr.DataArray(thetao_mean,
                          dims=['z_l', 'yh', 'xh'],
                          coords={
                              'z_l': ds.z_l,
                              'yh': grd.yh,
                              'xh': grd.xh
                          }).rename('thetao')
    temp_bias = np.ma.masked_invalid(thetao.values - obs_temp['TEMP'].values)
    ds_thetao = xr.Dataset(data_vars={
        'thetao': (('z_l', 'yh', 'xh'), thetao),
        'thetao_bias': (('z_l', 'yh', 'xh'), temp_bias)
    },
                           coords={
                               'z_l': ds.z_l,
                               'yh': grd.yh,
                               'xh': grd.xh
                           })
    add_global_attrs(ds_thetao, attrs)

    ds_thetao.to_netcdf('ncfiles/' + str(args.casename) +
                        '_thetao_time_mean.nc')
    so = xr.DataArray(ds.so.mean('time'),
                      dims=['z_l', 'yh', 'xh'],
                      coords={
                          'z_l': ds.z_l,
                          'yh': grd.yh,
                          'xh': grd.xh
                      }).rename('so')
    salt_bias = np.ma.masked_invalid(so.values - obs_salt['SALT'].values)
    ds_so = xr.Dataset(data_vars={
        'so': (('z_l', 'yh', 'xh'), so),
        'so_bias': (('z_l', 'yh', 'xh'), salt_bias)
    },
                       coords={
                           'z_l': ds.z_l,
                           'yh': grd.yh,
                           'xh': grd.xh
                       })
    add_global_attrs(ds_so, attrs)
    ds_so.to_netcdf('ncfiles/' + str(args.casename) + '_so_time_mean.nc')
    print('Time elasped: ', datetime.now() - startTime)

    if parallel:
        print('\n Releasing workers...')
        client.close()
        cluster.close()

    print('Global plots...')
    km = len(obs_temp['depth'])
    for k in range(km):
        if ds['z_l'][k].values < 1200.0:
            figname = 'PNG/TS_levels/' + str(dcase.casename) + '_' + str(
                ds['z_l'][k].values) + '_'
            temp_obs = np.ma.masked_invalid(obs_temp['TEMP'][k, :].values)
            xycompare(temp[k, :],
                      temp_obs,
                      grd.geolon,
                      grd.geolat,
                      area=grd.area_t,
                      title1='model temperature, depth =' +
                      str(ds['z_l'][k].values) + 'm',
                      title2='observed temperature, depth =' +
                      str(obs_temp['depth'][k].values) + 'm',
                      suptitle=dcase.casename + ', averaged ' +
                      str(args.start_date) + ' to ' + str(args.end_date),
                      extend='both',
                      dextend='neither',
                      clim=(-1.9, 30.),
                      dlim=(-2, 2),
                      dcolormap=plt.cm.bwr,
                      save=figname + 'global_temp.png')
            salt_obs = np.ma.masked_invalid(obs_salt['SALT'][k, :].values)
            xycompare(salt[k, :],
                      salt_obs,
                      grd.geolon,
                      grd.geolat,
                      area=grd.area_t,
                      title1='model salinity, depth =' +
                      str(ds['z_l'][k].values) + 'm',
                      title2='observed salinity, depth =' +
                      str(obs_temp['depth'][k].values) + 'm',
                      suptitle=dcase.casename + ', averaged ' +
                      str(args.start_date) + ' to ' + str(args.end_date),
                      extend='both',
                      dextend='neither',
                      clim=(30., 39.),
                      dlim=(-2, 2),
                      dcolormap=plt.cm.bwr,
                      save=figname + 'global_salt.png')

    print('Antarctic plots...')
    for k in range(km):
        if (ds['z_l'][k].values < 1200.):
            temp_obs = np.ma.masked_invalid(obs_temp['TEMP'][k, :].values)
            polarcomparison(temp[k, :],
                            temp_obs,
                            grd,
                            title1='model temperature, depth =' +
                            str(ds['z_l'][k].values) + 'm',
                            title2='observed temperature, depth =' +
                            str(obs_temp['depth'][k].values) + 'm',
                            extend='both',
                            dextend='neither',
                            clim=(-1.9, 10.5),
                            dlim=(-2, 2),
                            dcolormap=plt.cm.bwr,
                            suptitle=dcase.casename + ', averaged ' +
                            str(args.start_date) + ' to ' + str(args.end_date),
                            proj='SP',
                            save=figname + 'antarctic_temp.png')
            salt_obs = np.ma.masked_invalid(obs_salt['SALT'][k, :].values)
            polarcomparison(salt[k, :],
                            salt_obs,
                            grd,
                            title1='model salinity, depth =' +
                            str(ds['z_l'][k].values) + 'm',
                            title2='observed salinity, depth =' +
                            str(obs_temp['depth'][k].values) + 'm',
                            extend='both',
                            dextend='neither',
                            clim=(33., 35.),
                            dlim=(-2, 2),
                            dcolormap=plt.cm.bwr,
                            suptitle=dcase.casename + ', averaged ' +
                            str(args.start_date) + ' to ' + str(args.end_date),
                            proj='SP',
                            save=figname + 'antarctic_salt.png')

    print('Arctic plots...')
    for k in range(km):
        if (ds['z_l'][k].values < 100.):
            temp_obs = np.ma.masked_invalid(obs_temp['TEMP'][k, :].values)
            polarcomparison(temp[k, :],
                            temp_obs,
                            grd,
                            title1='model temperature, depth =' +
                            str(ds['z_l'][k].values) + 'm',
                            title2='observed temperature, depth =' +
                            str(obs_temp['depth'][k].values) + 'm',
                            extend='both',
                            dextend='neither',
                            clim=(-1.9, 11.5),
                            dlim=(-2, 2),
                            dcolormap=plt.cm.bwr,
                            suptitle=dcase.casename + ', averaged ' +
                            str(args.start_date) + ' to ' + str(args.end_date),
                            proj='NP',
                            save=figname + 'arctic_temp.png')
            salt_obs = np.ma.masked_invalid(obs_salt['SALT'][k, :].values)
            polarcomparison(salt[k, :],
                            salt_obs,
                            grd,
                            title1='model salinity, depth =' +
                            str(ds['z_l'][k].values) + 'm',
                            title2='observed salinity, depth =' +
                            str(obs_temp['depth'][k].values) + 'm',
                            extend='both',
                            dextend='neither',
                            clim=(31.5, 35.),
                            dlim=(-2, 2),
                            dcolormap=plt.cm.bwr,
                            suptitle=dcase.casename + ', averaged ' +
                            str(args.start_date) + ' to ' + str(args.end_date),
                            proj='NP',
                            save=figname + 'arctic_salt.png')
    return
예제 #5
0
def xystats(fname, variables, grd, dcase, basins, args):
    '''
   Compute and plot statistics for 2D variables.

   Parameters
  ----------

  fname : str
    Name of the file to be processed.

  variables : str
    List of variables to be processed.

  grd : OrderedDict
    Dictionary with statistics computed using function myStats_da

  dcase : case object
    Object created using mom6_tools.DiagsCase.

  basins : DataArray
   Basins mask to apply. Returns horizontal mean RMSE for each basin provided.
   Basins must be generated by genBasinMasks.

  args : object
    Object with command line options.

  Returns
  -------
    Plots min, max, mean, std and rms for variables provided and for different basins.

  '''
    parallel, cluster, client = request_workers(args.number_of_workers)

    RUNDIR = dcase.get_value('RUNDIR')
    area = grd.area_t.where(grd.wet > 0)

    def preprocess(ds):
        ''' Compute montly averages and return the dataset with variables'''
        return ds[variables].resample(time="1M", closed='left', \
               keep_attrs=True).mean(dim='time', keep_attrs=True)

    # read forcing files
    startTime = datetime.now()
    print('Reading dataset...')
    if parallel:
        ds = xr.open_mfdataset(RUNDIR+'/'+dcase.casename+fname, \
                               chunks={'time': 365}, parallel=True,  data_vars='minimal',
                               coords='minimal', preprocess=preprocess)
    else:
        ds = xr.open_mfdataset(RUNDIR + '/' + dcase.casename + fname,
                               data_vars='minimal',
                               compat='override',
                               coords='minimal',
                               preprocess=preprocess)

    print('Time elasped: ', datetime.now() - startTime)

    for var in variables:
        startTime = datetime.now()
        print('\n Processing {}...'.format(var))
        savefig1 = 'PNG/' + dcase.casename + '_' + str(var) + '_xymean.png'
        savefig2 = 'PNG/' + dcase.casename + '_' + str(var) + '_stats.png'

        # yearly mean
        ds_var = ds[var]
        stats = myStats_da(ds_var,
                           dims=ds_var.dims[1::],
                           weights=area,
                           basins=basins)
        stats.to_netcdf('ncfiles/' + dcase.casename + '_' + str(var) +
                        '_stats.nc')
        plot_stats_da(stats, var, ds_var.attrs['units'], save=savefig2)
        ds_var_mean = ds_var.mean(dim='time')
        ds_var_mean.to_netcdf('ncfiles/' + dcase.casename + '_' + str(var) +
                              '_time_ave.nc')
        dummy = np.ma.masked_invalid(ds_var_mean.values)
        xyplot(dummy,
               grd.geolon.values,
               grd.geolat.values,
               area.values,
               save=savefig1,
               suptitle=ds_var.attrs['long_name'] + ' [' +
               ds_var.attrs['units'] + ']',
               title='Averaged between ' + str(ds_var.time[0].values) +
               ' and ' + str(ds_var.time[-1].values))

        plt.close()
        print('Time elasped: ', datetime.now() - startTime)

    if parallel:
        # close processes
        print('Releasing workers...\n')
        client.close()
        cluster.close()

    return
예제 #6
0
def extract_time_series(fname, variables, grd, dcase, args):
    '''
   Extract time-series and saves annual means.

   Parameters
  ----------

  fname : str
    Name of the file to be processed.

  variables : str
    List of variables to be processed.

  grd : OrderedDict
    Dictionary with statistics computed using function myStats_da

  dcase : case object
    Object created using mom6_tools.DiagsCase.

  args : object
    Object with command line options.

  Returns
  -------
    NetCDF file with annual means.

  '''
    parallel, cluster, client = request_workers(args.number_of_workers)

    RUNDIR = dcase.get_value('RUNDIR')

    def preprocess(ds):
        ''' Compute montly averages and return the dataset with variables'''
        return ds[variables].resample(time="1M", closed='left', \
               keep_attrs=True).mean(dim='time', keep_attrs=True)

    # read forcing files
    startTime = datetime.now()
    print('Reading dataset...')
    if parallel:
        ds = xr.open_mfdataset(RUNDIR+'/'+dcase.casename+fname, \
                               chunks={'time': 365}, parallel=True,  data_vars='minimal',
                               coords='minimal', preprocess=preprocess)
    else:
        ds = xr.open_mfdataset(RUNDIR + '/' + dcase.casename + fname,
                               data_vars='minimal',
                               compat='override',
                               coords='minimal',
                               preprocess=preprocess)

    print('Time elasped: ', datetime.now() - startTime)

    # add attrs and save
    attrs = {'description': 'Annual averages of global mean ocean properties.'}
    add_global_attrs(ds, attrs)
    ds.to_netcdf('ncfiles/' + str(dcase.casename) + '_ann_ave_global_means.nc')
    if parallel:
        # close processes
        print('Releasing workers...\n')
        client.close()
        cluster.close()

    return
예제 #7
0
def main():
  # Get options
  args = options()

  nw = args.number_of_workers
  if not os.path.isdir('ncfiles'):
    print('Creating a directory to place figures (ncfiles)... \n')
    os.system('mkdir ncfiles')

  # Read in the yaml file
  diag_config_yml = yaml.load(open(args.diag_config_yml_path,'r'), Loader=yaml.Loader)

  # Create the case instance
  dcase = DiagsCase(diag_config_yml['Case'])
  args.case_name = dcase.casename
  RUNDIR = dcase.get_value('RUNDIR')
  print('Run directory is:', RUNDIR)
  print('Casename is:', dcase.casename)
  print('Number of workers to be used:', nw)

  # set avg dates
  avg = diag_config_yml['Avg']
  if not args.start_date : args.start_date = avg['start_date']
  if not args.end_date : args.end_date = avg['end_date']

  # read grid info
  grd = MOM6grid(RUNDIR+'/'+dcase.casename+'.mom6.static.nc')
  depth = grd.depth_ocean
  # remote Nan's, otherwise genBasinMasks won't work
  depth[np.isnan(depth)] = 0.0
  basin_code = m6toolbox.genBasinMasks(grd.geolon, grd.geolat, depth)

  parallel, cluster, client = m6toolbox.request_workers(nw)

  print('Reading {} dataset...'.format(args.file_name))
  startTime = datetime.now()
  # load data
  def preprocess(ds):
    variables = ['diftrblo', 'difmxylo' ,'difmxybo', 'diftrelo']
    for v in variables:
      if v not in ds.variables:
        ds[v] = xr.zeros_like(ds.vo)
    return ds[variables]

  ds = xr.open_mfdataset(RUNDIR+'/'+dcase.casename+args.file_name,
  parallel=True,
  combine="nested", # concatenate in order of files
  concat_dim="time", # concatenate along time
  preprocess=preprocess,
  ).chunk({"time": 12})


  print('Time elasped: ', datetime.now() - startTime)
  # compute yearly means first
  print('Computing yearly means...')
  startTime = datetime.now()
  ds_yr = ds.resample(time="1Y", closed='left').mean('time')
  print('Time elasped: ', datetime.now() - startTime)

  print('Selecting data between {} and {}...'.format(args.start_date, args.end_date))
  startTime = datetime.now()
  ds_sel = ds_yr.sel(time=slice(args.start_date, args.end_date))
  print('Time elasped: ', datetime.now() - startTime)

  print('Computing time mean...')
  startTime = datetime.now()
  ds_mean = ds_sel.mean('time').compute()
  print('Time elasped: ', datetime.now() - startTime)

  attrs = {'description': 'Time-mean mixing coefficients', 'units': 'm^2/s', 'start_date': avg['start_date'],
       'end_date': avg['end_date'], 'casename': dcase.casename}
  m6toolbox.add_global_attrs(ds_mean,attrs)

  print('Saving netCDF files...')
  ds_mean.to_netcdf('ncfiles/'+str(args.case_name)+'_avg_mixing_coeffs.nc')

  print('Releasing workers ...')
  client.close(); cluster.close()

  return
예제 #8
0
def time_mean_latlon(args, grd, variables=[]):
    def preprocess(ds):
        ''' Compute montly averages and return the dataset with variables'''
        return ds.resample(time="1Y", closed='left', \
               keep_attrs=True).mean(dim='time', keep_attrs=True)

    if args.nw > 1:
        from mom6_tools.m6toolbox import request_workers
        parallel, cluster, client = request_workers(args.nw)

        ds = xr.open_mfdataset(args.infile, \
             parallel=True, data_vars='minimal', chunks={'time': 12},\
             coords='minimal', compat='override', preprocess=preprocess)
    else:
        ds = xr.open_mfdataset(args.infile, \
             data_vars='minimal', chunks={'time': 12},\
             coords='minimal', compat='override', preprocess=preprocess)

    if len(variables) == 0:
        # plot all 2D varialbles in the dataset
        variables = ds.variables

    ds1 = ds.sel(time=slice(args.start_date, args.end_date)).mean('time')

    for var in variables:
        dim = len(ds1[var].shape)
        if dim == 2:
            filename = str('PNG/%s.png' % (var))
            if os.path.isfile(filename):
                print(' \n' + '==> ' +
                      '{} has been saved, moving to the next one ...\n' +
                      ''.format(var))
            else:
                print("About to plot time-average for {} ({})... \n".format(
                    ds[var].long_name, var))
                data = np.ma.masked_invalid(ds1[var].values)
                units = ds[var].attrs['units']

                if args.savefigs:
                    m6plot.xyplot(data,
                                  grd.geolon,
                                  grd.geolat,
                                  area=grd.area_t,
                                  suptitle=args.case_name,
                                  title=r'%s, [%s] averaged over years %i-%i' %
                                  (var, units, ti, tf),
                                  extend='both',
                                  save=filename)
                else:
                    m6plot.xyplot(data,
                                  grd.geolon,
                                  grd.geolat,
                                  area=grd.area_t,
                                  suptitle=args.case_name,
                                  title=r'%s, [%s] averaged over years %i-%i' %
                                  (var, units, ti, tf),
                                  extend='both',
                                  show=True)

        if args.time_series:
            # create Dataset
            dtime = ds1.time.values
            data = np.ma.masked_invalid(ds1[var].values)
            ds_new = create_xarray_dataset(var, units, dtime)
            # loop in time
            for t in range(0, len(dtime)):
                #print ("==> ' + 'step # {} out of {}  ...\n".format(t+1,tm))
                # get stats
                sMin, sMax, mean, std, rms = m6plot.myStats(
                    data[t], grd.area_t)
                # update Dataset
                ds_new[var][0, t] = sMin
                ds_new[var][1, t] = sMax
                ds_new[var][2, t] = mean
                ds_new[var][3, t] = std
                ds_new[var][4, t] = rms

            # plot
            plot_area_ave_stats(ds_new, var, args)
            #if args.to_netcdf:
            # save in a netcdf file
            #ds.to_netcdf('ncfiles/'+args.case_name+'_stats.nc')
    if args.nw > 1:
        client.close()
        cluster.close()

    return
def driver(args):
  nw = args.number_of_workers
  if not os.path.isdir('PNG/Equatorial'):
    print('Creating a directory to place figures (PNG/Equatorial)... \n')
    os.system('mkdir -p PNG/Equatorial')
  if not os.path.isdir('ncfiles'):
    print('Creating a directory to place netCDF files (ncfiles)... \n')
    os.system('mkdir ncfiles')

  # Read in the yaml file
  diag_config_yml = yaml.load(open(args.diag_config_yml_path,'r'), Loader=yaml.Loader)

  # Create the case instance
  dcase = DiagsCase(diag_config_yml['Case'])
  RUNDIR = dcase.get_value('RUNDIR')
  args.casename = dcase.casename
  print('Run directory is:', RUNDIR)
  print('Casename is:', args.casename)
  print('Number of workers: ', nw)

  # set avg dates
  avg = diag_config_yml['Avg']
  if not args.start_date : args.start_date = avg['start_date']
  if not args.end_date : args.end_date = avg['end_date']

  # read grid info
  grd = MOM6grid(RUNDIR+'/'+args.casename+'.mom6.static.nc', xrformat=True)
  # select Equatorial region
  grd_eq = grd.sel(yh=slice(-10,10))

  # load obs
  phc_path = '/glade/p/cesm/omwg/obs_data/phc/'
  phc_temp = xr.open_mfdataset(phc_path+'PHC2_TEMP_tx0.66v1_34lev_ann_avg.nc', decode_coords=False, decode_times=False)
  phc_salt = xr.open_mfdataset(phc_path+'PHC2_SALT_tx0.66v1_34lev_ann_avg.nc', decode_coords=False, decode_times=False)
  johnson = xr.open_dataset('/glade/p/cesm/omwg/obs_data/johnson_pmel/meanfit_m.nc')

  # get T and S and rename variables
  thetao_obs = phc_temp.TEMP.rename({'X': 'xh','Y': 'yh', 'depth': 'z_l'});
  salt_obs = phc_salt.SALT.rename({'X': 'xh','Y': 'yh', 'depth': 'z_l'});

  parallel, cluster, client = request_workers(nw)

  print('Reading surface dataset...')
  startTime = datetime.now()
  #variables = ['thetao', 'so', 'uo', 'time', 'time_bnds', 'e']

  #def preprocess(ds):
  #  ''' Compute yearly averages and return the dataset with variables'''
  #  return ds[variables].resample(time="1Y", closed='left', \
  #         keep_attrs=True).mean(dim='time', keep_attrs=True)

  # load data
  def preprocess(ds):
    variables = ['thetao', 'so', 'uo', 'time', 'time_bnds', 'e']
    return ds[variables]

  if parallel:
    ds = xr.open_mfdataset(RUNDIR+'/'+dcase.casename+'.mom6.h_*.nc', \
         parallel=True, data_vars='minimal', \
         coords='minimal', compat='override', preprocess=preprocess)
  else:
    ds = xr.open_mfdataset(RUNDIR+'/'+dcase.casename+'.mom6.monthly_*.nc', concat_dim=['time'],\
         data_vars='minimal', coords='minimal', compat='override', preprocess=preprocess)

  print('Time elasped: ', datetime.now() - startTime)

  # set obs coords to be same as model
  thetao_obs['xh'] = ds.xh; thetao_obs['yh'] = ds.yh;
  salt_obs['xh'] = ds.xh; salt_obs['yh'] = ds.yh;

  print('Selecting data between {} and {} (time) and -10 to 10 (yh)...'.format(args.start_date, \
        args.end_date))
  startTime = datetime.now()
  ds = ds.sel(time=slice(args.start_date, args.end_date)).sel(yh=slice(-10,10)).isel(z_i=slice(0,15)).isel(z_l=slice(0,14))
  print('Time elasped: ', datetime.now() - startTime)


  print('Yearly mean...')
  startTime = datetime.now()
  ds = ds.resample(time="1Y", closed='left',keep_attrs=True).mean('time',keep_attrs=True).compute()
  print('Time elasped: ', datetime.now() - startTime)

  print('Time averaging...')
  startTime = datetime.now()
  thetao = ds.thetao.mean('time')
  so = ds.so.mean('time')
  uo = ds.uo.mean('time')
  eta = ds.e.mean('time')
  # find point closest to eq. and select data
  j = np.abs( grd_eq.geolat[:,0].values - 0. ).argmin()
  temp_eq = np.ma.masked_invalid(thetao.isel(yh=j).values)
  salt_eq = np.ma.masked_invalid(so.isel(yh=j).values)
  u_eq    = np.ma.masked_invalid(uo.isel(yh=j).values)
  e_eq    = np.ma.masked_invalid(eta.isel(yh=j).values)
  thetao_obs_eq = np.ma.masked_invalid(thetao_obs.sel(yh=slice(-10,10)).isel(yh=j).isel(z_l=slice(0,14)).values)
  salt_obs_eq = np.ma.masked_invalid(salt_obs.sel(yh=slice(-10,10)).isel(yh=j).isel(z_l=slice(0,14)).values)
  print('Time elasped: ', datetime.now() - startTime)

  if parallel:
    print('\n Releasing workers...')
    client.close(); cluster.close()

  print('Equatorial Upper Ocean plots...')
  y = ds.yh.values
  zz = ds.z_i.values
  x = ds.xh.values
  [X, Z] = np.meshgrid(x, zz)
  z = 0.5 * ( Z[:-1] + Z[1:])

  figname = 'PNG/Equatorial/'+str(dcase.casename)+'_'
  yzcompare(temp_eq , thetao_obs_eq, x, -Z,
            title1 = 'model temperature', ylabel='Longitude', yunits='',
            title2 = 'observed temperature (PHC/Levitus)', #contour=True,
            suptitle=dcase.casename + ', averaged '+str(args.start_date)+ ' to ' +str(args.end_date),
            extend='neither', dextend='neither', clim=(6,31.), dlim=(-5,5), dcolormap=plt.cm.bwr,
            save=figname+'Equatorial_Global_temperature.png')

  yzcompare(salt_eq , salt_obs_eq, x, -Z,
        title1 = 'model salinity', ylabel='Longitude', yunits='',
        title2 = 'observed salinity (PHC/Levitus)', #contour=True,
        suptitle=dcase.casename + ', averaged '+str(args.start_date)+ ' to ' +str(args.end_date),
        extend='neither', dextend='neither', clim=(33.5,37.), dlim=(-1,1), dcolormap=plt.cm.bwr,
        save=figname+'Equatorial_Global_salinity.png')

  # create dataarays and saving data
  temp_eq_da = xr.DataArray(temp_eq, dims=['zl','xh'],
                           coords={'zl' : z[:,0], 'xh' : x[:]}).rename('temp_eq')

  temp_eq_da.to_netcdf('ncfiles/'+str(args.casename)+'_temp_eq.nc')
  salt_eq_da = xr.DataArray(salt_eq, dims=['zl','xh'],
                           coords={'zl' : z[:,0], 'xh' : x[:]}).rename('salt_eq')

  salt_eq_da.to_netcdf('ncfiles/'+str(args.casename)+'_salt_eq.nc')

  # Shift model data to compare against obs
  tmp, lonh = shiftgrid(thetao.xh[-1].values, thetao[0,0,:].values, ds.thetao.xh.values)
  tmp, lonq = shiftgrid(uo.xq[-1].values, uo[0,0,:].values, uo.xq.values)

  thetao['xh'].values[:] = lonh
  so['xh'].values[:] = lonh
  uo['xq'].values[:] = lonq

  # y and z from obs
  y_obs = johnson.YLAT11_101.values
  zz = np.arange(0,510,10)
  [Y, Z_obs] = np.meshgrid(y_obs, zz)
  z_obs = 0.5 * ( Z_obs[0:-1,:] + Z_obs[1:,] )

  # y and z from model
  y_model = thetao.yh.values
  z = eta.z_i.values
  [Y, Z_model] = np.meshgrid(y_model, z)
  z_model = 0.5 * ( Z_model[0:-1,:] + Z_model[1:,:] )

  # longitutes to be compared
  longitudes = [143., 156., 165., 180., 190., 205., 220., 235., 250., 265.]

  for l in longitudes:
    # Temperature
    fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(16,8))
    dummy_model = np.ma.masked_invalid(thetao.sel(xh=l, method='nearest').values)
    dummy_obs = np.ma.masked_invalid(johnson.POTEMPM.sel(XLON=l, method='nearest').values)
    yzplot(dummy_model, y_model, -Z_model, clim=(7,30), axis=ax1, zlabel='Depth', ylabel='Latitude', title=str(dcase.casename))
    cs1 = ax1.contour( y_model + 0*z_model, -z_model, dummy_model, levels=np.arange(0,30,2), colors='k',); plt.clabel(cs1,fmt='%3.1f', fontsize=14)
    ax1.set_ylim(-400,0)
    yzplot(dummy_obs, y_obs, -Z_obs, clim=(7,30), axis=ax2, zlabel='Depth', ylabel='Latitude', title='Johnson et al (2002)')
    cs2 = ax2.contour( y_obs + 0*z_obs, -z_obs, dummy_obs, levels=np.arange(0,30,2), colors='k',); plt.clabel(cs2,fmt='%3.1f', fontsize=14)
    ax2.set_ylim(-400,0)
    plt.suptitle('Temperature [C] @ '+str(l)+ ', averaged between '+str(args.start_date)+' and '+str(args.end_date))
    plt.savefig(figname+'temperature_'+str(l)+'.png')

    # Salt
    fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(16,8))
    dummy_model = np.ma.masked_invalid(so.sel(xh=l, method='nearest').values)
    dummy_obs = np.ma.masked_invalid(johnson.SALINITYM.sel(XLON=l, method='nearest').values)
    yzplot(dummy_model, y_model, -Z_model, clim=(32,36), axis=ax1, zlabel='Depth', ylabel='Latitude', title=str(dcase.casename))
    cs1 = ax1.contour( y_model + 0*z_model, -z_model, dummy_model, levels=np.arange(32,36,0.5), colors='k',); plt.clabel(cs1,fmt='%3.1f', fontsize=14)
    ax1.set_ylim(-400,0)
    yzplot(dummy_obs, y_obs, -Z_obs, clim=(32,36), axis=ax2, zlabel='Depth', ylabel='Latitude', title='Johnson et al (2002)')
    cs2 = ax2.contour( y_obs + 0*z_obs, -z_obs, dummy_obs, levels=np.arange(32,36,0.5), colors='k',); plt.clabel(cs2,fmt='%3.1f', fontsize=14)
    ax2.set_ylim(-400,0)
    plt.suptitle('Salinity [psu] @ '+str(l)+ ', averaged between '+str(args.start_date)+' and '+str(args.end_date))
    plt.savefig(figname+'salinity_'+str(l)+'.png')

    # uo
    fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(16,8))
    dummy_model = np.ma.masked_invalid(uo.sel(xq=l, method='nearest').values)
    dummy_obs = np.ma.masked_invalid(johnson.UM.sel(XLON=l, method='nearest').values)
    yzplot(dummy_model, y_model, -Z_model, clim=(-0.6,1.2), axis=ax1, zlabel='Depth', ylabel='Latitude', title=str(dcase.casename))
    cs1 = ax1.contour( y_model + 0*z_model, -z_model, dummy_model, levels=np.arange(-1.2,1.2,0.1), colors='k',); plt.clabel(cs1,fmt='%3.1f', fontsize=14)
    ax1.set_ylim(-400,0)
    yzplot(dummy_obs, y_obs, -Z_obs, clim=(-0.6,1.2), axis=ax2, zlabel='Depth', ylabel='Latitude', title='Johnson et al (2002)')
    cs2 = ax2.contour( y_obs + 0*z_obs, -z_obs, dummy_obs, levels=np.arange(-1.2,1.2,0.1), colors='k',); plt.clabel(cs2,fmt='%3.1f', fontsize=14)
    ax2.set_ylim(-400,0)
    plt.suptitle('Eastward velocity [m/s] @ '+str(l)+ ', averaged between '+str(args.start_date)+' and '+str(args.end_date))
    plt.savefig(figname+'uo_'+str(l)+'.png')

  # Eastward velocity [m/s] along the Equatorial Pacific
  x_obs = johnson.XLON.values
  [X_obs, Z_obs] = np.meshgrid(x_obs, zz)
  z_obs = 0.5 * ( Z_obs[:-1,:] + Z_obs[1:,:] )

  x_model = so.xh.values
  z = eta.z_i.values
  [X, Z_model] = np.meshgrid(x_model, z)
  z_model = 0.5 * ( Z_model[:-1,:] + Z_model[1:,:] )

  #from mom6_tools.m6plot import
  fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(16,8))
  dummy_obs = np.ma.masked_invalid(johnson.UM.sel(YLAT11_101=0).values)
  dummy_model = np.ma.masked_invalid(uo.sel(yh=0, method='nearest').values)
  yzplot(dummy_model, x_model, -Z_model, clim=(-0.6,1.2), axis=ax1, landcolor=[0., 0., 0.], title=str(dcase.casename), ylabel='Longitude')
  cs1 = ax1.contour( x_model + 0*z_model, -z_model, dummy_model, levels=np.arange(-1.2,1.2,0.1),  colors='k'); plt.clabel(cs1,fmt='%2.1f', fontsize=14)
  ax1.set_xlim(143,265); ax1.set_ylim(-400,0)
  yzplot(dummy_obs, x_obs, -Z_obs, clim=(-0.4,1.2), ylabel='Longitude', yunits='',  axis=ax2, title='Johnson et al (2002)')
  cs1 = ax2.contour( x_obs + 0*z_obs, -z_obs, dummy_obs,  levels=np.arange(-1.2,1.2,0.1), colors='k'); plt.clabel(cs1,fmt='%2.1f', fontsize=14)
  ax2.set_xlim(143,265); ax2.set_ylim(-400,0)
  plt.suptitle('Eastward velocity [m/s] along the Equatorial Pacific, averaged between '+str(args.start_date)+' and '+str(args.end_date))
  plt.savefig(figname+'Equatorial_Pacific_uo.png')

  plt.close('all')
  return