Exemple #1
0
def plot_frames(
        dataset, variable_name, 
        **kwargs):

    lon = get_data(dataset, 'lon')
    lat = get_data(dataset, 'lat')
    data = get_data(dataset, variable_name) 
    data = data.where(
        (lon > float(kwargs['mpMinLonF'])) & (lon <= float(kwargs['mpMaxLonF'])) &
        (lat > float(kwargs['mpMinLatF'])) & (lat <= float(kwargs['mpMaxLatF']))
    )
    print('Min of data:', data.min().values)

    # Get data range so that all frames will be consistent
    print('Get mins/maxes over entire dataset...'); sys.stdout.flush()
    kwargs['cnLevels'] = get_contour_levels(data)
    kwargs['cnLevelSelectionMode'] = 'ExplicitLevels'

    # Loop over time series and make a plot for each
    os.makedirs('tmp_frames', exist_ok=True)
    print('Looping over %i time indices'%len(dataset.time)); sys.stdout.flush()
    frames = []
    for i in range(len(dataset.time)):
        frame_name = f'tmp_frames/{variable_name}.{i}.png'
        pl = plot_frame(lon.isel(time=i).values, lat.isel(time=i).values, data.isel(time=i).values, frame_name, **kwargs)
        frames.append(frame_name)
        update_progress(i, len(dataset.time), bar_length=10)

    # Return list of frames
    return frames
Exemple #2
0
def main(var_name, animation_name, *inputfiles, **kwargs):
    with dask.config.set(scheduler='single-threaded'):

        # Get some info about dataset by opening once
        with open_mfdataset(sorted(inputfiles),
                            drop_variables=('P3_input_dim', 'P3_output_dim'),
                            chunks={'time': 1}) as ds:
            # Find mins/maxes
            data = get_data(ds, var_name)
            lon = get_data(ds, 'lon')
            lat = get_data(ds, 'lat')
            if 'mpMinLonF' in kwargs.keys():
                data = data.where(lon > float(kwargs['mpMinLonF']))
            if 'mpMaxLonF' in kwargs.keys():
                data = data.where(lon <= float(kwargs['mpMaxLonF']))
            if 'mpMinLatF' in kwargs.keys():
                data = data.where(lat > float(kwargs['mpMinLatF']))
            if 'mpMaxLatF' in kwargs.keys():
                data = data.where(lat <= float(kwargs['mpMaxLatF']))
            if 'vmin' not in kwargs.keys(): kwargs['vmin'] = data.min().values
            if 'vmax' not in kwargs.keys(): kwargs['vmax'] = data.max().values
            #percentile = 2
            #cmin = min([numpy.nanpercentile(data.values, percentile) for da in data_arrays])
            #cmax = max([numpy.nanpercentile(data.values, 100-percentile) for da in data_arrays])
            # Find length of time dimension to loop over later
            time = ds['time'].copy(deep=True)
            ntime = len(ds.time)

        # Make a bunch of plots, save separately, then stitch together later
        frames = []
        print('Loop over time series...')
        sys.stdout.flush()
        for i in range(ntime):
            frame_name = f'tmp_frames/{var_name}.{i}.png'
            kwargs['tiMainString'] = str(time.isel(time=i).values)
            # Try to run this as a subprocess to prevent an excruciatingly
            # frustating memory leak
            #plot_map(var_name, frame_name, *inputfiles, time_index=i, **kwargs)
            args = [
                "./e3smplot/pyngl/plot_map.py", var_name, frame_name,
                *inputfiles, f"time_index={i}",
                *[f"{k}={v}" for k, v in kwargs.items()]
            ]
            subprocess.run(args)
            # Trim frame
            subprocess.run(
                f'convert -trim {frame_name} {frame_name}'.split(' '))
            frames.append(frame_name)
            update_progress(i + 1, ntime)

    print('Animate frames...')
    sys.stdout.flush()
    animate_frames(animation_name, frames)

    print('Remove temporary files...')
    sys.stdout.flush()
    for frame in frames:
        os.remove(frame)
def get_ice_cloud_mask(ds, threshold=1e-5):
    cldice = get_data(ds, 'CLDICE')
    cld_mask = cldice.where(
        cldice > threshold).notnull()  #(cldice > threshold)
    cld_mask.attrs = {
        'long_name': 'Ice cloud mask',
        'units': 'none',
        'description': f'CLDICE > {threshold}',
    }
    return cld_mask
