Esempio n. 1
0
def matrix(data_matrix, mode='abs', **kwargs):
    """
    Create a "skyscraper" plot and a 2d color-coded plot of a matrix.

    Parameters
    ----------
    data_matrix: ndarray of float or complex
        2d matrix data
    mode: str from `constants.MODE_FUNC_DICT`
        choice of processing function to be applied to data
    **kwargs: dict
        standard plotting option (see separate documentation)

    Returns
    -------
    Figure, (Axes1, Axes2)
        figure and axes objects for further editing
    """
    if 'fig_ax' in kwargs:
        fig, (ax1, ax2) = kwargs['fig_ax']
    else:
        fig = plt.figure()
        ax1 = fig.add_subplot(1, 2, 1, projection='3d')
        ax2 = plt.subplot(1, 2, 2)

    matsize = len(data_matrix)
    element_count = matsize**2  # num. of elements to plot

    xgrid, ygrid = np.meshgrid(range(matsize), range(matsize))
    xgrid = xgrid.T.flatten() - 0.5  # center bars on integer value of x-axis
    ygrid = ygrid.T.flatten() - 0.5  # center bars on integer value of y-axis

    zbottom = np.zeros(element_count)  # all bars start at z=0
    dx = 0.75 * np.ones(element_count)  # width of bars in x-direction
    dy = dx  # width of bars in y-direction (same as x-direction)

    modefunction = constants.MODE_FUNC_DICT[mode]
    zheight = modefunction(
        data_matrix).flatten()  # height of bars from matrix elements
    nrm = mpl.colors.Normalize(
        0, max(zheight))  # <-- normalize colors to max. data
    colors = plt.cm.viridis(nrm(zheight))  # list of colors for each bar

    # skyscraper plot
    ax1.view_init(azim=210, elev=23)
    ax1.bar3d(xgrid, ygrid, zbottom, dx, dy, zheight, color=colors)
    ax1.axes.xaxis.set_major_locator(plt.IndexLocator(
        1, -0.5))  # set x-ticks to integers
    ax1.axes.yaxis.set_major_locator(plt.IndexLocator(
        1, -0.5))  # set y-ticks to integers
    ax1.set_zlim3d([0, max(zheight)])

    # 2d plot
    ax2.matshow(modefunction(data_matrix), cmap=plt.cm.viridis)
    cax, _ = mpl.colorbar.make_axes(
        ax2, shrink=.75, pad=.02)  # add colorbar with normalized range
    _ = mpl.colorbar.ColorbarBase(cax, cmap=plt.cm.viridis, norm=nrm)

    _process_options(fig, ax1, opts=defaults.matrix(), **kwargs)
    return fig, (ax1, ax2)
Esempio n. 2
0
def matrix_skyscraper(matrix, mode='abs', **kwargs):
    """Display a 3d skyscraper plot of the matrix

    Parameters
    ----------
    matrix: ndarray of float or complex
        2d matrix data
    mode: str from `constants.MODE_FUNC_DICT`
        choice of processing function to be applied to data
    **kwargs: dict
        standard plotting option (see separate documentation)

    Returns
    -------
    Figure, Axes
        figure and axes objects for further editing
    """
    fig, axes = kwargs.get('fig_ax') or plt.subplots(projection='3d')

    matsize = len(matrix)
    element_count = matsize**2  # num. of elements to plot

    xgrid, ygrid = np.meshgrid(range(matsize), range(matsize))
    xgrid = xgrid.T.flatten() - 0.5  # center bars on integer value of x-axis
    ygrid = ygrid.T.flatten() - 0.5  # center bars on integer value of y-axis

    zbottom = np.zeros(element_count)  # all bars start at z=0
    dx = 0.75 * np.ones(element_count)  # width of bars in x-direction
    dy = dx  # width of bars in y-direction (same as x-direction)

    modefunction = constants.MODE_FUNC_DICT[mode]
    zheight = modefunction(
        matrix).flatten()  # height of bars from matrix elements
    nrm = mpl.colors.Normalize(
        0, max(zheight))  # <-- normalize colors to max. data
    colors = plt.cm.viridis(nrm(zheight))  # list of colors for each bar

    # skyscraper plot
    axes.view_init(azim=210, elev=23)
    axes.bar3d(xgrid, ygrid, zbottom, dx, dy, zheight, color=colors)
    axes.axes.xaxis.set_major_locator(plt.IndexLocator(
        1, -0.5))  # set x-ticks to integers
    axes.axes.yaxis.set_major_locator(plt.IndexLocator(
        1, -0.5))  # set y-ticks to integers
    axes.set_zlim3d([0, max(zheight)])

    _process_options(fig, axes, opts=defaults.matrix(), **kwargs)
    return fig, axes
Esempio n. 3
0
def us(request):
    #取得網址上GET傳過來的資料
    ticker = request.GET["ticker"]
    r = requests.get(
        'https://www.alphavantage.co/query?function=TIME_SERIES_DAILY&symbol={}&apikey=2C8MUXABNVMED4DS'
        .format(ticker))
    #將JSON變成Python物件,obj變dict, array變list, String變str; 刪除所有","
    s = json.loads(r.text)
    #只取出價格
    df = pd.DataFrame(s['Time Series (Daily)'])
    # column & row互換
    df = df.T
    # 更改column 名稱
    df.columns = ['open', 'high', 'low', 'close', 'volumn']
    # 選出前13行,重新排序,改成float
    df = df.head(13).sort_index().astype(float)
    print("*** padas dataframe ***\n{}".format(df))
    # pyplot畫圖
    plt.plot('close', data=df, marker='o')
    plt.title('{} (Daily)'.format(ticker.upper()))
    plt.grid(True)
    # Get the current Axes instance on the current figure
    ax = plt.gca()
    # Setting X axis tick
    ax.xaxis.set_major_locator(plt.IndexLocator(4, 0))
    plt.savefig("chart/static/images/{}.png".format(ticker))
    url = "http://localhost:8000/static/images/{}.png".format(ticker)
    # 清除圖片
    plt.clf()
    # return圖片
    return HttpResponse(url)
def save_ability_wins_distribution(statistics, ability_wins):

    fig = plt.figure()
    ax = fig.add_subplot(111)

    keys, wins = list(zip(*statistics))  # pylint: disable=W0612

    data = [ability_wins[key] for key in keys]
    ax.boxplot(data)  #, positions=[i for i in xrange(len(keys))])

    ax.set_xlim(0.5, len(statistics) + 0.5)
    ax.set_ylim(0, TEST_BATTLES_NUMBER * len(HERO_LEVELS))

    locator = plt.IndexLocator(1, 0.5)
    formatter = plt.FixedFormatter([s[0] for s in statistics])

    ax.xaxis.set_major_locator(locator)
    ax.xaxis.set_major_formatter(formatter)
    plt.setp(plt.getp(ax, 'xticklabels'),
             rotation=45,
             fontsize=8,
             horizontalalignment='right')

    ax.set_title('wins destribution')

    plt.tight_layout()

    plt.savefig('/tmp/wins_destribution.png')
def save_ability_power_statistics(statistics):
    fig = plt.figure()
    ax = fig.add_subplot(111)

    x = np.arange(len(statistics))

    plt.bar(x, [s[1] for s in statistics], width=0.8, align='center')

    ax.set_xlim(-0.5, len(statistics) + 0.5)
    ax.set_ylim(0, statistics[0][1])

    locator = plt.IndexLocator(1, 0.4)
    formatter = plt.FixedFormatter([s[0] for s in statistics])

    ax.xaxis.set_major_locator(locator)
    ax.xaxis.set_major_formatter(formatter)

    plt.setp(plt.getp(ax, 'xticklabels'),
             rotation=45,
             fontsize=8,
             horizontalalignment='right')

    ax.set_title('Wins per ability')

    plt.tight_layout()

    plt.savefig('/tmp/wins.png')
