Exemplo n.º 1
0
def test_cmax_equals_cmin(byte_arr):
    """Fail gracefully when the cmax is smaller than the cmin."""

    with pytest.raises(
            ValueError,
            match="`cmax` and `cmin` should not be the same value. "):
        es.bytescale(byte_arr, cmin=100, cmax=100)
Exemplo n.º 2
0
def test_bytescale_high_low_val():
    """"Unit tests for earthpy.spatial.bytescale """

    arr = np.random.randint(300, size=(10, 10))

    # Bad high value
    with pytest.raises(ValueError):
        es.bytescale(arr, high=300)

    # Bad low value
    with pytest.raises(ValueError):
        es.bytescale(arr, low=-100)

    # High value is less than low value
    with pytest.raises(ValueError):
        es.bytescale(arr, high=100, low=150)

    # Valid case. should also take care of if statements for cmin/cmax
    val_arr = es.bytescale(arr, high=255, low=0)

    assert val_arr.min() == 0
    assert val_arr.max() == 255

    # Test scale value max is less than min
    with pytest.raises(ValueError):
        es.bytescale(arr, cmin=100, cmax=50)

    # TODO: write test case for cmax == cmin
    # Commented out because it breaks for unknown reasons.
    # es.bytescale(arr, cmin=100, cmax=100)

    # Test scale value max is less equal to min
    scale_arr = es.bytescale(arr, cmin=10, cmax=240)
    assert scale_arr.min() == 0
    assert scale_arr.max() == 255
Exemplo n.º 3
0
def test_low_val_range(byte_arr):
    """A low value <0 should fail gracefully. """

    # Bad low value
    with pytest.raises(ValueError,
                       match="`low` should be greater than or equal to 0."):
        es.bytescale(byte_arr, low=-100)
Exemplo n.º 4
0
def test_high_lessthan_low(byte_arr):
    """Fail gracefully when the high value is lower than the low value."""

    # High value is less than low value
    with pytest.raises(
            ValueError,
            match="`high` should be greater than or equal to `low`."):
        es.bytescale(byte_arr, high=100, low=150)
Exemplo n.º 5
0
def test_cmax_cmin_work(byte_arr):
    """"Cmax and min values returns an arr with the range 0-255."""

    scale_arr = es.bytescale(byte_arr, cmin=10, cmax=240)

    assert scale_arr.min() == 0
    assert scale_arr.max() == 255
Exemplo n.º 6
0
def plot_rgb(arr,
             rgb=(0, 1, 2),
             figsize=(10, 10),
             str_clip=2,
             ax=None,
             extent=None,
             title="",
             stretch=None,
             dest_file=None):
    if len(arr.shape) != 3:
        raise ValueError("Input needs to be 3 dimensions and in rasterio "
                         "order with bands first")

    # Index bands for plotting and clean up data for matplotlib
    rgb_bands = arr[rgb, :, :]

    if stretch:
        rgb_bands = _stretch_im(rgb_bands, str_clip)

    # If type is masked array - add alpha channel for plotting
    if ma.is_masked(rgb_bands):
        # Build alpha channel
        mask = ~(np.ma.getmask(rgb_bands[0])) * 255

        # Add the mask to the array & swap the axes order from (bands,
        # rows, columns) to (rows, columns, bands) for plotting
        rgb_bands = np.vstack((es.bytescale(rgb_bands),
                               np.expand_dims(mask,
                                              axis=0))).transpose([1, 2, 0])
    else:
        # Index bands for plotting and clean up data for matplotlib
        rgb_bands = es.bytescale(rgb_bands).transpose([1, 2, 0])

    # Then plot. Define ax if it's undefined
    show = False
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
        show = True

    ax.imshow(rgb_bands, extent=extent)
    ax.set_title(title)
    ax.set(xticks=[], yticks=[])

    if show:
        plt.savefig(dest_file)
        plt.clf()
    return ax
Exemplo n.º 7
0
def test_low_high_vals_work(byte_arr):
    """The high/low param vals determine the min and max of the output arr."""

    # Valid case. should also take care of if statements for cmin/cmax
    val_arr = es.bytescale(byte_arr, high=255, low=0)

    assert val_arr.min() == 0
    assert val_arr.max() == 255
