コード例 #1
0
def iplot_state(quantum_state, method='city', figsize=None):
    """Plot the quantum state.

    Args:
        quantum_state (ndarray): statevector or density matrix
                                 representation of a quantum state.
        method (str): Plotting method to use.
        figsize (tuple): Figure size in inches.

    Raises:
        VisualizationError: if the input is not a statevector or density
        matrix, or if the state is not an multi-qubit quantum state.
    """
    warnings.warn(
        "iplot_state is deprecated, and will be removed in \
                  the 0.9 release. Use the iplot_state_ * functions \
                  instead.", DeprecationWarning)
    rho = _validate_input_state(quantum_state)
    if method == "city":
        iplot_state_city(rho, figsize=figsize)
    elif method == "paulivec":
        iplot_state_paulivec(rho, figsize=figsize)
    elif method == "qsphere":
        iplot_state_qsphere(rho, figsize=figsize)
    elif method == "bloch":
        iplot_bloch_multivector(rho, figsize=figsize)
    elif method == "hinton":
        iplot_state_hinton(rho, figsize=figsize)
    else:
        raise VisualizationError('Invalid plot state method.')
コード例 #2
0
def plot_state_paulivec(rho, title="", figsize=None, color=None):
    """Plot the paulivec representation of a quantum state.

    Plot a bargraph of the mixed state rho over the pauli matrices

    Args:
        rho (ndarray): Numpy array for state vector or density matrix
        title (str): a string that represents the plot title
        figsize (tuple): Figure size in inches.
        color (list or str): Color of the expectation value bars.
    Returns:
         matplotlib.Figure: The matplotlib.Figure of the visualization
    Raises:
        ImportError: Requires matplotlib.
    """
    if not HAS_MATPLOTLIB:
        raise ImportError('Must have Matplotlib installed.')
    rho = _validate_input_state(rho)
    if figsize is None:
        figsize = (7, 5)
    num = int(np.log2(len(rho)))
    labels = list(map(lambda x: x.to_label(), pauli_group(num)))
    values = list(
        map(lambda x: np.real(np.trace(np.dot(x.to_matrix(), rho))),
            pauli_group(num)))
    numelem = len(values)
    if color is None:
        color = "#648fff"

    ind = np.arange(numelem)  # the x locations for the groups
    width = 0.5  # the width of the bars
    fig, ax = plt.subplots(figsize=figsize)
    ax.grid(zorder=0, linewidth=1, linestyle='--')
    ax.bar(ind, values, width, color=color, zorder=2)
    ax.axhline(linewidth=1, color='k')
    # add some text for labels, title, and axes ticks
    ax.set_ylabel('Expectation value', fontsize=14)
    ax.set_xticks(ind)
    ax.set_yticks([-1, -0.5, 0, 0.5, 1])
    ax.set_xticklabels(labels, fontsize=14, rotation=70)
    ax.set_xlabel('Pauli', fontsize=14)
    ax.set_ylim([-1, 1])
    ax.set_facecolor('#eeeeee')
    for tick in ax.xaxis.get_major_ticks() + ax.yaxis.get_major_ticks():
        tick.label.set_fontsize(14)
    ax.set_title(title, fontsize=16)
    plt.close(fig)
    return fig
コード例 #3
0
def plot_bloch_multivector(rho, title='', figsize=None):
    """Plot the Bloch sphere.

    Plot a sphere, axes, the Bloch vector, and its projections onto each axis.

    Args:
        rho (ndarray): Numpy array for state vector or density matrix.
        title (str): a string that represents the plot title
        figsize (tuple): Has no effect, here for compatibility only.

    Returns:
        Figure: A matplotlib figure instance if `ax = None`.

    Raises:
        ImportError: Requires matplotlib.
    """
    if not HAS_MATPLOTLIB:
        raise ImportError('Must have Matplotlib installed.')
    rho = _validate_input_state(rho)
    num = int(np.log2(len(rho)))
    width, height = plt.figaspect(1 / num)
    fig = plt.figure(figsize=(width, height))
    for i in range(num):
        ax = fig.add_subplot(1, num, i + 1, projection='3d')
        pauli_singles = [
            Pauli.pauli_single(num, i, 'X'),
            Pauli.pauli_single(num, i, 'Y'),
            Pauli.pauli_single(num, i, 'Z')
        ]
        bloch_state = list(
            map(lambda x: np.real(np.trace(np.dot(x.to_matrix(), rho))),
                pauli_singles))
        plot_bloch_vector(bloch_state,
                          "qubit " + str(i),
                          ax=ax,
                          figsize=figsize)
    fig.suptitle(title, fontsize=16)
    plt.close(fig)
    return fig
