def patched_plot(self, *args, **kwargs): """ PyARPES patch for `lmfit` summary plots. Scientists like to have LaTeX in their plots, but because underscores outside TeX environments crash matplotlib renders, we need to do some fiddling with titles and axis labels in order to prevent users having to switch TeX on and off all the time. Additionally, this patch provides better support for multidimensional curve fitting. :param self: :param args: :param kwargs: :return: """ from arpes.plotting.utils import transform_labels try: if self.model.n_dims != 1: from arpes.plotting.utils import fancy_labels fig, ax = plt.subplots(2,2, figsize=(10,8)) def to_dr(flat_data): shape = [len(self.independent[d]) for d in self.independent_order] return xr.DataArray(flat_data.reshape(shape), coords=self.independent, dims=self.independent_order) to_dr(self.init_fit).plot(ax=ax[1][0]) to_dr(self.data).plot(ax=ax[0][1]) to_dr(self.best_fit).plot(ax=ax[0][0]) to_dr(self.residual).plot(ax=ax[1][1]) ax[0][0].set_title('Best fit') ax[0][1].set_title('Data') ax[1][0].set_title('Initial fit') ax[1][1].set_title('Residual (Data - Best fit)') for axi in ax.ravel(): fancy_labels(axi) plt.tight_layout() return ax except AttributeError: pass ret = original_plot(self, *args, **kwargs) transform_labels(transform_lmfit_titles) return ret
def data(self, new_data): if self._initialized: self._data = new_data else: self._data = new_data self._initialized = True self.n_dims = len(new_data.dims) if self.n_dims == 2: self._axis_image = imshow_arr(self._data, ax=self.ax, **self.ax_kwargs)[1] fancy_labels(self.ax) else: self.ax_kwargs.pop('cmap', None) x, y = self.data.coords[ self.data.dims[0]].values, self.data.values self._axis_image = self.ax.plot(x, y, **self.ax_kwargs) self.ax.set_xlabel(self.data.dims[0]) cs = self.data.coords[self.data.dims[0]].values self.ax.set_xlim([np.min(cs), np.max(cs)]) fancy_labels(self.ax) if self.n_dims == 2: x, y = self._data.coords[self._data.dims[ 0]].values, self._data.coords[self._data.dims[1]].values extent = [y[0], y[-1], x[0], x[-1]] self._axis_image.set_extent(extent) self._axis_image.set_data(self._data.values) else: color = self.ax.lines[0].get_color() self.ax.lines.remove(self.ax.lines[0]) x, y = self.data.coords[self.data.dims[0]].values, self.data.values l, h = np.min(y), np.max(y) self._axis_image = self.ax.plot(x, y, c=color, **self.ax_kwargs) self.ax.set_ylim([l - 0.1 * (h - l), h + 0.1 * (h - l)]) if self.auto_autoscale: self.autoscale()
def offset_scatter_plot(data: DataType, name_to_plot=None, stack_axis=None, fermi_level=True, cbarmap=None, ax=None, out=None, scale_coordinate=0.5, ylim=None, aux_errorbars=True, **kwargs): assert isinstance(data, xr.Dataset) if name_to_plot is None: var_names = [k for k in data.data_vars.keys() if '_std' not in k] assert len(var_names) == 1 name_to_plot = var_names[0] assert (name_to_plot + '_std') in data.data_vars.keys() if len(data.data_vars[name_to_plot].dims) != 2: raise ValueError('In order to produce a stack plot, data must be image-like.' 'Passed data included dimensions: {}'.format(data.data_vars[name_to_plot].dims)) fig = None inset_ax = None if ax is None: fig, ax = plt.subplots(figsize=kwargs.get('figsize', (11, 5,))) if inset_ax is None: inset_ax = inset_axes(ax, width='40%', height='5%', loc='upper left') if stack_axis is None: stack_axis = data.data_vars[name_to_plot].dims[0] skip_colorbar = True if cbarmap is None: skip_colorbar = False try: cbarmap = colorbarmaps_for_axis[stack_axis] except: cbarmap = generic_colorbarmap_for_data(data.coords[stack_axis], ax=inset_ax, ticks=kwargs.get('ticks')) cbar, cmap = cbarmap if not isinstance(cmap, matplotlib.colors.Colormap): # do our best try: cmap = cmap() except: # might still be fine pass # should be exactly two other_dim = [d for d in data.dims if d != stack_axis][0] other_coord = data.coords[other_dim] if 'eV' in data.dims and 'eV' != stack_axis and fermi_level: ax.axhline(0, linestyle='--', color='red') ax.fill_betweenx([-1e6, 1e6], 0, 0.2, color='black', alpha=0.07) ax.set_ylim(ylim) # real plotting here for i, (coord, value) in enumerate(data.T.iterate_axis(stack_axis)): delta = data.T.stride(generic_dim_names=False)[other_dim] data_for = value.copy(deep=True) data_for.coords[other_dim] = data_for.coords[other_dim].copy(deep=True) data_for.coords[other_dim].values = data_for.coords[other_dim].values.copy() data_for.coords[other_dim].values -= i * delta * scale_coordinate / 10 scatter_with_std(data_for, name_to_plot, ax=ax, color=cmap(coord[stack_axis])) if aux_errorbars: assert ylim is not None data_for = data_for.copy(deep=True) flattened = data_for.data_vars[name_to_plot].copy(deep=True) flattened.values = ylim[0] * np.ones(flattened.values.shape) data_for = data_for.assign(**{name_to_plot: flattened}) scatter_with_std(data_for, name_to_plot, ax=ax, color=cmap(coord[stack_axis]), fmt='none') ax.set_xlabel(other_dim) ax.set_ylabel(name_to_plot) fancy_labels(ax) try: if inset_ax and not skip_colorbar: inset_ax.set_xlabel(stack_axis, fontsize=16) fancy_labels(inset_ax) cbar(ax=inset_ax, **kwargs) except TypeError: # colorbar already rendered pass if out is not None: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) return fig, ax
def flat_stack_plot(data: DataType, stack_axis=None, fermi_level=True, cbarmap=None, ax=None, mode='line', title=None, out=None, transpose=False, **kwargs): data = normalize_to_spectrum(data) if len(data.dims) != 2: raise ValueError('In order to produce a stack plot, data must be image-like.' 'Passed data included dimensions: {}'.format(data.dims)) fig = None inset_ax = None if ax is None: fig, ax = plt.subplots(figsize=kwargs.get('figsize', (7, 5,))) inset_ax = inset_axes(ax, width='40%', height='5%', loc=1) if stack_axis is None: stack_axis = data.dims[0] skip_colorbar = True if cbarmap is None: skip_colorbar = False try: cbarmap = colorbarmaps_for_axis[stack_axis] except KeyError: cbarmap = generic_colorbarmap_for_data(data.coords[stack_axis], ax=inset_ax, ticks=kwargs.get('ticks')) cbar, cmap = cbarmap # should be exactly two other_dim = [d for d in data.dims if d != stack_axis][0] other_coord = data.coords[other_dim] if not isinstance(cmap, matplotlib.colors.Colormap): # do our best try: cmap = cmap() except: # might still be fine pass if 'eV' in data.dims and 'eV' != stack_axis and fermi_level: if transpose: ax.axhline(0, color='red', alpha=0.8, linestyle='--', linewidth=1) else: ax.axvline(0, color='red', alpha=0.8, linestyle='--', linewidth=1) # meat of the plotting for coord_dict, marginal in list(data.T.iterate_axis(stack_axis)): if transpose: if mode == 'line': ax.plot(marginal.values, marginal.coords[marginal.dims[0]].values, color=cmap(coord_dict[stack_axis]), **kwargs) else: assert mode == 'scatter' raise NotImplementedError() else: if mode == 'line': marginal.plot(ax=ax, color=cmap(coord_dict[stack_axis]), **kwargs) else: assert mode == 'scatter' ax.scatter(*marginal.T.to_arrays(), color=cmap(coord_dict[stack_axis]), **kwargs) ax.set_xlabel(marginal.dims[0]) ax.set_xlabel(label_for_dim(data, ax.get_xlabel())) ax.set_ylabel('Spectrum Intensity (arb).') ax.set_title(title, fontsize=14) ax.set_xlim([other_coord.min().item(), other_coord.max().item()]) try: if inset_ax is not None and not skip_colorbar: inset_ax.set_xlabel(stack_axis, fontsize=16) fancy_labels(inset_ax) cbar(ax=inset_ax, **kwargs) except TypeError: # already rendered pass if out is not None: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) return fig, ax
def plot_spatial_reference(reference_map: DataType, data_list: List[DataType], offset_list: Optional[List[Dict[str, Any]]] = None, annotation_list: Optional[List[str]] = None, out: Optional[str] = None, plot_refs: bool = True): """ Helpfully plots data against a reference scanning dataset. This is essential to understand where data was taken and can be used early in the analysis phase in order to highlight the location of your datasets against core levels, etc. :param reference_map: A scanning photoemission like dataset :param data_list: A list of datasets you want to plot the relative locations of :param offset_list: Optionally, offsets given as coordinate dicts :param annotation_list: Optionally, text annotations for the data :param out: Where to save the figure if we are outputting to disk :param plot_refs: Whether to plot reference figures for each of the pieces of data in `data_list` :return: """ if offset_list is None: offset_list = [None] * len(data_list) if annotation_list is None: annotation_list = [str(i + 1) for i in range(len(data_list))] normalize_to_spectrum(reference_map) n_references = len(data_list) if n_references == 1 and plot_refs: fig, axes = plt.subplots(1, 2, figsize=( 12, 5, )) ax = axes[0] ax_refs = [axes[1]] elif plot_refs: n_extra_axes = 1 + (n_references // 4) fig = plt.figure(figsize=( 6 * (1 + n_extra_axes), 5, )) spec = gridspec.GridSpec(ncols=2 * (1 + n_extra_axes), nrows=2, figure=fig) ax = fig.add_subplot(spec[:2, :2]) ax_refs = [ fig.add_subplot(spec[i // (2 * n_extra_axes), 2 + i % (2 * n_extra_axes)]) for i in range(n_references) ] else: ax_refs = [] fig, ax = plt.subplots(1, 1, figsize=( 6, 5, )) try: reference_map = reference_map.S.spectra[0] except Exception: pass reference_map = reference_map.S.mean_other(['x', 'y', 'z']) ref_dims = reference_map.dims[::-1] assert len(reference_map.dims) == 2 reference_map.S.plot(ax=ax, cmap='Blues') cmap = cm.get_cmap('Reds') rendered_annotations = [] for i, (data, offset, annotation) in enumerate( zip(data_list, offset_list, annotation_list)): if offset is None: try: offset = data.S.logical_offsets - reference_map.S.logical_offsets except ValueError: offset = {} coords = {c: unwrap_xarray_item(data.coords[c]) for c in ref_dims} n_array_coords = len([ cv for cv in coords.values() if isinstance(cv, (np.ndarray, xr.DataArray)) ]) color = cmap(0.4 + (0.5 * i / len(data_list))) x = coords[ref_dims[0]] + offset.get(ref_dims[0], 0) y = coords[ref_dims[1]] + offset.get(ref_dims[1], 0) ref_x, ref_y = x, y off_x, off_y = 0, 0 scale = 0.03 if n_array_coords == 0: off_y = 1 ax.scatter([x], [y], s=60, color=color) if n_array_coords == 1: if isinstance(x, (np.ndarray, xr.DataArray)): y = [y] * len(x) ref_x = np.min(x) off_x = -1 else: x = [x] * len(y) ref_y = np.max(y) off_y = 1 ax.plot(x, y, color=color, linewidth=3) if n_array_coords == 2: off_y = 1 min_x, max_x = np.min(x), np.max(x) min_y, max_y = np.min(y), np.max(y) ref_x, ref_y = min_x, max_y color = cmap(0.4 + (0.5 * i / len(data_list)), alpha=0.5) rect = patches.Rectangle((min_x, min_y), max_x - min_x, max_y - min_y, facecolor=color) color = cmap(0.4 + (0.5 * i / len(data_list))) ax.add_patch(rect) dp = ddata_daxis_units(ax) text_location = np.asarray([ ref_x, ref_y, ]) + dp * scale * np.asarray([off_x, off_y]) text = ax.annotate(annotation, text_location, color='black', size=15) rendered_annotations.append(text) text.set_path_effects([ path_effects.Stroke(linewidth=2, foreground='white'), path_effects.Normal() ]) if plot_refs: ax_ref = ax_refs[i] keep_preference = list(ref_dims) + [ 'eV', 'temperature', 'kz', 'hv', 'kp', 'kx', 'ky', 'phi', 'theta', 'beta', 'pixel', ] keep = [d for d in keep_preference if d in data.dims][:2] data.S.mean_other(keep).S.plot(ax=ax_ref) ax_ref.set_title(annotation) fancy_labels(ax_ref) frame_with(ax_ref, color=color, linewidth=3) ax.set_title('') remove_colorbars() fancy_labels(ax) plt.tight_layout() try: from adjustText import adjust_text adjust_text(rendered_annotations, ax=ax, avoid_points=False, avoid_objects=False, avoid_self=False, autoalign='xy') except ImportError: pass if out is not None: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) return fig, [ax] + ax_refs