def save_ability_mathces_statistics(statistics, matches):  # pylint: disable=R0914
    fig = plt.figure()
    ax = fig.add_subplot(111)

    keys, wins = list(zip(*statistics))  # pylint: disable=W0612

    index = dict((key, i) for i, key in enumerate(keys))

    data = []
    for (x, y), (w_1, w_2) in list(matches.items()):
        data.append((index[x], index[y], 1000 * w_1 / float(w_1 + w_2)))
        data.append((index[y], index[x], 1000 * w_2 / float(w_1 + w_2)))

    x, y, powers = list(zip(*data))

    ax.scatter(x, y, s=powers, marker='o', c=powers)

    ax.set_xlim(-0.5, len(statistics) + 0.5)
    ax.set_ylim(-0.5, len(statistics) + 0.5)

    locator = plt.IndexLocator(1, 0)
    formatter = plt.FixedFormatter([s[0] for s in statistics])

    ax.xaxis.set_major_locator(locator)
    ax.xaxis.set_major_formatter(formatter)
    plt.setp(plt.getp(ax, 'xticklabels'),
             rotation=45,
             fontsize=8,
             horizontalalignment='right')

    ax.yaxis.set_major_locator(locator)
    ax.yaxis.set_major_formatter(formatter)
    plt.setp(plt.getp(ax, 'yticklabels'), fontsize=8)

    ax.set_title('matches results')

    plt.tight_layout()

    plt.savefig('/tmp/matches.png')
Esempio n. 7
0
def plot(state,
         *,
         ax=None,
         truncate_levels=None,
         colorbar=True,
         amp_limits=None,
         phase_limits=None):
    """
    Plots the density matrix as a complext 3D histogram.

    Parameters
    ----------
    state : qs.State
        State to display
    ax : matplotlib.axes.Axes or None
        Axes to plot onto. If None, new figure is created and returned.
    truncate_levels : None or int
        If not None, all the states higher than provided are discarded and a
        identity is added to the state instead, so that total trace is
        preserved. This should emulate behaviour of tomografy in the presence
        of leakage.
    colorbar : bool, optional
        If True, a colorbar is created and drawn to the figure axes, by default True
    amp_limits : list or tuple or None
        A list or tuple of two float numbers, corresponding to the lower and upper
        limit of the z-axis, in this case corresponding to the state amplitude.
        If None, the lower limit is set to 0, while the upper one to 1.
    phase_limits : list or tuple or None
        A list or tuple of two float numbers, corresponding to the lower and upper
        limit of phase-axis (the colorbar), in this case corresponding to the complex phase of the state.
        If None, the lower limit is set to :math:`-\\pi`, while the upper one to :math:`\\pi`.

    Returns
    -------
    fig : matplotlib.figure.Figure or None
    """
    import matplotlib.pyplot as plt
    from matplotlib.colors import Normalize
    from matplotlib import colorbar

    if ax is None:
        fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
    else:
        fig = None

    n_qubits = len(state.qubits)
    pv = state.pauli_vector
    _rho = pv.to_dm()
    _rho /= np.trace(_rho)

    if truncate_levels is not None:
        # Tomo emulation: truncate leaked states and add
        rho = (_rho.reshape(
            state.pauli_vector.dim_hilbert *
            2)[(slice(0, truncate_levels), ) * (2 * n_qubits)].reshape(
                truncate_levels**n_qubits, truncate_levels**n_qubits))
        trace = np.trace(rho)
        rho += ((1 - trace) * np.identity(2**n_qubits) *
                truncate_levels**-n_qubits)
        assert np.allclose(np.trace(rho), 1)
        dim = [truncate_levels] * n_qubits
    else:
        rho = _rho
        dim = pv.dim_hilbert

    def tuple_to_string(tup):
        state = ''.join(str(x) for x in tup)
        return r'$\left| %s \right\rangle$' % state

    labels = [tuple_to_string(x) for x in product(*(range(d) for d in dim))]

    if phase_limits and isinstance(phase_limits, (list, tuple)):
        assert len(phase_limits) == 2
    else:
        phase_limits = (-np.pi, np.pi)

    norm = Normalize(*phase_limits)
    cmap = plt.get_cmap('plasma')
    colors = cmap(norm(np.angle(rho.flatten())))

    if amp_limits and isinstance(phase_limits, (list, tuple)):
        assert len(amp_limits) == 2
    else:
        amp_limits = (0, 1)

    xpos, ypos = np.meshgrid(*(range(dim) for dim in rho.shape))
    xpos = xpos.flatten()
    ypos = ypos.flatten()
    zpos = 0
    dx = dy = 0.5 * np.ones(rho.size)
    dz = np.real(rho.flatten())

    ax.bar3d(xpos, ypos, zpos, dx, dy, dz, color=colors)
    ax.set_zlim3d(amp_limits)

    ax.axes.w_xaxis.set_major_locator(plt.IndexLocator(1, 0.25))
    ax.set_xticklabels(labels)

    ax.axes.w_yaxis.set_major_locator(plt.IndexLocator(1, 0.25))
    ax.set_yticklabels(labels)

    plt.setp(ax.get_xticklabels(),
             rotation=45,
             ha="center",
             va='baseline',
             rotation_mode="anchor")

    plt.setp(ax.get_yticklabels(),
             rotation=-45,
             ha="center",
             va='baseline',
             rotation_mode="anchor")

    ax.set_zlabel('Amplitude')

    if colorbar:
        cax, _ = colorbar.make_axes(ax)
        cb = colorbar.ColorbarBase(cax, cmap=cmap, norm=norm)
        cb.set_ticks((-np.pi, 0, np.pi))
        cb.set_ticklabels((r'$-\pi$', r'$0$', r'$\pi$'))
        cb.set_label('Phase')

    return fig
Esempio n. 8
0
sns.lineplot(
    data=scree_df,
    x="n_components",
    y="explained_variance",
    hue="gamma",
    ax=ax,
    marker="o",
    palette=palette,
)
ax.get_legend().remove()
ax.legend(bbox_to_anchor=(1, 1), loc="upper left", title="Gamma")
# ax.legend().set_title("Gamma")
ax.set(ylabel="Cumulative explained variance", xlabel="# of PCs")
ax.yaxis.set_major_locator(plt.MaxNLocator(3))
ax.xaxis.set_major_locator(plt.IndexLocator(base=5, offset=-1))
stashfig("screeplot")

#%%
fig, ax = plt.subplots(1, 1, figsize=(8, 4))