コード例 #4
0
def plot_state(quantum_state, method='city', figsize=None):
    """Plot the quantum state.

    Args:
        quantum_state (ndarray): statevector or density matrix
                                 representation of a quantum state.
        method (str): Plotting method to use.
        figsize (tuple): Figure size in inches,

    Returns:
         matplotlib.Figure: The matplotlib.Figure of the visualization
    Raises:
        ImportError: Requires matplotlib.
        VisualizationError: if the input is not a statevector or density
        matrix, or if the state is not an multi-qubit quantum state.
    """
    if not HAS_MATPLOTLIB:
        raise ImportError('Must have Matplotlib installed.')
    warnings.warn(
        "plot_state is deprecated, and will be removed in \
                  the 0.9 release. Use the plot_state_ * functions \
                  instead.", DeprecationWarning)
    # Check if input is a statevector, and convert to density matrix
    rho = _validate_input_state(quantum_state)
    fig = None
    if method == 'city':
        fig = plot_state_city(rho, figsize=figsize)
    elif method == "paulivec":
        fig = plot_state_paulivec(rho, figsize=figsize)
    elif method == "qsphere":
        fig = plot_state_qsphere(rho, figsize=figsize)
    elif method == "bloch":
        plot_bloch_multivector(rho, figsize=figsize)
    elif method == "wigner":
        fig = plot_wigner_function(rho)
    elif method == "hinton":
        fig = plot_state_hinton(rho, figsize=figsize)
    return fig
コード例 #5
0
def iplot_bloch_multivector(rho, figsize=None):
    """ Create a bloch sphere representation.

        Graphical representation of the input array, using as much bloch
        spheres as qubit are required.

        Args:
            rho (array): State vector or density matrix
            figsize (tuple): Figure size in pixels.
    """

    # HTML
    html_template = Template("""
    <p>
        <div id="content_$divNumber" style="position: absolute; z-index: 1;">
            <div id="bloch_$divNumber"></div>
        </div>
    </p>
    """)

    # JavaScript
    javascript_template = Template("""
    <script>
        requirejs.config({
            paths: {
                qVisualization: "https://qvisualization.mybluemix.net/q-visualizations"
            }
        });
        data = $data;
        dataValues = [];
        for (var i = 0; i < data.length; i++) {
            // Coordinates
            var x = data[i][0];
            var y = data[i][1];
            var z = data[i][2];
            var point = {'x': x,
                        'y': y,
                        'z': z};
            dataValues.push(point);
        }

        require(["qVisualization"], function(qVisualizations) {
            // Plot figure
            qVisualizations.plotState("bloch_$divNumber",
                                      "bloch",
                                      dataValues,
                                      $options);
        });
    </script>
    """)
    rho = _validate_input_state(rho)
    if figsize is None:
        options = {}
    else:
        options = {'width': figsize[0], 'height': figsize[1]}

    # Process data and execute
    num = int(np.log2(len(rho)))

    bloch_data = []
    for i in range(num):
        pauli_singles = [
            Pauli.pauli_single(num, i, 'X'),
            Pauli.pauli_single(num, i, 'Y'),
            Pauli.pauli_single(num, i, 'Z')
        ]
        bloch_state = list(
            map(lambda x: np.real(np.trace(np.dot(x.to_matrix(), rho))),
                pauli_singles))
        bloch_data.append(bloch_state)

    div_number = str(time.time())
    div_number = re.sub('[.]', '', div_number)

    html = html_template.substitute({'divNumber': div_number})

    javascript = javascript_template.substitute({
        'data': bloch_data,
        'divNumber': div_number,
        'options': options
    })

    display(HTML(html + javascript))
