def plot(ds: xr.Dataset, var: VarName.TYPE, index: DictLike.TYPE = None, file: str = None) -> None: """ Plot a variable, optionally save the figure in a file. The plot can either be shown using pyplot functionality, or saved, if a path is given. The following file formats for saving the plot are supported: eps, jpeg, jpg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff :param ds: Dataset that contains the variable named by *var*. :param var: The name of the variable to plot :param index: Optional index into the variable's data array. The *index* is a dictionary that maps the variable's dimension names to constant labels. For example, ``lat`` and ``lon`` are given in decimal degrees, while a ``time`` value may be provided as datetime object or a date string. *index* may also be a comma-separated string of key-value pairs, e.g. "lat=12.4, time='2012-05-02'". :param file: path to a file in which to save the plot """ var = VarName.convert(var) var = ds[var] index = DictLike.convert(index) try: if index: var_data = var.sel(**index) else: var_data = var except ValueError: var_data = var fig = plt.figure(figsize=(16, 8)) var_data.plot() if file: fig.savefig(file)
def plot_scatter(ds1: xr.Dataset, ds2: xr.Dataset, var1: VarName.TYPE, var2: VarName.TYPE, indexers1: DictLike.TYPE = None, indexers2: DictLike.TYPE = None, title: str = None, properties: DictLike.TYPE = None, file: str = None) -> Figure: """ Create a scatter plot of two variables of two variables given by datasets *ds1*, *ds2* and the variable names *var1*, *var2*. :param ds1: Dataset that contains the variable named by *var1*. :param ds2: Dataset that contains the variable named by *var2*. :param var1: The name of the first variable to plot :param var2: The name of the second variable to plot :param indexers1: Optional indexers into data array *var1*. The *indexers1* is a dictionary or comma-separated string of key-value pairs that maps the variable's dimension names to constant labels. e.g. "lat=12.4, time='2012-05-02'". :param indexers2: Optional indexers into data array *var2*. :param title: optional plot title :param properties: optional plot properties for Python matplotlib, e.g. "bins=512, range=(-1.5, +1.5), label='Sea Surface Temperature'" For full reference refer to https://matplotlib.org/api/lines_api.html and https://matplotlib.org/devdocs/api/_as_gen/matplotlib.patches.Patch.html#matplotlib.patches.Patch :param file: path to a file in which to save the plot :return: a matplotlib figure object or None if in IPython mode """ var_name1 = VarName.convert(var1) var_name2 = VarName.convert(var2) if not var_name1: raise ValueError("Missing value for 'var1'") if not var_name2: raise ValueError("Missing value for 'var2'") var1 = ds1[var_name1] var2 = ds2[var_name2] indexers1 = DictLike.convert(indexers1) or {} indexers2 = DictLike.convert(indexers2) or {} properties = DictLike.convert(properties) or {} try: if indexers1: var_data1 = var1.sel(method='nearest', **indexers1) if not indexers2: indexers2 = indexers1 var_data2 = var2.sel(method='nearest', **indexers2) remaining_dims = list(set(var1.dims) ^ set(indexers1.keys())) min_dim = max(var_data1[remaining_dims[0]].min(), var_data2[remaining_dims[0]].min()) max_dim = min(var_data1[remaining_dims[0]].max(), var_data2[remaining_dims[0]].max()) print(min_dim, max_dim) var_data1 = var_data1.where( (var_data1[remaining_dims[0]] >= min_dim) & (var_data1[remaining_dims[0]] <= max_dim), drop=True) var_data2 = var_data2.where( (var_data2[remaining_dims[0]] >= min_dim) & (var_data2[remaining_dims[0]] <= max_dim), drop=True) print(var_data1) print(var_data2) if len(remaining_dims) is 1: print(remaining_dims) indexer3 = { remaining_dims[0]: var_data1[remaining_dims[0]].data } var_data2.reindex(method='nearest', **indexer3) else: print("Err!") else: var_data1 = var1 var_data2 = var2 except ValueError: var_data1 = var1 var_data2 = var2 figure = plt.figure(figsize=(12, 8)) ax = figure.add_subplot(111) # var_data1.plot(ax = ax, **properties) ax.plot(var_data1.values, var_data2.values, '.', **properties) # var_data1.plot(ax=ax, **properties) xlabel_txt = "".join(", " + str(key) + " = " + str(value) for key, value in indexers1.items()) xlabel_txt = var_name1 + xlabel_txt ylabel_txt = "".join(", " + str(key) + " = " + str(value) for key, value in indexers2.items()) ylabel_txt = var_name2 + ylabel_txt ax.set_xlabel(xlabel_txt) ax.set_ylabel(ylabel_txt) figure.tight_layout() if title: ax.set_title(title) if file: figure.savefig(file) return figure if not in_notebook() else None
def plot_map(ds: xr.Dataset, var: VarName.TYPE = None, index: DictLike.TYPE = None, time: Union[str, int] = None, region: PolygonLike.TYPE = None, projection: str = 'PlateCarree', central_lon: float = 0.0, file: str = None) -> None: """ Plot the given variable from the given dataset on a map with coastal lines. In case no variable name is given, the first encountered variable in the dataset is plotted. In case no time index is given, the first time slice is taken. It is also possible to set extents of the plot. If no extents are given, a global plot is created. The plot can either be shown using pyplot functionality, or saved, if a path is given. The following file formats for saving the plot are supported: eps, jpeg, jpg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff :param ds: xr.Dataset to plot :param var: variable name in the dataset to plot :param index: Optional index into the variable's data array. The *index* is a dictionary that maps the variable's dimension names to constant labels. For example, ``lat`` and ``lon`` are given in decimal degrees, while a ``time`` value may be provided as datetime object or a date string. *index* may also be a comma-separated string of key-value pairs, e.g. "lat=12.4, time='2012-05-02'". :param time: time slice index to plot :param region: Region to plot :param projection: name of a global projection, see http://scitools.org.uk/cartopy/docs/v0.15/crs/projections.html :param central_lon: central longitude of the projection in degrees :param file: path to a file in which to save the plot """ if not isinstance(ds, xr.Dataset): raise NotImplementedError('Only raster datasets are currently ' 'supported') var_name = None if not var: for key in ds.data_vars.keys(): var_name = key break else: var_name = VarName.convert(var) var = ds[var_name] index = DictLike.convert(index) # 0 is a valid index, hence test if time is None if time is not None and isinstance(time, int) and 'time' in var.coords: time = var.coords['time'][time] if time: if not index: index = dict() index['time'] = time for dim_name in var.dims: if dim_name not in ('lat', 'lon'): if not index: index = dict() if dim_name not in index: index[dim_name] = 0 if region is None: lat_min = -90.0 lat_max = 90.0 lon_min = -180.0 lon_max = 180.0 else: region = PolygonLike.convert(region) lon_min, lat_min, lon_max, lat_max = region.bounds if not _check_bounding_box(lat_min, lat_max, lon_min, lon_max): raise ValueError( 'Provided plot extents do not form a valid bounding box ' 'within [-180.0,+180.0,-90.0,+90.0]') extents = [lon_min, lon_max, lat_min, lat_max] # See http://scitools.org.uk/cartopy/docs/v0.15/crs/projections.html# if projection == 'PlateCarree': proj = ccrs.PlateCarree(central_longitude=central_lon) elif projection == 'LambertCylindrical': proj = ccrs.LambertCylindrical(central_longitude=central_lon) elif projection == 'Mercator': proj = ccrs.Mercator(central_longitude=central_lon) elif projection == 'Miller': proj = ccrs.Miller(central_longitude=central_lon) elif projection == 'Mollweide': proj = ccrs.Mollweide(central_longitude=central_lon) elif projection == 'Orthographic': proj = ccrs.Orthographic(central_longitude=central_lon) elif projection == 'Robinson': proj = ccrs.Robinson(central_longitude=central_lon) elif projection == 'Sinusoidal': proj = ccrs.Sinusoidal(central_longitude=central_lon) elif projection == 'NorthPolarStereo': proj = ccrs.NorthPolarStereo(central_longitude=central_lon) elif projection == 'SouthPolarStereo': proj = ccrs.SouthPolarStereo(central_longitude=central_lon) else: raise ValueError('illegal projection') try: if index: var_data = var.sel(**index) else: var_data = var except ValueError: var_data = var fig = plt.figure(figsize=(16, 8)) ax = plt.axes(projection=proj) if extents: ax.set_extent(extents) else: ax.set_global() ax.coastlines() var_data.plot.contourf(ax=ax, transform=proj) if file: fig.savefig(file)