sns.lineplot(
    data=scree_df,
    x="n_nonzero",
    y="explained_variance",
    hue="gamma",
    ax=ax,
    marker="o",
    palette=palette,
)
ax.get_legend().remove()
Esempio n. 9
0
def matrix_histogram(M,
                     xlabels=None,
                     ylabels=None,
                     title=None,
                     limits=None,
                     colorbar=True,
                     fig=None,
                     ax=None):
    """
	Draw a histogram for the matrix M, with the given x and y labels and title.

	Parameters
	----------
	M : Matrix of Qobj
		The matrix to visualize

	xlabels : list of strings
		list of x labels

	ylabels : list of strings
		list of y labels

	title : string
		title of the plot (optional)

	limits : list/array with two float numbers
		The z-axis limits [min, max] (optional)

	ax : a matplotlib axes instance
		The axes context in which the plot will be drawn.

	Returns
	-------
	fig, ax : tuple
		A tuple of the matplotlib figure and axes instances used to produce
		the figure.

	Raises
	------
	ValueError
		Input argument is not valid.

	"""
    def make_bar(ax,
                 x0=0,
                 y0=0,
                 width=0.5,
                 height=1,
                 cmap="viridis",
                 norm=matplotlib.colors.Normalize(vmin=0, vmax=1),
                 **kwargs):
        # Make data
        u = np.linspace(0, 2 * np.pi, 4 + 1) + np.pi / 4.
        v_ = np.linspace(np.pi / 4., 3. / 4 * np.pi, 100)
        v = np.linspace(0, np.pi, len(v_) + 2)
        v[0] = 0
        v[-1] = np.pi
        v[1:-1] = v_
        x = np.outer(np.cos(u), np.sin(v))
        y = np.outer(np.sin(u), np.sin(v))
        z = np.outer(np.ones(np.size(u)), np.cos(v))

        xthr = np.sin(np.pi / 4.)**2
        zthr = np.sin(np.pi / 4.)
        x[x > xthr] = xthr
        x[x < -xthr] = -xthr
        y[y > xthr] = xthr
        y[y < -xthr] = -xthr
        z[z > zthr] = zthr
        z[z < -zthr] = -zthr

        x *= 1. / xthr * width
        y *= 1. / xthr * width
        z += zthr
        z *= height / (2. * zthr)
        # translate
        x += x0
        y += y0
        # plot
        ax.plot_surface(x, y, z, cmap=cmap, norm=norm, **kwargs)

    def frame(ax, M):
        M = M * 0.8
        ### Frame
        n = np.size(M)
        xpos, ypos = np.meshgrid(np.arange(M.shape[0]), np.arange(M.shape[1]))

        xpos = xpos.T.flatten() - 0.05  # + 0.5 #- 0.5
        ypos = ypos.T.flatten() - 0.05  # + 0.1
        zpos = np.zeros(n)  # + 0.1
        dx = dy = 0.8 * np.ones(n)
        dz = np.real(M.flatten())

        if limits and type(limits) is list and len(limits) == 2:
            z_min = limits[0]
            z_max = limits[1]
        else:
            z_min = min(dz)
            z_max = max(dz)
            if z_min == z_max:
                z_min -= 0.1
                z_max += 0.1

        norm = mpl.colors.Normalize(z_min, z_max)
        cmap = cm.get_cmap('Greys')  # Spectral jet 'RdBu'
        # cmap = cm.get_cmap('GnBu')
        colors = cmap(norm(np.ones_like(dz) * -1), alpha=0.5)

        ax.bar3d(xpos,
                 ypos,
                 zpos,
                 dx,
                 dy,
                 dz,
                 alpha=0.1,
                 color=colors,
                 edgecolor='black',
                 linewidth=0.7,
                 shade=True,
                 zorder=1)
        ax.set_alpha(0.5)
        return

    # ax.set_facecolor('white')
    # ax.set_alpha(0.5)
    if isinstance(M, Qobj):
        # extract matrix data from Qobj
        M = M.full()

    n = np.size(M)
    if ax is None:
        fig = plt.figure(figsize=(8, 6))
        # change seeing angle
        ax = fig.add_subplot(1, 1, 1, projection='3d', azim=-36, elev=36)
    # ax = Axes3D(fig, azim=-36, elev=36)

    xpos, ypos = np.meshgrid(np.arange(M.shape[0]), np.arange(M.shape[1]))

    xpos = xpos.T.flatten()  # + 0.5 #- 0.5
    ypos = ypos.T.flatten()  # + 0.1
    zpos = np.zeros(n)  # + 0.1
    dx = dy = 0.7 * np.ones(n)
    dz = np.real(M.flatten())

    if limits and type(limits) is list and len(limits) == 2:
        z_min = limits[0]
        z_max = limits[1]
    else:
        z_min = min(dz)
        z_max = max(dz)
        if z_min == z_max:
            z_min -= 0.1
            z_max += 0.1

    #norm = mpl.colors.Normalize(z_min, z_max)
    norm = mpl.colors.Normalize(-1, 1)
    cmap = cm.get_cmap('RdBu')  # Spectral jet 'RdBu'
    # cmap = cm.get_cmap('GnBu')
    colors = cmap(norm(dz), alpha=1)

    ax.bar3d(xpos,
             ypos,
             zpos,
             dx,
             dy,
             dz,
             alpha=1,
             color=colors,
             edgecolor='black',
             linewidth=0.4,
             shade=True)
    #frame(ax=ax, M=M)
    if title and fig:
        ax.set_title(title)

    # '''Here change the bar location'''
    # x axis
    ax.axes.w_xaxis.set_major_locator(plt.IndexLocator(1, 0.35))
    if xlabels:
        # ticksx = np.arange(M.shape[0])
        # plt.xticks(ticksx, xlabels)

        ax.set_xticklabels(xlabels)
    ax.tick_params(axis='x', labelsize=10, rotation=45)

    # y axis
    ax.axes.w_yaxis.set_major_locator(plt.IndexLocator(1, 0.35))
    if ylabels:
        # ticksy = np.arange(M.shape[1])
        # plt.yticks(ticksy, ylabels)
        ax.set_yticklabels(ylabels, rotation=45)
    ax.tick_params(axis='y', labelsize=10, rotation=-45)

    # z axis
    ax.axes.w_zaxis.set_major_locator(plt.IndexLocator(0.5, -1))
    #ax.set_zlim3d([min(z_min, 0), z_max])
    ax.set_zlim3d([-1, 1])

    # ax.set_title('test')

    # color axis
    if colorbar:
        cax, kw = mpl.colorbar.make_axes(ax, shrink=.75, pad=.1)
        mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm)

    return fig, ax
