def pgradient(var, lat1, lat2, lon1, lon2, plev):
    """Return d/dp of a lat-lon variable."""
    pwidth = 100
    p1, p2 = plev - pwidth, plev + pwidth
    var = atm.subset(var, {'lat' : (lat1, lat2), 'lon' : (lon1, lon2),
                           'plev' : (p1, p2)}, copy=False)
    latlonstr = latlon_filestr(lat1, lat2, lon1, lon2)
    attrs = var.attrs
    pname = atm.get_coord(var, 'plev', 'name')
    pdim = atm.get_coord(var, 'plev', 'dim')
    pres = var[pname]
    pres = atm.pres_convert(pres, pres.attrs['units'], 'Pa')
    dvar_dp = atm.gradient(var, pres, axis=pdim)
    dvar_dp = atm.subset(dvar_dp, {pname : (plev, plev)}, copy=False,
                         squeeze=True)
    varnm = 'D%sDP' % var.name
    name = '%s%d' % (varnm, plev)
    dvar_dp.name = name
    attrs['long_name'] = 'd/dp of ' + var.attrs['long_name']
    attrs['standard_name'] = 'd/dp of ' + var.attrs['standard_name']
    attrs['units'] = ('(%s)/Pa' % attrs['units'])
    attrs[pname] = plev
    attrs['filestr'] = '%s_%s' % (name, latlonstr)
    attrs['varnm'] = varnm
    dvar_dp.attrs = attrs
    return dvar_dp
def daily_corr(ind1, ind2, yearnm='year'):
    if ind1.name == ind2.name:
        raise ValueError('ind1 and ind2 have the same name ' + ind1.name)
    years = ind1[yearnm]
    corr = np.zeros(years.shape)
    for y, year in enumerate(years):
        subset_dict = {yearnm : (year, None)}
        df = atm.subset(ind1, subset_dict).to_series().to_frame(name=ind1.name)
        df[ind2.name] = atm.subset(ind2, subset_dict).to_series()
        corr[y] = df.corr().as_matrix()[0, 1]
    corr = pd.Series(corr, index=pd.Index(years, name=yearnm))
    return corr
def all_data(datafiles, npre, npost, lon1, lon2, compdays, comp_attrs):
    # Read daily data fields aligned relative to onset day
    data = collections.OrderedDict()
    sectordata = collections.OrderedDict()
    comp = collections.OrderedDict()
    sectorcomp = collections.OrderedDict()
    sector_latmax = {}

    for varnm in datafiles:
        print('Reading daily data for ' + varnm)
        var, onset, retreat = utils.load_dailyrel(datafiles[varnm])
        var = atm.subset(var, {'dayrel' : (-npre, npost)})
        var = housekeeping(var)

        # Compute sector mean and composite averages
        sectorvar = atm.dim_mean(var, 'lon', lon1, lon2)
        compvar = get_composites(var, compdays, comp_attrs)
        sectorcompvar = get_composites(sectorvar, compdays, comp_attrs)

        # Latitude of maximum subcloud theta_e
        if varnm == 'THETA_E950' or varnm == 'THETA_E_LML':
            sector_latmax[varnm] = theta_e_latmax(sectorvar)

        # Compute regression or take the climatology
        if 'year' in var.dims:
            var = atm.dim_mean(var, 'year')
            sectorvar = atm.dim_mean(sectorvar, 'year')
            compvar = atm.dim_mean(compvar, 'year')
            sectorcompvar = atm.dim_mean(sectorcompvar, 'year')

        # Pack everything into dicts for output
        data[varnm], sectordata[varnm] = var, sectorvar
        comp[varnm], sectorcomp[varnm] = compvar, sectorcompvar

    return data, sectordata, sector_latmax, comp, sectorcomp
Beispiel #4
0
def composite(data, compdays, return_avg=True, daynm='Dayrel'):
    """Return composite data fields for selected days.

    Parameters
    ----------
    data : xray.DataArray
        Daily data to composite.
    compdays: dict of arrays or lists
        Lists of days to include in each composite.
    return_avg : bool, optional
        If True, return the mean of the selected days, otherwise
        return the extracted individual days for each composite.
    daynnm : str, optional
        Name of day dimension in data.

    Returns
    -------
    comp : dict of xray.DataArrays
        Composite data fields for each key in compdays.keys().
    """

    comp = collections.OrderedDict()
    _, attrs, _, _ = atm.meta(data)

    for key in compdays:
        comp[key] = atm.subset(data, {daynm : (compdays[key], None)})
        if return_avg:
            comp[key] = comp[key].mean(dim=daynm)
            comp[key].attrs = attrs
            comp[key].attrs[daynm] = compdays[key]

    return comp
Beispiel #5
0
def wrapyear(data, data_prev, data_next, daymin, daymax, year=None):
    """Wrap daily data from previous and next years for extended day ranges.
    """
    daynm = atm.get_coord(data, 'day', 'name')

    def leap_adjust(data, year):
        data = atm.squeeze(data)
        ndays = 365
        if year is not None and atm.isleap(year):
            ndays += 1
        else:
            # Remove NaN for day 366 in non-leap year
            data = atm.subset(data, {'day' : (1, ndays)})
        return data, ndays

    data, ndays = leap_adjust(data, year)
    if data_prev is not None:
        data_prev, ndays_prev = leap_adjust(data_prev, year - 1)
        data_prev[daynm] = data_prev[daynm] - ndays_prev
        data_out = xray.concat([data_prev, data], dim=daynm)
    else:
        data_out = data
    if data_next is not None:
        data_next, _ = leap_adjust(data_next, year + 1)
        data_next[daynm] = data_next[daynm] + ndays
        data_out = xray.concat([data_out, data_next], dim=daynm)
    data_out = atm.subset(data_out, {daynm : (daymin, daymax)})

    return data_out
Beispiel #6
0
def eddy_decomp(var, nt, lon1, lon2, taxis=0):
    """Decompose variable into mean and eddy fields."""

    lonname = atm.get_coord(var, 'lon', 'name')
    tstr = 'Time mean (%d-%s rolling)' % (nt, var.dims[taxis])
    lonstr = atm.latlon_labels([lon1, lon2], 'lon', deg_symbol=False)
    lonstr = 'zonal mean (' + '-'.join(lonstr) + ')'
    name, attrs, coords, dims = atm.meta(var)

    varbar = atm.rolling_mean(var, nt, axis=taxis, center=True)
    varbarzon = atm.subset(varbar, {lonname : (lon1, lon2)})
    varbarzon = varbarzon.mean(dim=lonname)
    varbarzon.attrs = attrs

    comp = xray.Dataset()
    comp[name + '_AVG'] = varbarzon
    comp[name + '_AVG'].attrs['component'] = tstr + ', ' + lonstr
    comp[name + '_ST'] = varbar - varbarzon
    comp[name + '_ST'].attrs = attrs
    comp[name + '_ST'].attrs['component'] = 'Stationary eddy'
    comp[name + '_TR'] = var - varbar
    comp[name + '_TR'].attrs = attrs
    comp[name + '_TR'].attrs['component'] = 'Transient eddy'

    return comp