Exemplo n.º 8
0
def _plot_image(
    arr_im,
    cmap="Greys_r",
    title=None,
    extent=None,
    cbar=True,
    scale=True,
    vmin=None,
    vmax=None,
    ax=None,
):
    """
    Create a matplotlib figure with an image axis and associated extent.

    Parameters
    ----------
    arr_im : numpy array
        An n-dimensional numpy array to plot.
    cmap : str (default = "Greys_r")
        Colormap name for plots.
    title : str or list (optional)
        Title of one band or list of titles with one title per band.
    extent : tuple (optional)
        Bounding box that the data will fill: (minx, miny, maxx, maxy).
    cbar : Boolean (default = True)
        Turn off colorbar if needed.
    scale : Boolean (Default = True)
        Turn off bytescale scaling if needed.
    vmin : Int (Optional)
        Specify the vmin to scale imshow() plots.
    vmax : Int (Optional)
        Specify the vmax to scale imshow() plots.
    ax : Matplotlib axes object (Optional)
        Matplotlib axis object to plot image.

    Returns
    ----------
    ax : axes object
        The axes object(s) associated with the plot.
    """

    if scale:
        arr_im = es.bytescale(arr_im)

    im = ax.imshow(arr_im, cmap=cmap, vmin=vmin, vmax=vmax, extent=extent)
    if title:
        ax.set(title=title)
    if cbar:
        colorbar(im)
    ax.set(xticks=[], yticks=[])

    return ax
Exemplo n.º 9
0
def test_high_val_range():
    """A 16 bit int arr with values at the range end should be scaled properly.

    This test explicitly hits the user case of someone providing an array of a
    dtype at the end of it's range of values. Bytescale should not fail.
    """

    # The valid range of dtype int16 values is -32768 to 32767
    rgb_bands = np.array([1, 32767, 3, -32768]).astype("int16")
    arr = es.bytescale(rgb_bands)

    assert arr.min() == 0
    assert arr.max() == 255
Exemplo n.º 10
0
def _plot_image(
    arr_im,
    cmap="Greys_r",
    title=None,
    extent=None,
    cbar=True,
    scale=False,
    vmin=None,
    vmax=None,
    ax=None,
    alpha=1,
    norm=None,
):

    """
    Create a matplotlib figure with an image axis and associated extent.

    Parameters
    ----------
    arr_im : numpy array
        An n-dimensional numpy array to plot.
    cmap : str (default = "Greys_r")
        Colormap name for plots.
    title : str or list (optional)
        Title of one band or list of titles with one title per band.
    extent : tuple (optional)
        Bounding box that the data will fill: (minx, miny, maxx, maxy).
    cbar : Boolean (default = True)
        Turn off colorbar if needed.
    scale : Boolean (Default = False)
        Turn off bytescale scaling if needed.
    vmin : Int (Optional)
        Specify the vmin to scale imshow() plots.
    vmax : Int (Optional)
        Specify the vmax to scale imshow() plots.
    ax : Matplotlib axes object (Optional)
        Matplotlib axis object to plot image.
    alpha : float (optional)
        The alpha value for the plot. This will help adjust the transparency of
        the plot to the desired level.
    norm : matplotlib Normalize object (Optional)
        The normalized boundaries for custom values coloring. NOTE: For this
        argument to work, the scale argument MUST be set to false. Otherwise,
        the values will be scaled from 0-255.

    Returns
    ----------
    ax : matplotlib.axes object
        The axes object(s) associated with the plot.
    """

    if scale:
        arr_im = es.bytescale(arr_im)

    im = ax.imshow(
        arr_im,
        cmap=cmap,
        vmin=vmin,
        vmax=vmax,
        extent=extent,
        alpha=alpha,
        norm=norm,
    )
    if title:
        ax.set(title=title)
    if cbar:
        colorbar(im)
    ax.set(xticks=[], yticks=[])

    return ax