Exemple #4
0
def get_liq_cld_mask(ds, threshold=1e-5):
    cldliq = get_data(ds, 'CLDLIQ')
    cld_mask = (cldliq > threshold)  #(cldliq > threshold)
    cld_mask.attrs = {
        'name': 'liq_cld_mask',
        'long_name': 'Liquid cloud mask',
        'units': 'none',
        'description': f'CLDLIQ > {threshold}',
    }
    return cld_mask
Exemple #5
0
def plot_frames(dataset, variable_name, **kwargs):

    # Get data range so that all frames will be consistent
    if 'vmin' not in kwargs.keys():
        kwargs['vmin'] = get_data(dataset, variable_name).min().values
    if 'vmax' not in kwargs.keys():
        kwargs['vmax'] = get_data(dataset, variable_name).max().values

    # Loop over time series and make a plot for each
    frames = []
    print('Looping over %i time indices' % len(dataset.time))
    sys.stdout.flush()
    for i in range(len(dataset.time)):
        frame_name = '%s.%i.png' % (variable_name, i)
        plot_frame(dataset.isel(time=i), variable_name, frame_name, **kwargs)
        frames.append(frame_name)
        update_progress(i + 1, len(dataset.time))

    # Return list of frames
    return frames
def get_liq_cloud_mask(ds, threshold=1e-5):
    cldliq = get_data(ds, 'CLDLIQ')
    cld_mask = cldliq.where(
        cldliq > threshold).notnull()  #(cldliq > threshold)
    #cld_mask = cldliq.copy()
    cld_mask.attrs = {
        'long_name': 'Liquid cloud mask',
        'units': 'none',
        'description': f'CLDLIQ > {threshold}',
    }
    return cld_mask
Exemple #7
0
def main(vname, mapfile, inputfile, outputfile, **kwargs):

    # Read mapping file
    ds_map = xarray.open_mfdataset(mapfile)

    # Read data
    ds = xarray.open_mfdataset(
        inputfile, drop_variables=('P3_input_dim', 'P3_output_dim'),
        chunks={'time': 1, 'lev': 1}, engine='netcdf4',
    )
    da = get_data(ds, vname)

    # Remap
    da_out = apply_map(ds_map, da, ndim=1)
    print('Write to file...'); sys.stdout.flush()
    da_out.to_netcdf(outputfile)
Exemple #8
0
def main(vname, inputfile, outputfile, **kwargs):

    # Process kwargs
    for k, v in kwargs.items():
        kwargs[k] = eval(v)

    # Open dataset
    ds = xarray.open_mfdataset(inputfile,
                               drop_variables=('P3_input_dim',
                                               'P3_output_dim'),
                               **kwargs)

    # Compute cloud masks and save to disk
    da = get_data(ds, vname)
    ds_out = xarray.Dataset({vname: da})
    ds_out.to_netcdf(outputfile, encoding={vname: {'_FillValue': -9999}})

    # Clean up
    ds_out.close()
