def __init__( self, time_unit: u.Unit = "asec", show_electric_field=True, electric_field_unit="atomic_electric_field", show_vector_potential=False, vector_potential_unit="atomic_momentum", linewidth=3, show_y_label=False, show_ticks_bottom=True, show_ticks_top=False, show_ticks_right=True, show_ticks_left=True, grid_kwargs=None, legend_kwargs=None, ): self.show_electric_field = show_electric_field self.show_vector_potential = show_vector_potential if not show_electric_field and not show_vector_potential: logger.warning( f"{self} has both show_electric_field and show_vector_potential set to False" ) self.time_unit = time_unit self.time_unit_value, self.time_unit_latex = u.get_unit_value_and_latex( time_unit) self.electric_field_unit = electric_field_unit ( self.electric_field_unit_value, self.electric_field_unit_latex, ) = u.get_unit_value_and_latex(electric_field_unit) self.vector_potential_unit = vector_potential_unit ( self.vector_potential_unit_value, self.vector_potential_unit_latex, ) = u.get_unit_value_and_latex(vector_potential_unit) self.show_y_label = show_y_label self.show_ticks_bottom = show_ticks_bottom self.show_ticks_top = show_ticks_top self.show_ticks_right = show_ticks_right self.show_ticks_left = show_ticks_left self.linewidth = linewidth if legend_kwargs is None: legend_kwargs = dict() legend_defaults = dict(loc="lower left", fontsize=30, fancybox=True, framealpha=0) self.legend_kwargs = {**legend_defaults, **legend_kwargs} if grid_kwargs is None: grid_kwargs = {} self.grid_kwargs = {**si.vis.DEFAULT_GRID_KWARGS, **grid_kwargs} super().__init__()
def plot_mesh( self, mesh: "meshes.ScalarMesh", name: str = "", title: Optional[str] = None, distance_unit: u.Unit = "bohr_radius", colormap=vis.COLORMAP_WAVEFUNCTION, norm=si.vis.AbsoluteRenormalize(), shading: si.vis.ColormapShader = si.vis.ColormapShader.FLAT, plot_limit=None, slicer="get_mesh_slicer", show_colorbar=True, show_title=True, show_axes=True, grid_kwargs=None, overlay_probability_current=False, **kwargs, ): grid_kwargs = collections.ChainMap(grid_kwargs or {}, si.vis.COLORMESH_GRID_KWARGS) unit_value, unit_latex = u.get_unit_value_and_latex(distance_unit) with si.vis.FigureManager(name=f"{self.spec.name}__{name}", **kwargs) as figman: fig = figman.fig fig.set_tight_layout(True) ax = plt.subplot(111, projection="polar") ax.set_theta_zero_location("N") ax.set_theta_direction("clockwise") color_mesh = self.attach_mesh_to_axis( ax, mesh, distance_unit=distance_unit, colormap=colormap, norm=norm, shading=shading, plot_limit=plot_limit, slicer=slicer, ) if title is not None and title != "" and show_axes and show_title: title = ax.set_title(title, fontsize=15) title.set_x(0.03) # move title to the upper left corner title.set_y(0.97) if show_colorbar and show_axes: cax = fig.add_axes([0.8, 0.1, 0.02, 0.8]) plt.colorbar(mappable=color_mesh, cax=cax) fmt_polar_axis(fig, ax, colormap, grid_kwargs, unit_latex) if plot_limit is not None and plot_limit < self.mesh.r_max: ax.set_rmax( (plot_limit - (self.mesh.delta_r / 2)) / unit_value) else: ax.set_rmax( (self.mesh.r_max - (self.mesh.delta_r / 2)) / unit_value) if not show_axes: ax.axis("off")
def attach_mesh_repr_to_axis( self, axis: plt.Axes, mesh: "meshes.ScalarMesh", distance_unit: str = "bohr_radius", colormap=plt.get_cmap("inferno"), norm=si.vis.AbsoluteRenormalize(), shading: si.vis.ColormapShader = si.vis.ColormapShader.FLAT, plot_limit: Optional[float] = None, slicer: str = "get_mesh_slicer", **kwargs, ): unit_value, _ = u.get_unit_value_and_latex(distance_unit) _slice = getattr(self.mesh, slicer)(plot_limit) color_mesh = axis.pcolormesh( self.mesh.l_mesh[_slice], self.mesh.r_mesh[_slice] / unit_value, mesh[_slice], shading=shading, cmap=colormap, norm=norm, **kwargs, ) return color_mesh
def __init__( self, axman_lower_right=axes.ElectricPotentialPlotAxis(), axman_upper_right=axes.WavefunctionStackplotAxis(), axman_colorbar=axes.ColorBarAxis(), fig_dpi_scale=1, time_text_unit: u.Unit = "asec", **kwargs, ): super().__init__(**kwargs) self.axman_lower_right = axman_lower_right self.axman_upper_right = axman_upper_right self.axman_colorbar = axman_colorbar self.axis_managers += [ axman for axman in [ self.axman_lower_right, self.axman_upper_right, self.axman_colorbar, ] if axman is not None ] self.fig_dpi_scale = fig_dpi_scale self.time_unit_value, self.time_unit_latex = u.get_unit_value_and_latex( time_text_unit)
def attach_probability_current_to_axis( self, axis: plt.Axes, plot_limit: Optional[float] = None, distance_unit: u.Unit = "bohr_radius", rate_unit="per_asec", ): distance_unit_value, _ = u.get_unit_value_and_latex(distance_unit) rate_unit_value, _ = u.get_unit_value_and_latex(rate_unit) ( current_mesh_z, current_mesh_rho, ) = (self.mesh.get_probability_current_density_vector_field() ) # actually densities here current_mesh_z *= self.mesh.delta_z current_mesh_rho *= self.mesh.delta_rho skip_count = ( int(self.mesh.z_mesh.shape[0] / 50), int(self.mesh.z_mesh.shape[1] / 50), ) skip = (slice(None, None, skip_count[0]), slice(None, None, skip_count[1])) normalization = np.nanmax( np.sqrt((current_mesh_z**2) + (current_mesh_rho**2))[skip]) if normalization == 0: normalization = 1 sli = self.mesh.get_mesh_slicer(plot_limit) quiv = axis.quiver( self.mesh.z_mesh[sli][skip] / distance_unit_value, self.mesh.rho_mesh[sli][skip] / distance_unit_value, current_mesh_z[sli][skip] / normalization, current_mesh_rho[sli][skip] / normalization, pivot="middle", scale=10, units="width", scale_units="width", alpha=0.5, color="white", ) return quiv
def attach_electric_potential_plot_to_axis(self, axis, time_unit="asec", show_electric_field=True, overlay_kicks=True): time_unit_value, time_unit_latex = u.get_unit_value_and_latex( time_unit) if show_electric_field and not isinstance(self.spec.electric_potential, DeltaKicks): axis.plot( self.times / time_unit_value, self.spec.electric_potential.get_electric_field_amplitude( self.times) / u.atomic_electric_field, color=vis.COLOR_EFIELD, linewidth=1.5, ) if overlay_kicks: for kick in self.spec.kicks: axis.plot( [ kick.time / time_unit_value, kick.time / time_unit_value ], [ 0, self.spec.electric_potential. get_electric_field_amplitude(kick.time) / u.atomic_electric_field, ], linewidth=1.5, color=si.vis.PINK, ) axis.set_ylabel(rf"$ {vis.LATEX_EFIELD}(t) $") else: for kick in self.spec.kicks: axis.plot( [kick.time / time_unit_value, kick.time / time_unit_value], [ 0, kick.amplitude / (u.atomic_electric_field * u.atomic_time) ], linewidth=1.5, color=si.vis.PINK, ) axis.set_ylabel(r"$ \eta $") axis.set_xlabel("Time $t$ (${}$)".format(time_unit_latex), fontsize=13) axis.tick_params(labelright=True) axis.set_xlim(self.times[0] / time_unit_value, self.times[-1] / time_unit_value) axis.grid(True, **si.vis.DEFAULT_GRID_KWARGS)
def attach_electric_potential_plot_to_axis( self, axis, time_unit="asec", legend_kwargs=None, show_y_label: bool = False, show_electric_field: bool = True, show_vector_potential: bool = True, ): time_unit_value, time_unit_latex = u.get_unit_value_and_latex( time_unit) if legend_kwargs is None: legend_kwargs = dict() legend_defaults = dict(loc="lower left", fontsize=10, fancybox=True, framealpha=0.3) legend_kwargs = {**legend_defaults, **legend_kwargs} y_labels = [] if show_electric_field: e_label = fr"$ {vis.LATEX_EFIELD}(t) $" axis.plot( self.times / time_unit_value, self.spec.electric_potential.get_electric_field_amplitude( self.times) / u.atomic_electric_field, color=vis.COLOR_EFIELD, linewidth=1.5, label=e_label, ) y_labels.append(e_label) if show_vector_potential: a_label = fr"$ e \, {vis.LATEX_AFIELD}(t) $" axis.plot( self.times / time_unit_value, u.proton_charge * self.spec.electric_potential. get_vector_potential_amplitude_numeric_cumulative(self.times) / u.atomic_momentum, color=vis.COLOR_AFIELD, linewidth=1.5, label=a_label, ) y_labels.append(a_label) if show_y_label: axis.set_ylabel(", ".join(y_labels), fontsize=13) axis.set_xlabel("Time $t$ (${}$)".format(time_unit_latex), fontsize=13) axis.tick_params(labelright=True) axis.set_xlim(self.times[0] / time_unit_value, self.times[-1] / time_unit_value) axis.legend(**legend_kwargs) axis.grid(True, **si.vis.DEFAULT_GRID_KWARGS)
def attach_electric_potential_plot_to_axis( self, axis: plt.Axes, show_electric_field: bool = True, show_vector_potential: bool = True, time_unit: u.Unit = "asec", legend_kwargs: Optional[dict] = None, show_y_label: bool = False, ): time_unit_value, time_unit_latex = u.get_unit_value_and_latex( time_unit) if legend_kwargs is None: legend_kwargs = dict() legend_defaults = dict(loc="lower left", fontsize=10, fancybox=True, framealpha=0.3) legend_kwargs = {**legend_defaults, **legend_kwargs} if show_electric_field: axis.plot( self.sim.data_times / time_unit_value, self.sim.data.electric_field_amplitude / u.atomic_electric_field, color=vis.COLOR_EFIELD, linewidth=1.5, label=fr"$ {vis.LATEX_EFIELD}(t) $", ) if show_vector_potential: axis.plot( self.sim.data_times / time_unit_value, u.proton_charge * self.sim.data.vector_potential_amplitude / u.atomic_momentum, color=vis.COLOR_AFIELD, linewidth=1.5, label=fr"$ e \, {vis.LATEX_AFIELD}(t) $", ) if show_y_label: axis.set_ylabel(rf"${vis.LATEX_EFIELD}(t)$", fontsize=13, color=vis.COLOR_EFIELD) axis.set_xlabel(rf"Time $t$ (${time_unit_latex}$)", fontsize=13) axis.tick_params(labelright=True) axis.set_xlim(self.sim.times[0] / time_unit_value, self.sim.times[-1] / time_unit_value) axis.legend(**legend_kwargs) axis.grid(True, **si.vis.DEFAULT_GRID_KWARGS)
def initialize_axis(self): unit_value, unit_name = u.get_unit_value_and_latex(self.distance_unit) self.mesh = self.attach_method( self.axis, colormap=self.colormap, norm=self.norm, shading=self.shading, plot_limit=self.plot_limit, distance_unit=self.distance_unit, slicer=self.slicer, animated=True, linewidth=3, ) self.redraw.append(self.mesh) if self.log: self.axis.set_yscale("log") self.axis.set_ylim(bottom=1e-15) # TODO: code for show_potential self.axis.grid(True, **self.grid_kwargs) self.axis.set_xlabel(r"$ z $ ($ {} $)".format(unit_name), fontsize=24) plot_labels = { "g2": r"$ \left| g \right|^2 $", "psi2": r"$ \left| \Psi \right|^2 $", "g": r"$ g $", "psi": r"$ \Psi $", "fft": r"$ \phi $", } self.axis.set_ylabel(plot_labels[self.which], fontsize=30) self.axis.tick_params(axis="both", which="major", labelsize=20) self.axis.tick_params(labelright=True, labeltop=True) slice = getattr(self.sim.mesh, self.slicer)(self.plot_limit) z = self.sim.mesh.z_mesh[slice] z_lower_limit, z_upper_limit = np.nanmin(z), np.nanmax(z) self.axis.set_xlim(z_lower_limit / unit_value, z_upper_limit / unit_value) self.redraw += [ *self.axis.xaxis.get_gridlines(), *self.axis.yaxis.get_gridlines(), ] # gridlines must be redrawn over the mesh (it's important that they're AFTER the mesh itself in self.redraw) super().initialize_axis()
def initialize_axis(self): unit_value, unit_name = u.get_unit_value_and_latex(self.distance_unit) if self.which == "g": self.norm.equator_magnitude = np.max( np.abs(self.sim.mesh.g) / vis.DEFAULT_RICHARDSON_MAGNITUDE_DIVISOR) self.mesh = self.attach_method( self.axis, colormap=self.colormap, norm=self.norm, shading=self.shading, plot_limit=self.plot_limit, distance_unit=self.distance_unit, slicer=self.slicer, animated=True, ) self.redraw.append(self.mesh) self.axis.grid( True, **self.grid_kwargs ) # change grid color to make it show up against the colormesh self.axis.set_xlabel(r"$z$ (${}$)".format(unit_name), fontsize=24) self.axis.set_ylabel(r"$\rho$ (${}$)".format(unit_name), fontsize=24) self.axis.tick_params(axis="both", which="major", labelsize=20) self.axis.axis("tight") super().initialize_axis() self.redraw += [ *self.axis.xaxis.get_gridlines(), *self.axis.yaxis.get_gridlines(), *self.axis.yaxis.get_ticklabels(), ] # gridlines must be redrawn over the mesh (it's important that they're AFTER the mesh itself in self.redraw) if self.which not in ("g", "psi"): divider = make_axes_locatable(self.axis) cax = divider.append_axes("right", size="2%", pad=0.05) self.cbar = plt.colorbar(cax=cax, mappable=self.mesh) self.cbar.ax.tick_params(labelsize=20) else: logger.warning( "show_colorbar cannot be used with nonlinear colormaps")
def group_free_states_by_continuous_attr( self, attr="energy", divisions=10, cutoff_value=None, label_format_str=r"\phi_{{ {} \; \mathrm{{to}} \; {} \, {}, \ell }}", attr_unit: u.Unit = "eV", ): spectrum = set(getattr(s, attr) for s in self.sim.free_states) grouped_states = collections.defaultdict(list) group_labels = {} try: attr_min, attr_max = min(spectrum), max(spectrum) except ValueError: return [], [] if cutoff_value is None: boundaries = np.linspace(attr_min, attr_max, num=divisions + 1) else: boundaries = np.linspace(attr_min, cutoff_value, num=divisions) boundaries = np.concatenate((boundaries, [attr_max])) label_unit_value, label_unit_latex = u.get_unit_value_and_latex( attr_unit) free_states = list(self.sim.free_states) for ii, lower_boundary in enumerate(boundaries[:-1]): upper_boundary = boundaries[ii + 1] label = label_format_str.format( f"{lower_boundary / label_unit_value:.2f}", f"{upper_boundary / label_unit_value:.2f}", label_unit_latex, ) group_labels[(lower_boundary, upper_boundary)] = label for s in copy(free_states): if lower_boundary <= getattr(s, attr) <= upper_boundary: grouped_states[(lower_boundary, upper_boundary)].append(s) free_states.remove(s) return grouped_states, group_labels
def attach_mesh_to_axis( self, axis: plt.Axes, mesh: "meshes.ScalarMesh", distance_unit: u.Unit = "bohr_radius", norm=si.vis.AbsoluteRenormalize(), plot_limit=None, slicer="get_mesh_slicer", **kwargs, ): unit_value, _ = u.get_unit_value_and_latex(distance_unit) _slice = getattr(self.mesh, slicer)(plot_limit) (line, ) = axis.plot(self.mesh.z_mesh[_slice] / unit_value, norm(mesh[_slice]), **kwargs) return line
def __init__( self, show_norm=True, time_unit: u.Unit = "asec", y_label=None, show_ticks_bottom=True, show_ticks_top=False, show_ticks_right=True, show_ticks_left=True, grid_kwargs=None, legend_kwargs=None, ): self.show_norm = show_norm self.time_unit = time_unit self.time_unit_value, self.time_unit_latex = u.get_unit_value_and_latex( time_unit) self.y_label = y_label self.show_ticks_bottom = show_ticks_bottom self.show_ticks_top = show_ticks_top self.show_ticks_right = show_ticks_right self.show_ticks_left = show_ticks_left if legend_kwargs is None: legend_kwargs = {} legend_defaults = dict(loc="lower left", fontsize=30, fancybox=True, framealpha=0) self.legend_kwargs = {**legend_defaults, **legend_kwargs} if grid_kwargs is None: grid_kwargs = {} self.grid_kwargs = {**si.vis.DEFAULT_GRID_KWARGS, **grid_kwargs} super().__init__()
def state_overlaps_vs_time( self, states: Iterable[states.QuantumState] = None, log: bool = False, time_unit: u.Unit = "asec", show_electric_field: bool = True, show_vector_potential: bool = True, **kwargs, ): with si.vis.FigureManager(name=f"{self.spec.name}", **kwargs) as figman: time_unit_value, time_unit_latex = u.get_unit_value_and_latex( time_unit) grid_spec = matplotlib.gridspec.GridSpec(2, 1, height_ratios=[4, 1], hspace=0.07) ax_overlaps = plt.subplot(grid_spec[0]) ax_field = plt.subplot(grid_spec[1], sharex=ax_overlaps) self.attach_electric_potential_plot_to_axis( ax_field, show_electric_field=show_electric_field, show_vector_potential=show_vector_potential, # legend_kwargs = dict( # bbox_to_anchor = (1.1, .9), # loc = 'upper left', # borderaxespad = 0.1, # fontsize = 10, # ), ) ax_overlaps.plot( self.sim.data_times / time_unit_value, self.sim.data.norm, label=r"$\left\langle \psi|\psi \right\rangle$", color="black", linewidth=2, ) state_overlaps = self.sim.data.state_overlaps if states is not None: if callable(states): state_overlaps = { state: overlap for state, overlap in state_overlaps.items() if states(state) } else: states = set(states) state_overlaps = { state: overlap for state, overlap in state_overlaps.items() if state in states or ( state.numeric and state.analytic_state in states) } overlaps = [ overlap for state, overlap in sorted(state_overlaps.items()) ] labels = [ rf"$ \left| \left\langle \psi | {{{state.tex}}} \right\rangle \right|^2 $" for state, overlap in sorted(state_overlaps.items()) ] ax_overlaps.stackplot( self.sim.data_times / time_unit_value, *overlaps, labels=labels, # colors = colors, ) if log: ax_overlaps.set_yscale("log") min_overlap = min( [np.min(overlap) for overlap in state_overlaps.values()]) ax_overlaps.set_ylim(bottom=max(1e-9, min_overlap * 0.1), top=1.0) ax_overlaps.grid(True, which="both", **si.vis.DEFAULT_GRID_KWARGS) else: ax_overlaps.set_ylim(0.0, 1.0) ax_overlaps.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) ax_overlaps.grid(True, **si.vis.DEFAULT_GRID_KWARGS) ax_overlaps.set_xlim( self.sim.times[0] / time_unit_value, self.sim.times[-1] / time_unit_value, ) ax_overlaps.set_ylabel("Wavefunction Metric", fontsize=13) ax_overlaps.legend( bbox_to_anchor=(1.1, 1.1), loc="upper left", borderaxespad=0.075, fontsize=9, ncol=1 + (len(overlaps) // 10), ) ax_overlaps.tick_params(labelright=True) ax_overlaps.xaxis.tick_top() # plt.rcParams['xtick.major.pad'] = 5 # plt.rcParams['ytick.major.pad'] = 5 # Find at most n+1 ticks on the y-axis at 'nice' locations max_yticks = 4 yloc = plt.MaxNLocator(max_yticks, prune="upper") ax_field.yaxis.set_major_locator(yloc) max_xticks = 6 xloc = plt.MaxNLocator(max_xticks, prune="both") ax_field.xaxis.set_major_locator(xloc) ax_field.tick_params(axis="both", which="major", labelsize=10) ax_overlaps.tick_params(axis="both", which="major", labelsize=10) postfix = "" if log: postfix += "__log" figman.name += postfix
def wavefunction_vs_time( self, log: bool = False, time_unit: u.Unit = "asec", bound_state_max_n: int = 5, collapse_bound_state_angular_momenta: bool = True, grouped_free_states=None, group_free_states_labels=None, show_title: bool = False, plot_name_from: str = "file_name", show_electric_field: bool = True, show_vector_potential: bool = True, **kwargs, ): with si.vis.FigureManager(name=getattr(self, plot_name_from) + "__wavefunction_vs_time", **kwargs) as figman: time_unit_value, time_unit_latex = u.get_unit_value_and_latex( time_unit) grid_spec = matplotlib.gridspec.GridSpec(2, 1, height_ratios=[4, 1], hspace=0.07) ax_overlaps = plt.subplot(grid_spec[0]) ax_field = plt.subplot(grid_spec[1], sharex=ax_overlaps) self.attach_electric_potential_plot_to_axis( ax_field, show_electric_field=show_electric_field, show_vector_potential=show_vector_potential, # legend_kwargs = dict( # bbox_to_anchor = (1.1, .9), # loc = 'upper left', # borderaxespad = 0.1, # fontsize = 10) ) ax_overlaps.plot( self.sim.data_times / time_unit_value, self.sim.data.norm, label=r"$\left\langle \Psi | \Psi \right\rangle$", color="black", linewidth=2, ) if grouped_free_states is None: ( grouped_free_states, group_free_states_labels, ) = self.group_free_states_by_continuous_attr("energy", attr_unit="eV") overlaps = [] labels = [] colors = [] state_overlaps = ( self.sim.data.state_overlaps ) # it's a property that would otherwise get evaluated every time we asked for it extra_bound_overlap = np.zeros(self.sim.data_time_steps) if collapse_bound_state_angular_momenta: overlaps_by_n = { n: np.zeros(self.sim.data_time_steps) for n in range(1, bound_state_max_n + 1) } # prepare arrays to sum over angular momenta in, one for each n for state in sorted(self.sim.bound_states): if state.n <= bound_state_max_n: overlaps_by_n[state.n] += state_overlaps[state] else: extra_bound_overlap += state_overlaps[state] overlaps += [ overlap for n, overlap in sorted(overlaps_by_n.items()) ] labels += [ rf"$ \left| \left\langle \Psi | \psi_{{ {n}, \ell }} \right\rangle \right|^2 $" for n in sorted(overlaps_by_n) ] colors += [ matplotlib.colors.to_rgba("C" + str(n - 1), alpha=1) for n in sorted(overlaps_by_n) ] else: for state in sorted(self.sim.bound_states): if state.n <= bound_state_max_n: overlaps.append(state_overlaps[state]) labels.append( rf"$ \left| \left\langle \Psi | {{{state.tex}}} \right\rangle \right|^2 $" ) colors.append( matplotlib.colors.to_rgba( "C" + str((state.n - 1) % 10), alpha=1 - state.l / state.n, )) else: extra_bound_overlap += state_overlaps[state] overlaps.append(extra_bound_overlap) labels.append( rf"$ \left| \left\langle \Psi | \psi_{{n \geq {bound_state_max_n + 1} }} \right\rangle \right|^2 $" ) colors.append(".4") free_state_color_cycle = itertools.cycle([ "#8dd3c7", "#ffffb3", "#bebada", "#fb8072", "#80b1d3", "#fdb462", "#b3de69", "#fccde5", "#d9d9d9", "#bc80bd", "#ccebc5", "#ffed6f", ]) for group, states in sorted(grouped_free_states.items()): if len(states) != 0: overlaps.append(np.sum(state_overlaps[s] for s in states)) labels.append( rf"$\left| \left\langle \Psi | {{{group_free_states_labels[group]}}} \right\rangle \right|^2$" ) colors.append(free_state_color_cycle.__next__()) overlaps = [overlap for overlap in overlaps] ax_overlaps.stackplot( self.sim.data_times / time_unit_value, *overlaps, labels=labels, colors=colors, ) if log: ax_overlaps.set_yscale("log") min_overlap = min( [np.min(overlap) for overlap in state_overlaps.values()]) ax_overlaps.set_ylim(bottom=max(1e-9, min_overlap * 0.1), top=1.0) ax_overlaps.grid(True, which="both", **si.vis.DEFAULT_GRID_KWARGS) else: ax_overlaps.set_ylim(0.0, 1.0) ax_overlaps.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) ax_overlaps.grid(True, **si.vis.DEFAULT_GRID_KWARGS) ax_overlaps.set_xlim( self.spec.time_initial / time_unit_value, self.spec.time_final / time_unit_value, ) ax_overlaps.set_ylabel("Wavefunction Metric", fontsize=13) ax_overlaps.legend( bbox_to_anchor=(1.1, 1.1), loc="upper left", borderaxespad=0.075, fontsize=9, ncol=1 + (len(overlaps) // 12), ) ax_overlaps.tick_params( labelleft=True, labelright=True, labeltop=True, labelbottom=False, bottom=True, top=True, left=True, right=True, ) ax_field.tick_params( labelleft=True, labelright=True, labeltop=False, labelbottom=True, bottom=True, top=True, left=True, right=True, ) # Find at most n+1 ticks on the y-axis at 'nice' locations max_yticks = 4 yloc = plt.MaxNLocator(max_yticks, prune="upper") ax_field.yaxis.set_major_locator(yloc) max_xticks = 6 xloc = plt.MaxNLocator(max_xticks, prune="both") ax_field.xaxis.set_major_locator(xloc) ax_field.tick_params(axis="both", which="major", labelsize=10) ax_overlaps.tick_params(axis="both", which="major", labelsize=10) if show_title: title = ax_overlaps.set_title(self.sim.name) title.set_y(1.15) postfix = "" if log: postfix += "__log" figman.name += postfix
def plot_b2_vs_time( self, log=False, time_unit="asec", show_vector_potential=False, show_title=False, **kwargs, ): with si.vis.FigureManager(self.name + "__b2_vs_time", **kwargs) as figman: fig = figman.fig t_scale_unit, t_scale_name = u.get_unit_value_and_latex(time_unit) grid_spec = matplotlib.gridspec.GridSpec(2, 1, height_ratios=[4, 1], hspace=0.07) ax_b2 = plt.subplot(grid_spec[0]) ax_pot = plt.subplot(grid_spec[1], sharex=ax_b2) self.attach_electric_potential_plot_to_axis( ax_pot, show_vector_potential=show_vector_potential, time_unit=time_unit) ax_b2.plot(self.times / t_scale_unit, self.b2, color="black", linewidth=2) if log: ax_b2.set_yscale("log") min_overlap = np.min(self.b2) ax_b2.set_ylim(bottom=max(1e-9, min_overlap * 0.1), top=1.0) ax_b2.grid(True, which="both", **si.vis.DEFAULT_GRID_KWARGS) else: ax_b2.set_ylim(0.0, 1.0) ax_b2.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) ax_b2.grid(True, **si.vis.DEFAULT_GRID_KWARGS) ax_b2.set_xlim( self.spec.time_initial / t_scale_unit, self.spec.time_final / t_scale_unit, ) ax_b2.set_ylabel(r"$\left| b(t) \right|^2$", fontsize=13) # Find at most n+1 ticks on the y-axis at 'nice' locations max_yticks = 4 yloc = plt.MaxNLocator(max_yticks, prune="upper") ax_pot.yaxis.set_major_locator(yloc) max_xticks = 6 xloc = plt.MaxNLocator(max_xticks, prune="both") ax_pot.xaxis.set_major_locator(xloc) ax_pot.tick_params(axis="x", which="major", labelsize=10) ax_pot.tick_params(axis="y", which="major", labelsize=10) ax_b2.tick_params(axis="both", which="major", labelsize=10) ax_b2.tick_params( labelleft=True, labelright=True, labeltop=True, labelbottom=False, bottom=True, top=True, left=True, right=True, ) ax_pot.tick_params( labelleft=True, labelright=True, labeltop=False, labelbottom=True, bottom=True, top=True, left=True, right=True, ) if show_title: title = ax_b2.set_title(self.name) title.set_y(1.15) postfix = "" if log: postfix += "__log" figman.name += postfix
def energy_spectrum( self, states: str = "all", time_index: int = -1, energy_scale: str = "eV", time_scale: str = "asec", bins: int = 100, log: bool = False, energy_lower_bound: Optional[float] = None, energy_upper_bound: Optional[float] = None, group_angular_momentum: bool = True, angular_momentum_cutoff: Optional[int] = None, **kwargs, ): energy_unit, energy_unit_str = u.get_unit_value_and_latex(energy_scale) time_unit, time_unit_str = u.get_unit_value_and_latex(time_scale) if states == "all": state_list = self.spec.test_states elif states == "bound": state_list = self.sim.bound_states elif states == "free": state_list = self.sim.free_states else: raise ValueError("states must be one of 'all', 'bound', or 'free'") state_overlaps = self.sim.data.state_overlaps state_overlaps = {k: state_overlaps[k] for k in state_list } # filter down to just states in state_list if group_angular_momentum: overlap_by_angular_momentum_by_energy = collections.defaultdict( functools.partial(collections.defaultdict, float)) for state, overlap_vs_time in state_overlaps.items(): overlap_by_angular_momentum_by_energy[state.l][ state.energy] += overlap_vs_time[time_index] energies = [] overlaps = [] cutoff_energies = np.array([]) cutoff_overlaps = np.array([]) for l, overlap_by_energy in sorted( overlap_by_angular_momentum_by_energy.items()): if l < angular_momentum_cutoff: e, o = si.utils.dict_to_arrays(overlap_by_energy) energies.append(e / energy_unit) overlaps.append(o) else: e, o = si.utils.dict_to_arrays(overlap_by_energy) cutoff_energies = np.append(cutoff_energies, e) cutoff_overlaps = np.append(cutoff_overlaps, o) if len(cutoff_energies) != 0: energies.append(cutoff_energies) overlaps.append(cutoff_overlaps) if energy_lower_bound is None: energy_lower_bound = min([np.nanmin(e) for e in energies]) if energy_upper_bound is None: energy_upper_bound = max([np.nanmax(e) for e in energies]) labels = [ rf"$ \ell = {l} $" for l in range(angular_momentum_cutoff) ] + [rf"$ \ell \geq {angular_momentum_cutoff} $"] else: overlap_by_energy = collections.defaultdict(float) for state, overlap_vs_time in state_overlaps.items(): overlap_by_energy[state.energy] += overlap_vs_time[time_index] energies, overlaps = si.utils.dict_to_arrays(overlap_by_energy) energies /= energy_unit if energy_lower_bound is None: energy_lower_bound = np.nanmin(energies) if energy_upper_bound is None: energy_upper_bound = np.nanmax(energies) labels = None with si.vis.FigureManager(self.sim.name + "__energy_spectrum", **kwargs) as figman: fig = figman.fig ax = fig.add_subplot(111) hist_n, hist_bins, hist_patches = ax.hist( x=energies, weights=overlaps, bins=bins, stacked=True, log=log, range=(energy_lower_bound, energy_upper_bound), label=labels, ) ax.grid(True, **si.vis.DEFAULT_GRID_KWARGS) x_range = energy_upper_bound - energy_lower_bound ax.set_xlim(energy_lower_bound - 0.05 * x_range, energy_upper_bound + 0.05 * x_range) ax.set_xlabel(rf"Energy $E$ (${energy_unit_str}$)") ax.set_ylabel(r"Wavefunction Overlap") ax.set_title( rf"Wavefunction Overlap by Energy at $ t = {self.sim.times[time_index]/ time_unit:.3f} \, {time_unit_str} $" ) if group_angular_momentum: ax.legend(loc="best", ncol=1 + len(energies) // 8) ax.tick_params(axis="both", which="major", labelsize=10) figman.name += f"__{states}_states__index={time_index}" if log: figman.name += "__log" if group_angular_momentum: figman.name += "__grouped"
def plot_mesh_repr( self, mesh: "meshes.ScalarMesh", name: str = "", title: Optional[str] = None, distance_unit: str = "bohr_radius", colormap=vis.COLORMAP_WAVEFUNCTION, norm=si.vis.AbsoluteRenormalize(), shading: si.vis.ColormapShader = si.vis.ColormapShader.FLAT, plot_limit: Optional[float] = None, slicer: str = "get_mesh_slicer", aspect_ratio: float = si.vis.GOLDEN_RATIO, show_colorbar: bool = True, show_title: bool = True, show_axes: bool = True, title_y_adjust: float = 1.1, title_size: float = 12, axis_label_size: float = 12, tick_label_size: float = 10, grid_kwargs: Optional[dict] = None, **kwargs, ): if grid_kwargs is None: grid_kwargs = {} with si.vis.FigureManager(name=f"{self.spec.name}__{name}", aspect_ratio=aspect_ratio, **kwargs) as figman: fig = figman.fig fig.set_tight_layout(True) axis = plt.subplot(111) unit_value, unit_latex = u.get_unit_value_and_latex(distance_unit) color_mesh = self.attach_mesh_repr_to_axis( axis, mesh, distance_unit=distance_unit, colormap=colormap, norm=norm, shading=shading, plot_limit=plot_limit, slicer=slicer, ) axis.set_xlabel(r"$\ell$", fontsize=axis_label_size) axis.set_ylabel(rf"$r$ (${unit_latex}$)", fontsize=axis_label_size) if title is not None and title != "" and show_axes and show_title: title = axis.set_title(title, fontsize=title_size) title.set_y(title_y_adjust) # move title up a bit # make a colorbar if show_colorbar and show_axes: cbar = fig.colorbar(mappable=color_mesh, ax=axis) cbar.ax.tick_params(labelsize=tick_label_size) axis.grid( True, color=si.vis.CMAP_TO_OPPOSITE[colormap.name], **{ **si.vis.COLORMESH_GRID_KWARGS, **grid_kwargs }, ) # change grid color to make it show up against the colormesh axis.tick_params(labelright=True, labeltop=True) # ticks on all sides axis.tick_params( axis="both", which="major", labelsize=tick_label_size) # increase size of tick labels # axis.tick_params(axis = 'both', which = 'both', length = 0) y_ticks = axis.yaxis.get_major_ticks() y_ticks[0].label1.set_visible(False) y_ticks[0].label2.set_visible(False) y_ticks[-1].label1.set_visible(False) y_ticks[-1].label2.set_visible(False) axis.axis("tight") if not show_axes: axis.axis("off")
def plot_mesh( self, mesh: "meshes.ScalarMesh", name: str = "", title: Optional[str] = None, distance_unit: u.Unit = "bohr_radius", colormap=vis.COLORMAP_WAVEFUNCTION, norm=si.vis.AbsoluteRenormalize(), shading: si.vis.ColormapShader = si.vis.ColormapShader.FLAT, plot_limit=None, slicer="get_mesh_slicer", show_colorbar=True, show_title=True, show_axes=True, grid_kwargs=None, overlay_probability_current=False, **kwargs, ): grid_kwargs = collections.ChainMap(grid_kwargs or {}, si.vis.COLORMESH_GRID_KWARGS) unit_value, unit_name = u.get_unit_value_and_latex(distance_unit) with si.vis.FigureManager(f"{self.spec.name}__{name}", **kwargs) as figman: fig = figman.fig ax = plt.subplot(111) color_mesh = self.attach_mesh_to_axis( ax, mesh, distance_unit=distance_unit, colormap=colormap, norm=norm, shading=shading, plot_limit=plot_limit, slicer=slicer, ) ax.set_xlabel(rf"$z$ (${unit_name}$)") ax.set_ylabel(rf"$\rho$ (${unit_name}$)") if title is not None and title != "" and show_axes and show_title: ax.set_title(title, y=1.1) if show_colorbar and show_axes: cax = fig.add_axes([1.0, 0.1, 0.02, 0.8]) plt.colorbar(mappable=color_mesh, cax=cax) if overlay_probability_current: self.attach_probability_current_to_axis(ax) ax.axis("tight") # removes blank space between color mesh and axes ax.grid( True, color=si.vis.CMAP_TO_OPPOSITE[colormap], **grid_kwargs ) # change grid color to make it show up against the colormesh ax.tick_params(labelright=True, labeltop=True) # ticks on all sides if not show_axes: ax.axis("off") return figman
def initialize_axis(self): if self.which == "g": self.norm.equator_magnitude = np.max( np.abs(self.sim.mesh.g) / vis.DEFAULT_RICHARDSON_MAGNITUDE_DIVISOR) self.mesh = self.attach_method( self.axis, colormap=self.colormap, norm=self.norm, shading=self.shading, plot_limit=self.plot_limit, distance_unit=self.distance_unit, slicer=self.slicer, animated=True, ) self.redraw.append(self.mesh) unit_value, unit_name = u.get_unit_value_and_latex(self.distance_unit) self.axis.set_theta_zero_location("N") self.axis.set_theta_direction("clockwise") self.axis.set_rlabel_position(80) self.axis.grid( True, **self.grid_kwargs ) # change grid color to make it show up against the colormesh angle_labels = [ "{}\u00b0".format(s) for s in (0, 30, 60, 90, 120, 150, 180, 150, 120, 90, 60, 30) ] # \u00b0 is unicode degree symbol self.axis.set_thetagrids(np.arange(0, 359, 30), frac=1.075, labels=angle_labels) self.axis.tick_params(axis="both", which="major", labelsize=20) # increase size of tick labels self.axis.tick_params( axis="y", which="major", colors=si.vis.CMAP_TO_OPPOSITE[self.colormap.name], pad=3, ) # make r ticks a color that shows up against the colormesh self.axis.set_rlabel_position(80) if self.tick_labels is None: max_yticks = 5 yloc = plt.MaxNLocator(max_yticks, symmetric=False, prune="both") self.axis.yaxis.set_major_locator(yloc) plt.gcf().canvas.draw() # must draw early to modify the axis text self.tick_labels = self.axis.get_yticklabels() for t in self.tick_labels: t.set_text(t.get_text() + r"${}$".format(unit_name)) self.axis.set_yticklabels(self.tick_labels) self.axis.set_rmax( (self.sim.mesh.r_max - (self.sim.mesh.delta_r / 2)) / unit_value) self.axis.axis("tight") super().initialize_axis() self.redraw += [ *self.axis.xaxis.get_gridlines(), *self.axis.yaxis.get_gridlines(), *self.axis.yaxis.get_ticklabels(), ] # gridlines must be redrawn over the mesh (it's important that they're AFTER the mesh itself in self.redraw)
def test_get_unit_value_and_latex(): assert u.get_unit_value_and_latex("m") == (u.m, r"\mathrm{m}")
def radial_probability_current_vs_time__combined( self, r_upper_limit: Optional[float] = None, t_lower_limit: Optional[float] = None, t_upper_limit: Optional[float] = None, distance_unit: str = "bohr_radius", time_unit: u.Unit = "asec", current_unit: str = "per_asec", z_cut: float = 0.7, colormap=plt.get_cmap("coolwarm"), overlay_electric_field: bool = True, efield_unit: str = "atomic_electric_field", efield_color: str = "black", efield_label_fontsize: float = 12, title_fontsize: float = 12, y_axis_label_fontsize: float = 14, x_axis_label_fontsize: float = 12, cbar_label_fontsize: float = 12, aspect_ratio: float = 1.2, shading: str = "flat", use_name: bool = False, **kwargs, ): prefix = self.sim.file_name if use_name: prefix = self.sim.name distance_unit_value, distance_unit_latex = u.get_unit_value_and_latex( distance_unit) time_unit_value, time_unit_latex = u.get_unit_value_and_latex( time_unit) current_unit_value, current_unit_latex = u.get_unit_value_and_latex( current_unit) efield_unit_value, efield_unit_latex = u.get_unit_value_and_latex( efield_unit) if t_lower_limit is None: t_lower_limit = self.sim.data_times[0] if t_upper_limit is None: t_upper_limit = self.sim.data_times[-1] with si.vis.FigureManager( prefix + "__radial_probability_current_vs_time__combined", aspect_ratio=aspect_ratio, **kwargs, ) as figman: fig = figman.fig plt.set_cmap(colormap) gridspec = plt.GridSpec(2, 1, hspace=0.0) ax_pos = fig.add_subplot(gridspec[0]) ax_neg = fig.add_subplot(gridspec[1], sharex=ax_pos) # TICKS, LEGEND, LABELS, and TITLE ax_pos.tick_params( labeltop=True, labelright=False, labelbottom=False, labelleft=True, bottom=False, right=False, ) ax_neg.tick_params( labeltop=False, labelright=False, labelbottom=True, labelleft=True, top=False, right=False, ) # pos_label = ax_pos.set_ylabel(f"$ r, \; z > 0 \; ({distance_unit_latex}) $", fontsize = y_axis_label_fontsize) # neg_label = ax_neg.set_ylabel(f"$ -r, \; z < 0 \; ({distance_unit_latex}) $", fontsize = y_axis_label_fontsize) pos_label = ax_pos.set_ylabel(f"$ z > 0 $", fontsize=y_axis_label_fontsize) neg_label = ax_neg.set_ylabel(f"$ z < 0 $", fontsize=y_axis_label_fontsize) ax_pos.yaxis.set_label_coords(-0.12, 0.65) ax_neg.yaxis.set_label_coords(-0.12, 0.35) r_label = ax_pos.text( -0.22, 0.325, fr"Radius $ \pm r \; ({distance_unit_latex}) $", fontsize=y_axis_label_fontsize, rotation="vertical", transform=ax_pos.transAxes, ) ax_neg.set_xlabel(rf"Time $ t \; ({time_unit_latex}) $", fontsize=x_axis_label_fontsize) suptitle = fig.suptitle( "Radial Probability Current vs. Time and Radius", fontsize=title_fontsize, ) suptitle.set_x(0.6) suptitle.set_y(1.01) # COLORMESHES try: r = self.sim.mesh.r except AttributeError: r = np.linspace(0, self.spec.r_bound, self.spec.r_points) delta_r = r[1] - r[0] r += delta_r / 2 t_mesh, r_mesh = np.meshgrid(self.sim.data_times, r, indexing="ij") # slicer = (slice(), slice(0, 50, 1)) z_max = max( np.nanmax( np.abs(self.sim.data. radial_probability_current_vs_time__pos_z)), np.nanmax( np.abs(self.radial_probability_current_vs_time__neg_z)), ) norm = matplotlib.colors.Normalize( vmin=-z_cut * z_max / current_unit_value, vmax=z_cut * z_max / current_unit_value, ) pos_mesh = ax_pos.pcolormesh( t_mesh / time_unit_value, r_mesh / distance_unit_value, self.sim.data.radial_probability_current_vs_time__pos_z / current_unit_value, norm=norm, shading=shading, ) neg_mesh = ax_neg.pcolormesh( t_mesh / time_unit_value, -r_mesh / distance_unit_value, self.sim.data.radial_probability_current_vs_time__neg_z / current_unit_value, norm=norm, shading=shading, ) # LIMITS AND GRIDS grid_kwargs = si.vis.DEFAULT_GRID_KWARGS for ax in [ax_pos, ax_neg]: ax.set_xlim(t_lower_limit / time_unit_value, t_upper_limit / time_unit_value) ax.grid(True, which="major", **grid_kwargs) if r_upper_limit is None: r_upper_limit = r[-1] ax_pos.set_ylim(0, r_upper_limit / distance_unit_value) ax_neg.set_ylim(-r_upper_limit / distance_unit_value, 0) y_ticks_neg = ax_neg.yaxis.get_major_ticks() y_ticks_neg[-1].label1.set_visible(False) # COLORBAR ax_pos_position = ax_pos.get_position() ax_neg_position = ax_neg.get_position() left, bottom, width, height = ( ax_neg_position.x0, ax_neg_position.y0, ax_neg_position.x1 - ax_neg_position.x0, ax_pos_position.y1 - ax_neg_position.y0, ) ax_cbar = fig.add_axes( [left + width + 0.175, bottom, 0.05, height]) cbar = plt.colorbar(mappable=pos_mesh, cax=ax_cbar, extend="both") z_label = cbar.set_label( rf"Radial Probability Current $ J_r \; ({current_unit_latex}) $", fontsize=cbar_label_fontsize, ) # ELECTRIC FIELD OVERLAY if overlay_electric_field: ax_efield = fig.add_axes((left, bottom, width, height)) ax_efield.tick_params( labeltop=False, labelright=True, labelbottom=False, labelleft=False, left=False, top=False, bottom=False, right=True, ) ax_efield.tick_params(axis="y", colors=efield_color) ax_efield.tick_params(axis="x", colors=efield_color) (efield, ) = ax_efield.plot( self.sim.data_times / time_unit_value, self.sim.data.electric_field_amplitude / efield_unit_value, color=efield_color, linestyle="-", ) efield_grid_kwargs = { **si.vis.DEFAULT_GRID_KWARGS, **{ "color": efield_color, "linestyle": "--" }, } ax_efield.yaxis.grid(True, **efield_grid_kwargs) max_efield = np.nanmax( np.abs(self.sim.data.electric_field_amplitude)) ax_efield.set_xlim(t_lower_limit / time_unit_value, t_upper_limit / time_unit_value) ax_efield.set_ylim( -1.05 * max_efield / efield_unit_value, 1.05 * max_efield / efield_unit_value, ) ax_efield.set_ylabel( rf"Electric Field Amplitude $ {vis.LATEX_EFIELD}(t) \; ({efield_unit_latex}) $", color=efield_color, fontsize=efield_label_fontsize, ) ax_efield.yaxis.set_label_position("right")
def electron_momentum_spectrum_from_meshes( self, theta_mesh, r_mesh, inner_product_mesh, r_type: str, r_scale: float, log: bool = False, shading: si.vis.ColormapShader = si.vis.ColormapShader.FLAT, **kwargs, ): """ Generate a polar plot of the wavefunction decomposed into plane waves. The radial dimension can be displayed in wavenumbers, energy, or momentum. The angle is the angle of the plane wave in the z-x plane (because m=0, the decomposition is symmetric in the x-y plane). :param r_type: type of unit for the radial axis ('wavenumber', 'energy', or 'momentum') :param r_scale: unit specification for the radial dimension :param r_lower_lim: lower limit for the radial dimension :param r_upper_lim: upper limit for the radial dimension :param r_points: number of points for the radial dimension :param theta_points: number of points for the angular dimension :param log: True to displayed logged data, False otherwise (default: False) :param kwargs: kwargs are passed to compy.utils.FigureManager :return: the FigureManager generated during plot creation """ if r_type not in ("wavenumber", "energy", "momentum"): raise ValueError( "Invalid argument to plot_electron_spectrum: r_type must be either 'wavenumber', 'energy', or 'momentum'" ) r_unit_value, r_unit_name = u.get_unit_value_and_latex(r_scale) plot_kwargs = {**dict(aspect_ratio=1), **kwargs} r_mesh = np.real(r_mesh) overlap_mesh = np.abs(inner_product_mesh)**2 with si.vis.FigureManager(self.sim.name + "__electron_spectrum", **plot_kwargs) as figman: fig = figman.fig fig.set_tight_layout(True) axis = plt.subplot(111, projection="polar") axis.set_theta_zero_location("N") axis.set_theta_direction("clockwise") figman.name += f"__{r_type}" norm = None if log: norm = matplotlib.colors.LogNorm(vmin=np.nanmin(overlap_mesh), vmax=np.nanmax(overlap_mesh)) figman.name += "__log" color_mesh = axis.pcolormesh( theta_mesh, r_mesh / r_unit_value, overlap_mesh, shading=shading, norm=norm, cmap="viridis", ) # make a colorbar cbar_axis = fig.add_axes( [1.01, 0.1, 0.04, 0.8] ) # add a new axis for the cbar so that the old axis can stay square cbar = plt.colorbar(mappable=color_mesh, cax=cbar_axis) cbar.ax.tick_params(labelsize=10) axis.grid( True, color=si.vis.COLOR_OPPOSITE_VIRIDIS, **si.vis.COLORMESH_GRID_KWARGS, ) # change grid color to make it show up against the colormesh angle_labels = [ f"{s}\u00b0" for s in (0, 30, 60, 90, 120, 150, 180, 150, 120, 90, 60, 30) ] # \u00b0 is unicode degree symbol axis.set_thetagrids(np.arange(0, 359, 30), frac=1.075, labels=angle_labels) axis.tick_params(axis="both", which="major", labelsize=8) # increase size of tick labels axis.tick_params( axis="y", which="major", colors=si.vis.COLOR_OPPOSITE_VIRIDIS, pad=3 ) # make r ticks a color that shows up against the colormesh axis.tick_params(axis="both", which="both", length=0) axis.set_rlabel_position(80) max_yticks = 5 yloc = plt.MaxNLocator(max_yticks, symmetric=False, prune="both") axis.yaxis.set_major_locator(yloc) fig.canvas.draw() # must draw early to modify the axis text tick_labels = axis.get_yticklabels() for t in tick_labels: t.set_text(t.get_text() + rf"${r_unit_name}$") axis.set_yticklabels(tick_labels) axis.set_rmax(np.nanmax(r_mesh) / r_unit_value) return figman
def plot_b2_vs_time( self, log=False, show_electric_field=True, overlay_kicks=True, time_unit="asec", show_title=False, y_lower_limit=0, y_upper_limit=1, **kwargs, ): with si.vis.FigureManager(self.name + "__b2_vs_time", **kwargs) as figman: fig = figman.fig t_scale_unit, t_scale_name = u.get_unit_value_and_latex(time_unit) grid_spec = matplotlib.gridspec.GridSpec( 2, 1, height_ratios=[4, 1], hspace=0.07) # TODO: switch to fixed axis construction ax_b = plt.subplot(grid_spec[0]) ax_pot = plt.subplot(grid_spec[1], sharex=ax_b) self.attach_electric_potential_plot_to_axis( ax_pot, show_electric_field=show_electric_field, overlay_kicks=overlay_kicks, time_unit=time_unit, ) # the repeats produce the stair-step pattern ax_b.plot( np.repeat(self.data_times, 2)[1:-1] / t_scale_unit, np.repeat(self.b2, 2)[:-2], marker="o", markersize=2, markevery=2, linestyle=":", color="black", linewidth=1, ) if log: ax_b.set_yscale("log") min_overlap = np.min(self.b2) ax_b.set_ylim(bottom=max(1e-9, min_overlap * 0.1), top=1.0) ax_b.grid(True, which="both", **si.vis.DEFAULT_GRID_KWARGS) else: ax_b.set_ylim(0.0, 1.0) ax_b.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) ax_b.grid(True, **si.vis.DEFAULT_GRID_KWARGS) ax_b.set_xlim( self.spec.time_initial / t_scale_unit, self.spec.time_final / t_scale_unit, ) ax_b.set_ylim(y_lower_limit, y_upper_limit) ax_b.set_ylabel(r"$\left| b(t) \right|^2$", fontsize=13) # Find at most n+1 ticks on the y-axis at 'nice' locations max_yticks = 4 yloc = plt.MaxNLocator(max_yticks, prune="upper") ax_pot.yaxis.set_major_locator(yloc) max_xticks = 6 xloc = plt.MaxNLocator(max_xticks, prune="both") ax_pot.xaxis.set_major_locator(xloc) ax_pot.tick_params(axis="x", which="major", labelsize=10) ax_pot.tick_params(axis="y", which="major", labelsize=10) ax_b.tick_params(axis="both", which="major", labelsize=10) ax_b.tick_params( labelleft=True, labelright=True, labeltop=True, labelbottom=False, bottom=True, top=True, left=True, right=True, ) ax_pot.tick_params( labelleft=True, labelright=True, labeltop=False, labelbottom=True, bottom=True, top=True, left=True, right=True, ) if show_title: title = ax_b.set_title(self.name) title.set_y(1.15) postfix = "" if log: postfix += "__log" figman.name += postfix
def initialize_axis(self): unit_value, unit_name = u.get_unit_value_and_latex(self.distance_unit) if self.which == "g": self.norm.equator_magnitude = np.max( np.abs(self.sim.mesh.g) / vis.DEFAULT_RICHARDSON_MAGNITUDE_DIVISOR) self.mesh = self.attach_method( self.axis, colormap=self.colormap, norm=self.norm, shading=self.shading, plot_limit=self.plot_limit, distance_unit=self.distance_unit, slicer=self.slicer, animated=True, ) pot = self.spec.internal_potential(x=self.sim.mesh.x_mesh, z=self.sim.mesh.z_mesh, r=self.sim.mesh.r_mesh) pot[np.abs(pot) < 0.1 * np.nanmax(np.abs(pot))] = np.NaN self.potential_mesh = self.sim.mesh.plot.attach_mesh_to_axis( self.axis, np.ma.masked_invalid(pot), colormap=plt.get_cmap("coolwarm_r"), norm=plt.Normalize(vmin=-np.max(np.abs(pot)), vmax=np.max(np.abs(pot))), shading=self.shading, plot_limit=self.plot_limit, distance_unit=self.distance_unit, slicer=self.slicer, animated=True, ) self.redraw.append(self.mesh) self.redraw.append(self.potential_mesh) self.axis.set_xlabel(r"$x$ (${}$)".format(unit_name), fontsize=24) self.axis.set_ylabel(r"$z$ (${}$)".format(unit_name), fontsize=24) self.axis.tick_params(axis="both", which="major", labelsize=20) self.axis.axis("tight") super().initialize_axis() if self.which not in ("g", "psi"): divider = make_axes_locatable(self.axis) cax = divider.append_axes("right", size="2%", pad=0.05) self.cbar = plt.colorbar(cax=cax, mappable=self.mesh) self.cbar.ax.tick_params(labelsize=20) if self.axis_off: # self.axis.set_axis_off() # self.axis.set_visible(False) # self.axis.get_yaxis().set_visible(False) self.axis.axis("off") else: self.redraw += [ *self.axis.xaxis.get_gridlines(), *self.axis.yaxis.get_gridlines(), ] # gridlines must be redrawn over the mesh (it's important that they're AFTER the mesh itself in self.redraw) if self.show_grid: self.axis.grid( True, **self.grid_kwargs ) # change grid color to make it show up against the colormesh