Esempio n. 10
0
def matrix_histogram_complex(M,
                             xlabels=None,
                             ylabels=None,
                             title=None,
                             limits=None,
                             phase_limits=None,
                             colorbar=True,
                             fig=None,
                             ax=None,
                             threshold=None):
    def complex_phase_cmap():
        cdict = {
            'blue': ((0.00, 0.0, 0.0), (0.25, 0.0, 0.0), (0.50, 1.0, 1.0),
                     (0.75, 1.0, 1.0), (1.00, 0.0, 0.0)),
            'green': ((0.00, 0.0, 0.0), (0.25, 1.0, 1.0), (0.50, 0.0, 0.0),
                      (0.75, 1.0, 1.0), (1.00, 0.0, 0.0)),
            'red': ((0.00, 1.0, 1.0), (0.25, 0.5, 0.5), (0.50, 0.0, 0.0),
                    (0.75, 0.0, 0.0), (1.00, 1.0, 1.0))
        }

        cmap = mpl.colors.LinearSegmentedColormap('phase_colormap', cdict, 256)
        return cmap

    if isinstance(M, Qobj):
        # extract matrix data from Qobj
        M = M.full()

    n = np.size(M)
    xpos, ypos = np.meshgrid(range(M.shape[0]), range(M.shape[1]))
    xpos = xpos.T.flatten() - 0.5
    ypos = ypos.T.flatten() - 0.5
    zpos = np.zeros(n)
    dx = dy = 0.3 * np.ones(n)
    Mvec = M.flatten()
    dz = abs(Mvec)

    # make small numbers real, to avoid random colors
    idx, = np.where(abs(Mvec) < 0.001)
    Mvec[idx] = abs(Mvec[idx])

    if phase_limits:  # check that limits is a list type
        phase_min = phase_limits[0]
        phase_max = phase_limits[1]
    else:
        phase_min = -pi
        phase_max = pi

    norm = mpl.colors.Normalize(phase_min, phase_max)
    cmap = complex_phase_cmap()

    colors = cmap(norm(angle(Mvec)))
    if threshold is not None:
        colors[:, 3] = 1 * (dz > threshold)

    if ax is None:
        fig = plt.figure()
        ax = Axes3D(fig, azim=-35, elev=35)

    ax.bar3d(xpos, ypos, zpos, dx, dy, dz, color=colors)

    if title and fig:
        ax.set_title(title)

    # x axis
    ax.axes.w_xaxis.set_major_locator(plt.IndexLocator(1, -0.5))
    if xlabels:
        ax.set_xticklabels(xlabels)
    ax.tick_params(axis='x', labelsize=12)

    # y axis
    ax.axes.w_yaxis.set_major_locator(plt.IndexLocator(1, -0.5))
    if ylabels:
        ax.set_yticklabels(ylabels)
    ax.tick_params(axis='y', labelsize=12)

    # z axis
    if limits and isinstance(limits, list):
        ax.set_zlim3d(limits)
    else:
        ax.set_zlim3d([0, 1])  # use min/max
    # ax.set_zlabel('abs')

    # color axis
    if colorbar:
        cax, kw = mpl.colorbar.make_axes(ax, shrink=.75, pad=.0)
        cb = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm)
        cb.set_ticks([-pi, -pi / 2, 0, pi / 2, pi])
        cb.set_ticklabels(
            (r'$-\pi$', r'$-\pi/2$', r'$0$', r'$\pi/2$', r'$\pi$'))
        cb.set_label('arg')

    return fig, ax
Esempio n. 11
0
def plot_density_matrix(M,
                        xlabels=None,
                        ylabels=None,
                        title=None,
                        limits=None,
                        phase_limits=None,
                        fig=None,
                        axis_vals=None,
                        threshold=None):

    # if isinstance(M, Qobj):
    # extract matrix data from Qobj
    # M = M.full()
    # M = M.toarray(order='C')

    index_array = [0.2, 0.4, 0.6, 0.8, 1.0]

    key = 0.0
    for matrix in M:
        for value in matrix:
            if (key < abs(value)):
                key = abs(value)

    z_axis_limit = index_array[(int)(key / 0.2)]

    n = np.size(M)
    position_x, position_y = np.meshgrid(range(M.shape[0]), range(M.shape[1]))
    position_x = position_x.T.flatten() - 0.5
    position_y = position_y.T.flatten() - 0.5
    zpos = np.zeros(n)
    dx = dy = 0.8 * np.ones(n)
    vectors = M.flatten()
    dz = abs(vectors)

    # make small numbers real, to avoid random colors
    idx, = np.where(abs(vectors) < 0.001)
    vectors[idx] = abs(vectors[idx])

    if phase_limits:  # check that limits is a list type
        phase_min = phase_limits[0]
        phase_max = phase_limits[1]
    else:
        phase_min = -pi
        phase_max = pi

    norm = mpl.colors.Normalize(phase_min, phase_max)
    cmap = complex_phase_cmap()

    colors = cmap(norm(np.angle(vectors)))
    if threshold is not None:
        colors[:, 3] = 1 * (dz > threshold)

    if axis_vals is None:
        fig = plt.figure()
        axis_vals = Axes3D(fig, azim=-35, elev=35)

    axis_vals.bar3d(position_x, position_y, zpos, dx, dy, dz, color=colors)

    if title and fig:
        axis_vals.set_title(title)

    # x axis
    axis_vals.axes.w_xaxis.set_major_locator(plt.IndexLocator(1, -0.5))
    if xlabels:
        axis_vals.set_xticklabels(xlabels)
    axis_vals.tick_params(axis='x', labelsize=12)

    # y axis
    axis_vals.axes.w_yaxis.set_major_locator(plt.IndexLocator(1, -0.5))
    if ylabels:
        axis_vals.set_yticklabels(ylabels)
    axis_vals.tick_params(axis='y', labelsize=12)

    # z axis
    if limits and isinstance(limits, list):
        axis_vals.set_zlim3d(limits)
    else:
        axis_vals.set_zlim3d([0, z_axis_limit])  # use min/max
    # axis_vals.set_zlabel('abs')

    cax, kw = mpl.colorbar.make_axes(axis_vals, shrink=.75, pad=.0)
    cb = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm)
    cb.set_ticks([-pi, -pi / 2, 0, pi / 2, pi])
    cb.set_ticklabels((r'$-\pi$', r'$-\pi/2$', r'$0$', r'$\pi/2$', r'$\pi$'))
    cb.set_label('arg', rotation='horizontal')

    plt.show()
    # plt.ylabel('arg',rotation=)
    return fig, axis_vals