コード例 #6
0
def iplot_state_hinton(rho, figsize=None):
    """ Create a hinton representation.

        Graphical representation of the input array using a 2D city style
        graph (hinton).

        Args:
            rho (array): Density matrix
            figsize (tuple): Figure size in inches.
    """

    # HTML
    html_template = Template("""
    <p>
        <div id="hinton_$divNumber"></div>
    </p>
    """)

    # JavaScript
    javascript_template = Template("""
    <script>
        requirejs.config({
            paths: {
                qVisualization: "https://qvisualization.mybluemix.net/q-visualizations"
            }
        });

        require(["qVisualization"], function(qVisualizations) {
            qVisualizations.plotState("hinton_$divNumber",
                                      "hinton",
                                      $executions,
                                      $options);
        });
    </script>
    """)
    rho = _validate_input_state(rho)
    if figsize is None:
        options = {}
    else:
        options = {'width': figsize[0], 'height': figsize[1]}

    # Process data and execute
    div_number = str(time.time())
    div_number = re.sub('[.]', '', div_number)

    # Process data and execute
    real = []
    imag = []
    for xvalue in rho:
        row_real = []
        col_imag = []

        for value_real in xvalue.real:
            row_real.append(float(value_real))
        real.append(row_real)

        for value_imag in xvalue.imag:
            col_imag.append(float(value_imag))
        imag.append(col_imag)

    html = html_template.substitute({'divNumber': div_number})

    javascript = javascript_template.substitute({
        'divNumber':
        div_number,
        'executions': [{
            'data': real
        }, {
            'data': imag
        }],
        'options':
        options
    })

    display(HTML(html + javascript))
コード例 #7
0
def iplot_state_qsphere(rho, figsize=None):
    """ Create a Q sphere representation.

        Graphical representation of the input array, using a Q sphere for each
        eigenvalue.

        Args:
            rho (array): State vector or density matrix.
            figsize (tuple): Figure size in pixels.
    """

    # HTML
    html_template = Template("""
    <p>
        <div id="content_$divNumber" style="position: absolute; z-index: 1;">
            <div id="qsphere_$divNumber"></div>
        </div>
    </p>
    """)

    # JavaScript
    javascript_template = Template("""
    <script>
        requirejs.config({
            paths: {
                qVisualization: "https://qvisualization.mybluemix.net/q-visualizations"
            }
        });
        require(["qVisualization"], function(qVisualizations) {
            data = $data;
            qVisualizations.plotState("qsphere_$divNumber",
                                      "qsphere",
                                      data,
                                      $options);
        });
    </script>

    """)
    rho = _validate_input_state(rho)
    if figsize is None:
        options = {}
    else:
        options = {'width': figsize[0], 'height': figsize[1]}

    qspheres_data = []
    # Process data and execute
    num = int(np.log2(len(rho)))

    # get the eigenvectors and eigenvalues
    weig, stateall = linalg.eigh(rho)

    for _ in range(2**num):
        # start with the max
        probmix = weig.max()
        prob_location = weig.argmax()
        if probmix > 0.001:
            # print("The " + str(k) + "th eigenvalue = " + str(probmix))
            # get the max eigenvalue
            state = stateall[:, prob_location]
            loc = np.absolute(state).argmax()
            # get the element location closes to lowest bin representation.
            for j in range(2**num):
                test = np.absolute(
                    np.absolute(state[j]) - np.absolute(state[loc]))
                if test < 0.001:
                    loc = j
                    break
            # remove the global phase
            angles = (np.angle(state[loc]) + 2 * np.pi) % (2 * np.pi)
            angleset = np.exp(-1j * angles)
            state = angleset * state
            state.flatten()

            spherepoints = []
            for i in range(2**num):
                # get x,y,z points

                element = bin(i)[2:].zfill(num)
                weight = element.count("1")

                number_of_divisions = n_choose_k(num, weight)
                weight_order = bit_string_index(element)

                angle = weight_order * 2 * np.pi / number_of_divisions

                zvalue = -2 * weight / num + 1
                xvalue = np.sqrt(1 - zvalue**2) * np.cos(angle)
                yvalue = np.sqrt(1 - zvalue**2) * np.sin(angle)

                # get prob and angle - prob will be shade and angle color
                prob = np.real(np.dot(state[i], state[i].conj()))
                angles = (np.angle(state[i]) + 2 * np.pi) % (2 * np.pi)
                qpoint = {
                    'x': xvalue,
                    'y': yvalue,
                    'z': zvalue,
                    'prob': prob,
                    'phase': angles
                }
                spherepoints.append(qpoint)

            # Associate all points to one sphere
            sphere = {'points': spherepoints, 'eigenvalue': probmix}

            # Add sphere to the spheres array
            qspheres_data.append(sphere)
            weig[prob_location] = 0

    div_number = str(time.time())
    div_number = re.sub('[.]', '', div_number)

    html = html_template.substitute({'divNumber': div_number})

    javascript = javascript_template.substitute({
        'data': qspheres_data,
        'divNumber': div_number,
        'options': options
    })

    display(HTML(html + javascript))