Beispiel #7
0
def contourf_latday(var,
                    is_precip=False,
                    clev=None,
                    cticks=None,
                    climits=None,
                    nc_pref=40,
                    grp=None,
                    xlims=(-120, 200),
                    xticks=np.arange(-120, 201, 30),
                    ylims=(-60, 60),
                    yticks=np.arange(-60, 61, 20),
                    dlist=None,
                    grid=False,
                    ind_nm='onset',
                    xlabels=True):
    """Create a filled contour plot of data on latitude-day grid.
    """
    var = atm.subset(var, {'lat': ylims})
    vals = var.values.T
    lat = atm.get_coord(var, 'lat')
    days = atm.get_coord(var, 'dayrel')
    if var.min() < 0:
        symmetric = True
    else:
        symmetric = False
    if is_precip:
        cmap = 'PuBuGn'
        extend = 'max'
    else:
        cmap = 'RdBu_r'
        extend = 'both'

    if clev == None:
        cint = atm.cinterval(vals, n_pref=nc_pref, symmetric=symmetric)
        clev = atm.clevels(vals, cint, symmetric=symmetric)
    elif len(atm.makelist(clev)) == 1:
        if is_precip:
            clev = np.arange(0, 10 + clev / 2.0, clev)
        else:
            clev = atm.clevels(vals, clev, symmetric=symmetric)

    plt.contourf(days, lat, vals, clev, cmap=cmap, extend=extend)
    plt.colorbar(ticks=cticks)
    plt.clim(climits)
    atm.ax_lims_ticks(xlims, xticks, ylims, yticks)
    plt.grid(grid)

    if dlist is not None:
        for d0 in dlist:
            plt.axvline(d0, color='k')
    if xlabels:
        plt.gca().set_xticklabels(xticks)
        plt.xlabel('Days Since ' + ind_nm.capitalize())
    else:
        plt.gca().set_xticklabels([])
    if grp is not None and grp.col == 0:
        plt.ylabel('Latitude')

    return None
Beispiel #8
0
 def leap_adjust(data, year):
     data = atm.squeeze(data)
     ndays = 365
     if year is not None and atm.isleap(year):
         ndays += 1
     else:
         # Remove NaN for day 366 in non-leap year
         data = atm.subset(data, {'day' : (1, ndays)})
     return data, ndays
def ssn_average(var, onset, retreat, season):
    years = var['year'].values
    for y, year in enumerate(years):
        days = season_days(season, year, onset.values[y], retreat.values[y])
        var_yr = atm.subset(var, {'year' : (year, year)}, squeeze=False)
        var_yr = var_yr.sel(dayrel=days).mean(dim='dayrel')
        if y == 0:
            var_out = var_yr
        else:
            var_out = xray.concat([var_out, var_yr], dim='year')
    return var_out
Beispiel #10
0
def get_strength_indices(years, mfc, precip, onset, retreat, yearnm='year',
                         daynm='day', varnm1='MFC', varnm2='PCP'):
    """Return various indices of the monsoon strength.

    Inputs mfc and precip are the unsmoothed daily values averaged over
    the monsoon area.
    """

    ssn = xray.Dataset()
    coords = {yearnm : years}
    ssn['onset'] = xray.DataArray(onset, coords=coords)
    ssn['retreat'] = xray.DataArray(retreat, coords=coords)
    ssn['length'] = ssn['retreat'] - ssn['onset']

    data_in = {}
    if mfc is not None:
        data_in[varnm1] = mfc
    if precip is not None:
        data_in[varnm2] = precip

    for key in data_in:
        for key2 in ['_JJAS_AVG', '_JJAS_TOT', '_LRS_AVG', '_LRS_TOT']:
            ssn[key + key2] = xray.DataArray(np.nan * np.ones(len(years)),
                                             coords=coords)

    for key in data_in:
        for y, year in enumerate(years):
            d1 = int(onset.values[y])
            d2 = int(retreat.values[y] - 1)
            days_jjas = atm.season_days('JJAS', atm.isleap(year))
            data = atm.subset(data_in[key], {yearnm : (year, None)})
            data_jjas = atm.subset(data, {daynm : (days_jjas, None)})
            data_lrs = atm.subset(data, {daynm : (d1, d2)})
            ssn[key + '_JJAS_AVG'][y] = data_jjas.mean(dim=daynm).values
            ssn[key + '_LRS_AVG'][y] = data_lrs.mean(dim=daynm).values
            ssn[key + '_JJAS_TOT'][y] = ssn[key + '_JJAS_AVG'][y] * len(days_jjas)
            ssn[key + '_LRS_TOT'][y] = ssn[key + '_LRS_AVG'][y] * ssn['length'][y]

    ssn = ssn.to_dataframe()
    return ssn
def latlon_data(var, lat1, lat2, lon1, lon2, plev=None):
    """Extract lat-lon subset of data."""
    name = var.name
    varnm = name
    subset_dict = {'lat' : (lat1, lat2), 'lon' : (lon1, lon2)}
    latlonstr = latlon_filestr(lat1, lat2, lon1, lon2)
    if plev is not None:
        name = name + '%d' % plev
        subset_dict['plev'] = (plev, plev)
    var = atm.subset(var, subset_dict, copy=False, squeeze=True)
    var.name = name
    var.attrs['filestr'] = '%s_%s' % (name, latlonstr)
    var.attrs['varnm'] = varnm
    return var