Exemple #9
0
def main(varname,
         plotname,
         *datafiles,
         gridfile=None,
         time_index=None,
         vmin=None,
         vmax=None,
         **kwargs):

    # Read data
    ds_data = xarray.open_mfdataset(
        sorted(datafiles),
        chunks={'time': 1},
        drop_variables=('P3_input_dim', 'P3_output_dim'),
    )
    data = get_data(ds_data, varname)
    if gridfile is not None:
        ds_grid = xarray.open_dataset(gridfile).rename({'grid_size': 'ncol'})
        if 'lon' in ds_grid and 'lat' in ds_grid:
            x = ds_grid['lon']
            y = ds_grid['lat']
        elif 'grid_corner_lon' in ds_grid and 'grid_corner_lat' in ds_grid:
            x = ds_grid['grid_corner_lon']
            y = ds_grid['grid_corner_lat']
        else:
            raise RuntimeError('No valid coordinates in grid file.')
    else:
        x = ds_data['lon']
        y = ds_data['lat']

    # Make sure we don't have time or level dimensions
    if 'time' in data.dims:
        if time_index is None:
            data = data.mean(dim='time', keep_attrs=True).squeeze()
        else:
            data = data.isel(time=int(time_index))
    if 'lev' in data.dims:
        data = data.isel(lev=-1).squeeze()
    if 'time' in x.dims: x = x.isel(time=0)
    if 'time' in y.dims: y = y.isel(time=0)

    # Setup the canvas
    wks = ngl.open_wks(
        os.path.splitext(plotname)[1][1:],
        os.path.splitext(plotname)[0])

    # Get contour levels; the explicit type casting deals with problems calling
    # this standalone code using subprocess.run() with string arguments, where
    # all kwargs are going to be interpreted as strings
    if vmin is None: vmin = data.min().values
    if vmax is None: vmax = data.max().values
    if float(vmin) < 0 and float(vmax) > 0:
        *__, clevels = nice_cntr_levels(float(vmin),
                                        float(vmax),
                                        returnLevels=True,
                                        max_steps=13,
                                        aboutZero=True)
    else:
        *__, clevels = nice_cntr_levels(float(vmin),
                                        float(vmax),
                                        returnLevels=True,
                                        max_steps=13)
    kwargs['cnLevels'] = clevels  #get_contour_levels(data)
    kwargs['cnLevelSelectionMode'] = 'ExplicitLevels'

    # Make plot
    if 'lbTitleString' not in kwargs.keys():
        kwargs['lbTitleString'] = f'{data.long_name} ({data.units})'
    plot = plot_map(wks,
                    x.values,
                    y.values,
                    data.values,
                    mpGeophysicalLineColor='white',
                    lbOrientation='horizontal',
                    cnFillMode='RasterFill',
                    cnLineLabelsOn=False,
                    cnLinesOn=False,
                    **kwargs)

    ngl.destroy(wks)