コード例 #8
0
def iplot_state_city(rho, figsize=None):
    """ Create a cities representation.

        Graphical representation of the input array using a city style graph.

        Args:
            rho (array): State vector or density matrix.
            figsize (tuple): The figure size in inches.
    """

    # HTML
    html_template = Template("""
    <p>
        <div id="content_$divNumber" style="position: absolute; z-index: 1;">
            <div id="cities_$divNumber"></div>
        </div>
    </p>
    """)

    # JavaScript
    javascript_template = Template("""
    <script>
        requirejs.config({
            paths: {
                qVisualization: "https://qvisualization.mybluemix.net/q-visualizations"
            }
        });

        require(["qVisualization"], function(qVisualizations) {
            data = {
                real: $real,
                titleReal: "Real.[rho]",
                imaginary: $imag,
                titleImaginary: "Im.[rho]",
                qbits: $qbits
            };
            qVisualizations.plotState("cities_$divNumber",
                                      "cities",
                                      data,
                                      $options);
        });
    </script>
    """)
    rho = _validate_input_state(rho)
    if figsize is None:
        options = {}
    else:
        options = {'width': figsize[0], 'height': figsize[1]}
    # Process data and execute
    real = []
    imag = []
    for xvalue in rho:
        row_real = []
        col_imag = []

        for value_real in xvalue.real:
            row_real.append(float(value_real))
        real.append(row_real)

        for value_imag in xvalue.imag:
            col_imag.append(float(value_imag))
        imag.append(col_imag)

    div_number = str(time.time())
    div_number = re.sub('[.]', '', div_number)

    html = html_template.substitute({'divNumber': div_number})

    javascript = javascript_template.substitute({
        'real': real,
        'imag': imag,
        'qbits': len(real),
        'divNumber': div_number,
        'options': options
    })

    display(HTML(html + javascript))
コード例 #9
0
def iplot_state_paulivec(rho, figsize=None, slider=False, show_legend=False):
    """ Create a paulivec representation.

        Graphical representation of the input array.

        Args:
            rho (array): State vector or density matrix.
            figsize (tuple): Figure size in pixels.
            slider (bool): activate slider
            show_legend (bool): show legend of graph content
    """

    # HTML
    html_template = Template("""
    <p>
        <div id="paulivec_$divNumber"></div>
    </p>
    """)

    # JavaScript
    javascript_template = Template("""
    <script>
        requirejs.config({
            paths: {
                qVisualization: "https://qvisualization.mybluemix.net/q-visualizations"
            }
        });

        require(["qVisualization"], function(qVisualizations) {
            qVisualizations.plotState("paulivec_$divNumber",
                                      "paulivec",
                                      $executions,
                                      $options);
        });
    </script>
    """)
    rho = _validate_input_state(rho)
    # set default figure size if none given
    if figsize is None:
        figsize = (7, 5)

    options = {
        'width': figsize[0],
        'height': figsize[1],
        'slider': int(slider),
        'show_legend': int(show_legend)
    }

    # Process data and execute
    div_number = str(time.time())
    div_number = re.sub('[.]', '', div_number)

    data_to_plot = []
    rho_data = process_data(rho)
    data_to_plot.append(dict(data=rho_data))

    html = html_template.substitute({'divNumber': div_number})

    javascript = javascript_template.substitute({
        'divNumber': div_number,
        'executions': data_to_plot,
        'options': options
    })

    display(HTML(html + javascript))