Esempio n. 12
0
def matrix_gradient(M,
                    xlabels=None,
                    ylabels=None,
                    title=None,
                    limits=None,
                    colorbar=True,
                    fig=None,
                    ax=None):
    def make_bar(ax,
                 x0=0,
                 y0=0,
                 width=0.5,
                 height=1,
                 cmap="jet",
                 norm=matplotlib.colors.Normalize(vmin=-1, vmax=1),
                 **kwargs):
        # Make data
        u = np.linspace(0, 2 * np.pi, 4 + 1) + np.pi / 4.
        v_ = np.linspace(np.pi / 4., 3. / 4 * np.pi, 100)
        v = np.linspace(0, np.pi, len(v_) + 2)
        v[0] = 0
        v[-1] = np.pi
        v[1:-1] = v_
        x = np.outer(np.cos(u), np.sin(v))
        y = np.outer(np.sin(u), np.sin(v))
        z = np.outer(np.ones(np.size(u)), np.cos(v))

        xthr = np.sin(np.pi / 4.)**2
        zthr = np.sin(np.pi / 4.)
        x[x > xthr] = xthr
        x[x < -xthr] = -xthr
        y[y > xthr] = xthr
        y[y < -xthr] = -xthr
        z[z > zthr] = zthr
        z[z < -zthr] = -zthr

        x *= 1. / xthr * width
        y *= 1. / xthr * width
        z += zthr
        z *= height / (2. * zthr)
        # translate
        x += x0
        y += y0
        # plot
        ax.plot_surface(x, y, z, cmap=cmap, norm=norm, **kwargs)

        # xpos = [x0 - width,x0 - width,x0 + width,x0 + width]
        # xpos = xpos - np.ones_like(xpos) * 0.15
        # ypos = [y0 - width , y0 + width, y0 - width , y0 + width ]
        # ypos = ypos - np.ones_like(ypos) * 0.15
        # zpos = np.zeros(4)
        # ones_type = np.ones_like(np.array(xpos))
        # ax.bar3d(xpos, ypos, zpos, dx = ones_type*0.1, dy= ones_type*0.1, dz = ones_type*height, alpha=1, color='black')

    def frame(ax, M):
        M = M
        ### Frame
        n = np.size(M)
        xpos, ypos = np.meshgrid(np.arange(M.shape[0]), np.arange(M.shape[1]))

        xpos = xpos.T.flatten() - 0.3  # + 0.5 #- 0.5
        ypos = ypos.T.flatten() - 0.3  # + 0.1
        zpos = np.zeros(n)  # + 0.1
        dx = dy = 0.6 * np.ones(n)
        dz = np.real(M.flatten())

        if limits and type(limits) is list and len(limits) == 2:
            z_min = limits[0]
            z_max = limits[1]
        else:
            z_min = min(dz)
            z_max = max(dz)
            if z_min == z_max:
                z_min -= 0.1
                z_max += 0.1

        norm = mpl.colors.Normalize(z_min, z_max)
        cmap = cm.get_cmap('jet')  # Spectral jet 'RdBu'
        # cmap = cm.get_cmap('GnBu')
        colors = cmap(norm(dz), alpha=0.2)

        ax.bar3d(xpos,
                 ypos,
                 zpos,
                 dx,
                 dy,
                 dz,
                 alpha=1,
                 color=colors,
                 edgecolor='black',
                 linewidth=0.8)
        return

    def make_bars(ax, x, y, height, width=0.3):
        widths = np.array(width) * np.ones_like(x)
        x = np.array(x).flatten()
        y = np.array(y).flatten()

        h = np.array(height).flatten()
        w = np.array(widths).flatten()
        norm = matplotlib.colors.Normalize(vmin=h.min(), vmax=h.max())
        for i in range(len(x.flatten())):
            make_bar(ax, x0=x[i], y0=y[i], width=w[i], height=h[i], norm=norm)

    # ax.set_facecolor('white')
    # ax.set_alpha(0.5)
    if isinstance(M, Qobj):
        # extract matrix data from Qobj
        M = M.full()

    n = np.size(M)
    if ax is None:
        fig = plt.figure(figsize=(8, 6))
        # change seeing angle
        ax = fig.add_subplot(1, 1, 1, projection='3d', azim=-36, elev=36)
    # ax = Axes3D(fig, azim=-36, elev=36)

    xpos, ypos = np.meshgrid(np.arange(M.shape[0]), np.arange(M.shape[1]))

    xpos = xpos.T.flatten()  # + 0.5 #- 0.5
    ypos = ypos.T.flatten()  # + 0.1
    zpos = np.zeros(n)  # + 0.1
    dx = dy = 0.6 * np.ones(n)
    dz = np.real(M.flatten())

    if limits and type(limits) is list and len(limits) == 2:
        z_min = limits[0]
        z_max = limits[1]
    else:
        z_min = min(dz)
        z_max = max(dz)
        if z_min == z_max:
            z_min -= 0.1
            z_max += 0.1

    #norm = mpl.colors.Normalize(z_min, z_max)
    norm = mpl.colors.Normalize(-1, 1)
    cmap = cm.get_cmap('jet')  # Spectral jet 'RdBu'
    # cmap = cm.get_cmap('GnBu')
    colors = cmap(norm(dz), alpha=0.9)

    make_bars(ax, x=xpos, y=ypos, height=dz, width=0.3)
    # frame(ax=ax, M=M)

    if title and fig:
        ax.set_title(title)

    # '''Here change the bar location'''
    # x axis
    ax.axes.w_xaxis.set_major_locator(plt.IndexLocator(1, 0.3))
    if xlabels:
        # ticksx = np.arange(M.shape[0])
        # plt.xticks(ticksx, xlabels)

        ax.set_xticklabels(xlabels)
    ax.tick_params(axis='x', labelsize=10, rotation=45)

    # y axis
    ax.axes.w_yaxis.set_major_locator(plt.IndexLocator(1, 0.3))
    if ylabels:
        # ticksy = np.arange(M.shape[1])
        # plt.yticks(ticksy, ylabels)
        ax.set_yticklabels(ylabels)  # ,rotation=-90)
    ax.tick_params(axis='y', labelsize=10, rotation=-45)

    # z axis
    ax.axes.w_zaxis.set_major_locator(plt.IndexLocator(0.5, -1))
    #ax.set_zlim3d([min(z_min, 0), z_max])
    ax.set_zlim3d([-1, 1])

    # ax.set_title('test')

    # color axis
    if colorbar:
        cax, kw = mpl.colorbar.make_axes(ax, shrink=.75, pad=.1)
        mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm)

    return fig, ax
Esempio n. 13
0
def matrix_histogram_complex(M,
                             xlabels=None,
                             ylabels=None,
                             title=None,
                             limits=None,
                             phase_limits=None,
                             colorbar=True,
                             fig=None,
                             ax=None,
                             threshold=None):
    """
    Draw a histogram for the amplitudes of matrix M, using the argument
    of each element for coloring the bars, with the given x and y labels
    and title.

    Parameters
    ----------
    M : Matrix of Qobj
        The matrix to visualize

    xlabels : list of strings
        list of x labels

    ylabels : list of strings
        list of y labels

    title : string
        title of the plot (optional)

    limits : list/array with two float numbers
        The z-axis limits [min, max] (optional)

    phase_limits : list/array with two float numbers
        The phase-axis (colorbar) limits [min, max] (optional)

    ax : a matplotlib axes instance
        The axes context in which the plot will be drawn.

    threshold: float (None)
        Threshold for when bars of smaller height should be transparent. If
        not set, all bars are colored according to the color map.

    Returns
    -------
    fig, ax : tuple
        A tuple of the matplotlib figure and axes instances used to produce
        the figure.

    Raises
    ------
    ValueError
        Input argument is not valid.

    """

    if isinstance(M, Qobj):
        # extract matrix data from Qobj
        M = M.full()

    n = np.size(M)
    xpos, ypos = np.meshgrid(range(M.shape[0]), range(M.shape[1]))
    xpos = xpos.T.flatten() - 0.5
    ypos = ypos.T.flatten() - 0.5
    zpos = np.zeros(n)
    dx = dy = 0.8 * np.ones(n)
    Mvec = M.flatten()
    dz = abs(Mvec)

    # make small numbers real, to avoid random colors
    idx, = np.where(abs(Mvec) < 0.001)
    Mvec[idx] = abs(Mvec[idx])

    if phase_limits:  # check that limits is a list type
        phase_min = phase_limits[0]
        phase_max = phase_limits[1]
    else:
        phase_min = -pi
        phase_max = pi

    norm = mpl.colors.Normalize(phase_min, phase_max)
    cmap = complex_phase_cmap()

    colors = cmap(norm(angle(Mvec)))
    if threshold is not None:
        colors[:, 3] = 1 * (dz > threshold)

    if ax is None:
        fig = plt.figure()
        ax = Axes3D(fig, azim=-35, elev=35)

    ax.bar3d(xpos, ypos, zpos, dx, dy, dz, color=colors)

    if title and fig:
        ax.set_title(title)

    # x axis
    ax.axes.w_xaxis.set_major_locator(plt.IndexLocator(1, -0.5))
    if xlabels:
        ax.set_xticklabels(xlabels)
    ax.tick_params(axis='x', labelsize=12)

    # y axis
    ax.axes.w_yaxis.set_major_locator(plt.IndexLocator(1, -0.5))
    if ylabels:
        ax.set_yticklabels(ylabels)
    ax.tick_params(axis='y', labelsize=12)

    # z axis
    if limits and isinstance(limits, list):
        ax.set_zlim3d(limits)
    else:
        ax.set_zlim3d([0, 1])  # use min/max
    # ax.set_zlabel('abs')

    # color axis
    if colorbar:
        cax, kw = mpl.colorbar.make_axes(ax, shrink=.75, pad=.0)
        cb = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm)
        cb.set_ticks([-pi, -pi / 2, 0, pi / 2, pi])
        cb.set_ticklabels(
            (r'$-\pi$', r'$-\pi/2$', r'$0$', r'$\pi/2$', r'$\pi$'))
        cb.set_label('arg')

    return fig, ax
