コード例 #1
0
def create_colorbar(im,
                    ax,
                    fig,
                    size='5%',
                    padding=0.05,
                    position='right',
                    divider=None,
                    use_ax=False):  #pragma: no cover
    if use_ax is False:
        if divider is None:
            divider = _make_axes_locatable(ax)
        cax = divider.append_axes(position, size=size, pad=padding)
    else:
        cax = ax
    ca = fig.colorbar(im, cax=cax)
    cax.yaxis.set_ticks_position(position)
    cax.yaxis.set_label_position(position)
    ca.solids.set_rasterized(True)
    return (ca, divider, cax)
コード例 #2
0
ファイル: burpfile.py プロジェクト: meteokid/python-rpn
def plot_burp(bf, code=None, cval=None, ax=None, level=0, mask=None, projection='cyl', cbar_opt={}, vals_opt={}, dparallel=30., dmeridian=60., fontsize=20, **kwargs):
    """
    Plots a BURP file. Will plot BUFR code if specified. Only plots a single level, 
    which can be specified by the optional argument. Additional arguments not listed
    below will be passes to Basemap.scatter.

    Parameters
    ----------
      bf           BurpFile instance
      code         code to be plotted, if not set the observation locations will be plotted
      level        level to be plotted
      mask         mask to apply to data
      cval         directly supply plotted values instead of using code argument
      ax           subplot object
      projection   projection used by Basemap for plotting
      cbar_opt     dictionary for colorbar options
      vals_opt     dictionary for get_rval options if code is supplied
      dparallel    spacing of parallels, if None will not plot parallels
      dmeridian    spacing of meridians, if None will not plot meridians
      fontsize     font size for labels of parallels and meridians

    Returns
    -------
      dictionary with keys:
        m          Basemap.scatter object used for plotting
        cbar       colorbar object if used
    """
    assert isinstance(bf, BurpFile), "First argument must be an instance of BurpFile"
    assert code is None or cval is None, "Only one of code, cval should be supplied as an argument"

    if ax is None:
        fig = _plt.figure(figsize=(18, 11))
        ax = fig.add_subplot(111)
    else:
        fig = ax.get_figure()

    if bf.nrep==0:
        return

    opt = kwargs.copy()
    if not 's' in opt.keys():
        opt['s'] = 10
    if not 'edgecolors' in opt.keys():
        opt['edgecolors'] = 'None'
    if not 'cmap' in opt.keys() and (code is not None or cval is not None):
        opt['cmap'] = _plt.cm.jet

    if code is not None:
        vals = bf.get_rval(code, **vals_opt)
        if len(vals.shape)>1:
            vals = vals[:, level]
        opt['c'] = vals
    elif cval is not None:
        opt['c'] = cval

    msk = _np.array([ stn[:2] for stn in bf.stnids ]) != '>>'  # don't plot resumes

    if not mask is None:
        msk = _np.logical_and(msk, mask)

    lon = bf.lon[msk]
    lat = bf.lat[msk]
    if code is not None or cval is not None:
        opt['c'] = opt['c'][msk]

    basemap_opt = {'projection':projection, 'resolution':'c'}

    if projection=='cyl':
        basemap_opt.update({'llcrnrlat':-90, 'urcrnrlat':90, 'llcrnrlon':-180, 'urcrnrlon':180})
    elif projection=='npstere':
        basemap_opt.update({'boundinglat':10., 'lon_0':0.})
    elif projection=='spstere':
        basemap_opt.update({'boundinglat':-10., 'lon_0':270.})

    m = _Basemap(ax=ax, **basemap_opt)

    m.drawcoastlines()

    xpt, ypt = m(lon, lat)

    sctr = m.scatter(xpt, ypt, **opt)

    if dparallel is not None:
        m.drawparallels(_np.arange(-90, 91, dparallel), labels=[1, 0, 0, 0], color='grey', fontsize=fontsize)
    if dmeridian is not None:
        m.drawmeridians(_np.arange(-180, 180, dmeridian), labels=[0, 0, 0, 1], color='grey', fontsize=fontsize)

    output = {'m':m}

    if code is not None or cval is not None:
        divider = _make_axes_locatable(ax)
        cax = divider.append_axes("right", size="2%", pad=0.5)
        cbar = fig.colorbar(sctr, ax=ax, cax=cax, **cbar_opt)
        output['cbar'] = cbar

    return output