コード例 #10
0
def plot_state_hinton(rho, title='', figsize=None):
    """Plot a hinton diagram for the quanum state.

    Args:
        rho (ndarray): Numpy array for state vector or density matrix.
        title (str): a string that represents the plot title
        figsize (tuple): Figure size in inches.
    Returns:
         matplotlib.Figure: The matplotlib.Figure of the visualization

    Raises:
        ImportError: Requires matplotlib.
    """
    if not HAS_MATPLOTLIB:
        raise ImportError('Must have Matplotlib installed.')
    rho = _validate_input_state(rho)
    if figsize is None:
        figsize = (8, 5)
    num = int(np.log2(len(rho)))
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
    max_weight = 2**np.ceil(np.log(np.abs(rho).max()) / np.log(2))
    datareal = np.real(rho)
    dataimag = np.imag(rho)
    column_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
    row_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
    lx = len(datareal[0])  # Work out matrix dimensions
    ly = len(datareal[:, 0])
    # Real
    ax1.patch.set_facecolor('gray')
    ax1.set_aspect('equal', 'box')
    ax1.xaxis.set_major_locator(plt.NullLocator())
    ax1.yaxis.set_major_locator(plt.NullLocator())

    for (x, y), w in np.ndenumerate(datareal):
        color = 'white' if w > 0 else 'black'
        size = np.sqrt(np.abs(w) / max_weight)
        rect = plt.Rectangle([x - size / 2, y - size / 2],
                             size,
                             size,
                             facecolor=color,
                             edgecolor=color)
        ax1.add_patch(rect)

    ax1.set_xticks(np.arange(0, lx + 0.5, 1))
    ax1.set_yticks(np.arange(0, ly + 0.5, 1))
    ax1.set_yticklabels(row_names, fontsize=14)
    ax1.set_xticklabels(column_names, fontsize=14, rotation=90)
    ax1.autoscale_view()
    ax1.invert_yaxis()
    ax1.set_title('Real[rho]', fontsize=14)
    # Imaginary
    ax2.patch.set_facecolor('gray')
    ax2.set_aspect('equal', 'box')
    ax2.xaxis.set_major_locator(plt.NullLocator())
    ax2.yaxis.set_major_locator(plt.NullLocator())

    for (x, y), w in np.ndenumerate(dataimag):
        color = 'white' if w > 0 else 'black'
        size = np.sqrt(np.abs(w) / max_weight)
        rect = plt.Rectangle([x - size / 2, y - size / 2],
                             size,
                             size,
                             facecolor=color,
                             edgecolor=color)
        ax2.add_patch(rect)
    if np.any(dataimag != 0):
        ax2.set_xticks(np.arange(0, lx + 0.5, 1))
        ax2.set_yticks(np.arange(0, ly + 0.5, 1))
        ax2.set_yticklabels(row_names, fontsize=14)
        ax2.set_xticklabels(column_names, fontsize=14, rotation=90)
    ax2.autoscale_view()
    ax2.invert_yaxis()
    ax2.set_title('Imag[rho]', fontsize=14)
    if title:
        fig.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.close(fig)
    return fig
