def plot_detuning_energy_levels(s_qs: StaticQubitSystem, crossings: np.ndarray, ax: Axes, highlighted_indices: Sequence[int] = (-1, )): if s_qs.Omega_zero_energies is None: s_qs.get_energies() crossings_range = crossings.max() - crossings.min() xlims = crossings_range * -0.1, crossings.max() * 1.1 s_qs = StaticQubitSystem(s_qs.N, s_qs.V, s_qs.geometry, Omega=0, Delta=np.linspace(xlims[0], xlims[1], 20)) s_qs.get_energies() g = states_quimb.get_ground_states(1)[0] for i, state in enumerate(s_qs.states): is_highlight_state = any(state is s_qs.states[i] for i in highlighted_indices) is_ground_state = all((_state == g).all() for _state in state) color = 'g' if is_ground_state else 'r' if is_highlight_state else 'grey' linewidth = 5 if is_ground_state or is_highlight_state else 1 z_order = 2 if is_ground_state or is_highlight_state else 1 energies = s_qs.Omega_zero_energies[:, i] ax.plot(s_qs.Delta, energies, color=color, alpha=0.6, lw=linewidth, zorder=z_order, label=states_quimb.get_label_from_state(state), picker=3) def on_pick(event): line = event.artist print(f'Clicked on: {line.get_label()}') plt.gcf().canvas.mpl_connect('pick_event', on_pick) ax.grid() scaled_xaxis_ticker = ticker.EngFormatter(unit="Hz") scaled_yaxis_ticker = ticker.EngFormatter(unit="Hz") ax.xaxis.set_major_formatter(scaled_xaxis_ticker) ax.yaxis.set_major_formatter(scaled_yaxis_ticker) ax.locator_params(nbins=4) # plt.title(rf"Energy spectrum with $N = {self.N}$, $V = {self.V:0.2e}$, $\Omega = {self.Omega:0.2e}$") _m, _s = f"{s_qs.V:0.2e}".split('e') V_text = rf"{_m:s} \times 10^{{{int(_s):d}}}" plt.title(rf"Energy spectrum with $N = {s_qs.N}$, $V = {V_text:s}$ Hz") plt.xlabel(r"Detuning $\Delta$") plt.ylabel("Eigenenergy") plt.xlim(xlims) plt.tight_layout()
def plot_basis_state_populations_2d(e_qs: EvolvingQubitSystem, log=False, log_limit=1e-5): quartile_index = int(len(e_qs.t_list) / 4) indices = [0, quartile_index, quartile_index * 2, quartile_index * 3, -1] # states = get_states(e_qs.N) states = get_states(e_qs.N, sparse=True) state_product_basis_indices_dict = defaultdict(list) number_of_bytes = math.ceil(e_qs.N / 8) for i in range(2**e_qs.N): y = bitarray() y.frombytes((i).to_bytes(number_of_bytes, byteorder='big')) state_product_basis_indices_dict[y.count()].append(i) state_product_basis_indices = np.array([ _ for i in range(e_qs.N, -1, -1) for _ in state_product_basis_indices_dict[i] ]) labels = [get_label_from_state(state) for state in states] x = np.arange(len(labels)) fig, axs = plt.subplots(len(indices), 1, sharex='all', figsize=(14, 8)) for _i, i in enumerate(indices): ax = axs[_i] if not log: ax.set_ylim(0, 1) else: ax.set_ylim(log_limit, 1) ax.set_yscale('log', basey=10) ax.grid(axis='y') ax.set_ylabel(f"{e_qs.solved_t_list[i]:.2e}s") _solve_result_state = e_qs.solved_states[i] solve_result_state_populations = np.abs( _solve_result_state.flatten())**2 basis_state_populations = solve_result_state_populations[ state_product_basis_indices] basis_state_populations += np.ones_like( basis_state_populations) * log_limit above_limit = np.count_nonzero( solve_result_state_populations > log_limit) print( f"above limit {log_limit:.0e} \t count: {above_limit:4d} \t ({above_limit / 2 ** e_qs.N:5.1%})" ) ax.fill_between(x, np.zeros_like(basis_state_populations), basis_state_populations, step='mid') if len(x) > 20: label_indices = [0, -1] plt.xticks( [x[i] for i in label_indices], [labels[i] for i in label_indices], ) else: plt.xticks(x, labels) plt.tight_layout() plt.show()
def plot_basis_states_overlaps(self, ax, plot_title: bool = True, plot_others_as_sum: bool = False): states = states_quimb.get_states(self.N) if not plot_others_as_sum \ else [states_quimb.get_excited_states(self.N), states_quimb.get_ground_states(self.N)] fidelities = [] plot_individual_orthogonal_state_labels = len(states) <= 4 plotted_others = False # for i, state in enumerate(tqdm(states)): for i, state in enumerate(states): label = states_quimb.get_label_from_state(state) state_product_basis_index = states_quimb.get_product_basis_states_index( state) state_fidelities = [ np.abs(_instantaneous_state.flatten() [state_product_basis_index])**2 for _instantaneous_state in self.solved_states ] if 'e' not in label or 'g' not in label: fidelities.append(state_fidelities) if ('e' not in label or 'g' not in label): plot_label = r"$P_{" + f"{label.upper()[0]}" + "}$" elif plot_individual_orthogonal_state_labels: plot_label = r"$P_{" + f"{label.upper()}" + "}$" else: plot_label = 'Others' if plot_label == 'Others': if plotted_others: plot_label = None else: plotted_others = True ax.plot( self.solved_t_list, state_fidelities, label=plot_label, color='g' if 'e' not in label else 'r' if 'g' not in label else 'k', linewidth=1 if 'e' not in label or 'g' not in label else 0.5, alpha=0.5) fidelities_sum = np.array(fidelities).sum(axis=0) ax.plot(self.solved_t_list, fidelities_sum, label="$P_{E} + P_{G}$", color='C0', linestyle=":", linewidth=1, alpha=0.7) if plot_others_as_sum: others_sum = 1 - fidelities_sum ax.plot(self.solved_t_list, others_sum, label=r"$\sum{\textrm{Others}}$", color='C1', linestyle=":", linewidth=1, alpha=0.7) ax.set_ylabel("Population") if plot_title: ax.set_title("Basis state populations") ax.set_ylim((-0.1, 1.1)) ax.yaxis.set_ticks([0, 0.5, 1]) # ax.legend() handles, labels = ax.get_legend_handles_labels() # sort both labels and handles by labels labels, handles = zip(*sorted(zip(labels, handles))) ax.legend(handles, labels)
def plot_detuning_energy_levels( self, plot_state_names: bool, fig_kwargs: dict = None, plot_title: bool = True, ylim: Tuple[float, float] = None, highlight_states_by_label: List[str] = None): if self.Omega_zero_energies is None: self.get_energies() if highlight_states_by_label is None: highlight_states_by_label = ''.join(['e' for _ in range(self.N)]) if fig_kwargs is None: fig_kwargs = {} fig_kwargs = { **dict(figsize=(15, 7), num="Energy Levels"), **fig_kwargs } plt.figure(**fig_kwargs) plot_points = len(self.Delta) for i in reversed(range(len(self.states))): label = states_quimb.get_label_from_state(self.states[i]) is_highlight_state = label in highlight_states_by_label is_ground_state = 'e' not in label color = 'g' if is_ground_state else 'r' if is_highlight_state else 'grey' linewidth = 5 if is_ground_state or is_highlight_state else 1 z_order = 2 if is_ground_state or is_highlight_state else 1 # color = f'C{i}' plt.plot(self.Delta, self.Omega_zero_energies[:, i], color=color, label=label, alpha=0.6, lw=linewidth, zorder=z_order) if self.Omega != 0: plt.plot(self.Delta, self.Omega_non_zero_energies[:, i], color=f'C{i}', ls=':', alpha=0.6) if plot_state_names: Delta_index = int(plot_points / len(self.states)) * i + int( plot_points / 2 / len(self.states)) text_x = self.Delta[Delta_index] text_y = self.Omega_zero_energies[Delta_index, i] plt.text(text_x, text_y, label, ha='center', color=f'C{i}', fontsize=16, fontweight='bold') if plot_state_names: plt.legend() plt.grid() ax = plt.gca() scaled_xaxis_ticker = ticker.EngFormatter(unit="Hz") scaled_yaxis_ticker = ticker.EngFormatter(unit="Hz") ax.xaxis.set_major_formatter(scaled_xaxis_ticker) ax.yaxis.set_major_formatter(scaled_yaxis_ticker) plt.locator_params(nbins=4) # plt.title(rf"Energy spectrum with $N = {self.N}$, $V = {self.V:0.2e}$, $\Omega = {self.Omega:0.2e}$") _m, _s = f"{self.V:0.2e}".split('e') if plot_title: V_text = rf"{_m:s} \times 10^{{{int(_s):d}}}" plt.title( rf"Energy spectrum with $N = {self.N}$, $V = {V_text:s}$ Hz") plt.xlabel(r"Detuning $\Delta$") plt.ylabel("Eigenenergy") plt.xlim((self.Delta.min(), self.Delta.max())) if ylim: plt.ylim(ylim) plt.tight_layout()
def calculate_ghz_crossings(s_qs: StaticQubitSystem, other_highlighted_labels: List[str] = ()): assert s_qs.N % 2 == 0, f"N has to be even, not {s_qs.N}" s_qs.get_energies() EEE_index = len(s_qs.states) - 1 complementary_labels = [ label.translate(str.maketrans({ 'e': 'g', 'g': 'e' })) for label in other_highlighted_labels ] other_highlighted_labels = other_highlighted_labels + [ label for label in complementary_labels if label not in other_highlighted_labels ] other_highlighted_indices = [] for i, state in enumerate(s_qs.states): label = states_quimb.get_label_from_state(state) if label in other_highlighted_labels: other_highlighted_indices.append(i) def find_root(x: np.ndarray, y: np.ndarray): """ Finds crossing (where y equals 0), given that x, y is roughly linear. """ if (y == 0).all(): return np.nan _right_bound = (y < 0).argmax() _left_bound = _right_bound - 1 crossing = y[_left_bound] / (y[_left_bound] - y[_right_bound]) \ * (x[_right_bound] - x[_left_bound]) + x[_left_bound] return crossing crossings = [ find_root(s_qs.Delta, s_qs.Omega_zero_energies[:, i]) for i in range(len(s_qs.states)) ] # Crossing of GGG... with EEE... standard_GHZ_crossing = crossings[EEE_index] # Crossing of GGG... with EGEG... other_highlighted_crossings = [ crossings[i] for i in other_highlighted_indices ] crossings = np.array( crossings[1:]) # 1: removes first "crossing" of GGG (nan) unique_crossings, counts = np.unique(crossings, return_counts=True) fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(15, 8)) plot_detuning_energy_levels(s_qs, crossings, ax1, highlighted_indices=other_highlighted_indices + [-1]) ax2.plot(unique_crossings, counts, 'x', alpha=0.4) ax2.grid() plt.show() pass