def test_integer_levels(self): data = self.data + 1 cmap_params = _determine_cmap_params(data, levels=5, vmin=0, vmax=5, cmap='Blues') self.assertEqual(cmap_params['vmin'], cmap_params['levels'][0]) self.assertEqual(cmap_params['vmax'], cmap_params['levels'][-1]) self.assertEqual(cmap_params['cmap'].name, 'Blues') self.assertEqual(cmap_params['extend'], 'neither') self.assertEqual(cmap_params['cmap'].N, 5) self.assertEqual(cmap_params['cnorm'].N, 6) cmap_params = _determine_cmap_params(data, levels=5, vmin=0.5, vmax=1.5) self.assertEqual(cmap_params['cmap'].name, 'viridis') self.assertEqual(cmap_params['extend'], 'max')
def test_list_levels(self): data = self.data + 1 orig_levels = [0, 1, 2, 3, 4, 5] # vmin and vmax should be ignored if levels are explicitly provided cmap_params = _determine_cmap_params(data, levels=orig_levels, vmin=0, vmax=3) self.assertEqual(cmap_params['vmin'], 0) self.assertEqual(cmap_params['vmax'], 5) self.assertEqual(cmap_params['cmap'].N, 5) self.assertEqual(cmap_params['cnorm'].N, 6) for wrap_levels in [list, np.array, pd.Index, DataArray]: cmap_params = _determine_cmap_params( data, levels=wrap_levels(orig_levels)) self.assertArrayEqual(cmap_params['levels'], orig_levels)
def test_center(self): cmap_params = _determine_cmap_params(self.data, center=0.5) self.assertEqual(cmap_params['vmax'] - 0.5, 0.5 - cmap_params['vmin']) self.assertEqual(cmap_params['cmap'], 'RdBu_r') self.assertEqual(cmap_params['extend'], 'neither') self.assertIsNone(cmap_params['levels']) self.assertIsNone(cmap_params['cnorm'])
def test_robust(self): cmap_params = _determine_cmap_params(self.data, robust=True) self.assertEqual(cmap_params['vmin'], np.percentile(self.data, 2)) self.assertEqual(cmap_params['vmax'], np.percentile(self.data, 98)) self.assertEqual(cmap_params['cmap'].name, 'viridis') self.assertEqual(cmap_params['extend'], 'both') self.assertIsNone(cmap_params['levels']) self.assertIsNone(cmap_params['cnorm'])
def geo_plot(darray, ax=None, method='contourf', projection='PlateCarree', grid=False, **kwargs): """ Create a global plot of a given variable. Parameters: ----------- darray : xray.DataArray The darray to be plotted. ax : axis An existing axis instance, else one will be created. method : str String to use for looking up name of plotting function via iris projection : str or tuple Name of the cartopy projection to use and any args necessary for initializing it passed as a dictionary; see func:`make_geoaxes` for more information grid : bool Include lat-lon grid overlay **kwargs : dict Any additional keyword arguments to pass to the plotter, including colormap params. If 'vmin' is not in this set of optional keyword arguments, the plot colormap will be automatically inferred. """ # Set up plotting function if method in _PLOTTYPE_ARGS: extra_args = _PLOTTYPE_ARGS[method].copy() else: raise ValueError("Don't know how to deal with '%s' method" % method) extra_args.update(**kwargs) # Alias a plot function based on the requested method and the # datatype being plotted plot_func = plt.__dict__[method] # `transform` should be the ORIGINAL coordinate system - # which is always a simple lat-lon coordinate system in CESM # output extra_args['transform'] = ccrs.PlateCarree() # Was an axis passed to plot on? new_axis = ax is None if new_axis: # Create a new cartopy axis object for plotting if isinstance(projection, (list, tuple)): if len(projection) != 2: raise ValueError("Expected 'projection' to only have 2 values") projection, proj_kwargs = projection[0], projection[1] else: proj_kwargs = {} # hack to look up the name of the projection in the cartopy # reference system namespace; makes life a bit easier, so you # can just pass a string with the name of the projection wanted. proj = ccrs.__dict__[projection](**proj_kwargs) ax = plt.axes(projection=proj) else: # Set current axis to one passed as argument if not hasattr(ax, 'projection'): raise ValueError("Expected `ax` to be a GeoAxes instance") plt.sca(ax) # Setup map ax.set_global() ax.coastlines() try: gl = ax.gridlines(crs=extra_args['transform'], draw_labels=True, linewidth=0.5, color='grey', alpha=0.8) LON_TICKS = [ -180, -90, 0, 90, 180 ] LAT_TICKS = [ -90, -60, -30, 0, 30, 60, 90 ] gl.xlabels_top = False gl.ylabels_right = False gl.xlines = grid gl.ylines = grid gl.xlocator = mticker.FixedLocator(LON_TICKS) gl.ylocator = mticker.FixedLocator(LAT_TICKS) gl.xformatter = LONGITUDE_FORMATTER gl.yformatter = LATITUDE_FORMATTER except TypeError: warnings.warn("Could not label the given map projection.") # Infer colormap settings if not provided if not ('vmin' in kwargs): warnings.warn("Re-inferring color parameters...") cmap_kws = _determine_cmap_params(darray.data) extra_args.update(cmap_kws) gp = plot_func(darray.lon.values, darray.lat.values, darray.data, **extra_args) return ax, gp
def test_divergentcontrol(self): neg = self.data - 0.1 pos = self.data # Default with positive data will be a normal cmap cmap_params = _determine_cmap_params(pos) self.assertEqual(cmap_params['vmin'], 0) self.assertEqual(cmap_params['vmax'], 1) self.assertEqual(cmap_params['cmap'].name, "viridis") # Default with negative data will be a divergent cmap cmap_params = _determine_cmap_params(neg) self.assertEqual(cmap_params['vmin'], -0.9) self.assertEqual(cmap_params['vmax'], 0.9) self.assertEqual(cmap_params['cmap'], "RdBu_r") # Setting vmin or vmax should prevent this only if center is false cmap_params = _determine_cmap_params(neg, vmin=-0.1, center=False) self.assertEqual(cmap_params['vmin'], -0.1) self.assertEqual(cmap_params['vmax'], 0.9) self.assertEqual(cmap_params['cmap'].name, "viridis") cmap_params = _determine_cmap_params(neg, vmax=0.5, center=False) self.assertEqual(cmap_params['vmin'], -0.1) self.assertEqual(cmap_params['vmax'], 0.5) self.assertEqual(cmap_params['cmap'].name, "viridis") # Setting center=False too cmap_params = _determine_cmap_params(neg, center=False) self.assertEqual(cmap_params['vmin'], -0.1) self.assertEqual(cmap_params['vmax'], 0.9) self.assertEqual(cmap_params['cmap'].name, "viridis") # However, I should still be able to set center and have a div cmap cmap_params = _determine_cmap_params(neg, center=0) self.assertEqual(cmap_params['vmin'], -0.9) self.assertEqual(cmap_params['vmax'], 0.9) self.assertEqual(cmap_params['cmap'], "RdBu_r") # Setting vmin or vmax alone will force symetric bounds around center cmap_params = _determine_cmap_params(neg, vmin=-0.1) self.assertEqual(cmap_params['vmin'], -0.1) self.assertEqual(cmap_params['vmax'], 0.1) self.assertEqual(cmap_params['cmap'], "RdBu_r") cmap_params = _determine_cmap_params(neg, vmax=0.5) self.assertEqual(cmap_params['vmin'], -0.5) self.assertEqual(cmap_params['vmax'], 0.5) self.assertEqual(cmap_params['cmap'], "RdBu_r") cmap_params = _determine_cmap_params(neg, vmax=0.6, center=0.1) self.assertEqual(cmap_params['vmin'], -0.4) self.assertEqual(cmap_params['vmax'], 0.6) self.assertEqual(cmap_params['cmap'], "RdBu_r") # But this is only true if vmin or vmax are negative cmap_params = _determine_cmap_params(pos, vmin=-0.1) self.assertEqual(cmap_params['vmin'], -0.1) self.assertEqual(cmap_params['vmax'], 0.1) self.assertEqual(cmap_params['cmap'], "RdBu_r") cmap_params = _determine_cmap_params(pos, vmin=0.1) self.assertEqual(cmap_params['vmin'], 0.1) self.assertEqual(cmap_params['vmax'], 1) self.assertEqual(cmap_params['cmap'].name, "viridis") cmap_params = _determine_cmap_params(pos, vmax=0.5) self.assertEqual(cmap_params['vmin'], 0) self.assertEqual(cmap_params['vmax'], 0.5) self.assertEqual(cmap_params['cmap'].name, "viridis") # If both vmin and vmax are provided, output is non-divergent cmap_params = _determine_cmap_params(neg, vmin=-0.2, vmax=0.6) self.assertEqual(cmap_params['vmin'], -0.2) self.assertEqual(cmap_params['vmax'], 0.6) self.assertEqual(cmap_params['cmap'].name, "viridis")
# Read colorfile arguments; else, infer the color parameters if args.colorfile is not None: colorfile = args.colorfile try: with open(colorfile, 'rb') as f: color_data = pickle.load(f) except (FileNotFoundError, IOError): print("Could not open colorfile '%s'" % args.colorfile) sys.exit(1) # Warn about variables where color will be freshly inferred for v in dataset.variables: if v in dataset.dims: continue if not (v in color_data): warnings.warn("Couldn't find color data for %s" % v) color_data[v] = _determine_cmap_params(dataset[v].data) else: print("Inferring new colormaps") color_data = {} for v in dataset.variables: if v in dataset.dims: continue print(" " + v) color_data[v] = _determine_cmap_params(dataset[v].data, levels=21, robust=True, extend='both') # Save/update the new colorfile print("Saving colormap data") fn_basename, _ = os.path.splitext(args.nc_file)