コード例 #11
0
def plot_state_qsphere(rho, figsize=None):
    """Plot the qsphere representation of a quantum state.

    Args:
        rho (ndarray): State vector or density matrix representation.
        of quantum state.
        figsize (tuple): Figure size in inches.

    Returns:
        Figure: A matplotlib figure instance.

    Raises:
        ImportError: Requires matplotlib.
    """
    if not HAS_MATPLOTLIB:
        raise ImportError('Must have Matplotlib installed.')
    rho = _validate_input_state(rho)
    if figsize is None:
        figsize = (7, 7)
    num = int(np.log2(len(rho)))
    # get the eigenvectors and eigenvalues
    we, stateall = linalg.eigh(rho)
    for _ in range(2**num):
        # start with the max
        probmix = we.max()
        prob_location = we.argmax()
        if probmix > 0.001:
            # get the max eigenvalue
            state = stateall[:, prob_location]
            loc = np.absolute(state).argmax()
            # get the element location closes to lowest bin representation.
            for j in range(2**num):
                test = np.absolute(
                    np.absolute(state[j]) - np.absolute(state[loc]))
                if test < 0.001:
                    loc = j
                    break
            # remove the global phase
            angles = (np.angle(state[loc]) + 2 * np.pi) % (2 * np.pi)
            angleset = np.exp(-1j * angles)
            # print(state)
            # print(angles)
            state = angleset * state
            # print(state)
            state.flatten()
            # start the plotting
            fig = plt.figure(figsize=figsize)
            ax = fig.add_subplot(111, projection='3d')
            ax.axes.set_xlim3d(-1.0, 1.0)
            ax.axes.set_ylim3d(-1.0, 1.0)
            ax.axes.set_zlim3d(-1.0, 1.0)
            ax.set_aspect("equal")
            ax.axes.grid(False)
            # Plot semi-transparent sphere
            u = np.linspace(0, 2 * np.pi, 25)
            v = np.linspace(0, np.pi, 25)
            x = np.outer(np.cos(u), np.sin(v))
            y = np.outer(np.sin(u), np.sin(v))
            z = np.outer(np.ones(np.size(u)), np.cos(v))
            ax.plot_surface(x,
                            y,
                            z,
                            rstride=1,
                            cstride=1,
                            color='k',
                            alpha=0.05,
                            linewidth=0)
            # wireframe
            # Get rid of the panes
            ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
            ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
            ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))

            # Get rid of the spines
            ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
            ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
            ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
            # Get rid of the ticks
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_zticks([])

            d = num
            for i in range(2**num):
                # get x,y,z points
                element = bin(i)[2:].zfill(num)
                weight = element.count("1")
                zvalue = -2 * weight / d + 1
                number_of_divisions = n_choose_k(d, weight)
                weight_order = bit_string_index(element)
                # if weight_order >= number_of_divisions / 2:
                #    com_key = compliment(element)
                #    weight_order_temp = bit_string_index(com_key)
                #    weight_order = np.floor(
                #        number_of_divisions / 2) + weight_order_temp + 1
                angle = weight_order * 2 * np.pi / number_of_divisions
                xvalue = np.sqrt(1 - zvalue**2) * np.cos(angle)
                yvalue = np.sqrt(1 - zvalue**2) * np.sin(angle)
                ax.plot([xvalue], [yvalue], [zvalue],
                        markerfacecolor=(.5, .5, .5),
                        markeredgecolor=(.5, .5, .5),
                        marker='o',
                        markersize=10,
                        alpha=1)
                # get prob and angle - prob will be shade and angle color
                prob = np.real(np.dot(state[i], state[i].conj()))
                colorstate = phase_to_color_wheel(state[i])
                a = Arrow3D([0, xvalue], [0, yvalue], [0, zvalue],
                            mutation_scale=20,
                            alpha=prob,
                            arrowstyle="-",
                            color=colorstate,
                            lw=10)
                ax.add_artist(a)
            # add weight lines
            for weight in range(d + 1):
                theta = np.linspace(-2 * np.pi, 2 * np.pi, 100)
                z = -2 * weight / d + 1
                r = np.sqrt(1 - z**2)
                x = r * np.cos(theta)
                y = r * np.sin(theta)
                ax.plot(x, y, z, color=(.5, .5, .5))
            # add center point
            ax.plot([0], [0], [0],
                    markerfacecolor=(.5, .5, .5),
                    markeredgecolor=(.5, .5, .5),
                    marker='o',
                    markersize=10,
                    alpha=1)
            we[prob_location] = 0
        else:
            break
    plt.tight_layout()
    plt.close(fig)
    return fig