コード例 #3
0
def show_transform_quantiles(img,
                             transform,
                             fraction=1.0,
                             area_mask=None,
                             ax=None):
    """
    Show a plot of the quantiles of the transform coefficients.

    The `fraction` of `transform` coefficients holding the most energy for the
    image `img` is considered. The four quantiles within this fraction of
    coefficients are illustrated in the `transform` domain by showing
    coefficients between the different quantiles in different colours. If an
    `area_mask` is specified, only this area in the plot is higlighted whereas
    the rest is darkened.

    Parameters
    ----------
    img : ndarray
        The image to show the transform quantiles for.
    transform : str
        The transform to use.
    fraction : float
        The fraction of coefficients used in the quantiles calculation (the
        default value is 1.0, which implies that all coefficients are used).
    area_mask : ndarray
        Bool array of the same shape as `img` which indicates the area of the
        image to highlight (the default value is None, which implies that no
        particular part of the image is highlighted).
    ax : matplotlib.axes.Axes
        The axes on which the image is displayed (the default is None, which
        implies that a separate figure is created).

    Returns
    -------
    coef_count : dict
        Different counts of coeffcients within the `area_mask` (only returned
        if `area_mask` is not None).

    Notes
    -----
    The ticks on the colorbar shown below the figure are the percentiles of the
    entire set of coefficients corresponding to the quantiles with fraction of
    coefficients. For instance, if fraction is 0.10, then the percentiles are
    92.5, 95.0, 97.5, 100.0, corresponding to the four quantiles within the 10
    percent coefficients holding the most energy.

    The `coef_count` dictionary holds the following keys:

    * C_total : Total number of considered coefficients.
    * Q_potential : Number of potential coefficients within `mask_area`.
    * P_fraction : The fraction of Q_potential to the pixel count in `img`.
    * Q_total : Total number of (considered) coefficients within `mask_area`.
    * Q_fraction : The fraction of Q_total to Q_potential
    * QC_fraction : The fraction of Q_total to C_total.
    * Q0-Q1 : Number of coefficients smaller than the first quantile.
    * Q1-Q2 : Number of coefficients between the first and second quantile.
    * Q2-Q3 : Number of coefficients between the second and third quantile.
    * Q3-Q4 : Number of coefficients between the third and fourth quantile.

    Each of the QX-QY holds a tuple containing two values:

    1. The number of coefficients.
    2. The fraction of the number of coefficients to Q_total.

    Examples
    --------
    For example, show quantiles for a fraction of 0.2 of the DCT coefficients:

    >>> import numpy as np
    >>> from magni.imaging.dictionaries import analysis as _a
    >>> img = np.arange(64).astype(np.float).reshape(8, 8)
    >>> transforms = 'DCT'
    >>> fraction = 0.2
    >>> _a.show_transform_quantiles(img, transform, fraction=fraction)

    """
    @_decorate_validation
    def validate_input():
        _numeric('img', ('integer', 'floating', 'complex'), shape=(-1, -1))
        _generic('transform', 'string', value_in=_utils.get_transform_names())
        _numeric('fraction', 'floating', range_='[0;1]')
        _numeric('area_mask', 'boolean', shape=img.shape, ignore_none=True)
        _generic('ax', mpl.axes.Axes, ignore_none=True)

    @_decorate_validation
    def validate_output():
        _generic('coef_counts',
                 'mapping',
                 has_keys=('Q_total', 'Q_potential', 'Q0_Q1', 'Q_fraction',
                           'Q3_Q4', 'Q2_Q3', 'Q1_Q2', 'C_total',
                           'QC_fraction'))

    validate_input()

    # Colorbrewer qualitative 5-class Set 1 as colormap
    colours = [(228, 26, 28), (55, 126, 184), (77, 175, 74), (152, 78, 163),
               (255, 127, 0)]
    norm_colours = [
        tuple([round(val / 255, 4) for val in colour])
        for colour in colours[::-1]
    ]
    norm_colours = [norm_colours[0]] * 2 + norm_colours
    quantile_cmap = mpl.colors.ListedColormap(norm_colours)

    # Transform
    transform_matrix = _utils.get_function_handle('matrix',
                                                  transform)(img.shape)
    all_coefficients = _vec2mat(transform_matrix.conj().T.dot(_mat2vec(img)),
                                img.shape)
    # Force very low values to zero to avoid false visualisations
    all_coefficients[all_coefficients < np.finfo(np.float).eps * 10] = 0

    # Masked coefficients
    sorted_coefficients = np.sort(np.abs(all_coefficients), axis=None)[::-1]
    mask = np.abs(all_coefficients) > sorted_coefficients[
        int(np.round(fraction * all_coefficients.size)) - 1]

    used_coefficients = np.zeros_like(all_coefficients, dtype=np.float)
    used_coefficients[mask] = np.abs(all_coefficients[mask])

    # Quantiles
    q_linspace = np.linspace((1 - fraction) * 100, 100, 5)
    q_percentiles = tuple(q_linspace[1:4])
    quantiles = np.percentile(used_coefficients, q_percentiles)
    disp_coefficients = np.zeros_like(used_coefficients)
    disp_coefficients[(0 < used_coefficients)
                      & (used_coefficients <= quantiles[0])] = 1
    disp_coefficients[(quantiles[0] < used_coefficients)
                      & (used_coefficients <= quantiles[1])] = 2
    disp_coefficients[(quantiles[1] < used_coefficients)
                      & (used_coefficients <= quantiles[2])] = 3
    disp_coefficients[quantiles[2] < used_coefficients] = 4

    # Quantile figure
    disp, axes_extent = _utils.get_function_handle('visualisation',
                                                   transform)(img.shape)
    if ax is None:
        fig, axes = plt.subplots(1, 1)
    else:
        fig = ax.get_figure()
        axes = ax
    im = _imshow(
        disp(_mat2vec(10**disp_coefficients)),
        ax=axes,  # anti-log10
        cmap=quantile_cmap,
        show_axis='top',
        interpolation='none',
        extent=axes_extent)
    divider = _make_axes_locatable(axes)
    c_bar_ax = divider.append_axes('bottom', '5%', pad='3%')
    cbar = fig.colorbar(im, c_bar_ax, orientation='horizontal')
    cbar.solids.set_edgecolor("face")
    cbar.set_ticks([0.85, 1.705, 2.278, 2.85, 3.419, 4.0])
    cbar.set_ticklabels(['Excluded'] + [str(q) for q in q_linspace])

    plt.tight_layout(rect=(0, 0, 1, 0.95))

    # Area mask
    if area_mask is not None:
        _imshow(np.ma.array(np.ones_like(disp_coefficients), mask=area_mask),
                ax=axes,
                cmap='gray',
                show_axis='top',
                interpolation='none',
                extent=axes_extent,
                alpha=0.15)

        # Count of coefficients
        Q_total = np.sum(disp(_mat2vec(10**disp_coefficients))[area_mask] != 0)
        Qs = [
            np.sum(disp(_mat2vec(10**disp_coefficients))[area_mask] == k)
            for k in [1, 2, 3, 4]
        ]
        coef_counts = {
            'Q' + str(k - 1) + '_Q' + str(k):
            (Qs[k - 1], round(Qs[k - 1] / Q_total, 2))
            for k in [1, 2, 3, 4]
        }
        coef_counts['Q_total'] = Q_total
        coef_counts['Q_potential'] = np.sum(area_mask)
        coef_counts['Q_fraction'] = round(Q_total / coef_counts['Q_potential'],
                                          2)
        coef_counts['C_total'] = np.sum(used_coefficients != 0)
        coef_counts['P_fraction'] = round(
            coef_counts['Q_potential'] / img.size, 2)
        coef_counts['QC_fraction'] = round(Q_total / coef_counts['C_total'], 2)

        validate_output()

        return coef_counts
