Beispiel #1
0
def patched_plot(self, *args, **kwargs):
    """
    PyARPES patch for `lmfit` summary plots. Scientists like to have LaTeX in their plots,
    but because underscores outside TeX environments crash matplotlib renders, we need to do
    some fiddling with titles and axis labels in order to prevent users having to switch TeX
    on and off all the time.

    Additionally, this patch provides better support for multidimensional curve fitting.
    :param self:
    :param args:
    :param kwargs:
    :return:
    """
    from arpes.plotting.utils import transform_labels

    try:
        if self.model.n_dims != 1:
            from arpes.plotting.utils import fancy_labels

            fig, ax = plt.subplots(2,2, figsize=(10,8))

            def to_dr(flat_data):
                shape = [len(self.independent[d]) for d in self.independent_order]
                return xr.DataArray(flat_data.reshape(shape), coords=self.independent, dims=self.independent_order)

            to_dr(self.init_fit).plot(ax=ax[1][0])
            to_dr(self.data).plot(ax=ax[0][1])
            to_dr(self.best_fit).plot(ax=ax[0][0])
            to_dr(self.residual).plot(ax=ax[1][1])

            ax[0][0].set_title('Best fit')
            ax[0][1].set_title('Data')
            ax[1][0].set_title('Initial fit')
            ax[1][1].set_title('Residual (Data - Best fit)')

            for axi in ax.ravel():
                fancy_labels(axi)

            plt.tight_layout()
            return ax

    except AttributeError:
        pass

    ret = original_plot(self, *args, **kwargs)
    transform_labels(transform_lmfit_titles)
    return ret
    def data(self, new_data):
        if self._initialized:
            self._data = new_data
        else:
            self._data = new_data
            self._initialized = True
            self.n_dims = len(new_data.dims)
            if self.n_dims == 2:
                self._axis_image = imshow_arr(self._data,
                                              ax=self.ax,
                                              **self.ax_kwargs)[1]
                fancy_labels(self.ax)
            else:
                self.ax_kwargs.pop('cmap', None)
                x, y = self.data.coords[
                    self.data.dims[0]].values, self.data.values
                self._axis_image = self.ax.plot(x, y, **self.ax_kwargs)
                self.ax.set_xlabel(self.data.dims[0])
                cs = self.data.coords[self.data.dims[0]].values
                self.ax.set_xlim([np.min(cs), np.max(cs)])
                fancy_labels(self.ax)

        if self.n_dims == 2:
            x, y = self._data.coords[self._data.dims[
                0]].values, self._data.coords[self._data.dims[1]].values
            extent = [y[0], y[-1], x[0], x[-1]]
            self._axis_image.set_extent(extent)
            self._axis_image.set_data(self._data.values)
        else:
            color = self.ax.lines[0].get_color()
            self.ax.lines.remove(self.ax.lines[0])
            x, y = self.data.coords[self.data.dims[0]].values, self.data.values
            l, h = np.min(y), np.max(y)
            self._axis_image = self.ax.plot(x, y, c=color, **self.ax_kwargs)
            self.ax.set_ylim([l - 0.1 * (h - l), h + 0.1 * (h - l)])

        if self.auto_autoscale:
            self.autoscale()