コード例 #12
0
def plot_state_city(rho, title="", figsize=None, color=None):
    """Plot the cityscape of quantum state.

    Plot two 3d bar graphs (two dimensional) of the real and imaginary
    part of the density matrix rho.

    Args:
        rho (ndarray): Numpy array for state vector or density matrix.
        title (str): a string that represents the plot title
        figsize (tuple): Figure size in inches.
        color (list): A list of len=2 giving colors for real and
        imaginary components of matrix elements.
    Returns:
         matplotlib.Figure: The matplotlib.Figure of the visualization

    Raises:
        ImportError: Requires matplotlib.
        ValueError: When 'color' is not a list of len=2.
    """
    if not HAS_MATPLOTLIB:
        raise ImportError('Must have Matplotlib installed.')
    rho = _validate_input_state(rho)

    num = int(np.log2(len(rho)))
    # get the real and imag parts of rho
    datareal = np.real(rho)
    dataimag = np.imag(rho)

    # get the labels
    column_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
    row_names = [bin(i)[2:].zfill(num) for i in range(2**num)]

    lx = len(datareal[0])  # Work out matrix dimensions
    ly = len(datareal[:, 0])
    xpos = np.arange(0, lx, 1)  # Set up a mesh of positions
    ypos = np.arange(0, ly, 1)
    xpos, ypos = np.meshgrid(xpos + 0.25, ypos + 0.25)

    xpos = xpos.flatten()
    ypos = ypos.flatten()
    zpos = np.zeros(lx * ly)

    dx = 0.5 * np.ones_like(zpos)  # width of bars
    dy = dx.copy()
    dzr = datareal.flatten()
    dzi = dataimag.flatten()

    if color is None:
        color = ["#648fff", "#648fff"]
    else:
        if len(color) != 2:
            raise ValueError("'color' must be a list of len=2.")
        if color[0] is None:
            color[0] = "#648fff"
        if color[1] is None:
            color[1] = "#648fff"

    # set default figure size
    if figsize is None:
        figsize = (15, 5)

    fig = plt.figure(figsize=figsize)
    ax1 = fig.add_subplot(1, 2, 1, projection='3d')
    ax1.bar3d(xpos, ypos, zpos, dx, dy, dzr, color=color[0], alpha=0.5)
    ax2 = fig.add_subplot(1, 2, 2, projection='3d')
    ax2.bar3d(xpos, ypos, zpos, dx, dy, dzi, color=color[1], alpha=0.5)

    ax1.set_xticks(np.arange(0.5, lx + 0.5, 1))
    ax1.set_yticks(np.arange(0.5, ly + 0.5, 1))
    ax1.axes.set_zlim3d(np.min(dzr), np.max(dzr) + 1e-9)
    ax1.zaxis.set_major_locator(MaxNLocator(5))
    ax1.w_xaxis.set_ticklabels(row_names, fontsize=14, rotation=45)
    ax1.w_yaxis.set_ticklabels(column_names, fontsize=14, rotation=-22.5)
    ax1.set_zlabel("Real[rho]", fontsize=14)
    for tick in ax1.zaxis.get_major_ticks():
        tick.label.set_fontsize(14)

    ax2.set_xticks(np.arange(0.5, lx + 0.5, 1))
    ax2.set_yticks(np.arange(0.5, ly + 0.5, 1))
    if np.min(dzi) != np.max(dzi):
        eps = 0
        ax2.zaxis.set_major_locator(MaxNLocator(5))
    else:
        ax2.set_zticks([0])
        eps = 1e-9
    ax2.axes.set_zlim3d(np.min(dzi), np.max(dzi) + eps)
    ax2.w_xaxis.set_ticklabels(row_names, fontsize=14, rotation=45)
    ax2.w_yaxis.set_ticklabels(column_names, fontsize=14, rotation=-22.5)
    # ax2.set_xlabel('basis state', fontsize=12)
    # ax2.set_ylabel('basis state', fontsize=12)
    ax2.set_zlabel("Imag[rho]", fontsize=14)
    for tick in ax2.zaxis.get_major_ticks():
        tick.label.set_fontsize(14)
    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.close(fig)
    return fig