Exemple #10
0
        # Open file
        ds_v = xarray.open_dataset(f_d)
        ds_p = xarray.open_dataset(f_p)

        # Empty dict to hold remapped variables
        ds_out = xarray.Dataset({})

        # Loop over variables in file
        for v in ds_v.variables.keys():

            # If this variable does not have a level dimension, move on
            if 'lev' not in ds_v.variables[v].dims or v in exclude_variables:
                continue

            # Read data
            d = get_data(ds_v, v)
            p = get_data(ds_p, 'PMID')

            # Do vertical remap
            axis = 1
            shape_out = list(d.shape)
            shape_out[axis] = len(levels)
            dims_out = list(d.dims)
            dims_out[axis] = 'plev'
            coords_out = {
                c: (levels if c == 'plev' else d.coords[c])
                for c in dims_out
            }
            d_on_levels = xarray.DataArray(
                np.empty(shape_out),
                name=v,
def main(test_files,
         cntl_files,
         names,
         vname,
         fig_name,
         maps=None,
         time_offsets=None,
         t1=None,
         t2=None,
         dpi=400,
         verbose=False,
         **kwargs):

    if time_offsets is None: time_offsets = [None for x in files]
    if maps is None: maps = [None for x in files]

    # Load datasets
    print('Load data...')
    sys.stdout.flush()
    #with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    datasets = [
        open_dataset(f, time_offset=dt, chunks={'time': 1})
        for (f, dt) in zip((test_files, cntl_files), time_offsets)
    ]

    # Subset for overlapping time periods
    if verbose:
        print('Get overlapping time range...')
        sys.stdout.flush()
    if t1 is None: t1 = max([ds.time[0].values for ds in datasets])
    if t2 is None: t2 = min([ds.time[-1].values for ds in datasets])
    datasets = [ds.sel(time=slice(str(t1), str(t2))) for ds in datasets]

    # Rename
    datasets = [
        ds.rename({'level': 'plev'}) if 'level' in ds.dims else ds
        for ds in datasets
    ]

    # Select (and remask) data
    if verbose:
        print('Select data...')
        sys.stdout.flush()
    data_arrays = [mask_all_zero(get_data(ds, vname)) for ds in datasets]
    coords = [get_coords(ds) for ds in datasets]
    weights = [get_area_weights(ds) for ds in datasets]

    for ii in range(len(weights)):
        *__, weights[ii] = xarray.broadcast(data_arrays[ii], weights[ii])

    # TODO: we need to interpolate to common pressure levels here

    # Compute time averages
    if verbose:
        print('Compute time averages...')
        sys.stdout.flush()
    means = [da.mean(dim='time', keep_attrs=True) for da in data_arrays]
    weights = [
        wgt.mean(dim='time', keep_attrs=True) if 'time' in wgt.dims else wgt
        for wgt in weights
    ]

    # Remap data to a common grid (since we will both compute zonal mean and
    # compute diffs)
    if verbose:
        print('Remap to common grid...')
        sys.stdout.flush()
    dims = means[-1].dims[-2:]
    coords = {d: means[-1].coords[d] for d in dims}
    means = [
        apply_weights_wrap(f, m, x=coords['lon'], y=coords['lat'])
        if f is not None else m for (f, m) in zip(maps, means)
    ]
    weights = [
        apply_weights_wrap(f, m, x=coords['lon'], y=coords['lat'])
        if f is not None else m for (f, m) in zip(maps, weights)
    ]

    # Compute *zonal* average. Note that this is tricky for unstructured data.
    # Our approach is to compute means over latitude bins, and return a new
    # coordinate for latitude of the zonal mean. The function defined in
    # e3sm_utils takes the data, a set of latitude weights, and the input
    # latitudes, and optionally the number of new latitude bands within which to
    # average the data, and returns the binned/averaged data and the new
    # latitudes.
    #if map_file is not None:
    #    print('Apply map...')
    #    means[0] = apply_map(means[0], map_file, template=means[1])
    #    weights[0] = apply_map(weights[0], map_file, template=means[1])
    #    lats[0] = lats[1]
    #    weights, *__ = zip(*[xarray.broadcast(w, d) for (w, d) in zip(weights, means)])
    #    means = [zonal_mean(d, weights=w) for (d, w) in zip(means, weights)]
    #else:
    #    print('Try using our slow zonal mean routine...')
    #    means, lats = zip(*[calculate_zonal_mean(d, w, y) for (d, w, y) in zip(means, weights, lats)])
    if verbose:
        print('Compute zonal means...')
        sys.stdout.flush()
    weights, *__ = zip(
        *[xarray.broadcast(w, d) for (w, d) in zip(weights, means)])
    means = [zonal_mean(d, weights=w) for (d, w) in zip(means, weights)]

    # Make plots of zonal averages
    if verbose:
        print('Make pcolor plots of zonal means...')
        sys.stdout.flush()
    figure = make_zonal_profile_figure(means, names)
    figure.savefig(fig_name, bbox_inches='tight', dpi=dpi)

    # Finally, trim whitespace from our figures
    if verbose:
        print('Trimming whitespace from figure...')
        sys.stdout.flush()
    subprocess.call(f'convert -trim {fig_name} {fig_name}'.split(' '))
Exemple #12
0
def main(outputfile, variable_name, *inputfiles, **kwargs):
    import xarray
    import gc
    # Open files
    #dataset = open_files(sorted(inputfiles))
    with xarray.open_mfdataset(
        sorted(inputfiles), drop_variables=('P3_input_dim',
        'P3_output_dim'), chunks={'time': 1}
    ) as ds:

        # Get data
        data = get_data(ds, variable_name)
        lon = get_data(ds, 'lon')
        lat = get_data(ds, 'lat')

        # Subset
        data = data.where(
            (lon > float(kwargs['mpMinLonF'])) & (lon <= float(kwargs['mpMaxLonF'])) &
            (lat > float(kwargs['mpMinLatF'])) & (lat <= float(kwargs['mpMaxLatF']))
        )

        # Get mins and maxes
        minval = data.min().values
        maxval = data.max().values
        print(f'Data range: {minval} to {maxval}')

        # Get contour levels
        *__, clevels = nice_cntr_levels(minval, maxval, returnLevels=True, max_steps=13)
        kwargs['cnLevels'] = clevels #get_contour_levels(data)
        kwargs['cnLevelSelectionMode'] = 'ExplicitLevels'
        print('clevels: ', kwargs['cnLevels'])

        ntime = len(ds.time)

    # Loop
    print('Looping over %i time indices'%len(ds.time)); sys.stdout.flush()
    frames = []
    for i in range(ntime):
        gc.collect()
        with xarray.open_mfdataset(
            sorted(inputfiles), drop_variables=('P3_input_dim',
            'P3_output_dim'), chunks={'time': 1}
        ) as ds:
            data = get_data(ds, variable_name).isel(time=i).values
            lon = get_data(ds, 'lon').isel(time=i).values
            lat = get_data(ds, 'lat').isel(time=i).values
            plot_name = f'tmp_frames/{variable_name}.{i}.png'
            frames.append(plot_frame(lon, lat, data, plot_name, **kwargs))
            update_progress(i, ntime, bar_length=10)
            del data
            del lon
            del lat

    # Pull out keyward args
    #animate_kw = {}
    #for key in ('time_per_frame',):
    #    if key in kwargs.keys():
    #        animate_kw[key] = kwargs.pop(key)

    # Plot frames
    #frames = plot_frames(dataset, variable_name, **kwargs)

    # Stitch together frames into single animation
    animate_frames(outputfile, frames, **animate_kw)

    # Clean up
    remove_frames(frames)