Beispiel #12
0
def get_data_rel(varid, plev, years, datafiles, data, onset, npre, npost):
    """Return daily data aligned relative to onset/withdrawal day.
    """

    years = atm.makelist(years)
    onset = atm.makelist(onset)
    datafiles = atm.makelist(datafiles)

    daymin = min(onset) - npre
    daymax = max(onset) + npost

    # For a single year, add extra year before/after, if necessary
    wrap_single = False
    years_in = years
    if len(years) == 1 and var_type(varid) == 'basic':
        filenm = datafiles[0]
        year = years[0]
        if daymin < 1:
            wrap_single = True
            file_pre = filenm.replace(str(year), str(year - 1))
            if os.path.isfile(file_pre):
                years_in = [year - 1] + years_in
                datafiles = [file_pre] + datafiles
        if daymax > len(atm.season_days('ANN', year)):
            wrap_single = True
            file_post = filenm.replace(str(year), str(year + 1))
            if os.path.isfile(file_post):
                years_in = years_in + [year + 1]
                datafiles = datafiles + [file_post]

    var = get_daily_data(varid, plev, years_in, datafiles, data, daymin=daymin,
                         daymax=daymax)

    # Get rid of extra years
    if wrap_single:
        var = atm.subset(var, {'year' : (years[0], years[0])})

    # Make sure year dimension is included for single year
    if len(years) == 1 and 'year' not in var.dims:
        var = atm.expand_dims(var, 'year', years[0], axis=0)

    # Align relative to onset day
    # (not needed for calc variables since they're already aligned)
    if var_type(varid) == 'basic':
        print('Aligning data relative to onset day')
        var = daily_rel2onset(var, onset, npre, npost)

    return var
Beispiel #13
0
def daily_rel2onset(data, d_onset, npre, npost):
    """Return subset of daily data aligned relative to onset day.

    Parameters
    ----------
    data : xray.DataArray
        Daily data.
    d_onset : ndarray
        Array of onset date (day of year) for each year.
    npre, npost : int
        Number of days before and after onset to extract.

    Returns
    -------
    data_out : xray.DataArray
        Subset of N days of daily data for each year, where
        N = npre + npost + 1 and the day dimension is
        dayrel = day - d_onset.
    """

    name, attrs, coords, dimnames = atm.meta(data)
    yearnm = atm.get_coord(data, 'year', 'name')
    daynm = atm.get_coord(data, 'day', 'name')
    years = atm.makelist(atm.get_coord(data, 'year'))

    if isinstance(d_onset, xray.DataArray):
        d_onset = d_onset.values
    else:
        d_onset = atm.makelist(d_onset)

    relnm = daynm + 'rel'

    for y, year in enumerate(years):
        dmin, dmax = d_onset[y] - npre, d_onset[y] + npost
        subset_dict = {yearnm : (year, None), daynm : (dmin, dmax)}
        sub = atm.subset(data, subset_dict)
        sub = sub.rename({daynm : relnm})
        sub[relnm] = sub[relnm] - d_onset[y]
        sub[relnm].attrs['long_name'] = 'Day of year relative to onset day'
        if y == 0:
            data_out = sub
        else:
            data_out = xray.concat([data_out, sub], dim=yearnm)

    data_out.attrs['d_onset'] = d_onset

    return data_out
Beispiel #14
0
def plot_index_years(index, nrow=3, ncol=4,
                     fig_kw={'figsize' : (11, 7), 'sharex' : True,
                             'sharey' : True},
                     gridspec_kw={'left' : 0.1, 'right' : 0.95, 'wspace' : 0.05,
                                  'hspace' : 0.1},
                     incl_fit=False, suptitle='', xlabel='Day', ylabel='Index',
                     xlims=None, ylims=None, xticks=np.arange(0, 401, 100),
                     grid=True):
    """Plot daily timeseries of monsoon onset/retreat index each year.
    """

    years = atm.get_coord(index, 'year')
    days = atm.get_coord(index, 'day')
    grp = atm.FigGroup(nrow, ncol, fig_kw=fig_kw, gridspec_kw=gridspec_kw,
                       suptitle=suptitle)
    for year in years:
        grp.next()
        ind = atm.subset(index, {'year' : (year, year)}, squeeze=True)
        ts = ind['tseries']
        d0_list = [ind['onset'], ind['retreat']]
        plt.plot(days, ts, 'k')
        for d0 in d0_list:
            plt.axvline(d0, color='k')
        if incl_fit and 'tseries_fit_onset' in ind:
            plt.plot(days, ind['tseries_fit_onset'], 'r')
        if incl_fit and 'tseries_fit_retreat' in ind:
            plt.plot(days, ind['tseries_fit_retreat'], 'b')
        atm.text(year, (0.05, 0.9))
        atm.ax_lims_ticks(xlims=xlims, ylims=ylims, xticks=xticks)
        plt.grid(grid)
        if grp.row == grp.nrow - 1:
            plt.xlabel(xlabel)
        if grp.col == 0:
            plt.ylabel(ylabel)

    return grp
Beispiel #15
0
        latbuf = 5
        lat = atm.get_coord(data[key1][varnm], 'lat')
        latbig = atm.biggify(lat, data[key1][varnm], tile=True)
        vals = data[key1][varnm].values
        vals = np.where(abs(latbig)>latbuf, vals, np.nan)
        data[key1][varnm].values = vals

# ----------------------------------------------------------------------
# Sector mean data

lonname, latname = 'XDim', 'YDim'
sectordata = {}
for key1 in data:
    sectordata[key1] = collections.OrderedDict()
    for varnm in data[key1]:
        var = atm.subset(data[key1][varnm], {lonname : (lon1, lon2)})
        sectordata[key1][varnm] = var.mean(dim=lonname)

# ----------------------------------------------------------------------
# Plotting params and utilities

def plusminus(num):
    if num == 0:
        numstr = '+0'
    else:
        numstr = atm.format_num(num, ndecimals=0, plus_sym=True)
    return numstr