Exemplo n.º 11
0
def plot_rgb(
    arr,
    rgb=(0, 1, 2),
    figsize=(10, 10),
    str_clip=2,
    ax=None,
    extent=None,
    title="",
    stretch=None,
):
    """Plot three bands in a numpy array as a composite RGB image.

    Parameters
    ----------
    arr : numpy array
        An n-dimensional array in rasterio band order (bands, rows, columns)
        containing the layers to plot.
    rgb : list (default = (0, 1, 2))
        Indices of the three bands to be plotted.
    figsize : tuple (default = (10, 10)
        The x and y integer dimensions of the output plot.
    str_clip: int (default = 2)
        The percentage of clip to apply to the stretch. Default = 2 (2 and 98).
    ax : object (optional)
        The axes object where the ax element should be plotted.
    extent : tuple (optional)
        The extent object that matplotlib expects (left, right, bottom, top).
    title : string (optional)
        The intended title of the plot.
    stretch : Boolean (optional)
        Application of a linear stretch. If set to True, a linear stretch will
        be applied.

    Returns
    ----------
    ax : axes object
        The axes object associated with the 3 band image.

    Example
    -------

    .. plot::

        >>> import matplotlib.pyplot as plt
        >>> import rasterio as rio
        >>> import earthpy.plot as ep
        >>> from earthpy.io import path_to_example
        >>> with rio.open(path_to_example('rmnp-rgb.tif')) as src:
        ...     img_array = src.read()
        >>> # Ensure the input array doesn't have nodata values like -9999
        >>> ep.plot_rgb(img_array)
        <AxesSubplot:>

    """

    if len(arr.shape) != 3:
        raise ValueError(
            "Input needs to be 3 dimensions and in rasterio "
            "order with bands first"
        )

    # Index bands for plotting and clean up data for matplotlib
    rgb_bands = arr[rgb, :, :]

    if stretch:
        rgb_bands = _stretch_im(rgb_bands, str_clip)

    nan_check = np.isnan(rgb_bands)

    if np.any(nan_check):
        rgb_bands = np.ma.masked_array(rgb_bands, nan_check)

    # If type is masked array - add alpha channel for plotting
    if ma.is_masked(rgb_bands):
        # Build alpha channel
        mask = ~(np.ma.getmask(rgb_bands[0])) * 255

        # Add the mask to the array & swap the axes order from (bands,
        # rows, columns) to (rows, columns, bands) for plotting
        rgb_bands = np.vstack(
            (es.bytescale(rgb_bands), np.expand_dims(mask, axis=0))
        ).transpose([1, 2, 0])
    else:
        # Index bands for plotting and clean up data for matplotlib
        rgb_bands = es.bytescale(rgb_bands).transpose([1, 2, 0])

    # Then plot. Define ax if it's undefined
    show = False
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
        show = True

    ax.imshow(rgb_bands, extent=extent)
    ax.set_title(title)
    ax.set(xticks=[], yticks=[])

    # Multipanel won't work if plt.show is called prior to second plot def
    if show:
        plt.show()
    return ax