Esempio n. 14
0
def hinton_real(matrix: np.ndarray,
                max_weight: float = None,
                xlabels: List[str] = None,
                ylabels: List[str] = None,
                title: str = None,
                ax=None,
                cmap=None,
                label_top: bool = True):
    '''
    Draw Hinton diagram for visualizing a real valued weight matrix.

    In the traditional Hinton diagram positive and negative values are represented by white and
    black squares respectively. The size of each square represents the magnitude of each value.
    The traditional Hinton diagram can be recovered by setting cmap = cm.Greys_r.

    :param matrix: The matrix to be visualized.
    :param max_weight: normalize size to this scalar.
    :param xlabels: The labels for the operator basis.
    :param ylabels: The labels for the operator basis.
    :param title: The title for the plot.
    :param ax: The matplotlib axes.
    :param cmap: A matplotlib colormap to use when plotting.
    :param label_top: If True, x-axis labels will be placed on top, otherwise they will appear
    below the plot.
    :return: A tuple of the matplotlib figure and axes instances used to produce the figure.
    '''
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(8, 6))

    # Grays in increasing darkness: smokewhite, gainsboro, lightgrey, lightgray, silver
    backgnd_gray = 'gainsboro'

    if cmap is None:
        cmap = cm.RdBu
        cneg = cmap(0)
        cpos = cmap(256)
        cmap = mpl.colors.ListedColormap([cneg, backgnd_gray, cpos])
    else:
        cneg = cmap(0)
        cpos = cmap(256)
        cmap = mpl.colors.ListedColormap([cneg, backgnd_gray, cpos])

    if title and fig:
        ax.set_title(title, y=1.1, fontsize=18)

    ax.set_aspect('equal', 'box')
    ax.set_frame_on(False)

    height, width = matrix.shape
    if max_weight is None:
        max_weight = 1.25 * max(abs(np.diag(np.matrix(matrix))))
        if max_weight <= 0.0:
            max_weight = 1.0

    bounds = [-max_weight, -0.0001, 0.0001, max_weight]
    tick_loc = [-max_weight / 2, 0, max_weight / 2]

    ax.fill(np.array([0, width, width, 0]),
            np.array([0, 0, height, height]),
            color=cmap(1))
    for x in range(width):
        for y in range(height):
            _x = x + 1
            _y = y + 1
            if np.real(matrix[x, y]) > 0.0:
                _blob(_x - 0.5,
                      height - _y + 0.5,
                      abs(matrix[x, y]),
                      max_weight,
                      min(1,
                          abs(matrix[x, y]) / max_weight),
                      cmap=cmap(2))
            else:
                _blob(_x - 0.5,
                      height - _y + 0.5,
                      -abs(matrix[x, y]),
                      max_weight,
                      min(1,
                          abs(matrix[x, y]) / max_weight),
                      cmap=cmap(0))

    norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
    cax, kw = mpl.colorbar.make_axes(ax, shrink=0.75, pad=.1)
    mpl.colorbar.ColorbarBase(cax,
                              norm=norm,
                              cmap=cmap,
                              boundaries=bounds,
                              ticks=tick_loc).set_ticklabels(
                                  ['$-$', '$0$', '$+$'])
    cax.tick_params(labelsize=14)
    # x axis
    ax.xaxis.set_major_locator(plt.IndexLocator(1, 0.5))
    if xlabels:
        ax.set_xticklabels(xlabels)
        if label_top:
            ax.xaxis.tick_top()
    ax.tick_params(axis='x', labelsize=14)
    # y axis
    ax.yaxis.set_major_locator(plt.IndexLocator(1, 0.5))
    if ylabels:
        ax.set_yticklabels(list(reversed(ylabels)))
    ax.tick_params(axis='y', labelsize=14)

    return fig, ax
Esempio n. 15
0
def matrix_histogram(M,
                     xlabels=None,
                     ylabels=None,
                     title=None,
                     limits=None,
                     colorbar=True,
                     fig=None,
                     ax=None):
    """
    Draw a histogram for the matrix M, with the given x and y labels and title.

    Parameters
    ----------
    M : Matrix of Qobj
        The matrix to visualize

    xlabels : list of strings
        list of x labels

    ylabels : list of strings
        list of y labels

    title : string
        title of the plot (optional)

    limits : list/array with two float numbers
        The z-axis limits [min, max] (optional)

    ax : a matplotlib axes instance
        The axes context in which the plot will be drawn.

    Returns
    -------
    fig, ax : tuple
        A tuple of the matplotlib figure and axes instances used to produce
        the figure.

    Raises
    ------
    ValueError
        Input argument is not valid.

    """

    if isinstance(M, Qobj):
        # extract matrix data from Qobj
        M = M.full()

    n = np.size(M)
    xpos, ypos = np.meshgrid(range(M.shape[0]), range(M.shape[1]))
    xpos = xpos.T.flatten() - 0.5
    ypos = ypos.T.flatten() - 0.5
    zpos = np.zeros(n)
    dx = dy = 0.8 * np.ones(n)
    dz = np.real(M.flatten())

    if limits and type(limits) is list and len(limits) == 2:
        z_min = limits[0]
        z_max = limits[1]
    else:
        z_min = min(dz)
        z_max = max(dz)
        if z_min == z_max:
            z_min -= 0.1
            z_max += 0.1

    norm = mpl.colors.Normalize(z_min, z_max)
    cmap = cm.get_cmap('jet')  # Spectral
    colors = cmap(norm(dz))

    if ax is None:
        fig = plt.figure()
        ax = Axes3D(fig, azim=-35, elev=35)

    ax.bar3d(xpos, ypos, zpos, dx, dy, dz, color=colors)

    if title and fig:
        ax.set_title(title)

    # x axis
    ax.axes.w_xaxis.set_major_locator(plt.IndexLocator(1, -0.5))
    if xlabels:
        ax.set_xticklabels(xlabels)
    ax.tick_params(axis='x', labelsize=14)

    # y axis
    ax.axes.w_yaxis.set_major_locator(plt.IndexLocator(1, -0.5))
    if ylabels:
        ax.set_yticklabels(ylabels)
    ax.tick_params(axis='y', labelsize=14)

    # z axis
    ax.axes.w_zaxis.set_major_locator(plt.IndexLocator(1, 0.5))
    ax.set_zlim3d([min(z_min, 0), z_max])

    # color axis
    if colorbar:
        cax, kw = mpl.colorbar.make_axes(ax, shrink=.75, pad=.0)
        mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm)

    return fig, ax