# ----------------------------------------------------------------------
# Composite averages
print('Computing composites relative to onset day')
Beispiel #16
0
def plot_tseries_together(data, onset=None, years=None, suptitle='',
                          figsize=(14,10), legendsize=10,
                          legendloc='lower right', nrow=3, ncol=4,
                          yearnm='year', daynm='day', standardize=True,
                          label_attr=None, data_style=None, onset_style=None,
                          show_days=False):
    """Plot multiple daily timeseries together each year.

    Parameters
    ----------
    data : xray.Dataset
        Dataset of timeseries variables to plot together.
    onset : ndarray or dict of ndarrays, optional
        Array of onset day for each year, or dict of onset arrays (e.g.
        to compare onset days from different methods).
    years : ndarray, optional
        Subset of years to include.  If omitted, all years are included.
    suptitle : str, optional
        Supertitle for plot.
    figsize : 2-tuple, optional
        Size of each figure.
    legendsize : int, optional
        Font size for legend
    legendloc : str, optional
        Legend location
    nrow, ncol : int, optional
        Number of rows, columns in each figure.
    yearnm, daynm : str, optional
        Name of year and day dimensions in data.
    standardize : bool, optional
        If True, standardize each timeseries by dividing by its
        standard deviation.
    label_attr : str, optional
        Attribute of each data variable to use for labels.  If omitted,
        then the variable name is used.
    data_style, onset_style : list or dict, optional
        Matlab-style strings for each data variable or onset index.
    show_days : bool, optional
        If True, annotate each subplot with a textbox showing the
        onset days.
    """

    if years is None:
        # All years
        years = data[yearnm].values
    data = atm.subset(data, {yearnm : (years, None)})

    if label_attr is not None:
        labels = {nm : data[nm].attrs[label_attr] for nm in data.data_vars}

    if onset is not None:
        if isinstance(onset, dict):
            if onset_style is None:
                onset_style = {key : 'k' for key in onset.keys()}
        else:
            onset = {'onset' : onset}
            if onset_style is None:
                onset_style = {'onset' : 'k'}
        textpos = {key : (0.05, 0.9 - 0.1*i) for i, key in enumerate(onset)}

    # Plot each year
    for y, year in enumerate(years):
        df = atm.subset(data, {yearnm : (year, None)}).to_dataframe()
        df.drop(yearnm, axis=1, inplace=True)
        if label_attr is not None:
            df.rename(columns=labels, inplace=True)

        if standardize:
            for key in df.columns:
                df[key] = (df[key] - np.nanmean(df[key])) / np.nanstd(df[key])
            ylabel = 'Standardized Timeseries'
        else:
            ylabel = 'Timeseries'

        if y % (nrow * ncol) == 0:
            fig, axes = plt.subplots(nrow, ncol, figsize=figsize, sharex=True)
            plt.subplots_adjust(left=0.08, right=0.95, wspace=0.2, hspace=0.2)
            plt.suptitle(suptitle)
            yplot = 1
        else:
            yplot += 1

        i, j = atm.subplot_index(nrow, ncol, yplot)
        ax = axes[i-1, j-1]
        df.plot(ax=ax, style=data_style)
        ax.grid()
        if yplot == 1:
            ax.legend(fontsize=legendsize, loc=legendloc)
        else:
            ax.legend_.remove()
        if onset is not None:
            for key in onset:
                d0 = onset[key][y]
                ax.plot([d0, d0], ax.get_ylim(), onset_style[key])
                if show_days:
                    atm.text(d0, textpos[key], ax=ax, color=onset_style[key])
        if j == 1:
            ax.set_ylabel(ylabel)
        if i == nrow:
            ax.set_xlabel('Day')
        else:
            ax.set_xlabel('')
        ax.set_title(year)
Beispiel #17
0
def onset_changepoint(precip_acc, onset_range=(1, 250),
                      retreat_range=(201, 366), order=1, yearnm='year',
                      daynm='day'):
    """Return monsoon onset/retreat based on changepoint in precip.

    Uses piecewise least-squares fit of data to detect changepoints.

    Parameters
    ----------
    precip_acc : xray.DataArray
        Accumulated precipitation.
    onset_range, retreat_range : 2-tuple of ints, optional
        Range of days to use when calculating onset / retreat.
    order : int, optional
        Order of polynomial to fit.
    yearnm, daynm : str, optional
        Name of year and day dimensions in precip_acc.

    Returns
    -------
    chp : xray.Dataset
        Onset/retreat days, daily timeseries, piecewise polynomial
        fits, and rss values.

    Reference
    ---------
    Cook, B. I., & Buckley, B. M. (2009). Objective determination of
    monsoon season onset, withdrawal, and length. Journal of Geophysical
    Research, 114(D23), D23109. doi:10.1029/2009JD012795
    """

    def split(x, n):
        return x[:n], x[n:]

    def piecewise_polyfit(x, y, n, order=1):
        y = np.ma.masked_array(y, np.isnan(y))
        x1, x2 = split(x, n)
        y1, y2 = split(y, n)
        p1 = np.ma.polyfit(x1, y1, order)
        p2 = np.ma.polyfit(x2, y2, order)
        if np.isnan(p1).any() or np.isnan(p2).any():
            raise ValueError('NaN for polyfit coeffs. Check data.')
        ypred1 = np.polyval(p1, x1)
        ypred2 = np.polyval(p2, x2)
        ypred = np.concatenate([ypred1, ypred2])
        rss = np.sum((y - ypred)**2)
        return ypred, rss

    def find_changepoint(x, y, order=1):
        rss = np.nan * x
        for n in range(2, len(x)- 2):
            _, rssval = piecewise_polyfit(x, y, n, order)
            rss[n] = rssval
        n0 = np.nanargmin(rss)
        x0 = x[n0]
        ypred, _ = piecewise_polyfit(x, y, n0)
        return x0, ypred, rss

    if yearnm not in precip_acc.dims:
        precip_acc = atm.expand_dims(precip_acc, yearnm, -1, axis=0)
    years = precip_acc[yearnm].values
    chp = xray.Dataset()
    chp['tseries'] = precip_acc

    for key, drange in zip(['onset', 'retreat'], [onset_range, retreat_range]):
        print('Calculating ' + key)
        print(drange)
        dmin, dmax = drange
        precip_sub = atm.subset(precip_acc, {daynm : (dmin, dmax)})
        dsub = precip_sub[daynm].values

        d_cp = np.nan * np.ones(years.shape)
        pred = np.nan * np.ones(precip_sub.shape)
        rss = np.nan * np.ones(precip_sub.shape)
        for y, year in enumerate(years):
            # Cut out any NaNs from day range
            pcp_yr = precip_sub[y]
            ind = np.where(np.isfinite(pcp_yr))[0]
            islice = slice(ind.min(), ind.max() + 1)
            pcp_yr = pcp_yr[islice]
            days_yr = pcp_yr[daynm].values
            print('%d (%d-%d)' % (year, min(days_yr), max(days_yr)))
            results = find_changepoint(days_yr, pcp_yr, order)
            d_cp[y], pred[y, islice], rss[y, islice] = results
        chp[key] = xray.DataArray(d_cp, dims=[yearnm], coords={yearnm : years})
        chp['tseries_fit_' + key] = xray.DataArray(
            pred, dims=[yearnm, daynm], coords={yearnm : years, daynm : dsub})
        chp['rss_' + key] = xray.DataArray(
            rss, dims=[yearnm, daynm], coords={yearnm : years, daynm : dsub})

    chp.attrs['order'] = order
    chp.attrs['onset_range'] = onset_range
    chp.attrs['retreat_range'] = retreat_range

    return chp
    ds = atm.combine_daily_years(None, relfiles, years, yearname='year')
    ds = ds.mean(dim='year')
    ds.attrs['years'] = years
    print('Saving to ' + savefile)
    ds.to_netcdf(savefile)

# ----------------------------------------------------------------------
# Concatenate plevels in climatology and save

