Esempio n. 1
0
def matrix(data_matrix, mode='abs', **kwargs):
    """
    Create a "skyscraper" plot and a 2d color-coded plot of a matrix.

    Parameters
    ----------
    data_matrix: ndarray of float or complex
        2d matrix data
    mode: str from `constants.MODE_FUNC_DICT`
        choice of processing function to be applied to data
    **kwargs: dict
        standard plotting option (see separate documentation)

    Returns
    -------
    Figure, (Axes1, Axes2)
        figure and axes objects for further editing
    """
    if 'fig_ax' in kwargs:
        fig, (ax1, ax2) = kwargs['fig_ax']
    else:
        fig = plt.figure()
        ax1 = fig.add_subplot(1, 2, 1, projection='3d')
        ax2 = plt.subplot(1, 2, 2)

    matsize = len(data_matrix)
    element_count = matsize**2  # num. of elements to plot

    xgrid, ygrid = np.meshgrid(range(matsize), range(matsize))
    xgrid = xgrid.T.flatten() - 0.5  # center bars on integer value of x-axis
    ygrid = ygrid.T.flatten() - 0.5  # center bars on integer value of y-axis

    zbottom = np.zeros(element_count)  # all bars start at z=0
    dx = 0.75 * np.ones(element_count)  # width of bars in x-direction
    dy = dx  # width of bars in y-direction (same as x-direction)

    modefunction = constants.MODE_FUNC_DICT[mode]
    zheight = modefunction(
        data_matrix).flatten()  # height of bars from matrix elements
    nrm = mpl.colors.Normalize(
        0, max(zheight))  # <-- normalize colors to max. data
    colors = plt.cm.viridis(nrm(zheight))  # list of colors for each bar

    # skyscraper plot
    ax1.view_init(azim=210, elev=23)
    ax1.bar3d(xgrid, ygrid, zbottom, dx, dy, zheight, color=colors)
    ax1.axes.xaxis.set_major_locator(plt.IndexLocator(
        1, -0.5))  # set x-ticks to integers
    ax1.axes.yaxis.set_major_locator(plt.IndexLocator(
        1, -0.5))  # set y-ticks to integers
    ax1.set_zlim3d([0, max(zheight)])

    # 2d plot
    ax2.matshow(modefunction(data_matrix), cmap=plt.cm.viridis)
    cax, _ = mpl.colorbar.make_axes(
        ax2, shrink=.75, pad=.02)  # add colorbar with normalized range
    _ = mpl.colorbar.ColorbarBase(cax, cmap=plt.cm.viridis, norm=nrm)

    _process_options(fig, ax1, opts=defaults.matrix(), **kwargs)
    return fig, (ax1, ax2)
Esempio n. 2
0
def matrix_skyscraper(matrix: np.ndarray,
                      mode: str = "abs",
                      **kwargs) -> Tuple[Figure, Axes]:
    """Display a 3d skyscraper plot of the matrix

    Parameters
    ----------
    matrix:
        2d matrix data
    mode:
        choice from `constants.MODE_FUNC_DICT` for processing function to be applied to
        data
    **kwargs:
        standard plotting option (see separate documentation)

    Returns
    -------
        figure and axes objects for further editing
    """
    fig, axes = kwargs.get("fig_ax") or plt.subplots(projection="3d")

    y_count, x_count = matrix.shape  # We label the columns as "x", while rows as "y"
    element_count = x_count * y_count  # total num. of elements to plot

    xgrid, ygrid = np.meshgrid(range(x_count), range(y_count))
    xgrid = xgrid.flatten()
    ygrid = ygrid.flatten()

    zbottom = np.zeros(element_count)  # all bars start at z=0
    dx, dy = 0.75, 0.75  # width of bars in x and y directions

    modefunction = constants.MODE_FUNC_DICT[mode]
    zheight = modefunction(
        matrix).flatten()  # height of bars from matrix elements

    if mode == "abs" or mode == "abs_sqr":
        nrm = mpl.colors.Normalize(
            0, max(zheight))  # normalize colors between 0 and max. data
    else:
        nrm = mpl.colors.Normalize(
            min(zheight),
            max(zheight))  # normalize colors between min. and max. of data

    colors = plt.cm.viridis(nrm(zheight))  # list of colors for each bar

    # skyscraper plot
    axes.view_init(azim=210, elev=23)
    axes.bar3d(xgrid, ygrid, zbottom, dx, dy, zheight, color=colors)
    axes.set_zlim3d([0, max(zheight)])

    for axis, locs in [
        (axes.xaxis, np.arange(x_count)),
        (axes.yaxis, np.arange(y_count)),
    ]:
        axis.set_ticks(locs + 0.5, minor=True)
        axis.set(ticks=locs + 0.5, ticklabels=locs)

    _process_options(fig, axes, opts=defaults.matrix(), **kwargs)

    return fig, axes
Esempio n. 3
0
def matrix_skyscraper(matrix, mode='abs', **kwargs):
    """Display a 3d skyscraper plot of the matrix

    Parameters
    ----------
    matrix: ndarray of float or complex
        2d matrix data
    mode: str from `constants.MODE_FUNC_DICT`
        choice of processing function to be applied to data
    **kwargs: dict
        standard plotting option (see separate documentation)

    Returns
    -------
    Figure, Axes
        figure and axes objects for further editing
    """
    fig, axes = kwargs.get('fig_ax') or plt.subplots(projection='3d')

    matsize = len(matrix)
    element_count = matsize**2  # num. of elements to plot

    xgrid, ygrid = np.meshgrid(range(matsize), range(matsize))
    xgrid = xgrid.T.flatten() - 0.5  # center bars on integer value of x-axis
    ygrid = ygrid.T.flatten() - 0.5  # center bars on integer value of y-axis

    zbottom = np.zeros(element_count)  # all bars start at z=0
    dx = 0.75 * np.ones(element_count)  # width of bars in x-direction
    dy = dx  # width of bars in y-direction (same as x-direction)

    modefunction = constants.MODE_FUNC_DICT[mode]
    zheight = modefunction(
        matrix).flatten()  # height of bars from matrix elements
    nrm = mpl.colors.Normalize(
        0, max(zheight))  # <-- normalize colors to max. data
    colors = plt.cm.viridis(nrm(zheight))  # list of colors for each bar

    # skyscraper plot
    axes.view_init(azim=210, elev=23)
    axes.bar3d(xgrid, ygrid, zbottom, dx, dy, zheight, color=colors)
    axes.axes.xaxis.set_major_locator(plt.IndexLocator(
        1, -0.5))  # set x-ticks to integers
    axes.axes.yaxis.set_major_locator(plt.IndexLocator(
        1, -0.5))  # set y-ticks to integers
    axes.set_zlim3d([0, max(zheight)])

    _process_options(fig, axes, opts=defaults.matrix(), **kwargs)
    return fig, axes