Esempio n. 16
0
def hinton(rho,
           xlabels=None,
           ylabels=None,
           title=None,
           ax=None,
           cmap=None,
           label_top=True):
    """Draws a Hinton diagram for visualizing a density matrix or superoperator.

    Parameters
    ----------
    rho : qobj
        Input density matrix or superoperator.

    xlabels : list of strings or False
        list of x labels

    ylabels : list of strings or False
        list of y labels

    title : string
        title of the plot (optional)

    ax : a matplotlib axes instance
        The axes context in which the plot will be drawn.

    cmap : a matplotlib colormap instance
        Color map to use when plotting.

    label_top : bool
        If True, x-axis labels will be placed on top, otherwise
        they will appear below the plot.

    Returns
    -------
    fig, ax : tuple
        A tuple of the matplotlib figure and axes instances used to produce
        the figure.

    Raises
    ------
    ValueError
        Input argument is not a quantum object.

    """

    # Apply default colormaps.
    # TODO: abstract this away into something that makes default
    #       colormaps.
    cmap = ((cm.Greys_r if settings.colorblind_safe else cm.RdBu)
            if cmap is None else cmap)

    # Extract plotting data W from the input.
    if isinstance(rho, Qobj):
        if rho.isoper:
            W = rho.full()

            # Create default labels if none are given.
            if xlabels is None or ylabels is None:
                labels = _cb_labels(rho.dims[0])
                xlabels = xlabels if xlabels is not None else list(labels[0])
                ylabels = ylabels if ylabels is not None else list(labels[1])

        elif rho.isoperket:
            W = vector_to_operator(rho).full()
        elif rho.isoperbra:
            W = vector_to_operator(rho.dag()).full()
        elif rho.issuper:
            if not _isqubitdims(rho.dims):
                raise ValueError("Hinton plots of superoperators are "
                                 "currently only supported for qubits.")
            # Convert to a superoperator in the Pauli basis,
            # so that all the elements are real.
            sqobj = _super_to_superpauli(rho)
            nq = int(log2(sqobj.shape[0]) / 2)
            W = sqobj.full().T
            # Create default labels, too.
            if (xlabels is None) or (ylabels is None):
                labels = list(map("".join, it.product("IXYZ", repeat=nq)))
                xlabels = xlabels if xlabels is not None else labels
                ylabels = ylabels if ylabels is not None else labels

        else:
            raise ValueError(
                "Input quantum object must be an operator or superoperator.")

    else:
        W = rho

    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    else:
        fig = None

    if not (xlabels or ylabels):
        ax.axis('off')

    ax.axis('equal')
    ax.set_frame_on(False)

    height, width = W.shape

    w_max = 1.25 * max(abs(np.diag(np.matrix(W))))
    if w_max <= 0.0:
        w_max = 1.0

    ax.fill(array([0, width, width, 0]),
            array([0, 0, height, height]),
            color=cmap(128))
    for x in range(width):
        for y in range(height):
            _x = x + 1
            _y = y + 1
            if np.real(W[x, y]) > 0.0:
                _blob(_x - 0.5,
                      height - _y + 0.5,
                      abs(W[x, y]),
                      w_max,
                      min(1,
                          abs(W[x, y]) / w_max),
                      cmap=cmap)
            else:
                _blob(_x - 0.5,
                      height - _y + 0.5,
                      -abs(W[x, y]),
                      w_max,
                      min(1,
                          abs(W[x, y]) / w_max),
                      cmap=cmap)

    # color axis
    norm = mpl.colors.Normalize(-abs(W).max(), abs(W).max())
    cax, kw = mpl.colorbar.make_axes(ax, shrink=0.75, pad=.1)
    mpl.colorbar.ColorbarBase(cax, norm=norm, cmap=cmap)

    # x axis
    ax.xaxis.set_major_locator(plt.IndexLocator(1, 0.5))

    if xlabels:
        ax.set_xticklabels(xlabels)
        if label_top:
            ax.xaxis.tick_top()
    ax.tick_params(axis='x', labelsize=14)

    # y axis
    ax.yaxis.set_major_locator(plt.IndexLocator(1, 0.5))
    if ylabels:
        ax.set_yticklabels(list(reversed(ylabels)))
    ax.tick_params(axis='y', labelsize=14)

    return fig, ax
Esempio n. 17
0
def hinton(rho, xlabels=None, ylabels=None, title=None, ax=None):
    """Draws a Hinton diagram for visualizing a density matrix.

    Parameters
    ----------
    rho : qobj
        Input density matrix.

    xlabels : list of strings
        list of x labels

    ylabels : list of strings
        list of y labels

    title : string
        title of the plot (optional)

    ax : a matplotlib axes instance
        The axes context in which the plot will be drawn.

    Returns
    -------
    fig, ax : tuple
        A tuple of the matplotlib figure and axes instances used to produce
        the figure.

    Raises
    ------
    ValueError
        Input argument is not a quantum object.

    """

    if isinstance(rho, Qobj):
        if isket(rho) or isbra(rho):
            raise ValueError("argument must be a quantum operator")

        W = rho.full()
    else:
        W = rho

    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    else:
        fig = None

    if not (xlabels or ylabels):
        ax.axis('off')

    ax.axis('equal')
    ax.set_frame_on(False)

    height, width = W.shape

    w_max = 1.25 * max(abs(np.diag(np.matrix(W))))
    if w_max <= 0.0:
        w_max = 1.0

    # x axis
    ax.xaxis.set_major_locator(plt.IndexLocator(1, 0.5))
    if xlabels:
        ax.set_xticklabels(xlabels)
    ax.tick_params(axis='x', labelsize=14)

    # y axis
    ax.yaxis.set_major_locator(plt.IndexLocator(1, 0.5))
    if ylabels:
        ax.set_yticklabels(list(reversed(ylabels)))
    ax.tick_params(axis='y', labelsize=14)

    ax.fill(array([0, width, width, 0]), array([0, 0, height, height]),
            color=cm.RdBu(128))
    for x in range(width):
        for y in range(height):
            _x = x + 1
            _y = y + 1
            if np.real(W[x, y]) > 0.0:
                _blob(_x - 0.5, height - _y + 0.5, abs(W[x,
                      y]), w_max, min(1, abs(W[x, y]) / w_max))
            else:
                _blob(_x - 0.5, height - _y + 0.5, -abs(W[
                      x, y]), w_max, min(1, abs(W[x, y]) / w_max))

    # color axis
    norm = mpl.colors.Normalize(-abs(W).max(), abs(W).max())
    cax, kw = mpl.colorbar.make_axes(ax, shrink=0.75, pad=.1)
    cb = mpl.colorbar.ColorbarBase(cax, norm=norm, cmap=cm.RdBu)

    return fig, ax
Esempio n. 18
0
def plotMat(M,
            xlabels=None,
            ylabels=None,
            title=None,
            limits=None,
            colorbar=True,
            fig=None,
            ax=None,
            cmap='RdBu'):
    from mpl_toolkits.mplot3d import Axes3D
    from matplotlib import cm
    import matplotlib as mpl
    cmap = cm.get_cmap(cmap)
    n = np.size(M)
    xpos, ypos = np.meshgrid(range(M.shape[0]), range(M.shape[1]))
    bar_width = 0.6
    xpos = xpos.T.flatten() - bar_width / 2.
    ypos = ypos.T.flatten() - bar_width / 2.
    zpos = np.zeros(n)
    dx = dy = bar_width * np.ones(n)
    dz = np.real(M.flatten())

    if limits and type(limits) is list and len(limits) == 2:
        z_min = limits[0]
        z_max = limits[1]
    else:
        z_min = min(dz)
        z_max = max(dz)
        if z_min == z_max:
            z_min -= 0.1
            z_max += 0.1

    norm = mpl.colors.Normalize(z_min, z_max)

    colors = cmap(norm(dz))

    if ax is None:
        fig = plt.figure()
        ax = Axes3D(fig, azim=-37, elev=59)

    ax.bar3d(xpos,
             ypos,
             zpos,
             dx,
             dy,
             dz,
             color=colors,
             edgecolors='white',
             linewidth=1,
             zsort='max')

    if title and fig:
        ax.set_title(title)

    # x axis
    ax.axes.w_xaxis.set_major_locator(plt.IndexLocator(1, bar_width / 2.))
    if xlabels:
        ax.set_xticklabels(xlabels)
    ax.tick_params(axis='x', labelsize=14)

    # y axis
    ax.axes.w_yaxis.set_major_locator(plt.IndexLocator(1, bar_width / 2.))
    if ylabels:
        ax.set_yticklabels(ylabels)
    ax.tick_params(axis='y', labelsize=14)

    # z axis
    ax.axes.w_zaxis.set_major_locator(plt.IndexLocator(1, 0.5))
    ax.set_zlim3d([min(z_min, 0), z_max])

    # color axis
    if colorbar:
        cax, kw = mpl.colorbar.make_axes(ax, shrink=.75, pad=.0)
        mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm)

    return fig, ax