files = [savestr % plev + '_' + yearstr + '.nc' for plev in plevs]
ubudget = xray.Dataset()
pname, pdim = 'Height', 1
subset_dict = {'lat' : (-60, 60), 'lon' : (40, 120)}
for i, plev in enumerate(plevs):
    filenm = files[i]
    print('Loading ' + filenm)
    with xray.open_dataset(filenm) as ds:
        ds = atm.subset(ds, subset_dict)
        ds.load()
    for nm in ds.data_vars:
        ds[nm] = atm.expand_dims(ds[nm], pname, plev, axis=pdim)
    if i == 0:
        ubudget = ds
    else:
        ubudget = xray.concat([ubudget, ds], dim=pname)
ubudget.coords[pname].attrs['units'] = 'hPa'
savefile = files[0]
savefile = savefile.replace('%d' % plevs[0], '')
print('Saving to ' + savefile)
ubudget.to_netcdf(savefile)
import matplotlib.pyplot as plt
import xray
import atmos as atm

# ----------------------------------------------------------------------
# Read some data from OpenDAP url

url = ('http://goldsmr3.sci.gsfc.nasa.gov/opendap/MERRA_MONTHLY/'
    'MAIMCPASM.5.2.0/1979/MERRA100.prod.assim.instM_3d_asm_Cp.197901.hdf')

ds = xray.open_dataset(url)
ps = ds['PS'] / 100

plt.figure()
atm.pcolor_latlon(ps, cmap='jet')

lon1, lon2 = 0, 100
lat1, lat2 = -45, 45
ps_sub = atm.subset(ps, 'lon', lon1, lon2, 'lat', lat1, lat2)
plt.figure()
atm.pcolor_latlon(ps_sub, cmap='jet')
index = utils.get_onset_indices(onset_nm, indfiles, years)
mfc = atm.rolling_mean(index['ts_daily'], nroll, center=True)
onset = index['onset']
ssn_length=index['length'].mean(dim='year')

data = {}
data['MFC'] = utils.daily_rel2onset(mfc, onset, npre, npost)
data[pcp_nm] = utils.daily_rel2onset(pcp, onset, npre, npost)
data['MFC_ACC'] = utils.daily_rel2onset(index['tseries'], onset, npre, npost)

for nm in varnms:
    print('Loading ' + relfiles[nm])
    with xray.open_dataset(relfiles[nm]) as ds:
        if nm == 'PSI':
            data[nm] = atm.streamfunction(ds['V'])
            psimid = atm.subset(data[nm], {'plev' : (pmid, pmid)},
                                squeeze=True)
            psimid.name = 'PSI%d' % pmid
            data['PSI%d' % pmid] = psimid
        elif nm == 'VFLXLQV':
            var = atm.dim_mean(ds['VFLXQV'], 'lon', lon1, lon2)
            data[nm] = var * atm.constants.Lv.values
        elif nm == theta_nm:
            theta = ds[nm]
            _, _, dtheta = atm.divergence_spherical_2d(theta, theta)
            data[nm] = atm.dim_mean(ds[nm], 'lon', lon1, lon2)
            data[dtheta_nm] = atm.dim_mean(dtheta, 'lon', lon1, lon2)
        elif nm == dtheta_nm:
            continue
        else:
            data[nm] = atm.dim_mean(ds[nm], 'lon', lon1, lon2)
Beispiel #21
0
index['CHP_MFC'] = indices.onset_changepoint(mfc_acc)
index['CHP_PCP'] = indices.onset_changepoint(precip_acc)
for key in ['CHP_MFC', 'CHP_PCP']:
    index[key].attrs['title'] = key

# ----------------------------------------------------------------------
# Monsoon strength indices

def detrend(vals, index):
    vals_det = scipy.signal.detrend(vals)
    vals_det = vals_det / np.std(vals_det)
    output = pd.Series(vals_det, index=index)
    return output

# MERRA MFC
mfc_JJAS = atm.subset(mfcbar, {'day' : (atm.season_days('JJAS'), None)})
mfc_JJAS = mfc_JJAS.mean(dim='day')

# ERA-Interim MFC
era = pd.read_csv(eraIfile, index_col=0)

# All India Rainfall Index, convert to mm/day
air = pd.read_csv(airfile, skiprows=4, index_col=0).loc[years]
air /= len(atm.season_days('JJAS'))

strength = mfc_JJAS.to_series().to_frame(name='MERRA')
strength['ERAI'] = era
strength['AIR'] = air

# Detrended indices
for key in strength.columns:
# ----------------------------------------------------------------------
mldfile = '/home/jennifer/datastore/mld/ifremer_mld_DT02_c1m_reg2.0.nc'
suptitle = 'Ifremer Mixed Layer Depths'

ds = xray.open_dataset(mldfile)
mld = ds['mld']
missval = mld.attrs['mask_value']
vals = mld.values
vals = np.ma.masked_array(vals, vals==missval)
vals = np.ma.filled(vals, np.nan)
mld.values = vals

# Sector mean
lon1, lon2 = 60, 100
mldbar = atm.subset(mld, {'lon' : (lon1, lon2)}).mean(dim='lon')


# ----------------------------------------------------------------------
# Plots

cmap = 'hot_r'
axlims = (-30, 30, 40, 120)
clim1, clim2 = 0, 80
figsize = (12, 9)
months = [4, 5, 6]

plt.figure(figsize=figsize)
plt.suptitle(suptitle)

# Lat-lon maps
import matplotlib.pyplot as plt
import xray
import atmos as atm

# ----------------------------------------------------------------------
# Read some data from OpenDAP url

url = ('http://goldsmr3.sci.gsfc.nasa.gov/opendap/MERRA_MONTHLY/'
       'MAIMCPASM.5.2.0/1979/MERRA100.prod.assim.instM_3d_asm_Cp.197901.hdf')

ds = xray.open_dataset(url)
ps = ds['PS'] / 100

plt.figure()
atm.pcolor_latlon(ps, cmap='jet')

lon1, lon2 = 0, 100
lat1, lat2 = -45, 45
ps_sub = atm.subset(ps, 'lon', lon1, lon2, 'lat', lat1, lat2)
plt.figure()
atm.pcolor_latlon(ps_sub, cmap='jet')
psfile = atm.homedir() + 'dynamics/python/atmos-tools/data/topo/ncep2_ps.nc'
ps = atm.get_ps_clim(lat, lon, psfile)
ps = ps / 100

figsize = (7, 9)
omitzero = False

