Esempio n. 1
0
def plot_detuning_energy_levels(s_qs: StaticQubitSystem,
                                crossings: np.ndarray,
                                ax: Axes,
                                highlighted_indices: Sequence[int] = (-1, )):
    if s_qs.Omega_zero_energies is None:
        s_qs.get_energies()

    crossings_range = crossings.max() - crossings.min()
    xlims = crossings_range * -0.1, crossings.max() * 1.1
    s_qs = StaticQubitSystem(s_qs.N,
                             s_qs.V,
                             s_qs.geometry,
                             Omega=0,
                             Delta=np.linspace(xlims[0], xlims[1], 20))
    s_qs.get_energies()

    g = states_quimb.get_ground_states(1)[0]
    for i, state in enumerate(s_qs.states):
        is_highlight_state = any(state is s_qs.states[i]
                                 for i in highlighted_indices)
        is_ground_state = all((_state == g).all() for _state in state)

        color = 'g' if is_ground_state else 'r' if is_highlight_state else 'grey'
        linewidth = 5 if is_ground_state or is_highlight_state else 1
        z_order = 2 if is_ground_state or is_highlight_state else 1
        energies = s_qs.Omega_zero_energies[:, i]
        ax.plot(s_qs.Delta,
                energies,
                color=color,
                alpha=0.6,
                lw=linewidth,
                zorder=z_order,
                label=states_quimb.get_label_from_state(state),
                picker=3)

    def on_pick(event):
        line = event.artist
        print(f'Clicked on: {line.get_label()}')

    plt.gcf().canvas.mpl_connect('pick_event', on_pick)

    ax.grid()
    scaled_xaxis_ticker = ticker.EngFormatter(unit="Hz")
    scaled_yaxis_ticker = ticker.EngFormatter(unit="Hz")
    ax.xaxis.set_major_formatter(scaled_xaxis_ticker)
    ax.yaxis.set_major_formatter(scaled_yaxis_ticker)
    ax.locator_params(nbins=4)

    # plt.title(rf"Energy spectrum with $N = {self.N}$, $V = {self.V:0.2e}$, $\Omega = {self.Omega:0.2e}$")
    _m, _s = f"{s_qs.V:0.2e}".split('e')
    V_text = rf"{_m:s} \times 10^{{{int(_s):d}}}"
    plt.title(rf"Energy spectrum with $N = {s_qs.N}$, $V = {V_text:s}$ Hz")
    plt.xlabel(r"Detuning $\Delta$")
    plt.ylabel("Eigenenergy")

    plt.xlim(xlims)
    plt.tight_layout()
Esempio n. 2
0
def plot_basis_state_populations_2d(e_qs: EvolvingQubitSystem,
                                    log=False,
                                    log_limit=1e-5):
    quartile_index = int(len(e_qs.t_list) / 4)
    indices = [0, quartile_index, quartile_index * 2, quartile_index * 3, -1]

    # states = get_states(e_qs.N)
    states = get_states(e_qs.N, sparse=True)

    state_product_basis_indices_dict = defaultdict(list)
    number_of_bytes = math.ceil(e_qs.N / 8)
    for i in range(2**e_qs.N):
        y = bitarray()
        y.frombytes((i).to_bytes(number_of_bytes, byteorder='big'))

        state_product_basis_indices_dict[y.count()].append(i)
    state_product_basis_indices = np.array([
        _ for i in range(e_qs.N, -1, -1)
        for _ in state_product_basis_indices_dict[i]
    ])

    labels = [get_label_from_state(state) for state in states]
    x = np.arange(len(labels))

    fig, axs = plt.subplots(len(indices), 1, sharex='all', figsize=(14, 8))

    for _i, i in enumerate(indices):
        ax = axs[_i]
        if not log:
            ax.set_ylim(0, 1)
        else:
            ax.set_ylim(log_limit, 1)
            ax.set_yscale('log', basey=10)
        ax.grid(axis='y')
        ax.set_ylabel(f"{e_qs.solved_t_list[i]:.2e}s")

        _solve_result_state = e_qs.solved_states[i]
        solve_result_state_populations = np.abs(
            _solve_result_state.flatten())**2
        basis_state_populations = solve_result_state_populations[
            state_product_basis_indices]
        basis_state_populations += np.ones_like(
            basis_state_populations) * log_limit

        above_limit = np.count_nonzero(
            solve_result_state_populations > log_limit)
        print(
            f"above limit {log_limit:.0e} \t count: {above_limit:4d} \t ({above_limit / 2 ** e_qs.N:5.1%})"
        )

        ax.fill_between(x,
                        np.zeros_like(basis_state_populations),
                        basis_state_populations,
                        step='mid')

    if len(x) > 20:
        label_indices = [0, -1]
        plt.xticks(
            [x[i] for i in label_indices],
            [labels[i] for i in label_indices],
        )
    else:
        plt.xticks(x, labels)
    plt.tight_layout()
    plt.show()
