예제 #1
0
 def test_params(self):
     """Test pi_check with Parameters"""
     x = Parameter('x')
     y = Parameter('y')
     z = Parameter('z')
     pcase = (pi**3 * x + 3 / (4 * pi) * y - 13 * pi / 2 * z,
              'pi**3*x + 3/4pi*y - 13pi/2*z')
     input_number = pcase[0]
     expected_string = pcase[1]
     result = pi_check(input_number)
     self.assertEqual(result, expected_string)
예제 #2
0
 def test_params(self):
     """Test pi_check with Parameters"""
     x = Parameter("x")
     y = Parameter("y")
     z = Parameter("z")
     pcase = (pi**3 * x + 3 / (4 * pi) * y - 13 * pi / 2 * z,
              "π**3*x + 3/4π*y - 13π/2*z")
     input_number = pcase[0]
     expected_string = pcase[1]
     result = pi_check(input_number)
     self.assertEqual(result, expected_string)
예제 #3
0
    def latex(self):
        """Return LaTeX string representation of circuit."""

        self._initialize_latex_array()
        self._build_latex_array()
        header_1 = r"""% \documentclass[preview]{standalone}
% If the image is too large to fit on this documentclass use
\documentclass[draft]{beamer}
"""
        beamer_line = "\\usepackage[size=custom,height=%d,width=%d,scale=%.1f]{beamerposter}\n"
        header_2 = r"""% instead and customize the height and width (in cm) to fit.
% Large images may run out of memory quickly.
% To fix this use the LuaLaTeX compiler, which dynamically
% allocates memory.
\usepackage[braket, qm]{qcircuit}
\usepackage{amsmath}
\pdfmapfile{+sansmathaccent.map}
% \usepackage[landscape]{geometry}
% Comment out the above line if using the beamer documentclass.
\begin{document}
"""
        qcircuit_line = r"""
\begin{equation*}
    \Qcircuit @C=%.1fem @R=%.1fem @!R {
"""
        output = io.StringIO()
        output.write(header_1)
        output.write("%% img_width = %d, img_depth = %d\n" %
                     (self.img_width, self.img_depth))
        output.write(beamer_line % self._get_beamer_page())
        output.write(header_2)
        if self.global_phase:
            output.write(r"""{$\mathrm{%s} \mathrm{%s}$}""" %
                         ("global\\,phase:\\,",
                          pi_check(self.global_phase, output="latex")))
        output.write(qcircuit_line %
                     (self.column_separation, self.wire_separation))
        for i in range(self.img_width):
            output.write("\t \t")
            for j in range(self.img_depth + 1):
                output.write(self._latex[i][j])
                if j != self.img_depth:
                    output.write(" & ")
                else:
                    output.write(r"\\" + "\n")
        output.write("\t }\n")
        output.write("\\end{equation*}\n\n")
        output.write("\\end{document}")
        contents = output.getvalue()
        output.close()
        return contents
예제 #4
0
    def param_parse(v):
        # create an empty list to store the parameters in
        param_parts = [None] * len(v)
        for i, e in enumerate(v):
            try:
                param_parts[i] = pi_check(e, output='mpl', ndigits=3)
            except TypeError:
                param_parts[i] = str(e)

            if param_parts[i].startswith('-'):
                param_parts[i] = '$-$' + param_parts[i][1:]

        param_parts = ', '.join(param_parts)
        return param_parts
예제 #5
0
    def _add_params_to_gate_text(self, op, gate_text):
        """Add the params to the end of the current gate_text"""

        # Must limit to 4 params or may get dimension too large error
        # from xy-pic xymatrix command
        if (len(op.op.params) > 0 and not any(
                isinstance(param, np.ndarray) for param in op.op.params)):
            gate_text += "\\,\\mathrm{(}"
            for param_count, param in enumerate(op.op.params):
                if param_count > 3:
                    gate_text += "...,"
                    break
                gate_text += "\\mathrm{%s}," % pi_check(param, output='latex', ndigits=4)
            gate_text = gate_text[:-1] + "\\mathrm{)}"
        return gate_text