for ssn in ['ANN', 'DJF', 'JJA', 'MAR']:
    for lonlims in [(0, 360), (60, 100)]:
        lon1, lon2 = lonlims
        lonstr = atm.latlon_str(lon1, lon2, 'lon')
        suptitle = ssn + ' ' + lonstr
        months = atm.season_months(ssn)
        v = data['V'].sel(month=months)
        if (lon2 - lon1) < 360:
            v = atm.subset(v, {'lon': (lon1, lon2)})
            sector_scale = (lon2 - lon1) / 360.0
            psbar = atm.dim_mean(ps, 'lon', lon1, lon2)
            clev = 10
        else:
            sector_scale = None
            psbar = atm.dim_mean(ps, 'lon')
            clev = 20
        vssn = v.mean(dim='month')
        vssn_bar = atm.dim_mean(vssn, 'lon')
        psi1 = atm.streamfunction(vssn, sector_scale=sector_scale)
        psi1 = atm.dim_mean(psi1, 'lon')
        psi2 = atm.streamfunction(vssn_bar, sector_scale=sector_scale)
        plt.figure(figsize=figsize)
        plt.suptitle(suptitle)
        plt.subplot(2, 1, 1)
#               'wspace' : 0.3, 'hspace' : 0.2, 'bottom' : 0.06}
nrow, ncol, figsize = 3, 4, (11, 9)
gridspec_kw = {'width_ratios' : [1, 1, 1, 1.5], 'left' : 0.03, 'right' : 0.96,
               'wspace' : 0.35, 'hspace' : 0.2, 'bottom' : 0.06, 'top' : 0.9}
fig_kw = {'figsize' : figsize}
legend_opts = {'fontsize' : 9, 'handlelength' : 2.5, 'frameon' : False}
grp = atm.FigGroup(nrow, ncol, advance_by='col', fig_kw=fig_kw,
                   gridspec_kw=gridspec_kw, suptitle=suptitle)
for varnm in comp:
    if varnm == 'THETA_E_LML':
        varstr = 'THETA_EB'
    elif varnm.endswith('LML'):
        varstr = varnm.replace('LML', '_EB')
    else:
        varstr = varnm.upper()
    dat = {key : atm.subset(comp[varnm][key], subset_dict)
           for key in keys}
    if anom_plot:
        cmax = max([abs(dat[key]).max().values for key in keys])
        cmin = -cmax
    else:
        cmin, cmax = climits[varnm][0], climits[varnm][1]
    # Lat-lon maps of composites
    for j, key in enumerate(keys):
        grp.next()
        if comp_attrs[key]['axis'] == 1:
            cmap = get_colormap(varnm, anom_plot)
        else:
            cmap = 'RdBu_r'
        atm.pcolor_latlon(dat[key], axlims=axlims, cmap=cmap, fancy=False)
        plt.xticks(range(40, 121, 20))
Beispiel #26
0
        plt.xlabel('Longitude')
    if grp is not None and grp.col == 0:
        plt.ylabel('Rel Day')


nrow, ncol = 2, 2
fig_kw = {'figsize' : (11, 7), 'sharex' : True, 'sharey' : True}
gridspec_kw = {'left' : 0.07, 'right' : 0.99, 'bottom' : 0.07, 'top' : 0.9,
               'wspace' : 0.05}
suptitle = 'Cross-Eq <V*MSE> (%s)' % units
grp = atm.FigGroup(nrow, ncol, fig_kw=fig_kw, gridspec_kw=gridspec_kw,
                   suptitle=suptitle)
for lonrange in [(40, 120), (lon1, lon2)]:
    for nm in data_eq.data_vars:
        grp.next()
        var = atm.subset(data_eq[nm], {'lon' : lonrange})
        contour_londay(var, grp=grp)
        plt.title(nm, fontsize=11)
    plt.gca().invert_yaxis()


# ----------------------------------------------------------------------
# Line plots of sector means

days = atm.get_coord(eq_int, 'dayrel')
nms = data_eq.data_vars
styles = {'VMSE' : {'color' : 'k', 'linewidth' : 2}, 'VCPT' : {'color' : 'k'},
          'VPHI' : {'color' : 'k', 'linestyle' : 'dashed'},
          'VLQV' : {'color' : 'k', 'alpha' : 0.4, 'linewidth' : 1.5}}
locs = {'40E-60E' : 'upper left', '40E-100E' : 'upper left',
        '60E-100E' : 'lower left'}
Beispiel #27
0
def closest_day(nm, ts1, ts_sm, d0, buf=20):
    val0 = ts1[nm].sel(dayrel=d0).values
    sm = atm.subset(ts_sm[nm], {'dayrel' : (d0 - buf, d0 + buf)})
    i0 = int(np.argmin(abs(sm - val0)))
    day0 = int(sm['dayrel'][i0])
    return day0
Beispiel #28
0
plt.plot(mfcbar)
plt.plot(mfc_test_bar)
print(mfc_test_bar - mfcbar)

# ----------------------------------------------------------------------
# Vertical gradient du/dp

lon1, lon2 = 40, 120
pmin, pmax = 100, 300
subset_dict = {'XDim': (lon1, lon2), 'Height': (pmin, pmax)}

urls = merra.merra_urls([year])
month, day = 7, 15
url = urls['%d%02d%02d' % (year, month, day)]
with xray.open_dataset(url) as ds:
    u = atm.subset(ds['U'], subset_dict, copy=False)
    u = u.mean(dim='TIME')

pres = u['Height']
pres = atm.pres_convert(pres, pres.attrs['units'], 'Pa')
dp = np.gradient(pres)

# Calc 1
dims = u.shape
dudp = np.nan * u
for i in range(dims[1]):
    for j in range(dims[2]):
        dudp.values[:, i, j] = np.gradient(u[:, i, j], dp)

# Test atm.gradient
dudp_test = atm.gradient(u, pres, axis=0)
# lat, lon range to extract
lon1, lon2 = 40, 120
lat1, lat2 = -60, 60
lons = atm.latlon_labels([lon1, lon2], 'lon', deg_symbol=False)
lats = atm.latlon_labels([lat1, lat2], 'lat', deg_symbol=False)
latlon = '%s-%s_%s-%s' % (lons[0], lons[1], lats[0], lats[1])

savefile = datadir + ('merra_precip_%s_days%d-%d_%d-%d.nc' %
                      (latlon, daymin, daymax, years.min(), years.max()))

subset_dict = {'day' : (daymin, daymax), 'lat' : (lat1, lat2),
               'lon' : (lon1, lon2)}
for y, year in enumerate(years):
    datafile = datadir + 'merra_precip_%d.nc' % year
    print('Loading ' + datafile)
    with xray.open_dataset(datafile) as ds:
        precip1 = atm.subset(ds['PRECTOT'], subset_dict)
        precip1 = precip1.load()
        precip1.coords['year'] = year
    if y == 0:
        precip = precip1
    else:
        precip = xray.concat((precip, precip1), dim='year')