Exemple #13
0
def plot_frame(dataset,
               variable_name,
               frame_name,
               lat_min=None,
               lat_max=None,
               lon_min=None,
               lon_max=None,
               **kwargs):

    # Select data
    data = get_data(dataset, variable_name).squeeze()
    lon = get_data(dataset, 'lon').squeeze()
    lat = get_data(dataset, 'lat').squeeze()

    # Fix coordinates
    lon = fix_longitudes(lon)

    # Open figure
    figure, axes = pyplot.subplots(
        figsize=(10, 8),
        subplot_kw=dict(projection=crs.PlateCarree(central_longitude=180)))

    # Set extent
    if all(v is not None for v in [lat_min, lat_max, lon_min, lon_max]):
        axes.set_extent(
            [float(lon_min),
             float(lon_max),
             float(lat_min),
             float(lat_max)])
        if 'ncol' in data.dims:
            criteria = ((lon >= float(lon_min)) & (lon <= float(lon_max)) &
                        (lat >= float(lat_min)) & (lat <= float(lat_max)))

            data = data.where(criteria).dropna('ncol')
            lon = lon.where(criteria).dropna('ncol')
            lat = lat.where(criteria).dropna('ncol')

    # Plot data
    if 'ncol' in data.dims:
        pl = axes.tripcolor(lon,
                            lat,
                            data,
                            transform=crs.PlateCarree(),
                            **kwargs)
    else:
        pl = axes.pcolormesh(lon,
                             lat,
                             data.transpose('lat', 'lon'),
                             transform=crs.PlateCarree(),
                             **kwargs)

    # Label plot
    axes.set_title('time = %s' % (data['time'].values))
    axes.coastlines()

    # Add a colorbar
    cb = pyplot.colorbar(pl,
                         orientation='horizontal',
                         label='%s (%s)' % (data.long_name, data.units),
                         shrink=0.8,
                         pad=0.02)

    # Save figure
    figure.savefig(frame_name, dpi=100)
    pyplot.close()