コード例 #13
0
def plot_state_city(rho, title="", figsize=None, color=None, alpha=1):
    """Plot the cityscape of quantum state.

    Plot two 3d bar graphs (two dimensional) of the real and imaginary
    part of the density matrix rho.

    Args:
        rho (ndarray): Numpy array for state vector or density matrix.
        title (str): a string that represents the plot title
        figsize (tuple): Figure size in inches.
        color (list): A list of len=2 giving colors for real and
        imaginary components of matrix elements.
        alpha (float): Transparency value for bars
    Returns:
         matplotlib.Figure: The matplotlib.Figure of the visualization

    Raises:
        ImportError: Requires matplotlib.
        ValueError: When 'color' is not a list of len=2.
    """
    if not HAS_MATPLOTLIB:
        raise ImportError('Must have Matplotlib installed.')
    rho = _validate_input_state(rho)

    num = int(np.log2(len(rho)))
    # get the real and imag parts of rho
    datareal = np.real(rho)
    dataimag = np.imag(rho)

    # get the labels
    column_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
    row_names = [bin(i)[2:].zfill(num) for i in range(2**num)]

    lx = len(datareal[0])  # Work out matrix dimensions
    ly = len(datareal[:, 0])
    xpos = np.arange(0, lx, 1)  # Set up a mesh of positions
    ypos = np.arange(0, ly, 1)
    xpos, ypos = np.meshgrid(xpos + 0.25, ypos + 0.25)

    xpos = xpos.flatten()
    ypos = ypos.flatten()
    zpos = np.zeros(lx * ly)

    dx = 0.5 * np.ones_like(zpos)  # width of bars
    dy = dx.copy()
    dzr = datareal.flatten()
    dzi = dataimag.flatten()

    if color is None:
        color = ["#648fff", "#648fff"]
    else:
        if len(color) != 2:
            raise ValueError("'color' must be a list of len=2.")
        if color[0] is None:
            color[0] = "#648fff"
        if color[1] is None:
            color[1] = "#648fff"

    # set default figure size
    if figsize is None:
        figsize = (15, 5)

    fig = plt.figure(figsize=figsize)
    ax1 = fig.add_subplot(1, 2, 1, projection='3d')

    x = [0, max(xpos) + 0.5, max(xpos) + 0.5, 0]
    y = [0, 0, max(ypos) + 0.5, max(ypos) + 0.5]
    z = [0, 0, 0, 0]
    verts = [list(zip(x, y, z))]

    fc1 = generate_facecolors(xpos, ypos, zpos, dx, dy, dzr, color[0])
    for idx, cur_zpos in enumerate(zpos):
        if dzr[idx] > 0:
            zorder = 2
        else:
            zorder = 0
        b1 = ax1.bar3d(xpos[idx],
                       ypos[idx],
                       cur_zpos,
                       dx[idx],
                       dy[idx],
                       dzr[idx],
                       alpha=alpha,
                       zorder=zorder)
        b1.set_facecolors(fc1[6 * idx:6 * idx + 6])

    pc1 = Poly3DCollection(verts,
                           alpha=0.15,
                           facecolor='k',
                           linewidths=1,
                           zorder=1)

    if min(dzr) < 0 < max(dzr):
        ax1.add_collection3d(pc1)

    ax2 = fig.add_subplot(1, 2, 2, projection='3d')
    fc2 = generate_facecolors(xpos, ypos, zpos, dx, dy, dzi, color[1])
    for idx, cur_zpos in enumerate(zpos):
        if dzi[idx] > 0:
            zorder = 2
        else:
            zorder = 0
        b2 = ax2.bar3d(xpos[idx],
                       ypos[idx],
                       cur_zpos,
                       dx[idx],
                       dy[idx],
                       dzi[idx],
                       alpha=alpha,
                       zorder=zorder)
        b2.set_facecolors(fc2[6 * idx:6 * idx + 6])

    pc2 = Poly3DCollection(verts,
                           alpha=0.2,
                           facecolor='k',
                           linewidths=1,
                           zorder=1)

    if min(dzi) < 0 < max(dzi):
        ax2.add_collection3d(pc2)
    ax1.set_xticks(np.arange(0.5, lx + 0.5, 1))
    ax1.set_yticks(np.arange(0.5, ly + 0.5, 1))
    max_dzr = max(dzr)
    min_dzr = min(dzr)
    if max_dzr != min_dzr:
        ax1.axes.set_zlim3d(np.min(dzr), np.max(dzr) + 1e-9)
    else:
        if min_dzr == 0:
            ax1.axes.set_zlim3d(np.min(dzr), np.max(dzr) + 1e-9)
        else:
            ax1.axes.set_zlim3d(auto=True)
    ax1.zaxis.set_major_locator(MaxNLocator(5))
    ax1.w_xaxis.set_ticklabels(row_names, fontsize=14, rotation=45)
    ax1.w_yaxis.set_ticklabels(column_names, fontsize=14, rotation=-22.5)
    ax1.set_zlabel("Real[rho]", fontsize=14)
    for tick in ax1.zaxis.get_major_ticks():
        tick.label.set_fontsize(14)

    ax2.set_xticks(np.arange(0.5, lx + 0.5, 1))
    ax2.set_yticks(np.arange(0.5, ly + 0.5, 1))
    min_dzi = np.min(dzi)
    max_dzi = np.max(dzi)
    if min_dzi != max_dzi:
        eps = 0
        ax2.zaxis.set_major_locator(MaxNLocator(5))
        ax2.axes.set_zlim3d(np.min(dzi), np.max(dzi) + eps)
    else:
        if min_dzi == 0:
            ax2.set_zticks([0])
            eps = 1e-9
            ax2.axes.set_zlim3d(np.min(dzi), np.max(dzi) + eps)
        else:
            ax2.axes.set_zlim3d(auto=True)
    ax2.w_xaxis.set_ticklabels(row_names, fontsize=14, rotation=45)
    ax2.w_yaxis.set_ticklabels(column_names, fontsize=14, rotation=-22.5)
    ax2.set_zlabel("Imag[rho]", fontsize=14)
    for tick in ax2.zaxis.get_major_ticks():
        tick.label.set_fontsize(14)
    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.close(fig)
    return fig