Beispiel #3
0
def offset_scatter_plot(data: DataType, name_to_plot=None, stack_axis=None, fermi_level=True, cbarmap=None, ax=None,
                        out=None, scale_coordinate=0.5, ylim=None, aux_errorbars=True, **kwargs):
    assert isinstance(data, xr.Dataset)

    if name_to_plot is None:
        var_names = [k for k in data.data_vars.keys() if '_std' not in k]
        assert len(var_names) == 1
        name_to_plot = var_names[0]
        assert (name_to_plot + '_std') in data.data_vars.keys()

    if len(data.data_vars[name_to_plot].dims) != 2:
        raise ValueError('In order to produce a stack plot, data must be image-like.'
                         'Passed data included dimensions: {}'.format(data.data_vars[name_to_plot].dims))

    fig = None
    inset_ax = None
    if ax is None:
        fig, ax = plt.subplots(figsize=kwargs.get('figsize', (11, 5,)))

    if inset_ax is None:
        inset_ax = inset_axes(ax, width='40%', height='5%', loc='upper left')

    if stack_axis is None:
        stack_axis = data.data_vars[name_to_plot].dims[0]

    skip_colorbar = True
    if cbarmap is None:
        skip_colorbar = False
        try:
            cbarmap = colorbarmaps_for_axis[stack_axis]
        except:
            cbarmap = generic_colorbarmap_for_data(data.coords[stack_axis], ax=inset_ax, ticks=kwargs.get('ticks'))

    cbar, cmap = cbarmap

    if not isinstance(cmap, matplotlib.colors.Colormap):
        # do our best
        try:
            cmap = cmap()
        except:
            # might still be fine
            pass

    # should be exactly two
    other_dim = [d for d in data.dims if d != stack_axis][0]
    other_coord = data.coords[other_dim]

    if 'eV' in data.dims and 'eV' != stack_axis and fermi_level:
        ax.axhline(0, linestyle='--', color='red')
        ax.fill_betweenx([-1e6, 1e6], 0, 0.2, color='black', alpha=0.07)
        ax.set_ylim(ylim)

    # real plotting here
    for i, (coord, value) in enumerate(data.T.iterate_axis(stack_axis)):
        delta = data.T.stride(generic_dim_names=False)[other_dim]
        data_for = value.copy(deep=True)
        data_for.coords[other_dim] = data_for.coords[other_dim].copy(deep=True)
        data_for.coords[other_dim].values = data_for.coords[other_dim].values.copy()
        data_for.coords[other_dim].values -= i * delta * scale_coordinate / 10

        scatter_with_std(data_for, name_to_plot, ax=ax, color=cmap(coord[stack_axis]))

        if aux_errorbars:
            assert ylim is not None
            data_for = data_for.copy(deep=True)
            flattened = data_for.data_vars[name_to_plot].copy(deep=True)
            flattened.values = ylim[0] * np.ones(flattened.values.shape)
            data_for = data_for.assign(**{name_to_plot: flattened})
            scatter_with_std(data_for, name_to_plot, ax=ax, color=cmap(coord[stack_axis]), fmt='none')


    ax.set_xlabel(other_dim)
    ax.set_ylabel(name_to_plot)
    fancy_labels(ax)

    try:
        if inset_ax and not skip_colorbar:
            inset_ax.set_xlabel(stack_axis, fontsize=16)

            fancy_labels(inset_ax)
            cbar(ax=inset_ax, **kwargs)
    except TypeError:
        # colorbar already rendered
        pass

    if out is not None:
        plt.savefig(path_for_plot(out), dpi=400)
        return path_for_plot(out)

    return fig, ax
