def hue_brightness_plot(data: xr.Dataset, ax=None, out=None, **kwargs): assert ('intensity' in data and 'polarization' in data) fig = None if ax is None: fig, ax = plt.subplots(figsize=kwargs.get('figsize', ( 7, 5, ))) x, y = data.coords[data.intensity.dims[0]].values, data.coords[ data.intensity.dims[1]].values extent = [y[0], y[-1], x[0], x[-1]] ax.imshow(polarization_intensity_to_color(data, **kwargs), extent=extent, aspect='auto', origin='lower') ax.set_xlabel(data.intensity.dims[1]) ax.set_ylabel(data.intensity.dims[0]) ax.grid(False) if out is not None: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) return fig, ax
def plot_data_to_bz2d(data: DataType, cell, rotate=None, shift=None, scale=None, ax=None, mask=True, out=None, bz_number=None, **kwargs): data = normalize_to_spectrum(data) assert('You must k-space convert data before plotting to BZs' and data.S.is_kspace) if bz_number is None: bz_number = (0,0) fig = None if ax is None: fig, ax = plt.subplots(figsize=(9,9)) bz2d_plot(cell, paths='all', ax=ax) if len(cell) == 2: cell = [list(c) + [0] for c in cell] + [[0, 0, 1]] icell = np.linalg.inv(cell).T # Prep coordinates and mask raveled = data.T.meshgrid(as_dataset=True) dims = data.dims if rotate is not None: c, s = np.cos(rotate), np.sin(rotate) rotation = np.array([(c, -s), (s, c)]) raveled = raveled.T.transform_coords(dims, rotation) if scale is not None: raveled = raveled.T.scale_coords(dims, scale) if shift is not None: raveled = raveled.T.shift_coords(dims, shift) copied = data.values.copy() if mask: built_mask = apply_mask_to_coords(raveled, build_2dbz_poly(cell=cell), dims) copied[built_mask.T] = np.nan cmap = kwargs.get('cmap', matplotlib.cm.Blues) if isinstance(cmap, str): cmap = matplotlib.cm.get_cmap(cmap) cmap.set_bad((1, 1, 1, 0)) delta_x = np.dot(np.array(bz_number), icell[:2, 0]) delta_y = np.dot(np.array(bz_number), icell[:2, 1]) ax.pcolormesh(raveled.data_vars[dims[0]].values + delta_x, raveled.data_vars[dims[1]].values + delta_y, copied.T, cmap=cmap) if out is not None: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) return fig, ax
def fermi_surface_slices(arr: xr.DataArray, n_slices=9, ev_per_slice=0.02, bin=0.01, out=None, **kwargs): import holoviews as hv # pylint: disable=import-error slices = [] for i in range(n_slices): high = -ev_per_slice * i low = high - bin image = hv.Image(arr.sum([ d for d in arr.dims if d not in ['theta', 'beta', 'phi', 'eV', 'kp', 'kx', 'ky'] ]).sel(eV=slice(low, high)).sum('eV'), label='%g eV' % high) slices.append(image) layout = hv.Layout(slices).cols(3) if out is not None: renderer = hv.renderer('matplotlib').instance(fig='svg', holomap='gif') filename = path_for_plot(out) renderer.save(layout, path_for_holoviews(filename)) return filename else: return layout
def scatter_with_std(data: DataType, name_to_plot=None, ax=None, fmt='o', out=None, **kwargs): if name_to_plot is None: var_names = [k for k in data.data_vars.keys() if '_std' not in k] assert len(var_names) == 1 name_to_plot = var_names[0] assert (name_to_plot + '_std') in data.data_vars.keys() fig = None if ax is None: fig, ax = plt.subplots(figsize=kwargs.pop('figsize', (7, 5,))) x, y = data.data_vars[name_to_plot].T.to_arrays() std = data.data_vars[name_to_plot + '_std'].values ax.errorbar(x, y, yerr=std, fmt=fmt, markeredgecolor='black', **kwargs) if out is not None: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) ax.set_xlim([np.min(x), np.max(x)]) return fig, ax
def false_color_plot(data_r: xr.Dataset, data_g: xr.Dataset, data_b: xr.Dataset, ax=None, out=None, invert=False, pmin=0, pmax=1, **kwargs): data_r, data_g, data_b = [normalize_to_spectrum(d) for d in (data_r, data_g, data_b)] fig = None if ax is None: fig, ax = plt.subplots(figsize=kwargs.get('figsize', (7, 5,))) def normalize_channel(channel): channel -= np.percentile(channel, 100 * pmin) channel[channel > np.percentile(channel, 100 * pmax)] = np.percentile(channel, 100 * pmax) channel = channel / np.max(channel) return channel cs = dict(data_r.coords) cs['dim_color'] = [1, 2, 3] arr = xr.DataArray( np.stack([normalize_channel(data_r.values), normalize_channel(data_g.values), normalize_channel(data_b.values)], axis=-1), coords=cs, dims=list(data_r.dims) + ['dim_color'], ) if invert: vs = arr.values vs[vs > 1] = 1 hsv = matplotlib.colors.rgb_to_hsv(vs) hsv[:,:,2] = 1 - hsv[:,:,2] arr.values = matplotlib.colors.hsv_to_rgb(hsv) imshow_arr(arr, ax=ax) if out is not None: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) return fig, ax
def plot_with_std(data: DataType, name_to_plot=None, ax=None, out=None, **kwargs): if name_to_plot is None: var_names = [k for k in data.data_vars.keys() if '_std' not in k] assert len(var_names) == 1 name_to_plot = var_names[0] assert (name_to_plot + '_std') in data.data_vars.keys() fig = None if ax is None: fig, ax = plt.subplots(figsize=kwargs.pop('figsize', (7, 5,))) data.data_vars[name_to_plot].plot(ax=ax, **kwargs) x, y = data.data_vars[name_to_plot].T.to_arrays() std = data.data_vars[name_to_plot + '_std'].values ax.fill_between(x, y - std, y + std, alpha=0.3, **kwargs) if out is not None: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) ax.set_xlim([np.min(x), np.max(x)]) return fig, ax
def spin_difference_spectrum(spin_dr, title=None, ax=None, out=None, scatter=False, **kwargs): if ax is None: _, ax = plt.subplots(figsize=(6, 4)) try: as_intensity = to_intensity_polarization(spin_dr) except AssertionError: as_intensity = spin_dr intensity = as_intensity.intensity pol = as_intensity.polarization.copy(deep=True) if len(intensity.dims) == 1: inset_ax = inset_axes(ax, width="30%", height="5%", loc=1) coord = intensity.coords[intensity.dims[0]] points = np.array([coord.values, intensity.values]).T.reshape(-1, 1, 2) pol.values[np.isnan(pol.values)] = 0 pol.values[pol.values > 1] = 1 pol.values[pol.values < -1] = -1 pol_colors = cm.get_cmap('RdBu')(pol.values[:-1]) if scatter: pol_colors = cm.get_cmap('RdBu')(pol.values) ax.scatter(coord.values, intensity.values, c=pol_colors, s=1.5) else: segments = np.concatenate([points[:-1], points[1:]], axis=1) lc = LineCollection(segments, colors=pol_colors) ax.add_collection(lc) ax.set_xlim(coord.min().item(), coord.max().item()) ax.set_ylim(0, intensity.max().item() * 1.15) ax.set_ylabel('ARPES Spectrum Intensity (arb.)') ax.set_xlabel(label_for_dim(spin_dr, dim_name=intensity.dims[0])) ax.set_title(title if title is not None else 'Spin Polarization') polarization_colorbar(inset_ax) if out is not None: savefig(out, dpi=400) plt.clf() return path_for_plot(out) else: plt.show()
def plot_movie(data: xr.DataArray, time_dim, interval=None, fig=None, ax=None, out=None, **kwargs): if not isinstance(data, xr.DataArray): raise TypeError('You must provide a DataArray') if ax is None: fig, ax = plt.subplots(figsize=(7, 7)) cmap = arpes.config.SETTINGS.get('interactive', {}).get('palette', 'viridis') vmax = data.max().item() vmin = data.min().item() if data.S.is_subtracted: cmap = 'RdBu' vmax = np.max([np.abs(vmin), np.abs(vmax)]) vmin = -vmax if 'vmax' in kwargs: vmax = kwargs.pop('vmax') if 'vmin' in kwargs: vmin = kwargs.pop('vmin') plot = data.mean(time_dim).transpose().plot(vmax=vmax, vmin=vmin, cmap=cmap, **kwargs) def init(): plot.set_array(np.asarray([])) return plot, animation_coords = data.coords[time_dim].values def animate(i): coordinate = animation_coords[i] data_for_plot = data.sel(**dict([[time_dim, coordinate]])) plot.set_array(data_for_plot.values.T.ravel()) return plot, if interval: computed_interval = interval else: computed_interval = 100 anim = animation.FuncAnimation(fig, animate, init_func=init, repeat=500, frames=len(animation_coords), interval=computed_interval, blit=True) Writer = animation.writers['ffmpeg'] writer = Writer(fps=1000 / computed_interval, metadata=dict(artist='Me'), bitrate=1800) if out is not None: anim.save(path_for_plot(out), writer=writer) return path_for_plot(out) #plt.show() return anim
def magnify_circular_regions_plot(data: DataType, magnified_points, mag=10, radius=0.05, cmap='viridis', color=None, edgecolor='red', out=None, ax=None, **kwargs): data = normalize_to_spectrum(data) fig = None if ax is None: fig, ax = plt.subplots(figsize=kwargs.get('figsize', ( 7, 5, ))) mesh = data.plot(ax=ax, cmap=cmap) clim = list(mesh.get_clim()) clim[1] = clim[1] / mag mask = np.zeros(shape=(len(data.values.ravel()), )) pts = np.zeros(shape=( len(data.values.ravel()), 2, )) mask = mask > 0 raveled = data.T.ravel() pts[:, 0] = raveled[data.dims[0]] pts[:, 1] = raveled[data.dims[1]] x0, y0 = ax.transAxes.transform((0, 0)) # lower left in pixels x1, y1 = ax.transAxes.transform((1, 1)) # upper right in pixes dx = x1 - x0 dy = y1 - y0 maxd = max(dx, dy) xlim, ylim = ax.get_xlim(), ax.get_ylim() width = radius * maxd / dx * (xlim[1] - xlim[0]) height = radius * maxd / dy * (ylim[1] - ylim[0]) if not isinstance(edgecolor, list): edgecolor = [edgecolor for _ in range(len(magnified_points))] if not isinstance(color, list): color = [color for _ in range(len(magnified_points))] pts[:, 1] = (pts[:, 1]) / (xlim[1] - xlim[0]) pts[:, 0] = (pts[:, 0]) / (ylim[1] - ylim[0]) print(np.min(pts[:, 1]), np.max(pts[:, 1])) print(np.min(pts[:, 0]), np.max(pts[:, 0])) for c, ec, point in zip(color, edgecolor, magnified_points): patch = matplotlib.patches.Ellipse(point, width, height, color=c, edgecolor=ec, fill=False, linewidth=2, zorder=4) patchfake = matplotlib.patches.Ellipse([point[1], point[0]], radius, radius) ax.add_patch(patch) mask = np.logical_or(mask, patchfake.contains_points(pts)) data_masked = data.copy(deep=True) data_masked.values = np.array(data_masked.values, dtype=np.float32) cm = matplotlib.cm.get_cmap(name='viridis') cm.set_bad(color=(1, 1, 1, 0)) data_masked.values[np.swapaxes( np.logical_not(mask.reshape(data.values.shape[::-1])), 0, 1)] = np.nan aspect = ax.get_aspect() extent = [xlim[0], xlim[1], ylim[0], ylim[1]] ax.imshow(data_masked.values, cmap=cm, extent=extent, zorder=3, clim=clim, origin='lower') ax.set_aspect(aspect) for spine in ['left', 'top', 'right', 'bottom']: ax.spines[spine].set_zorder(5) if out is not None: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) return fig, ax
def overlapped_stack_dispersion_plot(data: DataType, stack_axis=None, ax=None, title=None, out=None, max_stacks=100, use_constant_correction=False, transpose=False, negate=False, s=1, scale_factor=None, linewidth=1, palette=None, **kwargs): data = normalize_to_spectrum(data) if stack_axis is None: stack_axis = data.dims[0] other_axes = list(data.dims) other_axes.remove(stack_axis) other_axis = other_axes[0] stack_coord = data.coords[stack_axis] if len(stack_coord.values) > max_stacks: data = rebin(data, reduction=dict([[ stack_axis, int(np.ceil(len(stack_coord.values) / max_stacks))] ])) fig = None if ax is None: fig, ax = plt.subplots(figsize=(7, 7)) if title is None: title = '{} Stack'.format(data.S.label.replace('_', ' ')) max_over_stacks = np.max(data.values) cvalues = data.coords[other_axis].values if scale_factor is None: maximum_deviation = -np.inf for _, marginal in data.T.iterate_axis(stack_axis): marginal_values = -marginal.values if negate else marginal.values marginal_offset, right_marginal_offset = marginal_values[0], marginal_values[-1] if use_constant_correction: true_ys = (marginal_values - marginal_offset) else: true_ys = (marginal_values - np.linspace(marginal_offset, right_marginal_offset, len(marginal_values))) maximum_deviation = np.max([maximum_deviation] + list(np.abs(true_ys))) scale_factor = 0.02 * (np.max(cvalues) - np.min(cvalues)) / maximum_deviation iteration_order = -1 # might need to fiddle with this in certain cases for coord_dict, marginal in list(data.T.iterate_axis(stack_axis))[::iteration_order]: coord_value = coord_dict[stack_axis] xs = cvalues marginal_values = -marginal.values if negate else marginal.values marginal_offset, right_marginal_offset = marginal_values[0], marginal_values[-1] if use_constant_correction: true_ys = (marginal_values - marginal_offset) / max_over_stacks ys = scale_factor * true_ys + coord_value else: true_ys = (marginal_values - np.linspace(marginal_offset, right_marginal_offset, len(marginal_values))) \ / max_over_stacks ys = scale_factor * true_ys + coord_value raw_colors = 'black' if palette: if isinstance(palette, str): palette = cm.get_cmap(palette) raw_colors = palette(np.abs(true_ys / max_over_stacks)) if transpose: xs, ys = ys, xs if isinstance(raw_colors, str): plt.plot(xs, ys, linewidth=linewidth, color=raw_colors, **kwargs) else: plt.scatter(xs, ys, color=raw_colors, s=s, **kwargs) x_label = other_axis y_label = stack_axis if transpose: x_label, y_label = y_label, x_label ax.set_xlabel(label_for_dim(data, x_label)) ax.set_ylabel(label_for_dim(data, y_label)) ax.set_title(title) if out is not None: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) plt.show() return fig, ax
def stack_dispersion_plot(data: DataType, stack_axis=None, ax=None, title=None, out=None, max_stacks=100, transpose=False, use_constant_correction=False, correction_side=None, color=None, c=None, label=None, shift=0, no_scatter=False, negate=False, s=1, scale_factor=None, linewidth=1, palette=None, zero_offset=False, uniform = False, **kwargs): data = normalize_to_spectrum(data) if stack_axis is None: stack_axis = data.dims[0] other_axes = list(data.dims) other_axes.remove(stack_axis) other_axis = other_axes[0] stack_coord = data.coords[stack_axis] if len(stack_coord.values) > max_stacks: data = rebin(data, reduction=dict([[ stack_axis, int(np.ceil(len(stack_coord.values) / max_stacks))] ])) fig = None if ax is None: fig, ax = plt.subplots(figsize=(7, 7)) if title is None: title = '{} Stack'.format(data.S.label.replace('_', ' ')) max_over_stacks = np.max(data.values) cvalues = data.coords[other_axis].values if scale_factor is None: maximum_deviation = -np.inf for _, marginal in data.T.iterate_axis(stack_axis): marginal_values = -marginal.values if negate else marginal.values marginal_offset, right_marginal_offset = marginal_values[0], marginal_values[-1] if use_constant_correction: true_ys = (marginal_values - marginal_offset) elif zero_offset: true_ys = marginal_values else: true_ys = (marginal_values - np.linspace(marginal_offset, right_marginal_offset, len(marginal_values))) maximum_deviation = np.max([maximum_deviation] + list(np.abs(true_ys))) scale_factor = 0.02 * (np.max(cvalues) - np.min(cvalues)) / maximum_deviation iteration_order = -1 # might need to fiddle with this in certain cases lim = [-np.inf, np.inf] labeled = False for i, (coord_dict, marginal) in enumerate(list(data.T.iterate_axis(stack_axis))[::iteration_order]): coord_value = coord_dict[stack_axis] xs = cvalues marginal_values = -marginal.values if negate else marginal.values marginal_offset, right_marginal_offset = marginal_values[0], marginal_values[-1] if use_constant_correction: offset = right_marginal_offset if correction_side == 'right' else marginal_offset true_ys = (marginal_values - offset) / max_over_stacks ys = scale_factor * true_ys + coord_value elif zero_offset: true_ys = marginal_values / max_over_stacks ys = scale_factor * true_ys + coord_value elif uniform: true_ys = marginal_values / max_over_stacks ys = scale_factor * true_ys + i else: true_ys = (marginal_values - np.linspace(marginal_offset, right_marginal_offset, len(marginal_values))) \ / max_over_stacks ys = scale_factor * true_ys + coord_value raw_colors = color or c or 'black' if palette: if isinstance(palette, str): palette = cm.get_cmap(palette) raw_colors = palette(np.abs(true_ys / max_over_stacks)) if transpose: xs, ys = ys, xs xs = xs - i * shift lim = [max(lim[0], np.min(xs)), min(lim[1], np.max(xs))] label_for = '_nolegend_' if not labeled: labeled = True label_for = label color_for_plot = raw_colors if callable(color_for_plot): color_for_plot = color_for_plot(coord_value) if isinstance(raw_colors, (str, tuple)) or no_scatter: ax.plot(xs, ys, linewidth=linewidth, color=color_for_plot, label=label_for, **kwargs) else: ax.scatter(xs, ys, color=color_for_plot, s=s, label=label_for, **kwargs) x_label = other_axis y_label = stack_axis if transpose: x_label, y_label = y_label, x_label ax.set_xlabel(label_for_dim(data, x_label)) ax.set_ylabel(label_for_dim(data, y_label)) if transpose: ax.set_ylim(lim) else: ax.set_xlim(lim) ax.set_title(title) if out is not None: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) return fig, ax
def offset_scatter_plot(data: DataType, name_to_plot=None, stack_axis=None, fermi_level=True, cbarmap=None, ax=None, out=None, scale_coordinate=0.5, ylim=None, aux_errorbars=True, **kwargs): assert isinstance(data, xr.Dataset) if name_to_plot is None: var_names = [k for k in data.data_vars.keys() if '_std' not in k] assert len(var_names) == 1 name_to_plot = var_names[0] assert (name_to_plot + '_std') in data.data_vars.keys() if len(data.data_vars[name_to_plot].dims) != 2: raise ValueError('In order to produce a stack plot, data must be image-like.' 'Passed data included dimensions: {}'.format(data.data_vars[name_to_plot].dims)) fig = None inset_ax = None if ax is None: fig, ax = plt.subplots(figsize=kwargs.get('figsize', (11, 5,))) if inset_ax is None: inset_ax = inset_axes(ax, width='40%', height='5%', loc='upper left') if stack_axis is None: stack_axis = data.data_vars[name_to_plot].dims[0] skip_colorbar = True if cbarmap is None: skip_colorbar = False try: cbarmap = colorbarmaps_for_axis[stack_axis] except: cbarmap = generic_colorbarmap_for_data(data.coords[stack_axis], ax=inset_ax, ticks=kwargs.get('ticks')) cbar, cmap = cbarmap if not isinstance(cmap, matplotlib.colors.Colormap): # do our best try: cmap = cmap() except: # might still be fine pass # should be exactly two other_dim = [d for d in data.dims if d != stack_axis][0] other_coord = data.coords[other_dim] if 'eV' in data.dims and 'eV' != stack_axis and fermi_level: ax.axhline(0, linestyle='--', color='red') ax.fill_betweenx([-1e6, 1e6], 0, 0.2, color='black', alpha=0.07) ax.set_ylim(ylim) # real plotting here for i, (coord, value) in enumerate(data.T.iterate_axis(stack_axis)): delta = data.T.stride(generic_dim_names=False)[other_dim] data_for = value.copy(deep=True) data_for.coords[other_dim] = data_for.coords[other_dim].copy(deep=True) data_for.coords[other_dim].values = data_for.coords[other_dim].values.copy() data_for.coords[other_dim].values -= i * delta * scale_coordinate / 10 scatter_with_std(data_for, name_to_plot, ax=ax, color=cmap(coord[stack_axis])) if aux_errorbars: assert ylim is not None data_for = data_for.copy(deep=True) flattened = data_for.data_vars[name_to_plot].copy(deep=True) flattened.values = ylim[0] * np.ones(flattened.values.shape) data_for = data_for.assign(**{name_to_plot: flattened}) scatter_with_std(data_for, name_to_plot, ax=ax, color=cmap(coord[stack_axis]), fmt='none') ax.set_xlabel(other_dim) ax.set_ylabel(name_to_plot) fancy_labels(ax) try: if inset_ax and not skip_colorbar: inset_ax.set_xlabel(stack_axis, fontsize=16) fancy_labels(inset_ax) cbar(ax=inset_ax, **kwargs) except TypeError: # colorbar already rendered pass if out is not None: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) return fig, ax
def flat_stack_plot(data: DataType, stack_axis=None, fermi_level=True, cbarmap=None, ax=None, mode='line', title=None, out=None, transpose=False, **kwargs): data = normalize_to_spectrum(data) if len(data.dims) != 2: raise ValueError('In order to produce a stack plot, data must be image-like.' 'Passed data included dimensions: {}'.format(data.dims)) fig = None inset_ax = None if ax is None: fig, ax = plt.subplots(figsize=kwargs.get('figsize', (7, 5,))) inset_ax = inset_axes(ax, width='40%', height='5%', loc=1) if stack_axis is None: stack_axis = data.dims[0] skip_colorbar = True if cbarmap is None: skip_colorbar = False try: cbarmap = colorbarmaps_for_axis[stack_axis] except KeyError: cbarmap = generic_colorbarmap_for_data(data.coords[stack_axis], ax=inset_ax, ticks=kwargs.get('ticks')) cbar, cmap = cbarmap # should be exactly two other_dim = [d for d in data.dims if d != stack_axis][0] other_coord = data.coords[other_dim] if not isinstance(cmap, matplotlib.colors.Colormap): # do our best try: cmap = cmap() except: # might still be fine pass if 'eV' in data.dims and 'eV' != stack_axis and fermi_level: if transpose: ax.axhline(0, color='red', alpha=0.8, linestyle='--', linewidth=1) else: ax.axvline(0, color='red', alpha=0.8, linestyle='--', linewidth=1) # meat of the plotting for coord_dict, marginal in list(data.T.iterate_axis(stack_axis)): if transpose: if mode == 'line': ax.plot(marginal.values, marginal.coords[marginal.dims[0]].values, color=cmap(coord_dict[stack_axis]), **kwargs) else: assert mode == 'scatter' raise NotImplementedError() else: if mode == 'line': marginal.plot(ax=ax, color=cmap(coord_dict[stack_axis]), **kwargs) else: assert mode == 'scatter' ax.scatter(*marginal.T.to_arrays(), color=cmap(coord_dict[stack_axis]), **kwargs) ax.set_xlabel(marginal.dims[0]) ax.set_xlabel(label_for_dim(data, ax.get_xlabel())) ax.set_ylabel('Spectrum Intensity (arb).') ax.set_title(title, fontsize=14) ax.set_xlim([other_coord.min().item(), other_coord.max().item()]) try: if inset_ax is not None and not skip_colorbar: inset_ax.set_xlabel(stack_axis, fontsize=16) fancy_labels(inset_ax) cbar(ax=inset_ax, **kwargs) except TypeError: # already rendered pass if out is not None: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) return fig, ax
def plot_spatial_reference(reference_map: DataType, data_list: List[DataType], offset_list: Optional[List[Dict[str, Any]]] = None, annotation_list: Optional[List[str]] = None, out: Optional[str] = None, plot_refs: bool = True): """ Helpfully plots data against a reference scanning dataset. This is essential to understand where data was taken and can be used early in the analysis phase in order to highlight the location of your datasets against core levels, etc. :param reference_map: A scanning photoemission like dataset :param data_list: A list of datasets you want to plot the relative locations of :param offset_list: Optionally, offsets given as coordinate dicts :param annotation_list: Optionally, text annotations for the data :param out: Where to save the figure if we are outputting to disk :param plot_refs: Whether to plot reference figures for each of the pieces of data in `data_list` :return: """ if offset_list is None: offset_list = [None] * len(data_list) if annotation_list is None: annotation_list = [str(i + 1) for i in range(len(data_list))] normalize_to_spectrum(reference_map) n_references = len(data_list) if n_references == 1 and plot_refs: fig, axes = plt.subplots(1, 2, figsize=( 12, 5, )) ax = axes[0] ax_refs = [axes[1]] elif plot_refs: n_extra_axes = 1 + (n_references // 4) fig = plt.figure(figsize=( 6 * (1 + n_extra_axes), 5, )) spec = gridspec.GridSpec(ncols=2 * (1 + n_extra_axes), nrows=2, figure=fig) ax = fig.add_subplot(spec[:2, :2]) ax_refs = [ fig.add_subplot(spec[i // (2 * n_extra_axes), 2 + i % (2 * n_extra_axes)]) for i in range(n_references) ] else: ax_refs = [] fig, ax = plt.subplots(1, 1, figsize=( 6, 5, )) try: reference_map = reference_map.S.spectra[0] except Exception: pass reference_map = reference_map.S.mean_other(['x', 'y', 'z']) ref_dims = reference_map.dims[::-1] assert len(reference_map.dims) == 2 reference_map.S.plot(ax=ax, cmap='Blues') cmap = cm.get_cmap('Reds') rendered_annotations = [] for i, (data, offset, annotation) in enumerate( zip(data_list, offset_list, annotation_list)): if offset is None: try: offset = data.S.logical_offsets - reference_map.S.logical_offsets except ValueError: offset = {} coords = {c: unwrap_xarray_item(data.coords[c]) for c in ref_dims} n_array_coords = len([ cv for cv in coords.values() if isinstance(cv, (np.ndarray, xr.DataArray)) ]) color = cmap(0.4 + (0.5 * i / len(data_list))) x = coords[ref_dims[0]] + offset.get(ref_dims[0], 0) y = coords[ref_dims[1]] + offset.get(ref_dims[1], 0) ref_x, ref_y = x, y off_x, off_y = 0, 0 scale = 0.03 if n_array_coords == 0: off_y = 1 ax.scatter([x], [y], s=60, color=color) if n_array_coords == 1: if isinstance(x, (np.ndarray, xr.DataArray)): y = [y] * len(x) ref_x = np.min(x) off_x = -1 else: x = [x] * len(y) ref_y = np.max(y) off_y = 1 ax.plot(x, y, color=color, linewidth=3) if n_array_coords == 2: off_y = 1 min_x, max_x = np.min(x), np.max(x) min_y, max_y = np.min(y), np.max(y) ref_x, ref_y = min_x, max_y color = cmap(0.4 + (0.5 * i / len(data_list)), alpha=0.5) rect = patches.Rectangle((min_x, min_y), max_x - min_x, max_y - min_y, facecolor=color) color = cmap(0.4 + (0.5 * i / len(data_list))) ax.add_patch(rect) dp = ddata_daxis_units(ax) text_location = np.asarray([ ref_x, ref_y, ]) + dp * scale * np.asarray([off_x, off_y]) text = ax.annotate(annotation, text_location, color='black', size=15) rendered_annotations.append(text) text.set_path_effects([ path_effects.Stroke(linewidth=2, foreground='white'), path_effects.Normal() ]) if plot_refs: ax_ref = ax_refs[i] keep_preference = list(ref_dims) + [ 'eV', 'temperature', 'kz', 'hv', 'kp', 'kx', 'ky', 'phi', 'theta', 'beta', 'pixel', ] keep = [d for d in keep_preference if d in data.dims][:2] data.S.mean_other(keep).S.plot(ax=ax_ref) ax_ref.set_title(annotation) fancy_labels(ax_ref) frame_with(ax_ref, color=color, linewidth=3) ax.set_title('') remove_colorbars() fancy_labels(ax) plt.tight_layout() try: from adjustText import adjust_text adjust_text(rendered_annotations, ax=ax, avoid_points=False, avoid_objects=False, avoid_self=False, autoalign='xy') except ImportError: pass if out is not None: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) return fig, [ax] + ax_refs
def reference_scan_spatial(data, out=None, **kwargs): data = normalize_to_spectrum(data) dims = [d for d in data.dims if d in {'cycle', 'phi', 'eV'}] summed_data = data.sum(dims, keep_attrs=True) fig, ax = plt.subplots(3, 2, figsize=(15, 15)) flat_axes = list(itertools.chain(*ax)) summed_data.plot(ax=flat_axes[0]) flat_axes[0].set_title(r'Full \textbf{eV} range') dims_except_eV = [d for d in dims if d != 'eV'] summed_data = data.sum(dims_except_eV) mul = 0.2 rng = data.coords['eV'].max().item() - data.coords['eV'].min().item() offset = data.coords['eV'].max().item() if offset > 0: offset = 0 if rng > 3: mul = rng / 5. for i in range(5): low_e, high_e = -mul * (i + 1) + offset, -mul * i + offset title = r'\textbf{eV}' + ': {:.2g} to {:.2g}'.format(low_e, high_e) summed_data.sel(eV=slice(low_e, high_e)).sum('eV').plot( ax=flat_axes[i + 1]) flat_axes[i + 1].set_title(title) y_range = flat_axes[0].get_ylim() x_range = flat_axes[0].get_xlim() delta_one_percent = ((x_range[1] - x_range[0]) / 100, (y_range[1] - y_range[0]) / 100) smart_delta = (2 * delta_one_percent[0], -1.5 * delta_one_percent[0]) referenced = data.S.referenced_scans # idea here is to collect points by those that are close together, then # only plot one annotation condensed = [] cutoff = 3 # 3 percent for index, row in referenced.iterrows(): ff = simple_load(index) x, y, _ = ff.S.sample_pos found = False for cx, cy, cl in condensed: if abs(cx - x) < cutoff * abs(delta_one_percent[0]) and abs( cy - y) < cutoff * abs(delta_one_percent[1]): cl.append(index) found = True break if not found: condensed.append((x, y, [index])) for fax in flat_axes: for cx, cy, cl in condensed: annotate_point(fax, ( cx, cy, ), ','.join([str(l) for l in cl]), delta=smart_delta, fontsize='large') plt.tight_layout() if out is not None: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) return fig, ax
def spin_polarized_spectrum(spin_dr, title=None, ax=None, out=None, component='y', scatter=False, stats=False, norm=None): if ax is None: _, ax = plt.subplots(2, 1, sharex=True) if stats: spin_dr = bootstrap(lambda x: x)(spin_dr, N=100) pol = mean_and_deviation(to_intensity_polarization(spin_dr)) counts = mean_and_deviation(spin_dr) else: counts = spin_dr pol = to_intensity_polarization(counts) ax_left = ax[0] ax_right = ax[1] up = counts.down.data down = counts.up.data energies = spin_dr.coords['eV'].values min_e, max_e = np.min(energies), np.max(energies) # Plot the spectra if stats: if scatter: scatter_with_std(counts, 'up', color='red', ax=ax_left) scatter_with_std(counts, 'down', color='blue', ax=ax_left) else: v, s = counts.up.values, counts.up_std.values ax_left.plot(energies, v, 'r') ax_left.fill_between(energies, v - s, v + s, color='r', alpha=0.25) v, s = counts.down.values, counts.down_std.values ax_left.plot(energies, v, 'b') ax_left.fill_between(energies, v - s, v + s, color='b', alpha=0.25) else: ax_left.plot(energies, up, 'r') ax_left.plot(energies, down, 'b') ax_left.set_title( title if title is not None else 'Spin spectrum {}'.format('')) ax_left.set_ylabel(r'\textbf{Spectrum Intensity}') ax_left.set_xlabel(r'\textbf{Kinetic energy} (eV)') ax_left.set_xlim(min_e, max_e) max_up = np.max(up) max_down = np.max(down) ax_left.set_ylim(0, max(max_down, max_up) * 1.2) # Plot the polarization and associated statistical error bars if stats: if scatter: scatter_with_std(pol, 'polarization', ax=ax_right, color='black') else: v = pol.polarization.data s = pol.polarization_std.data ax_right.plot(energies, v, color='black') ax_right.fill_between(energies, v - s, v + s, color='black', alpha=0.25) else: ax_right.plot(energies, pol.polarization.data, color='black') ax_right.fill_between(energies, 0, 1, facecolor='blue', alpha=0.1) ax_right.fill_between(energies, -1, 0, facecolor='red', alpha=0.1) ax_right.set_title('Spin polarization, $\\text{S}_\\textbf{' + component + '}$') ax_right.set_ylabel(r'\textbf{Polarization}') ax_right.set_xlabel(r'\textbf{Kinetic Energy} (eV)') ax_right.set_xlim(min_e, max_e) ax_right.axhline(0, color='white', linestyle=':') ax_right.set_ylim(-1, 1) ax_right.grid(True, axis='y') plt.tight_layout() if out is not None: savefig(out, dpi=400) plt.clf() return path_for_plot(out) else: pass return ax