def test_cmax_equals_cmin(byte_arr): """Fail gracefully when the cmax is smaller than the cmin.""" with pytest.raises( ValueError, match="`cmax` and `cmin` should not be the same value. "): es.bytescale(byte_arr, cmin=100, cmax=100)
def test_bytescale_high_low_val(): """"Unit tests for earthpy.spatial.bytescale """ arr = np.random.randint(300, size=(10, 10)) # Bad high value with pytest.raises(ValueError): es.bytescale(arr, high=300) # Bad low value with pytest.raises(ValueError): es.bytescale(arr, low=-100) # High value is less than low value with pytest.raises(ValueError): es.bytescale(arr, high=100, low=150) # Valid case. should also take care of if statements for cmin/cmax val_arr = es.bytescale(arr, high=255, low=0) assert val_arr.min() == 0 assert val_arr.max() == 255 # Test scale value max is less than min with pytest.raises(ValueError): es.bytescale(arr, cmin=100, cmax=50) # TODO: write test case for cmax == cmin # Commented out because it breaks for unknown reasons. # es.bytescale(arr, cmin=100, cmax=100) # Test scale value max is less equal to min scale_arr = es.bytescale(arr, cmin=10, cmax=240) assert scale_arr.min() == 0 assert scale_arr.max() == 255
def test_low_val_range(byte_arr): """A low value <0 should fail gracefully. """ # Bad low value with pytest.raises(ValueError, match="`low` should be greater than or equal to 0."): es.bytescale(byte_arr, low=-100)
def test_high_lessthan_low(byte_arr): """Fail gracefully when the high value is lower than the low value.""" # High value is less than low value with pytest.raises( ValueError, match="`high` should be greater than or equal to `low`."): es.bytescale(byte_arr, high=100, low=150)
def test_cmax_cmin_work(byte_arr): """"Cmax and min values returns an arr with the range 0-255.""" scale_arr = es.bytescale(byte_arr, cmin=10, cmax=240) assert scale_arr.min() == 0 assert scale_arr.max() == 255
def plot_rgb(arr, rgb=(0, 1, 2), figsize=(10, 10), str_clip=2, ax=None, extent=None, title="", stretch=None, dest_file=None): if len(arr.shape) != 3: raise ValueError("Input needs to be 3 dimensions and in rasterio " "order with bands first") # Index bands for plotting and clean up data for matplotlib rgb_bands = arr[rgb, :, :] if stretch: rgb_bands = _stretch_im(rgb_bands, str_clip) # If type is masked array - add alpha channel for plotting if ma.is_masked(rgb_bands): # Build alpha channel mask = ~(np.ma.getmask(rgb_bands[0])) * 255 # Add the mask to the array & swap the axes order from (bands, # rows, columns) to (rows, columns, bands) for plotting rgb_bands = np.vstack((es.bytescale(rgb_bands), np.expand_dims(mask, axis=0))).transpose([1, 2, 0]) else: # Index bands for plotting and clean up data for matplotlib rgb_bands = es.bytescale(rgb_bands).transpose([1, 2, 0]) # Then plot. Define ax if it's undefined show = False if ax is None: fig, ax = plt.subplots(figsize=figsize) show = True ax.imshow(rgb_bands, extent=extent) ax.set_title(title) ax.set(xticks=[], yticks=[]) if show: plt.savefig(dest_file) plt.clf() return ax
def test_low_high_vals_work(byte_arr): """The high/low param vals determine the min and max of the output arr.""" # Valid case. should also take care of if statements for cmin/cmax val_arr = es.bytescale(byte_arr, high=255, low=0) assert val_arr.min() == 0 assert val_arr.max() == 255
def _plot_image( arr_im, cmap="Greys_r", title=None, extent=None, cbar=True, scale=True, vmin=None, vmax=None, ax=None, ): """ Create a matplotlib figure with an image axis and associated extent. Parameters ---------- arr_im : numpy array An n-dimensional numpy array to plot. cmap : str (default = "Greys_r") Colormap name for plots. title : str or list (optional) Title of one band or list of titles with one title per band. extent : tuple (optional) Bounding box that the data will fill: (minx, miny, maxx, maxy). cbar : Boolean (default = True) Turn off colorbar if needed. scale : Boolean (Default = True) Turn off bytescale scaling if needed. vmin : Int (Optional) Specify the vmin to scale imshow() plots. vmax : Int (Optional) Specify the vmax to scale imshow() plots. ax : Matplotlib axes object (Optional) Matplotlib axis object to plot image. Returns ---------- ax : axes object The axes object(s) associated with the plot. """ if scale: arr_im = es.bytescale(arr_im) im = ax.imshow(arr_im, cmap=cmap, vmin=vmin, vmax=vmax, extent=extent) if title: ax.set(title=title) if cbar: colorbar(im) ax.set(xticks=[], yticks=[]) return ax
def test_high_val_range(): """A 16 bit int arr with values at the range end should be scaled properly. This test explicitly hits the user case of someone providing an array of a dtype at the end of it's range of values. Bytescale should not fail. """ # The valid range of dtype int16 values is -32768 to 32767 rgb_bands = np.array([1, 32767, 3, -32768]).astype("int16") arr = es.bytescale(rgb_bands) assert arr.min() == 0 assert arr.max() == 255
def _plot_image( arr_im, cmap="Greys_r", title=None, extent=None, cbar=True, scale=False, vmin=None, vmax=None, ax=None, alpha=1, norm=None, ): """ Create a matplotlib figure with an image axis and associated extent. Parameters ---------- arr_im : numpy array An n-dimensional numpy array to plot. cmap : str (default = "Greys_r") Colormap name for plots. title : str or list (optional) Title of one band or list of titles with one title per band. extent : tuple (optional) Bounding box that the data will fill: (minx, miny, maxx, maxy). cbar : Boolean (default = True) Turn off colorbar if needed. scale : Boolean (Default = False) Turn off bytescale scaling if needed. vmin : Int (Optional) Specify the vmin to scale imshow() plots. vmax : Int (Optional) Specify the vmax to scale imshow() plots. ax : Matplotlib axes object (Optional) Matplotlib axis object to plot image. alpha : float (optional) The alpha value for the plot. This will help adjust the transparency of the plot to the desired level. norm : matplotlib Normalize object (Optional) The normalized boundaries for custom values coloring. NOTE: For this argument to work, the scale argument MUST be set to false. Otherwise, the values will be scaled from 0-255. Returns ---------- ax : matplotlib.axes object The axes object(s) associated with the plot. """ if scale: arr_im = es.bytescale(arr_im) im = ax.imshow( arr_im, cmap=cmap, vmin=vmin, vmax=vmax, extent=extent, alpha=alpha, norm=norm, ) if title: ax.set(title=title) if cbar: colorbar(im) ax.set(xticks=[], yticks=[]) return ax
def plot_rgb( arr, rgb=(0, 1, 2), figsize=(10, 10), str_clip=2, ax=None, extent=None, title="", stretch=None, ): """Plot three bands in a numpy array as a composite RGB image. Parameters ---------- arr : numpy array An n-dimensional array in rasterio band order (bands, rows, columns) containing the layers to plot. rgb : list (default = (0, 1, 2)) Indices of the three bands to be plotted. figsize : tuple (default = (10, 10) The x and y integer dimensions of the output plot. str_clip: int (default = 2) The percentage of clip to apply to the stretch. Default = 2 (2 and 98). ax : object (optional) The axes object where the ax element should be plotted. extent : tuple (optional) The extent object that matplotlib expects (left, right, bottom, top). title : string (optional) The intended title of the plot. stretch : Boolean (optional) Application of a linear stretch. If set to True, a linear stretch will be applied. Returns ---------- ax : axes object The axes object associated with the 3 band image. Example ------- .. plot:: >>> import matplotlib.pyplot as plt >>> import rasterio as rio >>> import earthpy.plot as ep >>> from earthpy.io import path_to_example >>> with rio.open(path_to_example('rmnp-rgb.tif')) as src: ... img_array = src.read() >>> # Ensure the input array doesn't have nodata values like -9999 >>> ep.plot_rgb(img_array) <AxesSubplot:> """ if len(arr.shape) != 3: raise ValueError( "Input needs to be 3 dimensions and in rasterio " "order with bands first" ) # Index bands for plotting and clean up data for matplotlib rgb_bands = arr[rgb, :, :] if stretch: rgb_bands = _stretch_im(rgb_bands, str_clip) nan_check = np.isnan(rgb_bands) if np.any(nan_check): rgb_bands = np.ma.masked_array(rgb_bands, nan_check) # If type is masked array - add alpha channel for plotting if ma.is_masked(rgb_bands): # Build alpha channel mask = ~(np.ma.getmask(rgb_bands[0])) * 255 # Add the mask to the array & swap the axes order from (bands, # rows, columns) to (rows, columns, bands) for plotting rgb_bands = np.vstack( (es.bytescale(rgb_bands), np.expand_dims(mask, axis=0)) ).transpose([1, 2, 0]) else: # Index bands for plotting and clean up data for matplotlib rgb_bands = es.bytescale(rgb_bands).transpose([1, 2, 0]) # Then plot. Define ax if it's undefined show = False if ax is None: fig, ax = plt.subplots(figsize=figsize) show = True ax.imshow(rgb_bands, extent=extent) ax.set_title(title) ax.set(xticks=[], yticks=[]) # Multipanel won't work if plt.show is called prior to second plot def if show: plt.show() return ax
def plot_bands( arr, cmap="Greys_r", figsize=(12, 12), cols=3, title=None, extent=None ): """Plot each band in a numpy array in its own axis. Assumes band order (band, row, col). Parameters ---------- arr : numpy array An n-dimensional numpy array to plot. cmap : str (default = "Greys_r") Colormap name for plots. figsize : tuple (default = (12, 12)) Figure size in inches. cols : int (default = 3) Number of columns for plot grid. title : str or list (optional) Title of one band or list of titles with one title per band. extent : tuple (optional) Bounding box that the data will fill: (minx, miny, maxx, maxy). Returns ---------- tuple fig : figure object The figure of the plotted band(s). ax or axs : axes object(s) The axes object(s) associated with the plot. Example ------- .. plot:: >>> import matplotlib.pyplot as plt >>> import earthpy.plot as ep >>> from earthpy.io import path_to_example >>> import rasterio as rio >>> titles = ['Red', 'Green', 'Blue'] >>> with rio.open(path_to_example('rmnp-rgb.tif')) as src: ... ep.plot_bands(src.read(), ... title=titles, ... figsize=(8, 3)) (<Figure size ... with 3 Axes>, ...) """ try: arr.ndim except AttributeError: "Input arr should be a numpy array" if title: if (arr.ndim == 2) and (len(title) > 1): raise ValueError( """Plot_bands() expects one title for a single band array. You have provided more than one title.""" ) elif not (len(title) == arr.shape[0]): raise ValueError( """Plot_bands() expects the number of plot titles to equal the number of array raster layers.""" ) # If the array is 3 dimensional setup grid plotting if arr.ndim > 2 and arr.shape[0] > 1: # Calculate the total rows that will be required to plot each band plot_rows = int(np.ceil(arr.shape[0] / cols)) total_layers = arr.shape[0] # Plot all bands fig, axs = plt.subplots(plot_rows, cols, figsize=figsize) axs_ravel = axs.ravel() for ax, i in zip(axs_ravel, range(total_layers)): band = i + 1 ax.imshow(es.bytescale(arr[i]), cmap=cmap) if title: ax.set(title=title[i]) else: ax.set(title="Band %i" % band) ax.set(xticks=[], yticks=[]) # This loop clears out the plots for axes which are empty # A matplotlib axis grid is always uniform with x cols and x rows # eg: an 8 band plot with 3 cols will always be 3 x 3 for ax in axs_ravel[total_layers:]: ax.set_axis_off() ax.set(xticks=[], yticks=[]) plt.tight_layout() return fig, axs elif arr.ndim == 2 or arr.shape[0] == 1: # If it's a 2 dimensional array with a 3rd dimension arr = np.squeeze(arr) fig, ax = plt.subplots(figsize=figsize) ax.imshow(es.bytescale(arr), cmap=cmap, extent=extent) if title: ax.set(title=title) ax.set(xticks=[], yticks=[]) return fig, ax
def test_bytescale_high_low_val(): """"Unit tests for earthpy.spatial.bytescale """ arr = np.random.randint(300, size=(10, 10)) # Bad high value with pytest.raises(ValueError, match="`high` should be less than or equal to 255."): es.bytescale(arr, high=300) # Bad low value with pytest.raises(ValueError, match="`low` should be greater than or equal to 0."): es.bytescale(arr, low=-100) # High value is less than low value with pytest.raises( ValueError, match="`high` should be greater than or equal to `low`."): es.bytescale(arr, high=100, low=150) # Valid case. should also take care of if statements for cmin/cmax val_arr = es.bytescale(arr, high=255, low=0) assert val_arr.min() == 0 assert val_arr.max() == 255 # Test scale value max is less than min with pytest.raises(ValueError, match="`cmax` should be larger than `cmin`."): es.bytescale(arr, cmin=100, cmax=50) # Test scale value max is less equal to min. Commented out for now because it breaks stuff somehow. with pytest.raises( ValueError, match= "`cmax` and `cmin` should not be the same value. Please specify `cmax` > `cmin`", ): es.bytescale(arr, cmin=100, cmax=100) # Test scale value max is less equal to min scale_arr = es.bytescale(arr, cmin=10, cmax=240) assert scale_arr.min() == 0 assert scale_arr.max() == 255
def plot_rgb( arr, rgb=(0, 1, 2), ax=None, extent=None, title="", figsize=(10, 10), stretch=None, str_clip=2, ): """Plot three bands in a numpy array as a composite RGB image. Parameters ---------- arr: numpy ndarray N-dimensional array in rasterio band order (bands, rows, columns) rgb: list Indices of the three bands to be plotted (default = 0,1,2) extent: tuple The extent object that matplotlib expects (left, right, bottom, top) title: string (optional) String representing the title of the plot ax: object The axes object where the ax element should be plotted. Default = none figsize: tuple (optional) The x and y integer dimensions of the output plot if preferred to set. stretch: Boolean If True a linear stretch will be applied str_clip: int (optional) The % of clip to apply to the stretch. Default = 2 (2 and 98) Returns ---------- fig, ax : figure object, axes object The figure and axes object associated with the 3 band image. If the ax keyword is specified, the figure return will be None. Example ------- .. plot:: >>> import matplotlib.pyplot as plt >>> import rasterio as rio >>> import earthpy.plot as ep >>> from earthpy.io import path_to_example >>> with rio.open(path_to_example('rmnp-rgb.tif')) as src: ... img_array = src.read() >>> ep.plot_rgb(img_array) #doctest: +ELLIPSIS (<Figure size 1000x1000 with 1 Axes>, ...) """ if len(arr.shape) != 3: raise Exception("""Input needs to be 3 dimensions and in rasterio order with bands first""") # Index bands for plotting and clean up data for matplotlib rgb_bands = arr[rgb, :, :] if stretch: s_min = str_clip s_max = 100 - str_clip arr_rescaled = np.zeros_like(rgb_bands) for ii, band in enumerate(rgb_bands): lower, upper = np.percentile(band, (s_min, s_max)) arr_rescaled[ii] = exposure.rescale_intensity(band, in_range=(lower, upper)) rgb_bands = arr_rescaled.copy() # If type is masked array - add alpha channel for plotting if ma.is_masked(rgb_bands): # Build alpha channel mask = ~(np.ma.getmask(rgb_bands[0])) * 255 # Add the mask to the array & swap the axes order from (bands, # rows, columns) to (rows, columns, bands) for plotting rgb_bands = np.vstack((es.bytescale(rgb_bands), np.expand_dims(mask, axis=0))).transpose([1, 2, 0]) else: # Index bands for plotting and clean up data for matplotlib rgb_bands = es.bytescale(rgb_bands).transpose([1, 2, 0]) # Then plot. Define ax if it's default to none if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = None ax.imshow(rgb_bands, extent=extent) ax.set_title(title) ax.set(xticks=[], yticks=[]) return fig, ax
# title="Sat RGB Imagery - Band 2 - Green", # cbar=False) # plt.show() #Plot blue band using earthpy # ep.plot_bands(tif_file[2], # title="Sat RGB Imagery - Band 3 - Blue", # cbar=False) # plt.show() # titles = ["Red Band", "Green Band", "Blue Band", "Near Infrared (NIR) Band"] # Plot all bands using the earthpy function # ep.plot_bands(tif_file, # figsize=(12, 5), # cols=2, # title=titles, # cbar=False) # plt.show() #Plot rgb bands combined into an image # ep.plot_rgb(tif_file, # rgb=[0, 1, 2], # title="RGB Composite image - Sat") #Plot rgb bands combined into an image, but apply stretch to increase contrast ep.plot_rgb(bytescale(tif_file), rgb=[0, 1, 2], title="RGB Composite image - Sat") plt.show()
def test_cmax_higher_than_cmin(byte_arr): """Fail gracefully when the cmax is smaller than the cmin.""" with pytest.raises(ValueError, match="`cmax` should be larger than `cmin`."): es.bytescale(byte_arr, cmin=100, cmax=50)
def test_high_val_greater_255(byte_arr): """A high value >255 should fail gracefully. """ with pytest.raises(ValueError, match="`high` should be less than or equal to 255."): es.bytescale(byte_arr, high=300)