Beispiel #4
0
def flat_stack_plot(data: DataType, stack_axis=None, fermi_level=True, cbarmap=None, ax=None,
                    mode='line', title=None, out=None, transpose=False, **kwargs):
    data = normalize_to_spectrum(data)
    if len(data.dims) != 2:
        raise ValueError('In order to produce a stack plot, data must be image-like.'
                         'Passed data included dimensions: {}'.format(data.dims))

    fig = None
    inset_ax = None
    if ax is None:
        fig, ax = plt.subplots(figsize=kwargs.get('figsize', (7, 5,)))
        inset_ax = inset_axes(ax, width='40%', height='5%', loc=1)

    if stack_axis is None:
        stack_axis = data.dims[0]

    skip_colorbar = True
    if cbarmap is None:
        skip_colorbar = False
        try:
            cbarmap = colorbarmaps_for_axis[stack_axis]
        except KeyError:
            cbarmap = generic_colorbarmap_for_data(data.coords[stack_axis], ax=inset_ax, ticks=kwargs.get('ticks'))

    cbar, cmap = cbarmap

    # should be exactly two
    other_dim = [d for d in data.dims if d != stack_axis][0]
    other_coord = data.coords[other_dim]

    if not isinstance(cmap, matplotlib.colors.Colormap):
        # do our best
        try:
            cmap = cmap()
        except:
            # might still be fine
            pass

    if 'eV' in data.dims and 'eV' != stack_axis and fermi_level:
        if transpose:
            ax.axhline(0, color='red', alpha=0.8, linestyle='--', linewidth=1)
        else:
            ax.axvline(0, color='red', alpha=0.8, linestyle='--', linewidth=1)

    # meat of the plotting
    for coord_dict, marginal in list(data.T.iterate_axis(stack_axis)):
        if transpose:
            if mode == 'line':
                ax.plot(marginal.values, marginal.coords[marginal.dims[0]].values, color=cmap(coord_dict[stack_axis]), **kwargs)
            else:
                assert mode == 'scatter'
                raise NotImplementedError()
        else:
            if mode == 'line':
                marginal.plot(ax=ax, color=cmap(coord_dict[stack_axis]), **kwargs)
            else:
                assert mode == 'scatter'
                ax.scatter(*marginal.T.to_arrays(), color=cmap(coord_dict[stack_axis]), **kwargs)
                ax.set_xlabel(marginal.dims[0])

    ax.set_xlabel(label_for_dim(data, ax.get_xlabel()))
    ax.set_ylabel('Spectrum Intensity (arb).')
    ax.set_title(title, fontsize=14)
    ax.set_xlim([other_coord.min().item(), other_coord.max().item()])

    try:
        if inset_ax is not None and not skip_colorbar:
            inset_ax.set_xlabel(stack_axis, fontsize=16)
            fancy_labels(inset_ax)

            cbar(ax=inset_ax, **kwargs)
    except TypeError:
        # already rendered
        pass

    if out is not None:
        plt.savefig(path_for_plot(out), dpi=400)
        return path_for_plot(out)

    return fig, ax