예제 #6
0
    def params_for_label(instruction):
        """Get the params and format them to add them to a label. None if there
         are no params or if the params are numpy.ndarrays."""
        op = instruction.op
        if not hasattr(op, 'params'):
            return None
        if any(isinstance(param, ndarray) for param in op.params):
            return None

        ret = []
        for param in op.params:
            try:
                str_param = pi_check(param, ndigits=5)
                ret.append('%s' % str_param)
            except TypeError:
                ret.append('%s' % param)
        return ret
예제 #7
0
    def param_parse(self, v):
        # create an empty list to store the parameters in
        param_parts = [None] * len(v)
        for i, e in enumerate(v):
            try:
                param_parts[i] = pi_check(e, output='mpl', ndigits=3)
            except TypeError:
                param_parts[i] = str(e)

            if param_parts[i].startswith('-'):
                param_parts[i] = '$-$' + param_parts[i][1:]

        param_parts = ', '.join(param_parts)
        # remove $'s since "${}$".format will add them back on the outside
        param_parts = param_parts.replace('$', '')
        param_parts = param_parts.replace('-', u'\u02d7')
        return param_parts
예제 #8
0
    def latex(self):
        """Return LaTeX string representation of circuit."""

        self._initialize_latex_array()
        self._build_latex_array()
        header_1 = r"""\documentclass[border=2px]{standalone}
        """

        header_2 = r"""
\usepackage[braket, qm]{qcircuit}
\usepackage{graphicx}

\begin{document} 
"""
        header_scale = "\\scalebox{{{}}}".format(self.scale) + "{"

        qcircuit_line = r"""
\Qcircuit @C=%.1fem @R=%.1fem @!R { \\
"""
        output = io.StringIO()
        output.write(header_1)
        output.write(header_2)
        output.write(header_scale)
        if self.global_phase:
            output.write(r"""{$\mathrm{%s} \mathrm{%s}$}""" %
                         ("global\\,phase:\\,",
                          pi_check(self.global_phase, output="latex")))
        output.write(qcircuit_line %
                     (self.column_separation, self.wire_separation))
        for i in range(self.img_width):
            output.write("\t \t")
            for j in range(self.img_depth + 1):
                output.write(self._latex[i][j])
                if j != self.img_depth:
                    output.write(" & ")
                else:
                    output.write(r"\\ " + "\n")
        output.write(r"\\ " + "}}\n")
        output.write("\\end{document}")
        contents = output.getvalue()
        output.close()
        return contents
