Example #1
0
def plot(ds: xr.Dataset,
         var: VarName.TYPE,
         index: DictLike.TYPE = None,
         file: str = None) -> None:
    """
    Plot a variable, optionally save the figure in a file.

    The plot can either be shown using pyplot functionality, or saved,
    if a path is given. The following file formats for saving the plot
    are supported: eps, jpeg, jpg, pdf, pgf, png, ps, raw, rgba, svg,
    svgz, tif, tiff

    :param ds: Dataset that contains the variable named by *var*.
    :param var: The name of the variable to plot
    :param index: Optional index into the variable's data array. The *index* is a dictionary
                  that maps the variable's dimension names to constant labels. For example,
                  ``lat`` and ``lon`` are given in decimal degrees, while a ``time`` value may be provided as
                  datetime object or a date string. *index* may also be a comma-separated string of key-value pairs,
                  e.g. "lat=12.4, time='2012-05-02'".
    :param file: path to a file in which to save the plot
    """

    var = VarName.convert(var)
    var = ds[var]

    index = DictLike.convert(index)

    try:
        if index:
            var_data = var.sel(**index)
        else:
            var_data = var
    except ValueError:
        var_data = var

    fig = plt.figure(figsize=(16, 8))
    var_data.plot()
    if file:
        fig.savefig(file)
Example #2
0
def plot_scatter(ds1: xr.Dataset,
                 ds2: xr.Dataset,
                 var1: VarName.TYPE,
                 var2: VarName.TYPE,
                 indexers1: DictLike.TYPE = None,
                 indexers2: DictLike.TYPE = None,
                 title: str = None,
                 properties: DictLike.TYPE = None,
                 file: str = None) -> Figure:
    """
    Create a scatter plot of two variables of two variables given by datasets *ds1*, *ds2* and the
    variable names *var1*, *var2*.

    :param ds1: Dataset that contains the variable named by *var1*.
    :param ds2: Dataset that contains the variable named by *var2*.
    :param var1: The name of the first variable to plot
    :param var2: The name of the second variable to plot
    :param indexers1: Optional indexers into data array *var1*. The *indexers1* is a dictionary
           or comma-separated string of key-value pairs that maps the variable's dimension names
           to constant labels. e.g. "lat=12.4, time='2012-05-02'".
    :param indexers2: Optional indexers into data array *var2*.
    :param title: optional plot title
    :param properties: optional plot properties for Python matplotlib,
           e.g. "bins=512, range=(-1.5, +1.5), label='Sea Surface Temperature'"
           For full reference refer to
           https://matplotlib.org/api/lines_api.html and
           https://matplotlib.org/devdocs/api/_as_gen/matplotlib.patches.Patch.html#matplotlib.patches.Patch
    :param file: path to a file in which to save the plot
    :return: a matplotlib figure object or None if in IPython mode
    """

    var_name1 = VarName.convert(var1)
    var_name2 = VarName.convert(var2)
    if not var_name1:
        raise ValueError("Missing value for 'var1'")
    if not var_name2:
        raise ValueError("Missing value for 'var2'")
    var1 = ds1[var_name1]
    var2 = ds2[var_name2]

    indexers1 = DictLike.convert(indexers1) or {}
    indexers2 = DictLike.convert(indexers2) or {}
    properties = DictLike.convert(properties) or {}

    try:
        if indexers1:
            var_data1 = var1.sel(method='nearest', **indexers1)
            if not indexers2:
                indexers2 = indexers1

            var_data2 = var2.sel(method='nearest', **indexers2)
            remaining_dims = list(set(var1.dims) ^ set(indexers1.keys()))
            min_dim = max(var_data1[remaining_dims[0]].min(),
                          var_data2[remaining_dims[0]].min())
            max_dim = min(var_data1[remaining_dims[0]].max(),
                          var_data2[remaining_dims[0]].max())
            print(min_dim, max_dim)
            var_data1 = var_data1.where(
                (var_data1[remaining_dims[0]] >= min_dim) &
                (var_data1[remaining_dims[0]] <= max_dim),
                drop=True)
            var_data2 = var_data2.where(
                (var_data2[remaining_dims[0]] >= min_dim) &
                (var_data2[remaining_dims[0]] <= max_dim),
                drop=True)
            print(var_data1)
            print(var_data2)
            if len(remaining_dims) is 1:
                print(remaining_dims)
                indexer3 = {
                    remaining_dims[0]: var_data1[remaining_dims[0]].data
                }
                var_data2.reindex(method='nearest', **indexer3)
            else:
                print("Err!")
        else:
            var_data1 = var1
            var_data2 = var2
    except ValueError:
        var_data1 = var1
        var_data2 = var2

    figure = plt.figure(figsize=(12, 8))
    ax = figure.add_subplot(111)

    # var_data1.plot(ax = ax, **properties)
    ax.plot(var_data1.values, var_data2.values, '.', **properties)
    # var_data1.plot(ax=ax, **properties)
    xlabel_txt = "".join(", " + str(key) + " = " + str(value)
                         for key, value in indexers1.items())
    xlabel_txt = var_name1 + xlabel_txt
    ylabel_txt = "".join(", " + str(key) + " = " + str(value)
                         for key, value in indexers2.items())
    ylabel_txt = var_name2 + ylabel_txt
    ax.set_xlabel(xlabel_txt)
    ax.set_ylabel(ylabel_txt)
    figure.tight_layout()

    if title:
        ax.set_title(title)

    if file:
        figure.savefig(file)

    return figure if not in_notebook() else None