Beispiel #5
0
def plot_spatial_reference(reference_map: DataType,
                           data_list: List[DataType],
                           offset_list: Optional[List[Dict[str, Any]]] = None,
                           annotation_list: Optional[List[str]] = None,
                           out: Optional[str] = None,
                           plot_refs: bool = True):
    """
    Helpfully plots data against a reference scanning dataset. This is essential to understand
    where data was taken and can be used early in the analysis phase in order to highlight the
    location of your datasets against core levels, etc.

    :param reference_map: A scanning photoemission like dataset
    :param data_list: A list of datasets you want to plot the relative locations of
    :param offset_list: Optionally, offsets given as coordinate dicts
    :param annotation_list: Optionally, text annotations for the data
    :param out: Where to save the figure if we are outputting to disk
    :param plot_refs: Whether to plot reference figures for each of the pieces of data in `data_list`
    :return:
    """
    if offset_list is None:
        offset_list = [None] * len(data_list)

    if annotation_list is None:
        annotation_list = [str(i + 1) for i in range(len(data_list))]

    normalize_to_spectrum(reference_map)

    n_references = len(data_list)
    if n_references == 1 and plot_refs:
        fig, axes = plt.subplots(1, 2, figsize=(
            12,
            5,
        ))
        ax = axes[0]
        ax_refs = [axes[1]]
    elif plot_refs:
        n_extra_axes = 1 + (n_references // 4)
        fig = plt.figure(figsize=(
            6 * (1 + n_extra_axes),
            5,
        ))
        spec = gridspec.GridSpec(ncols=2 * (1 + n_extra_axes),
                                 nrows=2,
                                 figure=fig)
        ax = fig.add_subplot(spec[:2, :2])

        ax_refs = [
            fig.add_subplot(spec[i // (2 * n_extra_axes),
                                 2 + i % (2 * n_extra_axes)])
            for i in range(n_references)
        ]
    else:
        ax_refs = []
        fig, ax = plt.subplots(1, 1, figsize=(
            6,
            5,
        ))

    try:
        reference_map = reference_map.S.spectra[0]
    except Exception:
        pass

    reference_map = reference_map.S.mean_other(['x', 'y', 'z'])

    ref_dims = reference_map.dims[::-1]

    assert len(reference_map.dims) == 2
    reference_map.S.plot(ax=ax, cmap='Blues')

    cmap = cm.get_cmap('Reds')
    rendered_annotations = []
    for i, (data, offset, annotation) in enumerate(
            zip(data_list, offset_list, annotation_list)):
        if offset is None:
            try:
                offset = data.S.logical_offsets - reference_map.S.logical_offsets
            except ValueError:
                offset = {}

        coords = {c: unwrap_xarray_item(data.coords[c]) for c in ref_dims}
        n_array_coords = len([
            cv for cv in coords.values()
            if isinstance(cv, (np.ndarray, xr.DataArray))
        ])
        color = cmap(0.4 + (0.5 * i / len(data_list)))
        x = coords[ref_dims[0]] + offset.get(ref_dims[0], 0)
        y = coords[ref_dims[1]] + offset.get(ref_dims[1], 0)
        ref_x, ref_y = x, y
        off_x, off_y = 0, 0
        scale = 0.03

        if n_array_coords == 0:
            off_y = 1
            ax.scatter([x], [y], s=60, color=color)
        if n_array_coords == 1:
            if isinstance(x, (np.ndarray, xr.DataArray)):
                y = [y] * len(x)
                ref_x = np.min(x)
                off_x = -1
            else:
                x = [x] * len(y)
                ref_y = np.max(y)
                off_y = 1

            ax.plot(x, y, color=color, linewidth=3)
        if n_array_coords == 2:
            off_y = 1
            min_x, max_x = np.min(x), np.max(x)
            min_y, max_y = np.min(y), np.max(y)
            ref_x, ref_y = min_x, max_y

            color = cmap(0.4 + (0.5 * i / len(data_list)), alpha=0.5)
            rect = patches.Rectangle((min_x, min_y),
                                     max_x - min_x,
                                     max_y - min_y,
                                     facecolor=color)
            color = cmap(0.4 + (0.5 * i / len(data_list)))

            ax.add_patch(rect)

        dp = ddata_daxis_units(ax)
        text_location = np.asarray([
            ref_x,
            ref_y,
        ]) + dp * scale * np.asarray([off_x, off_y])
        text = ax.annotate(annotation, text_location, color='black', size=15)
        rendered_annotations.append(text)
        text.set_path_effects([
            path_effects.Stroke(linewidth=2, foreground='white'),
            path_effects.Normal()
        ])
        if plot_refs:
            ax_ref = ax_refs[i]
            keep_preference = list(ref_dims) + [
                'eV',
                'temperature',
                'kz',
                'hv',
                'kp',
                'kx',
                'ky',
                'phi',
                'theta',
                'beta',
                'pixel',
            ]
            keep = [d for d in keep_preference if d in data.dims][:2]
            data.S.mean_other(keep).S.plot(ax=ax_ref)
            ax_ref.set_title(annotation)
            fancy_labels(ax_ref)
            frame_with(ax_ref, color=color, linewidth=3)

    ax.set_title('')
    remove_colorbars()
    fancy_labels(ax)
    plt.tight_layout()

    try:
        from adjustText import adjust_text
        adjust_text(rendered_annotations,
                    ax=ax,
                    avoid_points=False,
                    avoid_objects=False,
                    avoid_self=False,
                    autoalign='xy')
    except ImportError:
        pass

    if out is not None:
        plt.savefig(path_for_plot(out), dpi=400)
        return path_for_plot(out)

    return fig, [ax] + ax_refs