コード例 #4
0
def plotContour(x, y, z, xlabel=None, ylabel=None, zlabel=None):
    """A simple function to create two-dimensional contour plots with matplotlib.

       Parameters
       ----------

       x : list
           A list of x data values.

       y : list
           A list of y data values.

       z : list
           A list of z data values.

       xlabel : str
           The x axis label string.

       ylabel : str
           The y axis label string.

       zlabel : str
           The z axis label string.
    """

    import numpy as _np
    import scipy.interpolate as _interp

    from mpl_toolkits.axes_grid1 import make_axes_locatable as _make_axes_locatable

    # Make sure were running interactively.
    if not _is_interactive:
        _warn("You can only use BioSimSpace.Notebook.plot when running interactively.")
        return None

    # Matplotlib failed to import.
    if not _has_matplotlib and _has_display:
        _warn("BioSimSpace.Notebook.plot is disabled as matplotlib failed "
            "to load. Please check your matplotlib installation.")
        return None

    # Convert tuple to a list.
    if type(x) is tuple:
        x = list(x)
    if type(y) is tuple:
        y = list(y)

    # Whether we need to convert the x, y, and z data to floats.
    is_unit_x = False
    is_unit_y = False
    is_unit_z = False

    # The x argument must be a list of data records.
    if type(x) is not list:
        raise TypeError("'x' must be of type 'list'")

    else:
        # Make sure all records are of the same type.
        _type = type(x[0])
        if not all(isinstance(xx, _type) for xx in x):
            raise TypeError("All 'x' data values must be of same type")

        # Convert int to float.
        if _type is int:
            x = [float(xx) for xx in x]
            _type = float

        # Does this type have units?
        if isinstance(x[0], _Type):
            is_unit_x = True

    # The y argument must be a list of data records.
    if type(y) is not list:
        raise TypeError("'y' must be of type 'list'")

    else:
        # Make sure all records are of the same type.
        _type = type(y[0])
        if not all(isinstance(yy, _type) for yy in y):
            raise TypeError("All 'y' data values must be of same type")

        # Convert int to float.
        if _type is int:
            y = [float(yy) for yy in y]
            _type = float

        # Does this type have units?
        if isinstance(y[0], _Type):
            is_unit_y = True

    if type(z) is not list:
        raise TypeError("'z' must be of type 'list'")

    else:
        # Make sure all records are of the same type.
        _type = type(z[0])
        if not all(isinstance(zz, _type) for zz in z):
            raise TypeError("All 'z' data values must be of same type")

        # Convert int to float.
        if _type is int:
            z = [float(zz) for zz in z]
            _type = float

        # Does this type have units?
        if isinstance(z[0], _Type):
            is_unit_z = True

    # Lists must contain the same number of records.
    # Truncate the longer list to the length of the shortest.
    if len(x) != len(y) or \
       len(x) != len(z) or \
       len(y) != len(z):
        _warn("Mismatch in list sizes: len(x) = %d, len(y) = %d, len(z) = %d"
            % (len(x), len(y), len(z)))

        lens = [len(x), len(y), len(z)]
        min_len = min(lens)

        x = x[:min_len]
        y = y[:min_len]
        z = z[:min_len]

    if xlabel is not None:
        if type(xlabel) is not str:
            raise TypeError("'xlabel' must be of type 'str'")
    else:
        if isinstance(x[0], _Type):
            xlabel = x[0].__class__.__qualname__ + " (" + x[0]._print_format[x[0].unit()] + ")"

    if ylabel is not None:
        if type(ylabel) is not str:
            raise TypeError("'ylabel' must be of type 'str'")
    else:
        if isinstance(y[0], _Type):
            ylabel = y[0].__class__.__qualname__ + " (" + y[0]._print_format[y[0].unit()] + ")"

    if zlabel is not None:
        if type(zlabel) is not str:
            raise TypeError("'zlabel' must be of type 'str'")
    else:
        if isinstance(z[0], _Type):
            zlabel = z[0].__class__.__qualname__ + " (" + z[0]._print_format[z[0].unit()] + ")"

    # Convert the x and y values to floats.
    if is_unit_x:
        x = [x.magnitude() for x in x]
    if is_unit_y:
        y = [y.magnitude() for y in y]
    if is_unit_z:
        z = [z.magnitude() for z in z]

    # Convert to two-dimensional arrays. We don't assume the data is on a grid,
    # so we interpolate the z values.
    try:
        X, Y, = _np.meshgrid(_np.linspace(_np.min(x), _np.max(y), 1000),
                             _np.linspace(_np.min(y), _np.max(y), 1000))
        Z = _interp.griddata((x, y), z, (X, Y), method="linear")
    except:
        raise ValueError("Unable to interpolate x, y, and z data to a grid.")

    # Set the figure size.
    _plt.figure(figsize=(8, 8))

    # Create the contour plot.
    cp = _plt.contourf(X, Y, Z)

    # Add axis labels.
    if xlabel is not None:
        _plt.xlabel(xlabel)
    if ylabel is not None:
        _plt.ylabel(ylabel)

    # Get the current axes.
    ax = _plt.gca()

    # Make sure the axes are equal.
    ax.set_aspect("equal", adjustable="box")

    # Make sure the colour bar matches size of the axes.
    divider = _make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.1)

    # Add a colour bar and label it.
    cbar = _plt.colorbar(cp, cax=cax)
    if zlabel is not None:
        cbar.set_label(zlabel)

    return _plt.show()