Exemple #14
0
def main(varname,
         outputfile,
         testfiles,
         cntlfiles,
         t1=None,
         t2=None,
         maps=None,
         percentile=5,
         verbose=False,
         **kwargs):
    #
    # Open datasets if needed (may also pass datasets directly rather than filenames)
    #
    if verbose: myprint('Open datasets...')
    datasets = [
        f if isinstance(f, xarray.Dataset) else open_dataset(*f)
        for f in (testfiles, cntlfiles)
    ]
    #
    # Subset data
    #
    if verbose: myprint('Subset consistent time periods...')
    if t1 is None: t1 = max([ds.time[0].values for ds in datasets])
    if t2 is None: t2 = min([ds.time[-1].values for ds in datasets])
    if verbose: myprint('Comparing period {} to {}'.format(str(t1), str(t2)))
    datasets = [ds.sel(time=slice(str(t1), str(t2))) for ds in datasets]
    #
    # Compute time average
    #
    if verbose: myprint('Compute time averages...')
    datasets = [ds.mean(dim='time', keep_attrs=True) for ds in datasets]
    #
    # Read selected data from file
    # TODO: set case names
    #
    if verbose: myprint('Get data...')
    data_arrays = [get_data(ds, varname) for ds in datasets]
    lons = [get_data(ds, 'lon') for ds in datasets]
    lats = [get_data(ds, 'lat') for ds in datasets]
    #
    # Area needed for weighted average calculation
    #
    if verbose: myprint('Get area weights...')
    area_arrays = [get_area_weights(ds) for ds in datasets]
    #
    # Remap if needed
    #
    if verbose: myprint('Remap to lat/lon grid if needed...')
    if maps is not None:
        map_datasets = [
            xarray.open_dataset(f) if f is not None else None for f in maps
        ]
        area_arrays = [
            apply_map(m, f)[0] if f is not None else m
            for (m, f) in zip(area_arrays, map_datasets)
        ]
        data_arrays = [
            apply_map(m, f)[0] if f is not None else m
            for (m, f) in zip(data_arrays, map_datasets)
        ]

    # redefine lons and lats after remap
    # TODO: this is not unstructured grid-friendly
    lons = [d.lon for d in data_arrays]
    lats = [d.lat for d in data_arrays]
    #
    # Compute differences
    #
    if verbose: myprint('Compute differences...')
    data_arrays.append(data_arrays[0] - data_arrays[1])
    area_arrays.append(area_arrays[0])
    lons.append(lons[0])
    lats.append(lats[0])
    data_arrays[-1].attrs = data_arrays[0].attrs
    data_arrays[-1].attrs['description'] = 'test minus control'
    #
    # Pop labels out of kwargs
    #
    if 'labels' in kwargs.keys():
        labels = (*kwargs.pop('labels'), 'Difference')
    else:
        labels = ('test', 'cntl', 'diff')
    #
    # Make figure
    #
    if verbose: myprint('Make plots...')
    figure, axes = pyplot.subplots(
        1,
        3,
        figsize=(15, 5),
        subplot_kw=dict(projection=crs.PlateCarree(central_longitude=180)))
    cmaps = ['viridis', 'viridis', 'RdBu_r']
    vmin = min([
        numpy.nanpercentile(da.values, percentile) for da in data_arrays[:-1]
    ])
    vmax = max([
        numpy.nanpercentile(da.values, 100 - percentile)
        for da in data_arrays[:-1]
    ])
    vmins = [vmin, vmin, -abs(data_arrays[-1].max().values)]
    vmaxs = [vmax, vmax, abs(data_arrays[-1].max().values)]
    plots = [
        plot_map(lons[i],
                 lats[i],
                 data_arrays[i],
                 axes=axes[i],
                 cmap=cmaps[i],
                 vmin=vmins[i],
                 vmax=vmaxs[i],
                 **kwargs) for i in range(len(data_arrays))
    ]
    #
    # Annotate maps
    #
    if verbose: myprint('Annotate plots...')
    for i in range(len(data_arrays)):
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            data_mean = area_average(data_arrays[i], area_arrays[i])
        label = labels[i]
        axes[i].set_title(
            f'{label}\nmin = {data_arrays[i].min().values:.2f}; max = {data_arrays[i].max().values:.2f}; mean = {data_mean.values:.2f}'
        )
    #
    # Save figure
    #
    figure.savefig(outputfile, bbox_inches='tight')
    pyplot.close()
output_path = '/global/cfs/cdirs/e3sm/terai/SCREAM/DYAMOND2/Output/20201127'
all_files = glob(f'{output_path}/*.eam.h[0-9].*.nc')
figure, ax = pyplot.subplots(1, 1)  # figsize=(10, 10))

for ivar, v in enumerate(variable_names):
    # find files
    print(v)
    these_files = [f for f in all_files if can_retrieve_field(f, v)]

    # Load files
    print('load dataset')
    ds = open_dataset(sorted(these_files), chunks={'time': 1, 'lev': 1})

    # Get data
    print('get data')
    data = get_data(ds, v)

    # Compute area averages
    print('compute averages')
    #w = get_area_weights(ds)
    #m = area_average(data, w, dims=[d for d in data.dims if d != 'time'])
    m = data.mean(dim=[d for d in data.dims if d != 'time'])

    # Convert units
    if False:  #data.attrs['units'] == 'none':
        print('Convert units')
        m = 100.0 * m
        m.attrs['units'] = '%'

    # Plot timeseries
    print('print values')