print('Converting to mm/day')
precip.values = atm.precip_convert(precip, precip.units, 'mm/day')
precip.attrs['units'] = 'mm/day'

print('Saving to ' + savefile)
atm.save_nc(savefile, precip)
Beispiel #30
0
# Mask out grid points where CHP index is ill-defined
def applymask(ds, mask_in):
    for nm in ds.data_vars:
        mask = atm.biggify(mask_in, ds[nm], tile=True)
        vals = np.ma.masked_array(ds[nm], mask=mask).filled(np.nan)
        ds[nm].values = vals
    return ds

if ptsmaskfile is not None:
    fracmin = 0.5
    day1 = atm.mmdd_to_jday(6, 1)
    day2 = atm.mmdd_to_jday(9, 30)
    with xray.open_dataset(ptsmaskfile) as ds:
        pcp = ds['PREC'].sel(lat=index_pts.lat).sel(lon=index_pts.lon).load()
    pcp_ssn = atm.subset(pcp, {'day' : (day1, day2)})
    pcp_frac = pcp_ssn.sum(dim='day') / pcp.sum(dim='day')
    mask = pcp_frac < fracmin
    index_pts = applymask(index_pts, mask)
    for key in pts_reg:
        pts_reg[key] = applymask(pts_reg[key], mask)

# MFC budget
with xray.open_dataset(mfcbudget_file) as mfc_budget:
    mfc_budget.load()
mfc_budget = mfc_budget.rename({'DWDT' : 'dw/dt'})
mfc_budget['P-E'] = mfc_budget['PRECTOT'] - mfc_budget['EVAP']
if nroll is not None:
    for nm in mfc_budget.data_vars:
        mfc_budget[nm] = atm.rolling_mean(mfc_budget[nm], nroll, center=True)
Beispiel #31
0
def onset_HOWI(uq_int, vq_int, npts=50, nroll=7, days_pre=range(138, 145),
               days_post=range(159, 166), yearnm='year', daynm='day',
               maxbreak=7):
    """Return monsoon Hydrologic Onset/Withdrawal Index.

    Parameters
    ----------
    uq_int, vq_int : xray.DataArrays
        Vertically integrated moisture fluxes.
    npts : int, optional
        Number of points to use to define HOWI index.
    nroll : int, optional
        Number of days for rolling mean.
    days_pre, days_post : list of ints, optional
        Default values correspond to May 18-24 and June 8-14 (numbered
        as non-leap year).
    yearnm, daynm : str, optional
        Name of year and day dimensions in DataArray
    maxbreak:
        Maximum number of days with index <=0 to consider a break in
        monsoon season rather than end of monsoon season.

    Returns
    -------
    howi : xray.Dataset
        HOWI daily timeseries for each year and monsoon onset and retreat
        days for each year.

    Reference
    ---------
    J. Fasullo and P. J. Webster, 2003: A hydrological definition of
        Indian monsoon onset and withdrawal. J. Climate, 16, 3200-3211.

    Notes
    -----
    In some years the HOWI index can give a bogus onset or bogus retreat
    when the index briefly goes above or below 0 for a few days.  To deal
    with these cases, I'm defining the monsoon season as the longest set
    of consecutive days with HOWI that is positive or has been negative
    for no more than `maxbreak` number of days (monsoon break).
    """

    _, _, coords, _ = atm.meta(uq_int)
    latnm = atm.get_coord(uq_int, 'lat', 'name')
    lonnm = atm.get_coord(uq_int, 'lon', 'name')

    ds = xray.Dataset()
    ds['uq'] = uq_int
    ds['vq'] = vq_int
    ds['vimt'] = np.sqrt(ds['uq']**2 + ds['vq']**2)

    # Climatological moisture fluxes
    dsbar = ds.mean(dim=yearnm)
    ds['uq_bar'], ds['vq_bar'] = dsbar['uq'], dsbar['vq']
    ds['vimt_bar'] = np.sqrt(ds['uq_bar']**2 + ds['vq_bar']**2)

    # Pre- and post- monsoon climatology composites
    dspre = atm.subset(dsbar, {daynm : (days_pre, None)}).mean(dim=daynm)
    dspost = atm.subset(dsbar, {daynm : (days_post, None)}).mean(dim=daynm)
    dsdiff = dspost - dspre
    ds['uq_bar_pre'], ds['vq_bar_pre'] = dspre['uq'], dspre['vq']
    ds['uq_bar_post'], ds['vq_bar_post'] = dspost['uq'], dspost['vq']
    ds['uq_bar_diff'], ds['vq_bar_diff'] = dsdiff['uq'], dsdiff['vq']

    # Magnitude of vector difference
    vimt_bar_diff = np.sqrt(dsdiff['uq']**2 + dsdiff['vq']**2)
    ds['vimt_bar_diff'] = vimt_bar_diff

    # Top N difference vectors
    def top_n(data, n):
        """Return a mask with the highest n values in 2D array."""
        vals = data.copy()
        mask = np.ones(vals.shape, dtype=bool)
        for k in range(n):
            i, j = np.unravel_index(np.nanargmax(vals), vals.shape)
            mask[i, j] = False
            vals[i, j] = np.nan
        return mask

    # Mask to extract top N points
    mask = top_n(vimt_bar_diff, npts)
    ds['mask'] = xray.DataArray(mask, coords={latnm: coords[latnm],
                                              lonnm: coords[lonnm]})

    # Apply mask to DataArrays
    def applymask(data, mask):
        _, _, coords, _ = atm.meta(data)
        maskbig = atm.biggify(mask, data, tile=True)
        vals = np.ma.masked_array(data, maskbig).filled(np.nan)
        data_out = xray.DataArray(vals, coords=coords)
        return data_out

    ds['vimt_bar_masked'] = applymask(ds['vimt_bar'], mask)
    ds['vimt_bar_diff_masked'] = applymask(vimt_bar_diff, mask)
    ds['uq_masked'] = applymask(ds['uq'], mask)
    ds['vq_masked'] = applymask(ds['vq'], mask)
    ds['vimt_masked'] = np.sqrt(ds['uq_masked']**2 + ds['vq_masked']**2)

    # Timeseries data averaged over selected N points
    ds['howi_clim_raw'] = ds['vimt_bar_masked'].mean(dim=latnm).mean(dim=lonnm)
    ds['howi_raw'] = ds['vimt_masked'].mean(dim=latnm).mean(dim=lonnm)

    # Normalize
    howi_min = ds['howi_clim_raw'].min().values
    howi_max = ds['howi_clim_raw'].max().values
    def applynorm(data):
        return 2 * (data - howi_min) / (howi_max - howi_min) - 1
    ds['howi_norm'] = applynorm(ds['howi_raw'])
    ds['howi_clim_norm'] = applynorm(ds['howi_clim_raw'])

    # Apply n-day rolling mean
    def rolling(data, nroll):
        center = True
        _, _, coords, _ = atm.meta(data)
        dims = data.shape
        vals = np.zeros(dims)
        if len(dims) > 1:
            nyears = dims[0]
            for y in range(nyears):
                vals[y] = pd.rolling_mean(data.values[y], nroll, center=center)
        else:
            vals = pd.rolling_mean(data.values, nroll, center=center)
        data_out = xray.DataArray(vals, coords=coords)
        return data_out

    ds['howi_norm_roll'] = rolling(ds['howi_norm'], nroll)
    ds['howi_clim_norm_roll'] = rolling(ds['howi_clim_norm'], nroll)

    # Index timeseries dataset
    howi = xray.Dataset()
    howi['tseries'] = ds['howi_norm_roll']
    howi['tseries_clim'] = ds['howi_clim_norm_roll']

    # Find zero crossings for onset and withdrawal indices
    nyears = len(howi[yearnm])
    onset = np.zeros(nyears, dtype=int)
    retreat = np.zeros(nyears, dtype=int)
    for y in range(nyears):
        # List of days with positive HOWI index
        pos = howi[daynm].values[howi['tseries'][y].values > 0]

        # In case of extra zero crossings, find the longest set of days
        # with positive index
        splitpos = atm.splitdays(pos)
        lengths = np.array([len(v) for v in splitpos])
        imonsoon = lengths.argmax()
        monsoon = splitpos[imonsoon]

        # In case there is a break in the monsoon season, check the
        # sets of days before and after and add to monsoon season
        # if applicable
        if imonsoon > 0:
            predays = splitpos[imonsoon - 1]
            if monsoon.min() - predays.max() <= maxbreak:
                predays = np.arange(predays.min(), monsoon.min())
                monsoon = np.concatenate([predays, monsoon])
        if imonsoon < len(splitpos) - 1:
            postdays = splitpos[imonsoon + 1]
            if postdays.min() - monsoon.max() <= maxbreak:
                postdays = np.arange(monsoon.max() + 1, postdays.max() + 1)
                monsoon = np.concatenate([monsoon, postdays])

        # Onset and retreat days
        onset[y] = monsoon[0]
        retreat[y] = monsoon[-1] + 1

    howi['onset'] = xray.DataArray(onset, coords={yearnm : howi[yearnm]})
    howi['retreat'] = xray.DataArray(retreat, coords={yearnm : howi[yearnm]})
    howi.attrs = {'npts' : npts, 'nroll' : nroll, 'maxbreak' : maxbreak,
                  'days_pre' : days_pre, 'days_post' : days_post}

    return howi, ds