コード例 #5
0
def plot_burp(bf,
              code=None,
              cval=None,
              ax=None,
              level=0,
              mask=None,
              projection='cyl',
              cbar_opt={},
              vals_opt={},
              dparallel=30.,
              dmeridian=60.,
              fontsize=20,
              **kwargs):
    """
    Plots a BURP file. Will plot BUFR code if specified. Only plots a single level,
    which can be specified by the optional argument. Additional arguments not listed
    below will be passes to Basemap.scatter.

    Args:
        bf         : BurpFile instance
        code       : code to be plotted, if not set the observation locations will be plotted
        level      : level to be plotted
        mask       : mask to apply to data
        cval       : directly supply plotted values instead of using code argument
        ax         : subplot object
        projection : projection used by Basemap for plotting
        cbar_opt   : dictionary for colorbar options
        vals_opt   : dictionary for get_rval options if code is supplied
        dparallel  : spacing of parallels, if None will not plot parallels
        dmeridian  : spacing of meridians, if None will not plot meridians
        fontsize   : font size for labels of parallels and meridians

    Returns:
        {
            'm'    : Basemap.scatter object used for plotting
            'cbar' : colorbar object if used
        }
    """
    assert isinstance(
        bf, BurpFile), "First argument must be an instance of BurpFile"
    assert code is None or cval is None, "Only one of code, cval should be supplied as an argument"

    if ax is None:
        fig = _plt.figure(figsize=(18, 11))
        ax = fig.add_subplot(111)
    else:
        fig = ax.get_figure()

    if bf.nrep == 0:
        return

    opt = kwargs.copy()
    if not 's' in opt.keys():
        opt['s'] = 10
    if not 'edgecolors' in opt.keys():
        opt['edgecolors'] = 'None'
    if not 'cmap' in opt.keys() and (code is not None or cval is not None):
        opt['cmap'] = _plt.cm.jet

    if code is not None:
        vals = bf.get_rval(code, **vals_opt)
        if len(vals.shape) > 1:
            vals = vals[:, level]
        opt['c'] = vals
    elif cval is not None:
        opt['c'] = cval

    msk = _np.array([stn[:2]
                     for stn in bf.stnids]) != '>>'  # don't plot resumes

    if not mask is None:
        msk = _np.logical_and(msk, mask)

    lon = bf.lon[msk]
    lat = bf.lat[msk]
    if code is not None or cval is not None:
        opt['c'] = opt['c'][msk]

    basemap_opt = {'projection': projection, 'resolution': 'c'}

    if projection == 'cyl':
        basemap_opt.update({
            'llcrnrlat': -90,
            'urcrnrlat': 90,
            'llcrnrlon': -180,
            'urcrnrlon': 180
        })
    elif projection == 'npstere':
        basemap_opt.update({'boundinglat': 10., 'lon_0': 0.})
    elif projection == 'spstere':
        basemap_opt.update({'boundinglat': -10., 'lon_0': 270.})

    m = _Basemap(ax=ax, **basemap_opt)

    m.drawcoastlines()

    xpt, ypt = m(lon, lat)

    sctr = m.scatter(xpt, ypt, **opt)

    if dparallel is not None:
        m.drawparallels(_np.arange(-90, 91, dparallel),
                        labels=[1, 0, 0, 0],
                        color='grey',
                        fontsize=fontsize)
    if dmeridian is not None:
        m.drawmeridians(_np.arange(-180, 180, dmeridian),
                        labels=[0, 0, 0, 1],
                        color='grey',
                        fontsize=fontsize)

    output = {'m': m}

    if code is not None or cval is not None:
        divider = _make_axes_locatable(ax)
        cax = divider.append_axes("right", size="2%", pad=0.5)
        cbar = fig.colorbar(sctr, ax=ax, cax=cax, **cbar_opt)
        output['cbar'] = cbar

    return output