Esempio n. 3
0
    def plot_basis_states_overlaps(self,
                                   ax,
                                   plot_title: bool = True,
                                   plot_others_as_sum: bool = False):
        states = states_quimb.get_states(self.N) if not plot_others_as_sum \
            else [states_quimb.get_excited_states(self.N), states_quimb.get_ground_states(self.N)]
        fidelities = []

        plot_individual_orthogonal_state_labels = len(states) <= 4
        plotted_others = False
        # for i, state in enumerate(tqdm(states)):
        for i, state in enumerate(states):
            label = states_quimb.get_label_from_state(state)
            state_product_basis_index = states_quimb.get_product_basis_states_index(
                state)
            state_fidelities = [
                np.abs(_instantaneous_state.flatten()
                       [state_product_basis_index])**2
                for _instantaneous_state in self.solved_states
            ]

            if 'e' not in label or 'g' not in label:
                fidelities.append(state_fidelities)

            if ('e' not in label or 'g' not in label):
                plot_label = r"$P_{" + f"{label.upper()[0]}" + "}$"
            elif plot_individual_orthogonal_state_labels:
                plot_label = r"$P_{" + f"{label.upper()}" + "}$"
            else:
                plot_label = 'Others'

            if plot_label == 'Others':
                if plotted_others:
                    plot_label = None
                else:
                    plotted_others = True

            ax.plot(
                self.solved_t_list,
                state_fidelities,
                label=plot_label,
                color='g'
                if 'e' not in label else 'r' if 'g' not in label else 'k',
                linewidth=1 if 'e' not in label or 'g' not in label else 0.5,
                alpha=0.5)

        fidelities_sum = np.array(fidelities).sum(axis=0)
        ax.plot(self.solved_t_list,
                fidelities_sum,
                label="$P_{E} + P_{G}$",
                color='C0',
                linestyle=":",
                linewidth=1,
                alpha=0.7)

        if plot_others_as_sum:
            others_sum = 1 - fidelities_sum

            ax.plot(self.solved_t_list,
                    others_sum,
                    label=r"$\sum{\textrm{Others}}$",
                    color='C1',
                    linestyle=":",
                    linewidth=1,
                    alpha=0.7)

        ax.set_ylabel("Population")
        if plot_title:
            ax.set_title("Basis state populations")
        ax.set_ylim((-0.1, 1.1))
        ax.yaxis.set_ticks([0, 0.5, 1])

        # ax.legend()
        handles, labels = ax.get_legend_handles_labels()
        # sort both labels and handles by labels
        labels, handles = zip(*sorted(zip(labels, handles)))
        ax.legend(handles, labels)