Beispiel #32
0
topo = atm.get_ps_clim(lat, lon) / 100
topo.units = 'hPa'

# ----------------------------------------------------------------------
# Correct for topography

u_orig = u
u = atm.correct_for_topography(u_orig, topo)

# ----------------------------------------------------------------------
# Zonal mean zonal wind
season = 'jjas'
lon1, lon2 = 60, 100
cint = 5
months = atm.season_months(season)

uplot = atm.subset(u, 'lon', lon1, lon2, 'mon', months)
uplot = uplot.mean(['lon', 'mon'])

ps_plot = atm.subset(topo, 'lon', lon1, lon2)
ps_plot = ps_plot.mean('lon')

plt.figure()
cs = atm.contour_latpres(uplot, clev=cint, topo=ps_plot)
clev = atm.clevels(uplot, cint, omitzero=True)
plt.clabel(cs, clev[::2], fmt='%02d')

plt.figure()
atm.contourf_latpres(uplot, clev=cint, topo=ps_plot)
Beispiel #33
0
 def extract_year(data, year, years):
     if year in years:
         data_out = atm.subset(data, {'year' : (year, year)})
     else:
         data_out = None
     return data_out
psfile = atm.homedir() + 'dynamics/python/atmos-tools/data/topo/ncep2_ps.nc'
ps = atm.get_ps_clim(lat, lon, psfile)
ps = ps / 100

figsize = (7, 9)
omitzero = False

for ssn in ['ANN', 'DJF', 'JJA', 'MAR']:
    for lonlims in [(0, 360), (60, 100)]:
        lon1, lon2 = lonlims
        lonstr = atm.latlon_str(lon1, lon2, 'lon')
        suptitle = ssn + ' ' + lonstr
        months = atm.season_months(ssn)
        v = data['V'].sel(month=months)
        if (lon2 - lon1) < 360:
            v = atm.subset(v, {'lon' : (lon1, lon2)})
            sector_scale = (lon2 - lon1) / 360.0
            psbar = atm.dim_mean(ps, 'lon', lon1, lon2)
            clev = 10
        else:
            sector_scale = None
            psbar = atm.dim_mean(ps, 'lon')
            clev = 20
        vssn = v.mean(dim='month')
        vssn_bar = atm.dim_mean(vssn, 'lon')
        psi1 = atm.streamfunction(vssn, sector_scale=sector_scale)
        psi1 = atm.dim_mean(psi1, 'lon')
        psi2 = atm.streamfunction(vssn_bar, sector_scale=sector_scale)
        plt.figure(figsize=figsize)
        plt.suptitle(suptitle)
        plt.subplot(2, 1, 1)
plt.plot(mfcbar)
plt.plot(mfc_test_bar)
print(mfc_test_bar - mfcbar)

# ----------------------------------------------------------------------
# Vertical gradient du/dp

lon1, lon2 = 40, 120
pmin, pmax = 100, 300
subset_dict = {'XDim' : (lon1, lon2), 'Height' : (pmin, pmax)}

urls = merra.merra_urls([year])
month, day = 7, 15
url = urls['%d%02d%02d' % (year, month, day)]
with xray.open_dataset(url) as ds:
    u = atm.subset(ds['U'], subset_dict, copy=False)
    u = u.mean(dim='TIME')

pres = u['Height']
pres = atm.pres_convert(pres, pres.attrs['units'], 'Pa')
dp = np.gradient(pres)

# Calc 1
dims = u.shape
dudp = np.nan * u
for i in range(dims[1]):
    for j in range(dims[2]):
        dudp.values[:, i, j] = np.gradient(u[:, i, j], dp)

# Test atm.gradient
dudp_test = atm.gradient(u, pres, axis=0)