Esempio n. 19
0
scree_df["n_components"] = np.tile(np.arange(1, n_components + 1), 2)
scree_df["explained_variance"] = np.concatenate(
    (np.cumsum(pca.explained_variance_ratio_), sca.explained_variance_ratio_))
scree_df["method"] = n_components * ["PCA"] + n_components * ["SCA"]
sns.lineplot(
    data=scree_df,
    x="n_components",
    y="explained_variance",
    hue="method",
    ax=ax,
    marker="o",
)
ax.legend().set_title("Method")
ax.set(ylabel="Cumulative explained variance", xlabel="# of PCs")
ax.yaxis.set_major_locator(plt.MaxNLocator(3))
ax.xaxis.set_major_locator(plt.IndexLocator(base=4, offset=1))
stashfig("screeplot")

#%%
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
rows = []
for method, model in zip(["PCA", "SCA"], [pca, sca]):
    for i in range(1, n_components + 1):
        components = model.components_[:i]
        n_nonzero = np.count_nonzero(components)
        rows.append({
            "n_components": i,
            "n_nonzero": n_nonzero,
            "method": method
        })
nonzero_df = pd.DataFrame(rows)
Esempio n. 20
0
def matrix(data_matrix,
           mode='abs',
           xlabel='',
           ylabel='',
           zlabel='',
           filename=None,
           fig_ax=None):
    """
    Create a "skyscraper" plot and a 2d color-coded plot of a matrix.

    Parameters
    ----------
    data_matrix: ndarray of float or complex
        2d matrix data
    mode: str from `constants.MODE_FUNC_DICT`
        choice of processing function to be applied to data
    xlabel, ylabel, zlabel: str, optional
    filename: str, optional
        file path and name (not including suffix)
    fig_ax: tuple(Figure, (Axes, Axes)), optional
        fig and ax objects for matplotlib figure addition

    Returns
    -------
    Figure, (Axes1, Axes2)
        figure and axes objects for further editing
    """
    if fig_ax is None:
        fig = plt.figure(figsize=(10, 5))
        ax1 = fig.add_subplot(1, 2, 1, projection='3d')
        ax2 = plt.subplot(1, 2, 2)
    else:
        fig, (ax1, ax2) = fig_ax

    matsize = len(data_matrix)
    element_count = matsize**2  # num. of elements to plot

    xgrid, ygrid = np.meshgrid(range(matsize), range(matsize))
    xgrid = xgrid.T.flatten() - 0.5  # center bars on integer value of x-axis
    ygrid = ygrid.T.flatten() - 0.5  # center bars on integer value of y-axis

    zbottom = np.zeros(element_count)  # all bars start at z=0
    dx = 0.75 * np.ones(element_count)  # width of bars in x-direction
    dy = dx  # width of bars in y-direction (same as x-direction)

    modefunction = constants.MODE_FUNC_DICT[mode]
    zheight = modefunction(
        data_matrix).flatten()  # height of bars from matrix elements
    nrm = mpl.colors.Normalize(
        0, max(zheight))  # <-- normalize colors to max. data
    colors = plt.cm.viridis(nrm(zheight))  # list of colors for each bar

    # skyscraper plot
    ax1.view_init(azim=210, elev=23)
    ax1.bar3d(xgrid, ygrid, zbottom, dx, dy, zheight, color=colors)
    ax1.axes.w_xaxis.set_major_locator(plt.IndexLocator(
        1, -0.5))  # set x-ticks to integers
    ax1.axes.w_yaxis.set_major_locator(plt.IndexLocator(
        1, -0.5))  # set y-ticks to integers
    ax1.set_zlim3d([0, max(zheight)])
    ax1.set_xlabel(xlabel)
    ax1.set_ylabel(ylabel)
    ax1.set_zlabel(zlabel)

    # 2d plot
    ax2.matshow(modefunction(data_matrix), cmap=plt.cm.viridis)

    cax, _ = mpl.colorbar.make_axes(
        ax2, shrink=.75, pad=.02)  # add colorbar with normalized range
    _ = mpl.colorbar.ColorbarBase(cax, cmap=plt.cm.viridis, norm=nrm)

    if filename:
        out_file = mplpdf.PdfPages(filename)
        out_file.savefig()
        out_file.close()

    return fig, (ax1, ax2)
Esempio n. 21
0
def hinton(W,
           xlabels=None,
           ylabels=None,
           labelsize=9,
           title=None,
           fig=None,
           ax=None,
           cmap=None):

    if cmap is None:
        cmap = plt.get_cmap('twilight')

    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(4, 4))

    if not (xlabels or ylabels):
        ax.axis('off')

    ax.axis('equal')
    ax.set_frame_on(False)

    height, width = W.shape
    ax.set(xlim=(0, width), ylim=(0, height))

    max_abs = np.max(np.abs(W))
    scale = 0.7

    for i in range(width):
        for j in range(height):
            x = i + 1 - 0.5
            y = j + 1 - 0.5
            _blob(x,
                  height - y,
                  np.angle(W[i, j]),
                  -np.pi,
                  np.pi,
                  np.abs(W[i, j]) / max_abs * scale,
                  cmap=cmap,
                  ax=ax)

    # x axis
    ax.xaxis.set_major_locator(plt.IndexLocator(1, 0.5))
    if xlabels:
        ax.set_xticklabels(xlabels, rotation='vertical')
        ax.xaxis.tick_top()
    ax.tick_params(axis='x', labelsize=labelsize, pad=0)
    ax.xaxis.set_ticks_position('none')

    # y axis
    ax.yaxis.set_major_locator(plt.IndexLocator(1, 0.5))
    ax.yaxis.set_ticks_position('none')
    if ylabels:
        ax.set_yticklabels(list(reversed(ylabels)))
    ax.tick_params(axis='y', labelsize=labelsize, pad=0)

    # color axis
    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='4%', pad='2%')
    cbar = mpl.colorbar.ColorbarBase(cax,
                                     cmap=cmap,
                                     norm=mpl.colors.Normalize(-np.pi, np.pi),
                                     ticks=[])
    #                                      ticks=[-np.pi, 0, np.pi])
    cax.text(0.5,
             0.0,
             '$-\pi$',
             transform=cax.transAxes,
             va='top',
             ha='center')
    cax.text(0.5,
             1.0,
             '$+\pi$',
             transform=cax.transAxes,
             va='bottom',
             ha='center')
    #     cbar.ax.set_yticklabels(['$-\pi$','$0$','$+\pi$'])

    # Make title in corner
    if title is not None:
        plt.text(-.07,
                 1.05,
                 title,
                 ha='center',
                 va='center',
                 fontsize=22,
                 transform=ax.transAxes)

    return fig, ax