Exemplo n.º 12
0
def plot_bands(
    arr, cmap="Greys_r", figsize=(12, 12), cols=3, title=None, extent=None
):
    """Plot each band in a numpy array in its own axis.

    Assumes band order (band, row, col).

    Parameters
    ----------
    arr : numpy array
        An n-dimensional numpy array to plot.
    cmap : str (default = "Greys_r")
        Colormap name for plots.
    figsize : tuple (default = (12, 12))
        Figure size in inches.
    cols : int (default = 3)
        Number of columns for plot grid.
    title : str or list (optional)
        Title of one band or list of titles with one title per band.
    extent : tuple (optional)
        Bounding box that the data will fill: (minx, miny, maxx, maxy).

    Returns
    ----------
    tuple
        fig : figure object
            The figure of the plotted band(s).
        ax or axs : axes object(s)
            The axes object(s) associated with the plot.

    Example
    -------
    .. plot::

        >>> import matplotlib.pyplot as plt
        >>> import earthpy.plot as ep
        >>> from earthpy.io import path_to_example
        >>> import rasterio as rio
        >>> titles = ['Red', 'Green', 'Blue']
        >>> with rio.open(path_to_example('rmnp-rgb.tif')) as src:
        ...     ep.plot_bands(src.read(),
        ...                   title=titles,
        ...                   figsize=(8, 3))
        (<Figure size ... with 3 Axes>, ...)
    """

    try:
        arr.ndim
    except AttributeError:
        "Input arr should be a numpy array"

    if title:
        if (arr.ndim == 2) and (len(title) > 1):
            raise ValueError(
                """Plot_bands() expects one title for a single
                             band array. You have provided more than one
                             title."""
            )
        elif not (len(title) == arr.shape[0]):
            raise ValueError(
                """Plot_bands() expects the number of plot titles
                             to equal the number of array raster layers."""
            )

    # If the array is 3 dimensional setup grid plotting
    if arr.ndim > 2 and arr.shape[0] > 1:

        # Calculate the total rows that will be required to plot each band
        plot_rows = int(np.ceil(arr.shape[0] / cols))
        total_layers = arr.shape[0]

        # Plot all bands
        fig, axs = plt.subplots(plot_rows, cols, figsize=figsize)
        axs_ravel = axs.ravel()
        for ax, i in zip(axs_ravel, range(total_layers)):
            band = i + 1
            ax.imshow(es.bytescale(arr[i]), cmap=cmap)
            if title:
                ax.set(title=title[i])
            else:
                ax.set(title="Band %i" % band)
            ax.set(xticks=[], yticks=[])
        # This loop clears out the plots for axes which are empty
        # A matplotlib axis grid is always uniform with x cols and x rows
        # eg: an 8 band plot with 3 cols will always be 3 x 3
        for ax in axs_ravel[total_layers:]:
            ax.set_axis_off()
            ax.set(xticks=[], yticks=[])
        plt.tight_layout()
        return fig, axs

    elif arr.ndim == 2 or arr.shape[0] == 1:
        # If it's a 2 dimensional array with a 3rd dimension
        arr = np.squeeze(arr)

        fig, ax = plt.subplots(figsize=figsize)
        ax.imshow(es.bytescale(arr), cmap=cmap, extent=extent)
        if title:
            ax.set(title=title)
        ax.set(xticks=[], yticks=[])
        return fig, ax
Exemplo n.º 13
0
def test_bytescale_high_low_val():
    """"Unit tests for earthpy.spatial.bytescale """
    arr = np.random.randint(300, size=(10, 10))

    # Bad high value
    with pytest.raises(ValueError,
                       match="`high` should be less than or equal to 255."):
        es.bytescale(arr, high=300)

    # Bad low value
    with pytest.raises(ValueError,
                       match="`low` should be greater than or equal to 0."):
        es.bytescale(arr, low=-100)

    # High value is less than low value
    with pytest.raises(
            ValueError,
            match="`high` should be greater than or equal to `low`."):
        es.bytescale(arr, high=100, low=150)

    # Valid case. should also take care of if statements for cmin/cmax
    val_arr = es.bytescale(arr, high=255, low=0)

    assert val_arr.min() == 0
    assert val_arr.max() == 255

    # Test scale value max is less than min
    with pytest.raises(ValueError,
                       match="`cmax` should be larger than `cmin`."):
        es.bytescale(arr, cmin=100, cmax=50)

    # Test scale value max is less equal to min. Commented out for now because it breaks stuff somehow.
    with pytest.raises(
            ValueError,
            match=
            "`cmax` and `cmin` should not be the same value. Please specify `cmax` > `cmin`",
    ):
        es.bytescale(arr, cmin=100, cmax=100)

    # Test scale value max is less equal to min
    scale_arr = es.bytescale(arr, cmin=10, cmax=240)

    assert scale_arr.min() == 0
    assert scale_arr.max() == 255