Esempio n. 4
0
    def plot_detuning_energy_levels(
            self,
            plot_state_names: bool,
            fig_kwargs: dict = None,
            plot_title: bool = True,
            ylim: Tuple[float, float] = None,
            highlight_states_by_label: List[str] = None):
        if self.Omega_zero_energies is None:
            self.get_energies()

        if highlight_states_by_label is None:
            highlight_states_by_label = ''.join(['e' for _ in range(self.N)])

        if fig_kwargs is None:
            fig_kwargs = {}
        fig_kwargs = {
            **dict(figsize=(15, 7), num="Energy Levels"),
            **fig_kwargs
        }

        plt.figure(**fig_kwargs)

        plot_points = len(self.Delta)
        for i in reversed(range(len(self.states))):
            label = states_quimb.get_label_from_state(self.states[i])
            is_highlight_state = label in highlight_states_by_label
            is_ground_state = 'e' not in label
            color = 'g' if is_ground_state else 'r' if is_highlight_state else 'grey'
            linewidth = 5 if is_ground_state or is_highlight_state else 1
            z_order = 2 if is_ground_state or is_highlight_state else 1
            # color = f'C{i}'
            plt.plot(self.Delta,
                     self.Omega_zero_energies[:, i],
                     color=color,
                     label=label,
                     alpha=0.6,
                     lw=linewidth,
                     zorder=z_order)
            if self.Omega != 0:
                plt.plot(self.Delta,
                         self.Omega_non_zero_energies[:, i],
                         color=f'C{i}',
                         ls=':',
                         alpha=0.6)

            if plot_state_names:
                Delta_index = int(plot_points / len(self.states)) * i + int(
                    plot_points / 2 / len(self.states))
                text_x = self.Delta[Delta_index]
                text_y = self.Omega_zero_energies[Delta_index, i]
                plt.text(text_x,
                         text_y,
                         label,
                         ha='center',
                         color=f'C{i}',
                         fontsize=16,
                         fontweight='bold')

        if plot_state_names:
            plt.legend()

        plt.grid()
        ax = plt.gca()
        scaled_xaxis_ticker = ticker.EngFormatter(unit="Hz")
        scaled_yaxis_ticker = ticker.EngFormatter(unit="Hz")
        ax.xaxis.set_major_formatter(scaled_xaxis_ticker)
        ax.yaxis.set_major_formatter(scaled_yaxis_ticker)
        plt.locator_params(nbins=4)

        # plt.title(rf"Energy spectrum with $N = {self.N}$, $V = {self.V:0.2e}$, $\Omega = {self.Omega:0.2e}$")
        _m, _s = f"{self.V:0.2e}".split('e')
        if plot_title:
            V_text = rf"{_m:s} \times 10^{{{int(_s):d}}}"
            plt.title(
                rf"Energy spectrum with $N = {self.N}$, $V = {V_text:s}$ Hz")
        plt.xlabel(r"Detuning $\Delta$")
        plt.ylabel("Eigenenergy")
        plt.xlim((self.Delta.min(), self.Delta.max()))
        if ylim:
            plt.ylim(ylim)
        plt.tight_layout()
Esempio n. 5
0
def calculate_ghz_crossings(s_qs: StaticQubitSystem,
                            other_highlighted_labels: List[str] = ()):
    assert s_qs.N % 2 == 0, f"N has to be even, not {s_qs.N}"
    s_qs.get_energies()

    EEE_index = len(s_qs.states) - 1

    complementary_labels = [
        label.translate(str.maketrans({
            'e': 'g',
            'g': 'e'
        })) for label in other_highlighted_labels
    ]
    other_highlighted_labels = other_highlighted_labels + [
        label for label in complementary_labels
        if label not in other_highlighted_labels
    ]
    other_highlighted_indices = []

    for i, state in enumerate(s_qs.states):
        label = states_quimb.get_label_from_state(state)
        if label in other_highlighted_labels:
            other_highlighted_indices.append(i)

    def find_root(x: np.ndarray, y: np.ndarray):
        """
        Finds crossing (where y equals 0), given that x, y is roughly linear.
        """
        if (y == 0).all():
            return np.nan
        _right_bound = (y < 0).argmax()
        _left_bound = _right_bound - 1
        crossing = y[_left_bound] / (y[_left_bound] - y[_right_bound]) \
                   * (x[_right_bound] - x[_left_bound]) + x[_left_bound]
        return crossing

    crossings = [
        find_root(s_qs.Delta, s_qs.Omega_zero_energies[:, i])
        for i in range(len(s_qs.states))
    ]
    # Crossing of GGG... with EEE...
    standard_GHZ_crossing = crossings[EEE_index]

    # Crossing of GGG... with EGEG...
    other_highlighted_crossings = [
        crossings[i] for i in other_highlighted_indices
    ]

    crossings = np.array(
        crossings[1:])  # 1: removes first "crossing" of GGG (nan)
    unique_crossings, counts = np.unique(crossings, return_counts=True)

    fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(15, 8))
    plot_detuning_energy_levels(s_qs,
                                crossings,
                                ax1,
                                highlighted_indices=other_highlighted_indices +
                                [-1])

    ax2.plot(unique_crossings, counts, 'x', alpha=0.4)
    ax2.grid()
    plt.show()
    pass