def iplot_state(quantum_state, method='city', figsize=None): """Plot the quantum state. Args: quantum_state (ndarray): statevector or density matrix representation of a quantum state. method (str): Plotting method to use. figsize (tuple): Figure size in inches. Raises: VisualizationError: if the input is not a statevector or density matrix, or if the state is not an multi-qubit quantum state. """ warnings.warn( "iplot_state is deprecated, and will be removed in \ the 0.9 release. Use the iplot_state_ * functions \ instead.", DeprecationWarning) rho = _validate_input_state(quantum_state) if method == "city": iplot_state_city(rho, figsize=figsize) elif method == "paulivec": iplot_state_paulivec(rho, figsize=figsize) elif method == "qsphere": iplot_state_qsphere(rho, figsize=figsize) elif method == "bloch": iplot_bloch_multivector(rho, figsize=figsize) elif method == "hinton": iplot_state_hinton(rho, figsize=figsize) else: raise VisualizationError('Invalid plot state method.')
def plot_state_paulivec(rho, title="", figsize=None, color=None): """Plot the paulivec representation of a quantum state. Plot a bargraph of the mixed state rho over the pauli matrices Args: rho (ndarray): Numpy array for state vector or density matrix title (str): a string that represents the plot title figsize (tuple): Figure size in inches. color (list or str): Color of the expectation value bars. Returns: matplotlib.Figure: The matplotlib.Figure of the visualization Raises: ImportError: Requires matplotlib. """ if not HAS_MATPLOTLIB: raise ImportError('Must have Matplotlib installed.') rho = _validate_input_state(rho) if figsize is None: figsize = (7, 5) num = int(np.log2(len(rho))) labels = list(map(lambda x: x.to_label(), pauli_group(num))) values = list( map(lambda x: np.real(np.trace(np.dot(x.to_matrix(), rho))), pauli_group(num))) numelem = len(values) if color is None: color = "#648fff" ind = np.arange(numelem) # the x locations for the groups width = 0.5 # the width of the bars fig, ax = plt.subplots(figsize=figsize) ax.grid(zorder=0, linewidth=1, linestyle='--') ax.bar(ind, values, width, color=color, zorder=2) ax.axhline(linewidth=1, color='k') # add some text for labels, title, and axes ticks ax.set_ylabel('Expectation value', fontsize=14) ax.set_xticks(ind) ax.set_yticks([-1, -0.5, 0, 0.5, 1]) ax.set_xticklabels(labels, fontsize=14, rotation=70) ax.set_xlabel('Pauli', fontsize=14) ax.set_ylim([-1, 1]) ax.set_facecolor('#eeeeee') for tick in ax.xaxis.get_major_ticks() + ax.yaxis.get_major_ticks(): tick.label.set_fontsize(14) ax.set_title(title, fontsize=16) plt.close(fig) return fig
def plot_bloch_multivector(rho, title='', figsize=None): """Plot the Bloch sphere. Plot a sphere, axes, the Bloch vector, and its projections onto each axis. Args: rho (ndarray): Numpy array for state vector or density matrix. title (str): a string that represents the plot title figsize (tuple): Has no effect, here for compatibility only. Returns: Figure: A matplotlib figure instance if `ax = None`. Raises: ImportError: Requires matplotlib. """ if not HAS_MATPLOTLIB: raise ImportError('Must have Matplotlib installed.') rho = _validate_input_state(rho) num = int(np.log2(len(rho))) width, height = plt.figaspect(1 / num) fig = plt.figure(figsize=(width, height)) for i in range(num): ax = fig.add_subplot(1, num, i + 1, projection='3d') pauli_singles = [ Pauli.pauli_single(num, i, 'X'), Pauli.pauli_single(num, i, 'Y'), Pauli.pauli_single(num, i, 'Z') ] bloch_state = list( map(lambda x: np.real(np.trace(np.dot(x.to_matrix(), rho))), pauli_singles)) plot_bloch_vector(bloch_state, "qubit " + str(i), ax=ax, figsize=figsize) fig.suptitle(title, fontsize=16) plt.close(fig) return fig
def plot_state(quantum_state, method='city', figsize=None): """Plot the quantum state. Args: quantum_state (ndarray): statevector or density matrix representation of a quantum state. method (str): Plotting method to use. figsize (tuple): Figure size in inches, Returns: matplotlib.Figure: The matplotlib.Figure of the visualization Raises: ImportError: Requires matplotlib. VisualizationError: if the input is not a statevector or density matrix, or if the state is not an multi-qubit quantum state. """ if not HAS_MATPLOTLIB: raise ImportError('Must have Matplotlib installed.') warnings.warn( "plot_state is deprecated, and will be removed in \ the 0.9 release. Use the plot_state_ * functions \ instead.", DeprecationWarning) # Check if input is a statevector, and convert to density matrix rho = _validate_input_state(quantum_state) fig = None if method == 'city': fig = plot_state_city(rho, figsize=figsize) elif method == "paulivec": fig = plot_state_paulivec(rho, figsize=figsize) elif method == "qsphere": fig = plot_state_qsphere(rho, figsize=figsize) elif method == "bloch": plot_bloch_multivector(rho, figsize=figsize) elif method == "wigner": fig = plot_wigner_function(rho) elif method == "hinton": fig = plot_state_hinton(rho, figsize=figsize) return fig
def iplot_bloch_multivector(rho, figsize=None): """ Create a bloch sphere representation. Graphical representation of the input array, using as much bloch spheres as qubit are required. Args: rho (array): State vector or density matrix figsize (tuple): Figure size in pixels. """ # HTML html_template = Template(""" <p> <div id="content_$divNumber" style="position: absolute; z-index: 1;"> <div id="bloch_$divNumber"></div> </div> </p> """) # JavaScript javascript_template = Template(""" <script> requirejs.config({ paths: { qVisualization: "https://qvisualization.mybluemix.net/q-visualizations" } }); data = $data; dataValues = []; for (var i = 0; i < data.length; i++) { // Coordinates var x = data[i][0]; var y = data[i][1]; var z = data[i][2]; var point = {'x': x, 'y': y, 'z': z}; dataValues.push(point); } require(["qVisualization"], function(qVisualizations) { // Plot figure qVisualizations.plotState("bloch_$divNumber", "bloch", dataValues, $options); }); </script> """) rho = _validate_input_state(rho) if figsize is None: options = {} else: options = {'width': figsize[0], 'height': figsize[1]} # Process data and execute num = int(np.log2(len(rho))) bloch_data = [] for i in range(num): pauli_singles = [ Pauli.pauli_single(num, i, 'X'), Pauli.pauli_single(num, i, 'Y'), Pauli.pauli_single(num, i, 'Z') ] bloch_state = list( map(lambda x: np.real(np.trace(np.dot(x.to_matrix(), rho))), pauli_singles)) bloch_data.append(bloch_state) div_number = str(time.time()) div_number = re.sub('[.]', '', div_number) html = html_template.substitute({'divNumber': div_number}) javascript = javascript_template.substitute({ 'data': bloch_data, 'divNumber': div_number, 'options': options }) display(HTML(html + javascript))
def iplot_state_hinton(rho, figsize=None): """ Create a hinton representation. Graphical representation of the input array using a 2D city style graph (hinton). Args: rho (array): Density matrix figsize (tuple): Figure size in inches. """ # HTML html_template = Template(""" <p> <div id="hinton_$divNumber"></div> </p> """) # JavaScript javascript_template = Template(""" <script> requirejs.config({ paths: { qVisualization: "https://qvisualization.mybluemix.net/q-visualizations" } }); require(["qVisualization"], function(qVisualizations) { qVisualizations.plotState("hinton_$divNumber", "hinton", $executions, $options); }); </script> """) rho = _validate_input_state(rho) if figsize is None: options = {} else: options = {'width': figsize[0], 'height': figsize[1]} # Process data and execute div_number = str(time.time()) div_number = re.sub('[.]', '', div_number) # Process data and execute real = [] imag = [] for xvalue in rho: row_real = [] col_imag = [] for value_real in xvalue.real: row_real.append(float(value_real)) real.append(row_real) for value_imag in xvalue.imag: col_imag.append(float(value_imag)) imag.append(col_imag) html = html_template.substitute({'divNumber': div_number}) javascript = javascript_template.substitute({ 'divNumber': div_number, 'executions': [{ 'data': real }, { 'data': imag }], 'options': options }) display(HTML(html + javascript))
def iplot_state_qsphere(rho, figsize=None): """ Create a Q sphere representation. Graphical representation of the input array, using a Q sphere for each eigenvalue. Args: rho (array): State vector or density matrix. figsize (tuple): Figure size in pixels. """ # HTML html_template = Template(""" <p> <div id="content_$divNumber" style="position: absolute; z-index: 1;"> <div id="qsphere_$divNumber"></div> </div> </p> """) # JavaScript javascript_template = Template(""" <script> requirejs.config({ paths: { qVisualization: "https://qvisualization.mybluemix.net/q-visualizations" } }); require(["qVisualization"], function(qVisualizations) { data = $data; qVisualizations.plotState("qsphere_$divNumber", "qsphere", data, $options); }); </script> """) rho = _validate_input_state(rho) if figsize is None: options = {} else: options = {'width': figsize[0], 'height': figsize[1]} qspheres_data = [] # Process data and execute num = int(np.log2(len(rho))) # get the eigenvectors and eigenvalues weig, stateall = linalg.eigh(rho) for _ in range(2**num): # start with the max probmix = weig.max() prob_location = weig.argmax() if probmix > 0.001: # print("The " + str(k) + "th eigenvalue = " + str(probmix)) # get the max eigenvalue state = stateall[:, prob_location] loc = np.absolute(state).argmax() # get the element location closes to lowest bin representation. for j in range(2**num): test = np.absolute( np.absolute(state[j]) - np.absolute(state[loc])) if test < 0.001: loc = j break # remove the global phase angles = (np.angle(state[loc]) + 2 * np.pi) % (2 * np.pi) angleset = np.exp(-1j * angles) state = angleset * state state.flatten() spherepoints = [] for i in range(2**num): # get x,y,z points element = bin(i)[2:].zfill(num) weight = element.count("1") number_of_divisions = n_choose_k(num, weight) weight_order = bit_string_index(element) angle = weight_order * 2 * np.pi / number_of_divisions zvalue = -2 * weight / num + 1 xvalue = np.sqrt(1 - zvalue**2) * np.cos(angle) yvalue = np.sqrt(1 - zvalue**2) * np.sin(angle) # get prob and angle - prob will be shade and angle color prob = np.real(np.dot(state[i], state[i].conj())) angles = (np.angle(state[i]) + 2 * np.pi) % (2 * np.pi) qpoint = { 'x': xvalue, 'y': yvalue, 'z': zvalue, 'prob': prob, 'phase': angles } spherepoints.append(qpoint) # Associate all points to one sphere sphere = {'points': spherepoints, 'eigenvalue': probmix} # Add sphere to the spheres array qspheres_data.append(sphere) weig[prob_location] = 0 div_number = str(time.time()) div_number = re.sub('[.]', '', div_number) html = html_template.substitute({'divNumber': div_number}) javascript = javascript_template.substitute({ 'data': qspheres_data, 'divNumber': div_number, 'options': options }) display(HTML(html + javascript))
def iplot_state_city(rho, figsize=None): """ Create a cities representation. Graphical representation of the input array using a city style graph. Args: rho (array): State vector or density matrix. figsize (tuple): The figure size in inches. """ # HTML html_template = Template(""" <p> <div id="content_$divNumber" style="position: absolute; z-index: 1;"> <div id="cities_$divNumber"></div> </div> </p> """) # JavaScript javascript_template = Template(""" <script> requirejs.config({ paths: { qVisualization: "https://qvisualization.mybluemix.net/q-visualizations" } }); require(["qVisualization"], function(qVisualizations) { data = { real: $real, titleReal: "Real.[rho]", imaginary: $imag, titleImaginary: "Im.[rho]", qbits: $qbits }; qVisualizations.plotState("cities_$divNumber", "cities", data, $options); }); </script> """) rho = _validate_input_state(rho) if figsize is None: options = {} else: options = {'width': figsize[0], 'height': figsize[1]} # Process data and execute real = [] imag = [] for xvalue in rho: row_real = [] col_imag = [] for value_real in xvalue.real: row_real.append(float(value_real)) real.append(row_real) for value_imag in xvalue.imag: col_imag.append(float(value_imag)) imag.append(col_imag) div_number = str(time.time()) div_number = re.sub('[.]', '', div_number) html = html_template.substitute({'divNumber': div_number}) javascript = javascript_template.substitute({ 'real': real, 'imag': imag, 'qbits': len(real), 'divNumber': div_number, 'options': options }) display(HTML(html + javascript))
def iplot_state_paulivec(rho, figsize=None, slider=False, show_legend=False): """ Create a paulivec representation. Graphical representation of the input array. Args: rho (array): State vector or density matrix. figsize (tuple): Figure size in pixels. slider (bool): activate slider show_legend (bool): show legend of graph content """ # HTML html_template = Template(""" <p> <div id="paulivec_$divNumber"></div> </p> """) # JavaScript javascript_template = Template(""" <script> requirejs.config({ paths: { qVisualization: "https://qvisualization.mybluemix.net/q-visualizations" } }); require(["qVisualization"], function(qVisualizations) { qVisualizations.plotState("paulivec_$divNumber", "paulivec", $executions, $options); }); </script> """) rho = _validate_input_state(rho) # set default figure size if none given if figsize is None: figsize = (7, 5) options = { 'width': figsize[0], 'height': figsize[1], 'slider': int(slider), 'show_legend': int(show_legend) } # Process data and execute div_number = str(time.time()) div_number = re.sub('[.]', '', div_number) data_to_plot = [] rho_data = process_data(rho) data_to_plot.append(dict(data=rho_data)) html = html_template.substitute({'divNumber': div_number}) javascript = javascript_template.substitute({ 'divNumber': div_number, 'executions': data_to_plot, 'options': options }) display(HTML(html + javascript))
def plot_state_hinton(rho, title='', figsize=None): """Plot a hinton diagram for the quanum state. Args: rho (ndarray): Numpy array for state vector or density matrix. title (str): a string that represents the plot title figsize (tuple): Figure size in inches. Returns: matplotlib.Figure: The matplotlib.Figure of the visualization Raises: ImportError: Requires matplotlib. """ if not HAS_MATPLOTLIB: raise ImportError('Must have Matplotlib installed.') rho = _validate_input_state(rho) if figsize is None: figsize = (8, 5) num = int(np.log2(len(rho))) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) max_weight = 2**np.ceil(np.log(np.abs(rho).max()) / np.log(2)) datareal = np.real(rho) dataimag = np.imag(rho) column_names = [bin(i)[2:].zfill(num) for i in range(2**num)] row_names = [bin(i)[2:].zfill(num) for i in range(2**num)] lx = len(datareal[0]) # Work out matrix dimensions ly = len(datareal[:, 0]) # Real ax1.patch.set_facecolor('gray') ax1.set_aspect('equal', 'box') ax1.xaxis.set_major_locator(plt.NullLocator()) ax1.yaxis.set_major_locator(plt.NullLocator()) for (x, y), w in np.ndenumerate(datareal): color = 'white' if w > 0 else 'black' size = np.sqrt(np.abs(w) / max_weight) rect = plt.Rectangle([x - size / 2, y - size / 2], size, size, facecolor=color, edgecolor=color) ax1.add_patch(rect) ax1.set_xticks(np.arange(0, lx + 0.5, 1)) ax1.set_yticks(np.arange(0, ly + 0.5, 1)) ax1.set_yticklabels(row_names, fontsize=14) ax1.set_xticklabels(column_names, fontsize=14, rotation=90) ax1.autoscale_view() ax1.invert_yaxis() ax1.set_title('Real[rho]', fontsize=14) # Imaginary ax2.patch.set_facecolor('gray') ax2.set_aspect('equal', 'box') ax2.xaxis.set_major_locator(plt.NullLocator()) ax2.yaxis.set_major_locator(plt.NullLocator()) for (x, y), w in np.ndenumerate(dataimag): color = 'white' if w > 0 else 'black' size = np.sqrt(np.abs(w) / max_weight) rect = plt.Rectangle([x - size / 2, y - size / 2], size, size, facecolor=color, edgecolor=color) ax2.add_patch(rect) if np.any(dataimag != 0): ax2.set_xticks(np.arange(0, lx + 0.5, 1)) ax2.set_yticks(np.arange(0, ly + 0.5, 1)) ax2.set_yticklabels(row_names, fontsize=14) ax2.set_xticklabels(column_names, fontsize=14, rotation=90) ax2.autoscale_view() ax2.invert_yaxis() ax2.set_title('Imag[rho]', fontsize=14) if title: fig.suptitle(title, fontsize=16) plt.tight_layout() plt.close(fig) return fig
def plot_state_qsphere(rho, figsize=None): """Plot the qsphere representation of a quantum state. Args: rho (ndarray): State vector or density matrix representation. of quantum state. figsize (tuple): Figure size in inches. Returns: Figure: A matplotlib figure instance. Raises: ImportError: Requires matplotlib. """ if not HAS_MATPLOTLIB: raise ImportError('Must have Matplotlib installed.') rho = _validate_input_state(rho) if figsize is None: figsize = (7, 7) num = int(np.log2(len(rho))) # get the eigenvectors and eigenvalues we, stateall = linalg.eigh(rho) for _ in range(2**num): # start with the max probmix = we.max() prob_location = we.argmax() if probmix > 0.001: # get the max eigenvalue state = stateall[:, prob_location] loc = np.absolute(state).argmax() # get the element location closes to lowest bin representation. for j in range(2**num): test = np.absolute( np.absolute(state[j]) - np.absolute(state[loc])) if test < 0.001: loc = j break # remove the global phase angles = (np.angle(state[loc]) + 2 * np.pi) % (2 * np.pi) angleset = np.exp(-1j * angles) # print(state) # print(angles) state = angleset * state # print(state) state.flatten() # start the plotting fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111, projection='3d') ax.axes.set_xlim3d(-1.0, 1.0) ax.axes.set_ylim3d(-1.0, 1.0) ax.axes.set_zlim3d(-1.0, 1.0) ax.set_aspect("equal") ax.axes.grid(False) # Plot semi-transparent sphere u = np.linspace(0, 2 * np.pi, 25) v = np.linspace(0, np.pi, 25) x = np.outer(np.cos(u), np.sin(v)) y = np.outer(np.sin(u), np.sin(v)) z = np.outer(np.ones(np.size(u)), np.cos(v)) ax.plot_surface(x, y, z, rstride=1, cstride=1, color='k', alpha=0.05, linewidth=0) # wireframe # Get rid of the panes ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) # Get rid of the spines ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) # Get rid of the ticks ax.set_xticks([]) ax.set_yticks([]) ax.set_zticks([]) d = num for i in range(2**num): # get x,y,z points element = bin(i)[2:].zfill(num) weight = element.count("1") zvalue = -2 * weight / d + 1 number_of_divisions = n_choose_k(d, weight) weight_order = bit_string_index(element) # if weight_order >= number_of_divisions / 2: # com_key = compliment(element) # weight_order_temp = bit_string_index(com_key) # weight_order = np.floor( # number_of_divisions / 2) + weight_order_temp + 1 angle = weight_order * 2 * np.pi / number_of_divisions xvalue = np.sqrt(1 - zvalue**2) * np.cos(angle) yvalue = np.sqrt(1 - zvalue**2) * np.sin(angle) ax.plot([xvalue], [yvalue], [zvalue], markerfacecolor=(.5, .5, .5), markeredgecolor=(.5, .5, .5), marker='o', markersize=10, alpha=1) # get prob and angle - prob will be shade and angle color prob = np.real(np.dot(state[i], state[i].conj())) colorstate = phase_to_color_wheel(state[i]) a = Arrow3D([0, xvalue], [0, yvalue], [0, zvalue], mutation_scale=20, alpha=prob, arrowstyle="-", color=colorstate, lw=10) ax.add_artist(a) # add weight lines for weight in range(d + 1): theta = np.linspace(-2 * np.pi, 2 * np.pi, 100) z = -2 * weight / d + 1 r = np.sqrt(1 - z**2) x = r * np.cos(theta) y = r * np.sin(theta) ax.plot(x, y, z, color=(.5, .5, .5)) # add center point ax.plot([0], [0], [0], markerfacecolor=(.5, .5, .5), markeredgecolor=(.5, .5, .5), marker='o', markersize=10, alpha=1) we[prob_location] = 0 else: break plt.tight_layout() plt.close(fig) return fig
def plot_state_city(rho, title="", figsize=None, color=None): """Plot the cityscape of quantum state. Plot two 3d bar graphs (two dimensional) of the real and imaginary part of the density matrix rho. Args: rho (ndarray): Numpy array for state vector or density matrix. title (str): a string that represents the plot title figsize (tuple): Figure size in inches. color (list): A list of len=2 giving colors for real and imaginary components of matrix elements. Returns: matplotlib.Figure: The matplotlib.Figure of the visualization Raises: ImportError: Requires matplotlib. ValueError: When 'color' is not a list of len=2. """ if not HAS_MATPLOTLIB: raise ImportError('Must have Matplotlib installed.') rho = _validate_input_state(rho) num = int(np.log2(len(rho))) # get the real and imag parts of rho datareal = np.real(rho) dataimag = np.imag(rho) # get the labels column_names = [bin(i)[2:].zfill(num) for i in range(2**num)] row_names = [bin(i)[2:].zfill(num) for i in range(2**num)] lx = len(datareal[0]) # Work out matrix dimensions ly = len(datareal[:, 0]) xpos = np.arange(0, lx, 1) # Set up a mesh of positions ypos = np.arange(0, ly, 1) xpos, ypos = np.meshgrid(xpos + 0.25, ypos + 0.25) xpos = xpos.flatten() ypos = ypos.flatten() zpos = np.zeros(lx * ly) dx = 0.5 * np.ones_like(zpos) # width of bars dy = dx.copy() dzr = datareal.flatten() dzi = dataimag.flatten() if color is None: color = ["#648fff", "#648fff"] else: if len(color) != 2: raise ValueError("'color' must be a list of len=2.") if color[0] is None: color[0] = "#648fff" if color[1] is None: color[1] = "#648fff" # set default figure size if figsize is None: figsize = (15, 5) fig = plt.figure(figsize=figsize) ax1 = fig.add_subplot(1, 2, 1, projection='3d') ax1.bar3d(xpos, ypos, zpos, dx, dy, dzr, color=color[0], alpha=0.5) ax2 = fig.add_subplot(1, 2, 2, projection='3d') ax2.bar3d(xpos, ypos, zpos, dx, dy, dzi, color=color[1], alpha=0.5) ax1.set_xticks(np.arange(0.5, lx + 0.5, 1)) ax1.set_yticks(np.arange(0.5, ly + 0.5, 1)) ax1.axes.set_zlim3d(np.min(dzr), np.max(dzr) + 1e-9) ax1.zaxis.set_major_locator(MaxNLocator(5)) ax1.w_xaxis.set_ticklabels(row_names, fontsize=14, rotation=45) ax1.w_yaxis.set_ticklabels(column_names, fontsize=14, rotation=-22.5) ax1.set_zlabel("Real[rho]", fontsize=14) for tick in ax1.zaxis.get_major_ticks(): tick.label.set_fontsize(14) ax2.set_xticks(np.arange(0.5, lx + 0.5, 1)) ax2.set_yticks(np.arange(0.5, ly + 0.5, 1)) if np.min(dzi) != np.max(dzi): eps = 0 ax2.zaxis.set_major_locator(MaxNLocator(5)) else: ax2.set_zticks([0]) eps = 1e-9 ax2.axes.set_zlim3d(np.min(dzi), np.max(dzi) + eps) ax2.w_xaxis.set_ticklabels(row_names, fontsize=14, rotation=45) ax2.w_yaxis.set_ticklabels(column_names, fontsize=14, rotation=-22.5) # ax2.set_xlabel('basis state', fontsize=12) # ax2.set_ylabel('basis state', fontsize=12) ax2.set_zlabel("Imag[rho]", fontsize=14) for tick in ax2.zaxis.get_major_ticks(): tick.label.set_fontsize(14) plt.suptitle(title, fontsize=16) plt.tight_layout() plt.close(fig) return fig
def plot_state_city(rho, title="", figsize=None, color=None, alpha=1): """Plot the cityscape of quantum state. Plot two 3d bar graphs (two dimensional) of the real and imaginary part of the density matrix rho. Args: rho (ndarray): Numpy array for state vector or density matrix. title (str): a string that represents the plot title figsize (tuple): Figure size in inches. color (list): A list of len=2 giving colors for real and imaginary components of matrix elements. alpha (float): Transparency value for bars Returns: matplotlib.Figure: The matplotlib.Figure of the visualization Raises: ImportError: Requires matplotlib. ValueError: When 'color' is not a list of len=2. """ if not HAS_MATPLOTLIB: raise ImportError('Must have Matplotlib installed.') rho = _validate_input_state(rho) num = int(np.log2(len(rho))) # get the real and imag parts of rho datareal = np.real(rho) dataimag = np.imag(rho) # get the labels column_names = [bin(i)[2:].zfill(num) for i in range(2**num)] row_names = [bin(i)[2:].zfill(num) for i in range(2**num)] lx = len(datareal[0]) # Work out matrix dimensions ly = len(datareal[:, 0]) xpos = np.arange(0, lx, 1) # Set up a mesh of positions ypos = np.arange(0, ly, 1) xpos, ypos = np.meshgrid(xpos + 0.25, ypos + 0.25) xpos = xpos.flatten() ypos = ypos.flatten() zpos = np.zeros(lx * ly) dx = 0.5 * np.ones_like(zpos) # width of bars dy = dx.copy() dzr = datareal.flatten() dzi = dataimag.flatten() if color is None: color = ["#648fff", "#648fff"] else: if len(color) != 2: raise ValueError("'color' must be a list of len=2.") if color[0] is None: color[0] = "#648fff" if color[1] is None: color[1] = "#648fff" # set default figure size if figsize is None: figsize = (15, 5) fig = plt.figure(figsize=figsize) ax1 = fig.add_subplot(1, 2, 1, projection='3d') x = [0, max(xpos) + 0.5, max(xpos) + 0.5, 0] y = [0, 0, max(ypos) + 0.5, max(ypos) + 0.5] z = [0, 0, 0, 0] verts = [list(zip(x, y, z))] fc1 = generate_facecolors(xpos, ypos, zpos, dx, dy, dzr, color[0]) for idx, cur_zpos in enumerate(zpos): if dzr[idx] > 0: zorder = 2 else: zorder = 0 b1 = ax1.bar3d(xpos[idx], ypos[idx], cur_zpos, dx[idx], dy[idx], dzr[idx], alpha=alpha, zorder=zorder) b1.set_facecolors(fc1[6 * idx:6 * idx + 6]) pc1 = Poly3DCollection(verts, alpha=0.15, facecolor='k', linewidths=1, zorder=1) if min(dzr) < 0 < max(dzr): ax1.add_collection3d(pc1) ax2 = fig.add_subplot(1, 2, 2, projection='3d') fc2 = generate_facecolors(xpos, ypos, zpos, dx, dy, dzi, color[1]) for idx, cur_zpos in enumerate(zpos): if dzi[idx] > 0: zorder = 2 else: zorder = 0 b2 = ax2.bar3d(xpos[idx], ypos[idx], cur_zpos, dx[idx], dy[idx], dzi[idx], alpha=alpha, zorder=zorder) b2.set_facecolors(fc2[6 * idx:6 * idx + 6]) pc2 = Poly3DCollection(verts, alpha=0.2, facecolor='k', linewidths=1, zorder=1) if min(dzi) < 0 < max(dzi): ax2.add_collection3d(pc2) ax1.set_xticks(np.arange(0.5, lx + 0.5, 1)) ax1.set_yticks(np.arange(0.5, ly + 0.5, 1)) max_dzr = max(dzr) min_dzr = min(dzr) if max_dzr != min_dzr: ax1.axes.set_zlim3d(np.min(dzr), np.max(dzr) + 1e-9) else: if min_dzr == 0: ax1.axes.set_zlim3d(np.min(dzr), np.max(dzr) + 1e-9) else: ax1.axes.set_zlim3d(auto=True) ax1.zaxis.set_major_locator(MaxNLocator(5)) ax1.w_xaxis.set_ticklabels(row_names, fontsize=14, rotation=45) ax1.w_yaxis.set_ticklabels(column_names, fontsize=14, rotation=-22.5) ax1.set_zlabel("Real[rho]", fontsize=14) for tick in ax1.zaxis.get_major_ticks(): tick.label.set_fontsize(14) ax2.set_xticks(np.arange(0.5, lx + 0.5, 1)) ax2.set_yticks(np.arange(0.5, ly + 0.5, 1)) min_dzi = np.min(dzi) max_dzi = np.max(dzi) if min_dzi != max_dzi: eps = 0 ax2.zaxis.set_major_locator(MaxNLocator(5)) ax2.axes.set_zlim3d(np.min(dzi), np.max(dzi) + eps) else: if min_dzi == 0: ax2.set_zticks([0]) eps = 1e-9 ax2.axes.set_zlim3d(np.min(dzi), np.max(dzi) + eps) else: ax2.axes.set_zlim3d(auto=True) ax2.w_xaxis.set_ticklabels(row_names, fontsize=14, rotation=45) ax2.w_yaxis.set_ticklabels(column_names, fontsize=14, rotation=-22.5) ax2.set_zlabel("Imag[rho]", fontsize=14) for tick in ax2.zaxis.get_major_ticks(): tick.label.set_fontsize(14) plt.suptitle(title, fontsize=16) plt.tight_layout() plt.close(fig) return fig