def matrix_histogram_complex(M, xlabels=None, ylabels=None, title=None, limits=None, phase_limits=None, colorbar=True, fig=None, ax=None, threshold=None): """ Draw a histogram for the amplitudes of matrix M, using the argument of each element for coloring the bars, with the given x and y labels and title. Parameters ---------- M : Matrix of Qobj The matrix to visualize xlabels : list of strings list of x labels ylabels : list of strings list of y labels title : string title of the plot (optional) limits : list/array with two float numbers The z-axis limits [min, max] (optional) phase_limits : list/array with two float numbers The phase-axis (colorbar) limits [min, max] (optional) ax : a matplotlib axes instance The axes context in which the plot will be drawn. threshold: float (None) Threshold for when bars of smaller height should be transparent. If not set, all bars are colored according to the color map. Returns ------- fig, ax : tuple A tuple of the matplotlib figure and axes instances used to produce the figure. Raises ------ ValueError Input argument is not valid. """ if isinstance(M, Qobj): # extract matrix data from Qobj M = M.full() n = np.size(M) xpos, ypos = np.meshgrid(range(M.shape[0]), range(M.shape[1])) xpos = xpos.T.flatten() - 0.5 ypos = ypos.T.flatten() - 0.5 zpos = np.zeros(n) dx = dy = 0.8 * np.ones(n) Mvec = M.flatten() dz = abs(Mvec) # make small numbers real, to avoid random colors idx, = np.where(abs(Mvec) < 0.001) Mvec[idx] = abs(Mvec[idx]) if phase_limits: # check that limits is a list type phase_min = phase_limits[0] phase_max = phase_limits[1] else: phase_min = -pi phase_max = pi norm = mpl.colors.Normalize(phase_min, phase_max) cmap = complex_phase_cmap() colors = cmap(norm(angle(Mvec))) if threshold is not None: colors[:, 3] = 1 * (dz > threshold) if ax is None: fig = plt.figure() ax = Axes3D(fig, azim=-35, elev=35) ax.bar3d(xpos, ypos, zpos, dx, dy, dz, color=colors) if title and fig: ax.set_title(title) # x axis ax.axes.w_xaxis.set_major_locator(plt.IndexLocator(1, -0.5)) if xlabels: ax.set_xticklabels(xlabels) ax.tick_params(axis='x', labelsize=12) # y axis ax.axes.w_yaxis.set_major_locator(plt.IndexLocator(1, -0.5)) if ylabels: ax.set_yticklabels(ylabels) ax.tick_params(axis='y', labelsize=12) # z axis if limits and isinstance(limits, list): ax.set_zlim3d(limits) else: ax.set_zlim3d([0, 1]) # use min/max # ax.set_zlabel('abs') # color axis if colorbar: cax, kw = mpl.colorbar.make_axes(ax, shrink=.75, pad=.0) cb = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm) cb.set_ticks([-pi, -pi / 2, 0, pi / 2, pi]) cb.set_ticklabels( (r'$-\pi$', r'$-\pi/2$', r'$0$', r'$\pi/2$', r'$\pi$')) cb.set_label('arg') return fig, ax