def plot_state_qsphere(
    state,
    figsize=None,
    ax=None,
    show_state_labels=True,
    show_state_phases=False,
    use_degrees=False,
    *,
    rho=None,
    filename=None,
):
    """Plot the qsphere representation of a quantum state.
    Here, the size of the points is proportional to the probability
    of the corresponding term in the state and the color represents
    the phase.

    Args:
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
        figsize (tuple): Figure size in inches.
        ax (matplotlib.axes.Axes): An optional Axes object to be used for
            the visualization output. If none is specified a new matplotlib
            Figure will be created and used. Additionally, if specified there
            will be no returned Figure since it is redundant.
        show_state_labels (bool): An optional boolean indicating whether to
            show labels for each basis state.
        show_state_phases (bool): An optional boolean indicating whether to
            show the phase for each basis state.
        use_degrees (bool): An optional boolean indicating whether to use
            radians or degrees for the phase values in the plot.

    Returns:
        Figure: A matplotlib figure instance if the ``ax`` kwarg is not set

    Raises:
        MissingOptionalLibraryError: Requires matplotlib.
        VisualizationError: if input is not a valid N-qubit state.

        QiskitError: Input statevector does not have valid dimensions.

    Example:
        .. jupyter-execute::

           from qiskit import QuantumCircuit
           from qiskit.quantum_info import Statevector
           from qiskit.visualization import plot_state_qsphere
           %matplotlib inline

           qc = QuantumCircuit(2)
           qc.h(0)
           qc.cx(0, 1)

           state = Statevector.from_instruction(qc)
           plot_state_qsphere(state)
    """
    if not HAS_MATPLOTLIB:
        raise MissingOptionalLibraryError(
            libname="Matplotlib",
            name="plot_state_qsphere",
            pip_install="pip install matplotlib",
        )

    import matplotlib.gridspec as gridspec
    from matplotlib import pyplot as plt
    from matplotlib.patches import Circle
    from qiskit.visualization.bloch import Arrow3D

    try:
        import seaborn as sns
    except ImportError as ex:
        raise MissingOptionalLibraryError(
            libname="seaborn",
            name="plot_state_qsphere",
            pip_install="pip install seaborn",
        ) from ex
    rho = DensityMatrix(state)
    num = rho.num_qubits
    if num is None:
        raise VisualizationError("Input is not a multi-qubit quantum state.")
    # get the eigenvectors and eigenvalues
    eigvals, eigvecs = linalg.eigh(rho.data)

    if figsize is None:
        figsize = (7, 7)

    if ax is None:
        return_fig = True
        fig = plt.figure(figsize=figsize)
    else:
        return_fig = False
        fig = ax.get_figure()

    gs = gridspec.GridSpec(nrows=3, ncols=3)

    ax = fig.add_subplot(gs[0:3, 0:3], projection="3d")
    ax.axes.set_xlim3d(-1.0, 1.0)
    ax.axes.set_ylim3d(-1.0, 1.0)
    ax.axes.set_zlim3d(-1.0, 1.0)
    ax.axes.grid(False)
    ax.view_init(elev=5, azim=275)

    # Force aspect ratio
    # MPL 3.2 or previous do not have set_box_aspect
    if hasattr(ax.axes, "set_box_aspect"):
        ax.axes.set_box_aspect((1, 1, 1))

    # start the plotting
    # Plot semi-transparent sphere
    u = np.linspace(0, 2 * np.pi, 25)
    v = np.linspace(0, np.pi, 25)
    x = np.outer(np.cos(u), np.sin(v))
    y = np.outer(np.sin(u), np.sin(v))
    z = np.outer(np.ones(np.size(u)), np.cos(v))
    ax.plot_surface(
        x, y, z, rstride=1, cstride=1, color=plt.rcParams["grid.color"], alpha=0.2, linewidth=0
    )

    # Get rid of the panes
    ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))

    # Get rid of the spines
    ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
    ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
    ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))

    # Get rid of the ticks
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])

    # traversing the eigvals/vecs backward as sorted low->high
    for idx in range(eigvals.shape[0] - 1, -1, -1):
        if eigvals[idx] > 0.001:
            # get the max eigenvalue
            state = eigvecs[:, idx]
            loc = np.absolute(state).argmax()
            # remove the global phase from max element
            angles = (np.angle(state[loc]) + 2 * np.pi) % (2 * np.pi)
            angleset = np.exp(-1j * angles)
            state = angleset * state

            d = num
            for i in range(2 ** num):
                # get x,y,z points
                element = bin(i)[2:].zfill(num)
                weight = element.count("1")
                zvalue = -2 * weight / d + 1
                number_of_divisions = n_choose_k(d, weight)
                weight_order = bit_string_index(element)
                angle = (float(weight) / d) * (np.pi * 2) + (
                    weight_order * 2 * (np.pi / number_of_divisions)
                )

                if (weight > d / 2) or (
                    (weight == d / 2) and (weight_order >= number_of_divisions / 2)
                ):
                    angle = np.pi - angle - (2 * np.pi / number_of_divisions)

                xvalue = np.sqrt(1 - zvalue ** 2) * np.cos(angle)
                yvalue = np.sqrt(1 - zvalue ** 2) * np.sin(angle)

                # get prob and angle - prob will be shade and angle color
                prob = np.real(np.dot(state[i], state[i].conj()))
                prob = min(prob, 1)  # See https://github.com/Qiskit/qiskit-terra/issues/4666
                colorstate = phase_to_rgb(state[i])

                alfa = 1
                if yvalue >= 0.1:
                    alfa = 1.0 - yvalue

                if not np.isclose(prob, 0) and show_state_labels:
                    rprime = 1.3
                    angle_theta = np.arctan2(np.sqrt(1 - zvalue ** 2), zvalue)
                    xvalue_text = rprime * np.sin(angle_theta) * np.cos(angle)
                    yvalue_text = rprime * np.sin(angle_theta) * np.sin(angle)
                    zvalue_text = rprime * np.cos(angle_theta)
                    element_text = "$\\vert" + element + "\\rangle$"
                    if show_state_phases:
                        element_angle = (np.angle(state[i]) + (np.pi * 4)) % (np.pi * 2)
                        if use_degrees:
                            element_text += "\n$%.1f^\\circ$" % (element_angle * 180 / np.pi)
                        else:
                            element_angle = pi_check(element_angle, ndigits=3).replace("pi", "\\pi")
                            element_text += "\n$%s$" % (element_angle)
                    ax.text(
                        xvalue_text,
                        yvalue_text,
                        zvalue_text,
                        element_text,
                        ha="center",
                        va="center",
                        size=12,
                    )

                ax.plot(
                    [xvalue],
                    [yvalue],
                    [zvalue],
                    markerfacecolor=colorstate,
                    markeredgecolor=colorstate,
                    marker="o",
                    markersize=np.sqrt(prob) * 30,
                    alpha=alfa,
                )

                a = Arrow3D(
                    [0, xvalue],
                    [0, yvalue],
                    [0, zvalue],
                    mutation_scale=20,
                    alpha=prob,
                    arrowstyle="-",
                    color=colorstate,
                    lw=2,
                )
                ax.add_artist(a)

            # add weight lines
            for weight in range(d + 1):
                theta = np.linspace(-2 * np.pi, 2 * np.pi, 100)
                z = -2 * weight / d + 1
                r = np.sqrt(1 - z ** 2)
                x = r * np.cos(theta)
                y = r * np.sin(theta)
                ax.plot(x, y, z, color=(0.5, 0.5, 0.5), lw=1, ls=":", alpha=0.5)

            # add center point
            ax.plot(
                [0],
                [0],
                [0],
                markerfacecolor=(0.5, 0.5, 0.5),
                markeredgecolor=(0.5, 0.5, 0.5),
                marker="o",
                markersize=3,
                alpha=1,
            )
        else:
            break

    n = 64
    theta = np.ones(n)
    colors = sns.hls_palette(n)

    ax2 = fig.add_subplot(gs[2:, 2:])
    ax2.pie(theta, colors=colors[5 * n // 8 :] + colors[: 5 * n // 8], radius=0.75)
    ax2.add_artist(Circle((0, 0), 0.5, color="white", zorder=1))
    offset = 0.95  # since radius of sphere is one.

    if use_degrees:
        labels = ["Phase\n(Deg)", "0", "90", "180   ", "270"]
    else:
        labels = ["Phase", "$0$", "$\\pi/2$", "$\\pi$", "$3\\pi/2$"]

    ax2.text(0, 0, labels[0], horizontalalignment="center", verticalalignment="center", fontsize=14)
    ax2.text(
        offset, 0, labels[1], horizontalalignment="center", verticalalignment="center", fontsize=14
    )
    ax2.text(
        0, offset, labels[2], horizontalalignment="center", verticalalignment="center", fontsize=14
    )
    ax2.text(
        -offset, 0, labels[3], horizontalalignment="center", verticalalignment="center", fontsize=14
    )
    ax2.text(
        0, -offset, labels[4], horizontalalignment="center", verticalalignment="center", fontsize=14
    )

    if return_fig:
        matplotlib_close_if_inline(fig)
    if filename is None:
        return fig
    else:
        return fig.savefig(filename)
 def test_default(self, case):
     """Default pi_check({case[0]})='{case[1]}'"""
     input_number = case[0]
     expected_string = case[1]
     result = pi_check(input_number)
     self.assertEqual(result, expected_string)
예제 #11
0
    def lines(self, line_length=None):
        """Generates a list with lines. These lines form the text drawing.

        Args:
            line_length (int): Optional. Breaks the circuit drawing to this length. This
                               useful when the drawing does not fit in the console. If
                               None (default), it will try to guess the console width using
                               shutil.get_terminal_size(). If you don't want pagination
                               at all, set line_length=-1.

        Returns:
            list: A list of lines with the text drawing.
        """
        if line_length is None:
            line_length = self.line_length
        if not line_length:
            if ('ipykernel' in sys.modules) and ('spyder' not in sys.modules):
                line_length = 80
            else:
                line_length, _ = get_terminal_size()

        noqubits = len(self.qregs)

        try:
            layers = self.build_layers()
        except TextDrawerCregBundle:
            self.cregbundle = False
            warn('The parameter "cregbundle" was disable, since an instruction needs to refer to '
                 'individual classical wires', RuntimeWarning, 2)
            layers = self.build_layers()

        layer_groups = [[]]
        rest_of_the_line = line_length
        for layerno, layer in enumerate(layers):
            # Replace the Nones with EmptyWire
            layers[layerno] = EmptyWire.fillup_layer(layer, noqubits)

            TextDrawing.normalize_width(layer)

            if line_length == -1:
                # Do not use pagination (aka line breaking. aka ignore line_length).
                layer_groups[-1].append(layer)
                continue

            # chop the layer to the line_length (pager)
            layer_length = layers[layerno][0].length

            if layer_length < rest_of_the_line:
                layer_groups[-1].append(layer)
                rest_of_the_line -= layer_length
            else:
                layer_groups[-1].append(BreakWire.fillup_layer(len(layer), '»'))

                # New group
                layer_groups.append([BreakWire.fillup_layer(len(layer), '«')])
                rest_of_the_line = line_length - layer_groups[-1][-1][0].length

                layer_groups[-1].append(
                    InputWire.fillup_layer(self.wire_names(with_initial_state=False)))
                rest_of_the_line -= layer_groups[-1][-1][0].length

                layer_groups[-1].append(layer)
                rest_of_the_line -= layer_groups[-1][-1][0].length

        lines = []

        if self.global_phase:
            lines.append('global phase: %s' % pi_check(self.global_phase,
                                                       ndigits=5))

        for layer_group in layer_groups:
            wires = list(zip(*layer_group))
            lines += self.draw_wires(wires)

        return lines
예제 #12
0
def plot_state_qsphere(state,
                       figsize=None,
                       ax=None,
                       show_state_labels=True,
                       show_state_phases=False,
                       use_degrees=False,
                       *,
                       rho=None):
    """Plot the qsphere representation of a quantum state.
    Here, the size of the points is proportional to the probability
    of the corresponding term in the state and the color represents
    the phase.

    Args:
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
        figsize (tuple): Figure size in inches.
        ax (matplotlib.axes.Axes): An optional Axes object to be used for
            the visualization output. If none is specified a new matplotlib
            Figure will be created and used. Additionally, if specified there
            will be no returned Figure since it is redundant.
        show_state_labels (bool): An optional boolean indicating whether to
            show labels for each basis state.
        show_state_phases (bool): An optional boolean indicating whether to
            show the phase for each basis state.
        use_degrees (bool): An optional boolean indicating whether to use
            radians or degrees for the phase values in the plot.

    Returns:
        Figure: A matplotlib figure instance if the ``ax`` kwag is not set

    Raises:
        ImportError: Requires matplotlib.
        VisualizationError: if input is not a valid N-qubit state.

        QiskitError: Input statevector does not have valid dimensions.

    Example:
        .. jupyter-execute::

           from qiskit import QuantumCircuit
           from qiskit.quantum_info import Statevector
           from qiskit.visualization import plot_state_qsphere
           %matplotlib inline

           qc = QuantumCircuit(2)
           qc.h(0)
           qc.cx(0, 1)

           state = Statevector.from_instruction(qc)
           plot_state_qsphere(state)
    """
    if not HAS_MATPLOTLIB:
        raise ImportError(
            'Must have Matplotlib installed. To install, run "pip install '
            'matplotlib".')
    try:
        import seaborn as sns
    except ImportError:
        raise ImportError(
            'Must have seaborn installed to use '
            'plot_state_qsphere. To install, run "pip install seaborn".')
    rho = DensityMatrix(state)
    num = rho.num_qubits
    if num is None:
        raise VisualizationError("Input is not a multi-qubit quantum state.")
    # get the eigenvectors and eigenvalues
    eigvals, eigvecs = linalg.eigh(rho.data)

    if figsize is None:
        figsize = (7, 7)

    if ax is None:
        return_fig = True
        fig = plt.figure(figsize=figsize)
    else:
        return_fig = False
        fig = ax.get_figure()

    gs = gridspec.GridSpec(nrows=3, ncols=3)

    ax = fig.add_subplot(gs[0:3, 0:3], projection='3d')
    ax.axes.set_xlim3d(-1.0, 1.0)
    ax.axes.set_ylim3d(-1.0, 1.0)
    ax.axes.set_zlim3d(-1.0, 1.0)
    ax.axes.grid(False)
    ax.view_init(elev=5, azim=275)

    # start the plotting
    # Plot semi-transparent sphere
    u = np.linspace(0, 2 * np.pi, 25)
    v = np.linspace(0, np.pi, 25)
    x = np.outer(np.cos(u), np.sin(v))
    y = np.outer(np.sin(u), np.sin(v))
    z = np.outer(np.ones(np.size(u)), np.cos(v))
    ax.plot_surface(x,
                    y,
                    z,
                    rstride=1,
                    cstride=1,
                    color='k',
                    alpha=0.05,
                    linewidth=0)

    # Get rid of the panes
    ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))

    # Get rid of the spines
    ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
    ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
    ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))

    # Get rid of the ticks
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])

    # traversing the eigvals/vecs backward as sorted low->high
    for idx in range(eigvals.shape[0] - 1, -1, -1):
        if eigvals[idx] > 0.001:
            # get the max eigenvalue
            state = eigvecs[:, idx]
            loc = np.absolute(state).argmax()
            # remove the global phase from max element
            angles = (np.angle(state[loc]) + 2 * np.pi) % (2 * np.pi)
            angleset = np.exp(-1j * angles)
            state = angleset * state

            d = num
            for i in range(2**num):
                # get x,y,z points
                element = bin(i)[2:].zfill(num)
                weight = element.count("1")
                zvalue = -2 * weight / d + 1
                number_of_divisions = n_choose_k(d, weight)
                weight_order = bit_string_index(element)
                angle = (float(weight) / d) * (np.pi * 2) + \
                        (weight_order * 2 * (np.pi / number_of_divisions))

                if (weight > d / 2) or (
                    ((weight == d / 2) and
                     (weight_order >= number_of_divisions / 2))):
                    angle = np.pi - angle - (2 * np.pi / number_of_divisions)

                xvalue = np.sqrt(1 - zvalue**2) * np.cos(angle)
                yvalue = np.sqrt(1 - zvalue**2) * np.sin(angle)

                # get prob and angle - prob will be shade and angle color
                prob = np.real(np.dot(state[i], state[i].conj()))
                colorstate = phase_to_rgb(state[i])

                alfa = 1
                if yvalue >= 0.1:
                    alfa = 1.0 - yvalue

                if prob > 0 and show_state_labels:
                    rprime = 1.3
                    angle_theta = np.arctan2(np.sqrt(1 - zvalue**2), zvalue)
                    xvalue_text = rprime * np.sin(angle_theta) * np.cos(angle)
                    yvalue_text = rprime * np.sin(angle_theta) * np.sin(angle)
                    zvalue_text = rprime * np.cos(angle_theta)
                    element_text = '$\\vert' + element + '\\rangle$'
                    if show_state_phases:
                        element_angle = (np.angle(state[i]) +
                                         (np.pi * 4)) % (np.pi * 2)
                        if use_degrees:
                            element_text += '\n$%.1f^\\circ$' % (
                                element_angle * 180 / np.pi)
                        else:
                            element_angle = pi_check(element_angle,
                                                     ndigits=3).replace(
                                                         'pi', '\\pi')
                            element_text += '\n$%s$' % (element_angle)
                    ax.text(xvalue_text,
                            yvalue_text,
                            zvalue_text,
                            element_text,
                            ha='center',
                            va='center',
                            size=12)

                ax.plot([xvalue], [yvalue], [zvalue],
                        markerfacecolor=colorstate,
                        markeredgecolor=colorstate,
                        marker='o',
                        markersize=np.sqrt(prob) * 30,
                        alpha=alfa)

                a = Arrow3D([0, xvalue], [0, yvalue], [0, zvalue],
                            mutation_scale=20,
                            alpha=prob,
                            arrowstyle="-",
                            color=colorstate,
                            lw=2)
                ax.add_artist(a)

            # add weight lines
            for weight in range(d + 1):
                theta = np.linspace(-2 * np.pi, 2 * np.pi, 100)
                z = -2 * weight / d + 1
                r = np.sqrt(1 - z**2)
                x = r * np.cos(theta)
                y = r * np.sin(theta)
                ax.plot(x, y, z, color=(.5, .5, .5), lw=1, ls=':', alpha=.5)

            # add center point
            ax.plot([0], [0], [0],
                    markerfacecolor=(.5, .5, .5),
                    markeredgecolor=(.5, .5, .5),
                    marker='o',
                    markersize=3,
                    alpha=1)
        else:
            break

    n = 64
    theta = np.ones(n)

    ax2 = fig.add_subplot(gs[2:, 2:])
    ax2.pie(theta, colors=sns.color_palette("hls", n), radius=0.75)
    ax2.add_artist(Circle((0, 0), 0.5, color='white', zorder=1))
    offset = 0.95  # since radius of sphere is one.

    if use_degrees:
        labels = ['Phase\n(Deg)', '0', '90', '180   ', '270']
    else:
        labels = ['Phase', '$0$', '$\\pi/2$', '$\\pi$', '$3\\pi/2$']

    ax2.text(0,
             0,
             labels[0],
             horizontalalignment='center',
             verticalalignment='center',
             fontsize=14)
    ax2.text(offset,
             0,
             labels[1],
             horizontalalignment='center',
             verticalalignment='center',
             fontsize=14)
    ax2.text(0,
             offset,
             labels[2],
             horizontalalignment='center',
             verticalalignment='center',
             fontsize=14)
    ax2.text(-offset,
             0,
             labels[3],
             horizontalalignment='center',
             verticalalignment='center',
             fontsize=14)
    ax2.text(0,
             -offset,
             labels[4],
             horizontalalignment='center',
             verticalalignment='center',
             fontsize=14)

    if return_fig:
        if get_backend() in [
                'module://ipykernel.pylab.backend_inline', 'nbAgg'
        ]:
            plt.close(fig)
        return fig