Example #3
0
def plot_map(ds: xr.Dataset,
             var: VarName.TYPE = None,
             index: DictLike.TYPE = None,
             time: Union[str, int] = None,
             region: PolygonLike.TYPE = None,
             projection: str = 'PlateCarree',
             central_lon: float = 0.0,
             file: str = None) -> None:
    """
    Plot the given variable from the given dataset on a map with coastal lines.
    In case no variable name is given, the first encountered variable in the
    dataset is plotted. In case no time index is given, the first time slice
    is taken. It is also possible to set extents of the plot. If no extents
    are given, a global plot is created.

    The plot can either be shown using pyplot functionality, or saved,
    if a path is given. The following file formats for saving the plot
    are supported: eps, jpeg, jpg, pdf, pgf, png, ps, raw, rgba, svg,
    svgz, tif, tiff

    :param ds: xr.Dataset to plot
    :param var: variable name in the dataset to plot
    :param index: Optional index into the variable's data array. The *index* is a dictionary
                  that maps the variable's dimension names to constant labels. For example,
                  ``lat`` and ``lon`` are given in decimal degrees, while a ``time`` value may be provided as
                  datetime object or a date string. *index* may also be a comma-separated string of key-value pairs,
                  e.g. "lat=12.4, time='2012-05-02'".
    :param time: time slice index to plot
    :param region: Region to plot
    :param projection: name of a global projection, see http://scitools.org.uk/cartopy/docs/v0.15/crs/projections.html
    :param central_lon: central longitude of the projection in degrees
    :param file: path to a file in which to save the plot
    """
    if not isinstance(ds, xr.Dataset):
        raise NotImplementedError('Only raster datasets are currently '
                                  'supported')

    var_name = None
    if not var:
        for key in ds.data_vars.keys():
            var_name = key
            break
    else:
        var_name = VarName.convert(var)

    var = ds[var_name]
    index = DictLike.convert(index)

    # 0 is a valid index, hence test if time is None
    if time is not None and isinstance(time, int) and 'time' in var.coords:
        time = var.coords['time'][time]

    if time:
        if not index:
            index = dict()
        index['time'] = time

    for dim_name in var.dims:
        if dim_name not in ('lat', 'lon'):
            if not index:
                index = dict()
            if dim_name not in index:
                index[dim_name] = 0

    if region is None:
        lat_min = -90.0
        lat_max = 90.0
        lon_min = -180.0
        lon_max = 180.0
    else:
        region = PolygonLike.convert(region)
        lon_min, lat_min, lon_max, lat_max = region.bounds

    if not _check_bounding_box(lat_min, lat_max, lon_min, lon_max):
        raise ValueError(
            'Provided plot extents do not form a valid bounding box '
            'within [-180.0,+180.0,-90.0,+90.0]')
    extents = [lon_min, lon_max, lat_min, lat_max]

    # See http://scitools.org.uk/cartopy/docs/v0.15/crs/projections.html#
    if projection == 'PlateCarree':
        proj = ccrs.PlateCarree(central_longitude=central_lon)
    elif projection == 'LambertCylindrical':
        proj = ccrs.LambertCylindrical(central_longitude=central_lon)
    elif projection == 'Mercator':
        proj = ccrs.Mercator(central_longitude=central_lon)
    elif projection == 'Miller':
        proj = ccrs.Miller(central_longitude=central_lon)
    elif projection == 'Mollweide':
        proj = ccrs.Mollweide(central_longitude=central_lon)
    elif projection == 'Orthographic':
        proj = ccrs.Orthographic(central_longitude=central_lon)
    elif projection == 'Robinson':
        proj = ccrs.Robinson(central_longitude=central_lon)
    elif projection == 'Sinusoidal':
        proj = ccrs.Sinusoidal(central_longitude=central_lon)
    elif projection == 'NorthPolarStereo':
        proj = ccrs.NorthPolarStereo(central_longitude=central_lon)
    elif projection == 'SouthPolarStereo':
        proj = ccrs.SouthPolarStereo(central_longitude=central_lon)
    else:
        raise ValueError('illegal projection')

    try:
        if index:
            var_data = var.sel(**index)
        else:
            var_data = var
    except ValueError:
        var_data = var

    fig = plt.figure(figsize=(16, 8))
    ax = plt.axes(projection=proj)
    if extents:
        ax.set_extent(extents)
    else:
        ax.set_global()

    ax.coastlines()
    var_data.plot.contourf(ax=ax, transform=proj)
    if file:
        fig.savefig(file)