Exemplo n.º 14
0
def plot_rgb(
        arr,
        rgb=(0, 1, 2),
        ax=None,
        extent=None,
        title="",
        figsize=(10, 10),
        stretch=None,
        str_clip=2,
):
    """Plot three bands in a numpy array as a composite RGB image.

    Parameters
    ----------
    arr: numpy ndarray
        N-dimensional array in rasterio band order (bands, rows, columns)
    rgb: list
        Indices of the three bands to be plotted (default = 0,1,2)
    extent: tuple
        The extent object that matplotlib expects (left, right, bottom, top)
    title: string (optional)
        String representing the title of the plot
    ax: object
        The axes object where the ax element should be plotted. Default = none
    figsize: tuple (optional)
        The x and y integer dimensions of the output plot if preferred to set.
    stretch: Boolean
        If True a linear stretch will be applied
    str_clip: int (optional)
        The % of clip to apply to the stretch. Default = 2 (2 and 98)

    Returns
    ----------
    fig, ax : figure object, axes object
        The figure and axes object associated with the 3 band image. If the
        ax keyword is specified,
        the figure return will be None.

    Example
    -------

    .. plot::

        >>> import matplotlib.pyplot as plt
        >>> import rasterio as rio
        >>> import earthpy.plot as ep
        >>> from earthpy.io import path_to_example
        >>> with rio.open(path_to_example('rmnp-rgb.tif')) as src:
        ...     img_array = src.read()
        >>> ep.plot_rgb(img_array) #doctest: +ELLIPSIS
        (<Figure size 1000x1000 with 1 Axes>, ...)

    """

    if len(arr.shape) != 3:
        raise Exception("""Input needs to be 3 dimensions and in rasterio
                           order with bands first""")

    # Index bands for plotting and clean up data for matplotlib
    rgb_bands = arr[rgb, :, :]

    if stretch:
        s_min = str_clip
        s_max = 100 - str_clip
        arr_rescaled = np.zeros_like(rgb_bands)
        for ii, band in enumerate(rgb_bands):
            lower, upper = np.percentile(band, (s_min, s_max))
            arr_rescaled[ii] = exposure.rescale_intensity(band,
                                                          in_range=(lower,
                                                                    upper))
        rgb_bands = arr_rescaled.copy()

    # If type is masked array - add alpha channel for plotting
    if ma.is_masked(rgb_bands):
        # Build alpha channel
        mask = ~(np.ma.getmask(rgb_bands[0])) * 255

        # Add the mask to the array & swap the axes order from (bands,
        # rows, columns) to (rows, columns, bands) for plotting
        rgb_bands = np.vstack((es.bytescale(rgb_bands),
                               np.expand_dims(mask,
                                              axis=0))).transpose([1, 2, 0])
    else:
        # Index bands for plotting and clean up data for matplotlib
        rgb_bands = es.bytescale(rgb_bands).transpose([1, 2, 0])

    # Then plot. Define ax if it's default to none
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    else:
        fig = None
    ax.imshow(rgb_bands, extent=extent)
    ax.set_title(title)
    ax.set(xticks=[], yticks=[])
    return fig, ax
Exemplo n.º 15
0
#               title="Sat RGB Imagery - Band 2 - Green",
#               cbar=False)
# plt.show()

#Plot blue band using earthpy
# ep.plot_bands(tif_file[2],
#               title="Sat RGB Imagery - Band 3 - Blue",
#               cbar=False)
# plt.show()

# titles = ["Red Band", "Green Band", "Blue Band", "Near Infrared (NIR) Band"]

# Plot all bands using the earthpy function
# ep.plot_bands(tif_file,
#               figsize=(12, 5),
#               cols=2,
#               title=titles,
#               cbar=False)
# plt.show()

#Plot rgb bands combined into an image
# ep.plot_rgb(tif_file,
#            rgb=[0, 1, 2],
#            title="RGB Composite image - Sat")

#Plot rgb bands combined into an image, but apply stretch to increase contrast
ep.plot_rgb(bytescale(tif_file),
            rgb=[0, 1, 2],
            title="RGB Composite image - Sat")
plt.show()
Exemplo n.º 16
0
def test_cmax_higher_than_cmin(byte_arr):
    """Fail gracefully when the cmax is smaller than the cmin."""

    with pytest.raises(ValueError,
                       match="`cmax` should be larger than `cmin`."):
        es.bytescale(byte_arr, cmin=100, cmax=50)
Exemplo n.º 17
0
def test_high_val_greater_255(byte_arr):
    """A high value >255 should fail gracefully. """

    with pytest.raises(ValueError,
                       match="`high` should be less than or equal to 255."):
        es.bytescale(byte_arr, high=300)