Пример #1
0
def hue_brightness_plot(data: xr.Dataset, ax=None, out=None, **kwargs):
    assert ('intensity' in data and 'polarization' in data)

    fig = None
    if ax is None:
        fig, ax = plt.subplots(figsize=kwargs.get('figsize', (
            7,
            5,
        )))

    x, y = data.coords[data.intensity.dims[0]].values, data.coords[
        data.intensity.dims[1]].values
    extent = [y[0], y[-1], x[0], x[-1]]
    ax.imshow(polarization_intensity_to_color(data, **kwargs),
              extent=extent,
              aspect='auto',
              origin='lower')
    ax.set_xlabel(data.intensity.dims[1])
    ax.set_ylabel(data.intensity.dims[0])

    ax.grid(False)

    if out is not None:
        plt.savefig(path_for_plot(out), dpi=400)
        return path_for_plot(out)

    return fig, ax
Пример #2
0
def plot_data_to_bz2d(data: DataType, cell, rotate=None, shift=None, scale=None, ax=None,
                      mask=True, out=None, bz_number=None, **kwargs):
    data = normalize_to_spectrum(data)

    assert('You must k-space convert data before plotting to BZs' and data.S.is_kspace)

    if bz_number is None:
        bz_number = (0,0)

    fig = None
    if ax is None:
        fig, ax = plt.subplots(figsize=(9,9))
        bz2d_plot(cell, paths='all', ax=ax)

    if len(cell) == 2:
        cell = [list(c) + [0] for c in cell] + [[0, 0, 1]]

    icell = np.linalg.inv(cell).T

    # Prep coordinates and mask
    raveled = data.T.meshgrid(as_dataset=True)
    dims = data.dims
    if rotate is not None:
        c, s = np.cos(rotate), np.sin(rotate)
        rotation = np.array([(c, -s), (s, c)])

        raveled = raveled.T.transform_coords(dims, rotation)

    if scale is not None:
        raveled = raveled.T.scale_coords(dims, scale)

    if shift is not None:
        raveled = raveled.T.shift_coords(dims, shift)

    copied = data.values.copy()

    if mask:
        built_mask = apply_mask_to_coords(raveled, build_2dbz_poly(cell=cell), dims)
        copied[built_mask.T] = np.nan

    cmap = kwargs.get('cmap', matplotlib.cm.Blues)
    if isinstance(cmap, str):
        cmap = matplotlib.cm.get_cmap(cmap)

    cmap.set_bad((1, 1, 1, 0))

    delta_x = np.dot(np.array(bz_number), icell[:2, 0])
    delta_y = np.dot(np.array(bz_number), icell[:2, 1])

    ax.pcolormesh(raveled.data_vars[dims[0]].values + delta_x, raveled.data_vars[dims[1]].values + delta_y, copied.T, cmap=cmap)

    if out is not None:
        plt.savefig(path_for_plot(out), dpi=400)
        return path_for_plot(out)

    return fig, ax
def fermi_surface_slices(arr: xr.DataArray,
                         n_slices=9,
                         ev_per_slice=0.02,
                         bin=0.01,
                         out=None,
                         **kwargs):
    import holoviews as hv  # pylint: disable=import-error
    slices = []
    for i in range(n_slices):
        high = -ev_per_slice * i
        low = high - bin
        image = hv.Image(arr.sum([
            d for d in arr.dims
            if d not in ['theta', 'beta', 'phi', 'eV', 'kp', 'kx', 'ky']
        ]).sel(eV=slice(low, high)).sum('eV'),
                         label='%g eV' % high)

        slices.append(image)

    layout = hv.Layout(slices).cols(3)
    if out is not None:
        renderer = hv.renderer('matplotlib').instance(fig='svg', holomap='gif')
        filename = path_for_plot(out)
        renderer.save(layout, path_for_holoviews(filename))
        return filename
    else:
        return layout
Пример #4
0
def scatter_with_std(data: DataType, name_to_plot=None, ax=None, fmt='o', out=None, **kwargs):
    if name_to_plot is None:
        var_names = [k for k in data.data_vars.keys() if '_std' not in k]
        assert len(var_names) == 1
        name_to_plot = var_names[0]
        assert (name_to_plot + '_std') in data.data_vars.keys()

    fig = None
    if ax is None:
        fig, ax = plt.subplots(figsize=kwargs.pop('figsize', (7, 5,)))

    x, y = data.data_vars[name_to_plot].T.to_arrays()

    std = data.data_vars[name_to_plot + '_std'].values
    ax.errorbar(x, y, yerr=std, fmt=fmt, markeredgecolor='black', **kwargs)

    if out is not None:
        plt.savefig(path_for_plot(out), dpi=400)
        return path_for_plot(out)

    ax.set_xlim([np.min(x), np.max(x)])

    return fig, ax
Пример #5
0
def false_color_plot(data_r: xr.Dataset, data_g: xr.Dataset, data_b: xr.Dataset, ax=None, out=None, invert=False, pmin=0, pmax=1, **kwargs):
    data_r, data_g, data_b = [normalize_to_spectrum(d) for d in (data_r, data_g, data_b)]
    fig = None
    if ax is None:
        fig, ax = plt.subplots(figsize=kwargs.get('figsize', (7, 5,)))

    def normalize_channel(channel):
        channel -= np.percentile(channel, 100 * pmin)
        channel[channel > np.percentile(channel, 100 * pmax)] = np.percentile(channel, 100 * pmax)
        channel = channel / np.max(channel)
        return channel

    cs = dict(data_r.coords)
    cs['dim_color'] = [1, 2, 3]

    arr = xr.DataArray(
        np.stack([normalize_channel(data_r.values),
                  normalize_channel(data_g.values),
                  normalize_channel(data_b.values)], axis=-1),
        coords=cs,
        dims=list(data_r.dims) + ['dim_color'],
    )

    if invert:
        vs = arr.values
        vs[vs > 1] = 1
        hsv = matplotlib.colors.rgb_to_hsv(vs)
        hsv[:,:,2] = 1 - hsv[:,:,2]
        arr.values = matplotlib.colors.hsv_to_rgb(hsv)

    imshow_arr(arr, ax=ax)

    if out is not None:
        plt.savefig(path_for_plot(out), dpi=400)
        return path_for_plot(out)

    return fig, ax
Пример #6
0
def plot_with_std(data: DataType, name_to_plot=None, ax=None, out=None, **kwargs):
    if name_to_plot is None:
        var_names = [k for k in data.data_vars.keys() if '_std' not in k]
        assert len(var_names) == 1
        name_to_plot = var_names[0]
        assert (name_to_plot + '_std') in data.data_vars.keys()

    fig = None
    if ax is None:
        fig, ax = plt.subplots(figsize=kwargs.pop('figsize', (7, 5,)))

    data.data_vars[name_to_plot].plot(ax=ax, **kwargs)
    x, y = data.data_vars[name_to_plot].T.to_arrays()

    std = data.data_vars[name_to_plot + '_std'].values
    ax.fill_between(x, y - std, y + std, alpha=0.3, **kwargs)

    if out is not None:
        plt.savefig(path_for_plot(out), dpi=400)
        return path_for_plot(out)

    ax.set_xlim([np.min(x), np.max(x)])

    return fig, ax
Пример #7
0
def spin_difference_spectrum(spin_dr,
                             title=None,
                             ax=None,
                             out=None,
                             scatter=False,
                             **kwargs):
    if ax is None:
        _, ax = plt.subplots(figsize=(6, 4))

    try:
        as_intensity = to_intensity_polarization(spin_dr)
    except AssertionError:
        as_intensity = spin_dr
    intensity = as_intensity.intensity
    pol = as_intensity.polarization.copy(deep=True)

    if len(intensity.dims) == 1:
        inset_ax = inset_axes(ax, width="30%", height="5%", loc=1)
        coord = intensity.coords[intensity.dims[0]]
        points = np.array([coord.values, intensity.values]).T.reshape(-1, 1, 2)
        pol.values[np.isnan(pol.values)] = 0
        pol.values[pol.values > 1] = 1
        pol.values[pol.values < -1] = -1
        pol_colors = cm.get_cmap('RdBu')(pol.values[:-1])

        if scatter:
            pol_colors = cm.get_cmap('RdBu')(pol.values)
            ax.scatter(coord.values, intensity.values, c=pol_colors, s=1.5)
        else:
            segments = np.concatenate([points[:-1], points[1:]], axis=1)
            lc = LineCollection(segments, colors=pol_colors)

            ax.add_collection(lc)

        ax.set_xlim(coord.min().item(), coord.max().item())
        ax.set_ylim(0, intensity.max().item() * 1.15)
        ax.set_ylabel('ARPES Spectrum Intensity (arb.)')
        ax.set_xlabel(label_for_dim(spin_dr, dim_name=intensity.dims[0]))
        ax.set_title(title if title is not None else 'Spin Polarization')
        polarization_colorbar(inset_ax)

    if out is not None:
        savefig(out, dpi=400)
        plt.clf()
        return path_for_plot(out)
    else:
        plt.show()
Пример #8
0
def plot_movie(data: xr.DataArray,
               time_dim,
               interval=None,
               fig=None,
               ax=None,
               out=None,
               **kwargs):
    if not isinstance(data, xr.DataArray):
        raise TypeError('You must provide a DataArray')

    if ax is None:
        fig, ax = plt.subplots(figsize=(7, 7))

    cmap = arpes.config.SETTINGS.get('interactive',
                                     {}).get('palette', 'viridis')
    vmax = data.max().item()
    vmin = data.min().item()

    if data.S.is_subtracted:
        cmap = 'RdBu'
        vmax = np.max([np.abs(vmin), np.abs(vmax)])
        vmin = -vmax

    if 'vmax' in kwargs:
        vmax = kwargs.pop('vmax')
    if 'vmin' in kwargs:
        vmin = kwargs.pop('vmin')

    plot = data.mean(time_dim).transpose().plot(vmax=vmax,
                                                vmin=vmin,
                                                cmap=cmap,
                                                **kwargs)

    def init():
        plot.set_array(np.asarray([]))
        return plot,

    animation_coords = data.coords[time_dim].values

    def animate(i):
        coordinate = animation_coords[i]
        data_for_plot = data.sel(**dict([[time_dim, coordinate]]))
        plot.set_array(data_for_plot.values.T.ravel())
        return plot,

    if interval:
        computed_interval = interval
    else:
        computed_interval = 100

    anim = animation.FuncAnimation(fig,
                                   animate,
                                   init_func=init,
                                   repeat=500,
                                   frames=len(animation_coords),
                                   interval=computed_interval,
                                   blit=True)

    Writer = animation.writers['ffmpeg']
    writer = Writer(fps=1000 / computed_interval,
                    metadata=dict(artist='Me'),
                    bitrate=1800)

    if out is not None:
        anim.save(path_for_plot(out), writer=writer)
        return path_for_plot(out)

    #plt.show()
    return anim
def magnify_circular_regions_plot(data: DataType,
                                  magnified_points,
                                  mag=10,
                                  radius=0.05,
                                  cmap='viridis',
                                  color=None,
                                  edgecolor='red',
                                  out=None,
                                  ax=None,
                                  **kwargs):
    data = normalize_to_spectrum(data)
    fig = None
    if ax is None:
        fig, ax = plt.subplots(figsize=kwargs.get('figsize', (
            7,
            5,
        )))

    mesh = data.plot(ax=ax, cmap=cmap)
    clim = list(mesh.get_clim())
    clim[1] = clim[1] / mag

    mask = np.zeros(shape=(len(data.values.ravel()), ))
    pts = np.zeros(shape=(
        len(data.values.ravel()),
        2,
    ))
    mask = mask > 0

    raveled = data.T.ravel()
    pts[:, 0] = raveled[data.dims[0]]
    pts[:, 1] = raveled[data.dims[1]]

    x0, y0 = ax.transAxes.transform((0, 0))  # lower left in pixels
    x1, y1 = ax.transAxes.transform((1, 1))  # upper right in pixes
    dx = x1 - x0
    dy = y1 - y0
    maxd = max(dx, dy)
    xlim, ylim = ax.get_xlim(), ax.get_ylim()

    width = radius * maxd / dx * (xlim[1] - xlim[0])
    height = radius * maxd / dy * (ylim[1] - ylim[0])

    if not isinstance(edgecolor, list):
        edgecolor = [edgecolor for _ in range(len(magnified_points))]

    if not isinstance(color, list):
        color = [color for _ in range(len(magnified_points))]

    pts[:, 1] = (pts[:, 1]) / (xlim[1] - xlim[0])
    pts[:, 0] = (pts[:, 0]) / (ylim[1] - ylim[0])
    print(np.min(pts[:, 1]), np.max(pts[:, 1]))
    print(np.min(pts[:, 0]), np.max(pts[:, 0]))

    for c, ec, point in zip(color, edgecolor, magnified_points):
        patch = matplotlib.patches.Ellipse(point,
                                           width,
                                           height,
                                           color=c,
                                           edgecolor=ec,
                                           fill=False,
                                           linewidth=2,
                                           zorder=4)
        patchfake = matplotlib.patches.Ellipse([point[1], point[0]], radius,
                                               radius)
        ax.add_patch(patch)
        mask = np.logical_or(mask, patchfake.contains_points(pts))

    data_masked = data.copy(deep=True)
    data_masked.values = np.array(data_masked.values, dtype=np.float32)

    cm = matplotlib.cm.get_cmap(name='viridis')
    cm.set_bad(color=(1, 1, 1, 0))
    data_masked.values[np.swapaxes(
        np.logical_not(mask.reshape(data.values.shape[::-1])), 0, 1)] = np.nan

    aspect = ax.get_aspect()
    extent = [xlim[0], xlim[1], ylim[0], ylim[1]]
    ax.imshow(data_masked.values,
              cmap=cm,
              extent=extent,
              zorder=3,
              clim=clim,
              origin='lower')
    ax.set_aspect(aspect)

    for spine in ['left', 'top', 'right', 'bottom']:
        ax.spines[spine].set_zorder(5)

    if out is not None:
        plt.savefig(path_for_plot(out), dpi=400)
        return path_for_plot(out)

    return fig, ax
Пример #10
0
def overlapped_stack_dispersion_plot(data: DataType, stack_axis=None, ax=None, title=None, out=None,
                                     max_stacks=100, use_constant_correction=False, transpose=False,
                                     negate=False, s=1, scale_factor=None, linewidth=1, palette=None, **kwargs):
    data = normalize_to_spectrum(data)

    if stack_axis is None:
        stack_axis = data.dims[0]

    other_axes = list(data.dims)
    other_axes.remove(stack_axis)
    other_axis = other_axes[0]

    stack_coord = data.coords[stack_axis]
    if len(stack_coord.values) > max_stacks:
        data = rebin(data, reduction=dict([[
            stack_axis, int(np.ceil(len(stack_coord.values) / max_stacks))]
        ]))

    fig = None
    if ax is None:
        fig, ax = plt.subplots(figsize=(7, 7))

    if title is None:
        title = '{} Stack'.format(data.S.label.replace('_', ' '))

    max_over_stacks = np.max(data.values)

    cvalues = data.coords[other_axis].values
    if scale_factor is None:
        maximum_deviation = -np.inf

        for _, marginal in data.T.iterate_axis(stack_axis):
            marginal_values = -marginal.values if negate else marginal.values
            marginal_offset, right_marginal_offset = marginal_values[0], marginal_values[-1]

            if use_constant_correction:
                true_ys = (marginal_values - marginal_offset)
            else:
                true_ys = (marginal_values - np.linspace(marginal_offset, right_marginal_offset, len(marginal_values)))

            maximum_deviation = np.max([maximum_deviation] + list(np.abs(true_ys)))

        scale_factor = 0.02 * (np.max(cvalues) - np.min(cvalues)) / maximum_deviation

    iteration_order = -1 # might need to fiddle with this in certain cases
    for coord_dict, marginal in list(data.T.iterate_axis(stack_axis))[::iteration_order]:
        coord_value = coord_dict[stack_axis]

        xs = cvalues
        marginal_values = -marginal.values if negate else marginal.values
        marginal_offset, right_marginal_offset = marginal_values[0], marginal_values[-1]

        if use_constant_correction:
            true_ys = (marginal_values - marginal_offset) / max_over_stacks
            ys = scale_factor * true_ys + coord_value
        else:
            true_ys = (marginal_values - np.linspace(marginal_offset, right_marginal_offset, len(marginal_values))) \
                      / max_over_stacks
            ys = scale_factor * true_ys + coord_value

        raw_colors = 'black'
        if palette:
            if isinstance(palette, str):
                palette = cm.get_cmap(palette)
            raw_colors = palette(np.abs(true_ys / max_over_stacks))

        if transpose:
            xs, ys = ys, xs

        if isinstance(raw_colors, str):
            plt.plot(xs, ys, linewidth=linewidth, color=raw_colors, **kwargs)
        else:
            plt.scatter(xs, ys, color=raw_colors, s=s, **kwargs)

    x_label = other_axis
    y_label = stack_axis

    if transpose:
        x_label, y_label = y_label, x_label

    ax.set_xlabel(label_for_dim(data, x_label))
    ax.set_ylabel(label_for_dim(data, y_label))

    ax.set_title(title)

    if out is not None:
        plt.savefig(path_for_plot(out), dpi=400)
        return path_for_plot(out)

    plt.show()

    return fig, ax
Пример #11
0
def stack_dispersion_plot(data: DataType, stack_axis=None, ax=None, title=None, out=None,
                          max_stacks=100, transpose=False,
                          use_constant_correction=False, correction_side=None,
                          color=None, c=None,
                          label=None,
                          shift=0,
                          no_scatter=False,
                          negate=False, s=1, scale_factor=None, linewidth=1, palette=None, zero_offset=False, uniform = False, **kwargs):
    data = normalize_to_spectrum(data)

    if stack_axis is None:
        stack_axis = data.dims[0]

    other_axes = list(data.dims)
    other_axes.remove(stack_axis)
    other_axis = other_axes[0]

    stack_coord = data.coords[stack_axis]
    if len(stack_coord.values) > max_stacks:
        data = rebin(data, reduction=dict([[
            stack_axis, int(np.ceil(len(stack_coord.values) / max_stacks))]
        ]))

    fig = None
    if ax is None:
        fig, ax = plt.subplots(figsize=(7, 7))

    if title is None:
        title = '{} Stack'.format(data.S.label.replace('_', ' '))

    max_over_stacks = np.max(data.values)

    cvalues = data.coords[other_axis].values
    if scale_factor is None:
        maximum_deviation = -np.inf

        for _, marginal in data.T.iterate_axis(stack_axis):
            marginal_values = -marginal.values if negate else marginal.values
            marginal_offset, right_marginal_offset = marginal_values[0], marginal_values[-1]

            if use_constant_correction:
                true_ys = (marginal_values - marginal_offset)
            elif zero_offset:
                true_ys = marginal_values 
            else:
                true_ys = (marginal_values - np.linspace(marginal_offset, right_marginal_offset, len(marginal_values)))

            maximum_deviation = np.max([maximum_deviation] + list(np.abs(true_ys)))

        scale_factor = 0.02 * (np.max(cvalues) - np.min(cvalues)) / maximum_deviation

    iteration_order = -1 # might need to fiddle with this in certain cases
    lim = [-np.inf, np.inf]
    labeled = False
    for i, (coord_dict, marginal) in enumerate(list(data.T.iterate_axis(stack_axis))[::iteration_order]):
        coord_value = coord_dict[stack_axis]

        xs = cvalues
        marginal_values = -marginal.values if negate else marginal.values
        marginal_offset, right_marginal_offset = marginal_values[0], marginal_values[-1]

        if use_constant_correction:
            offset = right_marginal_offset if correction_side == 'right' else marginal_offset
            true_ys = (marginal_values - offset) / max_over_stacks
            ys = scale_factor * true_ys + coord_value
        elif zero_offset:
            true_ys = marginal_values / max_over_stacks
            ys = scale_factor * true_ys + coord_value
        elif uniform:
            true_ys = marginal_values / max_over_stacks
            ys = scale_factor * true_ys + i            
        else:
            true_ys = (marginal_values - np.linspace(marginal_offset, right_marginal_offset, len(marginal_values))) \
                      / max_over_stacks
            ys = scale_factor * true_ys + coord_value

        raw_colors = color or c or 'black'

        if palette:
            if isinstance(palette, str):
                palette = cm.get_cmap(palette)
            raw_colors = palette(np.abs(true_ys / max_over_stacks))

        if transpose:
            xs, ys = ys, xs

        xs = xs - i * shift

        lim = [max(lim[0], np.min(xs)), min(lim[1], np.max(xs))]

        label_for = '_nolegend_'
        if not labeled:
            labeled = True
            label_for = label

        color_for_plot = raw_colors
        if callable(color_for_plot):
            color_for_plot = color_for_plot(coord_value)

        if isinstance(raw_colors, (str, tuple)) or no_scatter:
            ax.plot(xs, ys, linewidth=linewidth, color=color_for_plot, label=label_for, **kwargs)
        else:
            ax.scatter(xs, ys, color=color_for_plot, s=s, label=label_for, **kwargs)

    x_label = other_axis
    y_label = stack_axis

    if transpose:
        x_label, y_label = y_label, x_label

    ax.set_xlabel(label_for_dim(data, x_label))
    ax.set_ylabel(label_for_dim(data, y_label))

    if transpose:
        ax.set_ylim(lim)
    else:
        ax.set_xlim(lim)

    ax.set_title(title)

    if out is not None:
        plt.savefig(path_for_plot(out), dpi=400)
        return path_for_plot(out)

    return fig, ax
Пример #12
0
def offset_scatter_plot(data: DataType, name_to_plot=None, stack_axis=None, fermi_level=True, cbarmap=None, ax=None,
                        out=None, scale_coordinate=0.5, ylim=None, aux_errorbars=True, **kwargs):
    assert isinstance(data, xr.Dataset)

    if name_to_plot is None:
        var_names = [k for k in data.data_vars.keys() if '_std' not in k]
        assert len(var_names) == 1
        name_to_plot = var_names[0]
        assert (name_to_plot + '_std') in data.data_vars.keys()

    if len(data.data_vars[name_to_plot].dims) != 2:
        raise ValueError('In order to produce a stack plot, data must be image-like.'
                         'Passed data included dimensions: {}'.format(data.data_vars[name_to_plot].dims))

    fig = None
    inset_ax = None
    if ax is None:
        fig, ax = plt.subplots(figsize=kwargs.get('figsize', (11, 5,)))

    if inset_ax is None:
        inset_ax = inset_axes(ax, width='40%', height='5%', loc='upper left')

    if stack_axis is None:
        stack_axis = data.data_vars[name_to_plot].dims[0]

    skip_colorbar = True
    if cbarmap is None:
        skip_colorbar = False
        try:
            cbarmap = colorbarmaps_for_axis[stack_axis]
        except:
            cbarmap = generic_colorbarmap_for_data(data.coords[stack_axis], ax=inset_ax, ticks=kwargs.get('ticks'))

    cbar, cmap = cbarmap

    if not isinstance(cmap, matplotlib.colors.Colormap):
        # do our best
        try:
            cmap = cmap()
        except:
            # might still be fine
            pass

    # should be exactly two
    other_dim = [d for d in data.dims if d != stack_axis][0]
    other_coord = data.coords[other_dim]

    if 'eV' in data.dims and 'eV' != stack_axis and fermi_level:
        ax.axhline(0, linestyle='--', color='red')
        ax.fill_betweenx([-1e6, 1e6], 0, 0.2, color='black', alpha=0.07)
        ax.set_ylim(ylim)

    # real plotting here
    for i, (coord, value) in enumerate(data.T.iterate_axis(stack_axis)):
        delta = data.T.stride(generic_dim_names=False)[other_dim]
        data_for = value.copy(deep=True)
        data_for.coords[other_dim] = data_for.coords[other_dim].copy(deep=True)
        data_for.coords[other_dim].values = data_for.coords[other_dim].values.copy()
        data_for.coords[other_dim].values -= i * delta * scale_coordinate / 10

        scatter_with_std(data_for, name_to_plot, ax=ax, color=cmap(coord[stack_axis]))

        if aux_errorbars:
            assert ylim is not None
            data_for = data_for.copy(deep=True)
            flattened = data_for.data_vars[name_to_plot].copy(deep=True)
            flattened.values = ylim[0] * np.ones(flattened.values.shape)
            data_for = data_for.assign(**{name_to_plot: flattened})
            scatter_with_std(data_for, name_to_plot, ax=ax, color=cmap(coord[stack_axis]), fmt='none')


    ax.set_xlabel(other_dim)
    ax.set_ylabel(name_to_plot)
    fancy_labels(ax)

    try:
        if inset_ax and not skip_colorbar:
            inset_ax.set_xlabel(stack_axis, fontsize=16)

            fancy_labels(inset_ax)
            cbar(ax=inset_ax, **kwargs)
    except TypeError:
        # colorbar already rendered
        pass

    if out is not None:
        plt.savefig(path_for_plot(out), dpi=400)
        return path_for_plot(out)

    return fig, ax
Пример #13
0
def flat_stack_plot(data: DataType, stack_axis=None, fermi_level=True, cbarmap=None, ax=None,
                    mode='line', title=None, out=None, transpose=False, **kwargs):
    data = normalize_to_spectrum(data)
    if len(data.dims) != 2:
        raise ValueError('In order to produce a stack plot, data must be image-like.'
                         'Passed data included dimensions: {}'.format(data.dims))

    fig = None
    inset_ax = None
    if ax is None:
        fig, ax = plt.subplots(figsize=kwargs.get('figsize', (7, 5,)))
        inset_ax = inset_axes(ax, width='40%', height='5%', loc=1)

    if stack_axis is None:
        stack_axis = data.dims[0]

    skip_colorbar = True
    if cbarmap is None:
        skip_colorbar = False
        try:
            cbarmap = colorbarmaps_for_axis[stack_axis]
        except KeyError:
            cbarmap = generic_colorbarmap_for_data(data.coords[stack_axis], ax=inset_ax, ticks=kwargs.get('ticks'))

    cbar, cmap = cbarmap

    # should be exactly two
    other_dim = [d for d in data.dims if d != stack_axis][0]
    other_coord = data.coords[other_dim]

    if not isinstance(cmap, matplotlib.colors.Colormap):
        # do our best
        try:
            cmap = cmap()
        except:
            # might still be fine
            pass

    if 'eV' in data.dims and 'eV' != stack_axis and fermi_level:
        if transpose:
            ax.axhline(0, color='red', alpha=0.8, linestyle='--', linewidth=1)
        else:
            ax.axvline(0, color='red', alpha=0.8, linestyle='--', linewidth=1)

    # meat of the plotting
    for coord_dict, marginal in list(data.T.iterate_axis(stack_axis)):
        if transpose:
            if mode == 'line':
                ax.plot(marginal.values, marginal.coords[marginal.dims[0]].values, color=cmap(coord_dict[stack_axis]), **kwargs)
            else:
                assert mode == 'scatter'
                raise NotImplementedError()
        else:
            if mode == 'line':
                marginal.plot(ax=ax, color=cmap(coord_dict[stack_axis]), **kwargs)
            else:
                assert mode == 'scatter'
                ax.scatter(*marginal.T.to_arrays(), color=cmap(coord_dict[stack_axis]), **kwargs)
                ax.set_xlabel(marginal.dims[0])

    ax.set_xlabel(label_for_dim(data, ax.get_xlabel()))
    ax.set_ylabel('Spectrum Intensity (arb).')
    ax.set_title(title, fontsize=14)
    ax.set_xlim([other_coord.min().item(), other_coord.max().item()])

    try:
        if inset_ax is not None and not skip_colorbar:
            inset_ax.set_xlabel(stack_axis, fontsize=16)
            fancy_labels(inset_ax)

            cbar(ax=inset_ax, **kwargs)
    except TypeError:
        # already rendered
        pass

    if out is not None:
        plt.savefig(path_for_plot(out), dpi=400)
        return path_for_plot(out)

    return fig, ax
Пример #14
0
def plot_spatial_reference(reference_map: DataType,
                           data_list: List[DataType],
                           offset_list: Optional[List[Dict[str, Any]]] = None,
                           annotation_list: Optional[List[str]] = None,
                           out: Optional[str] = None,
                           plot_refs: bool = True):
    """
    Helpfully plots data against a reference scanning dataset. This is essential to understand
    where data was taken and can be used early in the analysis phase in order to highlight the
    location of your datasets against core levels, etc.

    :param reference_map: A scanning photoemission like dataset
    :param data_list: A list of datasets you want to plot the relative locations of
    :param offset_list: Optionally, offsets given as coordinate dicts
    :param annotation_list: Optionally, text annotations for the data
    :param out: Where to save the figure if we are outputting to disk
    :param plot_refs: Whether to plot reference figures for each of the pieces of data in `data_list`
    :return:
    """
    if offset_list is None:
        offset_list = [None] * len(data_list)

    if annotation_list is None:
        annotation_list = [str(i + 1) for i in range(len(data_list))]

    normalize_to_spectrum(reference_map)

    n_references = len(data_list)
    if n_references == 1 and plot_refs:
        fig, axes = plt.subplots(1, 2, figsize=(
            12,
            5,
        ))
        ax = axes[0]
        ax_refs = [axes[1]]
    elif plot_refs:
        n_extra_axes = 1 + (n_references // 4)
        fig = plt.figure(figsize=(
            6 * (1 + n_extra_axes),
            5,
        ))
        spec = gridspec.GridSpec(ncols=2 * (1 + n_extra_axes),
                                 nrows=2,
                                 figure=fig)
        ax = fig.add_subplot(spec[:2, :2])

        ax_refs = [
            fig.add_subplot(spec[i // (2 * n_extra_axes),
                                 2 + i % (2 * n_extra_axes)])
            for i in range(n_references)
        ]
    else:
        ax_refs = []
        fig, ax = plt.subplots(1, 1, figsize=(
            6,
            5,
        ))

    try:
        reference_map = reference_map.S.spectra[0]
    except Exception:
        pass

    reference_map = reference_map.S.mean_other(['x', 'y', 'z'])

    ref_dims = reference_map.dims[::-1]

    assert len(reference_map.dims) == 2
    reference_map.S.plot(ax=ax, cmap='Blues')

    cmap = cm.get_cmap('Reds')
    rendered_annotations = []
    for i, (data, offset, annotation) in enumerate(
            zip(data_list, offset_list, annotation_list)):
        if offset is None:
            try:
                offset = data.S.logical_offsets - reference_map.S.logical_offsets
            except ValueError:
                offset = {}

        coords = {c: unwrap_xarray_item(data.coords[c]) for c in ref_dims}
        n_array_coords = len([
            cv for cv in coords.values()
            if isinstance(cv, (np.ndarray, xr.DataArray))
        ])
        color = cmap(0.4 + (0.5 * i / len(data_list)))
        x = coords[ref_dims[0]] + offset.get(ref_dims[0], 0)
        y = coords[ref_dims[1]] + offset.get(ref_dims[1], 0)
        ref_x, ref_y = x, y
        off_x, off_y = 0, 0
        scale = 0.03

        if n_array_coords == 0:
            off_y = 1
            ax.scatter([x], [y], s=60, color=color)
        if n_array_coords == 1:
            if isinstance(x, (np.ndarray, xr.DataArray)):
                y = [y] * len(x)
                ref_x = np.min(x)
                off_x = -1
            else:
                x = [x] * len(y)
                ref_y = np.max(y)
                off_y = 1

            ax.plot(x, y, color=color, linewidth=3)
        if n_array_coords == 2:
            off_y = 1
            min_x, max_x = np.min(x), np.max(x)
            min_y, max_y = np.min(y), np.max(y)
            ref_x, ref_y = min_x, max_y

            color = cmap(0.4 + (0.5 * i / len(data_list)), alpha=0.5)
            rect = patches.Rectangle((min_x, min_y),
                                     max_x - min_x,
                                     max_y - min_y,
                                     facecolor=color)
            color = cmap(0.4 + (0.5 * i / len(data_list)))

            ax.add_patch(rect)

        dp = ddata_daxis_units(ax)
        text_location = np.asarray([
            ref_x,
            ref_y,
        ]) + dp * scale * np.asarray([off_x, off_y])
        text = ax.annotate(annotation, text_location, color='black', size=15)
        rendered_annotations.append(text)
        text.set_path_effects([
            path_effects.Stroke(linewidth=2, foreground='white'),
            path_effects.Normal()
        ])
        if plot_refs:
            ax_ref = ax_refs[i]
            keep_preference = list(ref_dims) + [
                'eV',
                'temperature',
                'kz',
                'hv',
                'kp',
                'kx',
                'ky',
                'phi',
                'theta',
                'beta',
                'pixel',
            ]
            keep = [d for d in keep_preference if d in data.dims][:2]
            data.S.mean_other(keep).S.plot(ax=ax_ref)
            ax_ref.set_title(annotation)
            fancy_labels(ax_ref)
            frame_with(ax_ref, color=color, linewidth=3)

    ax.set_title('')
    remove_colorbars()
    fancy_labels(ax)
    plt.tight_layout()

    try:
        from adjustText import adjust_text
        adjust_text(rendered_annotations,
                    ax=ax,
                    avoid_points=False,
                    avoid_objects=False,
                    avoid_self=False,
                    autoalign='xy')
    except ImportError:
        pass

    if out is not None:
        plt.savefig(path_for_plot(out), dpi=400)
        return path_for_plot(out)

    return fig, [ax] + ax_refs
Пример #15
0
def reference_scan_spatial(data, out=None, **kwargs):
    data = normalize_to_spectrum(data)

    dims = [d for d in data.dims if d in {'cycle', 'phi', 'eV'}]

    summed_data = data.sum(dims, keep_attrs=True)

    fig, ax = plt.subplots(3, 2, figsize=(15, 15))
    flat_axes = list(itertools.chain(*ax))

    summed_data.plot(ax=flat_axes[0])
    flat_axes[0].set_title(r'Full \textbf{eV} range')

    dims_except_eV = [d for d in dims if d != 'eV']
    summed_data = data.sum(dims_except_eV)

    mul = 0.2
    rng = data.coords['eV'].max().item() - data.coords['eV'].min().item()
    offset = data.coords['eV'].max().item()
    if offset > 0:
        offset = 0

    if rng > 3:
        mul = rng / 5.

    for i in range(5):
        low_e, high_e = -mul * (i + 1) + offset, -mul * i + offset
        title = r'\textbf{eV}' + ': {:.2g} to {:.2g}'.format(low_e, high_e)
        summed_data.sel(eV=slice(low_e, high_e)).sum('eV').plot(
            ax=flat_axes[i + 1])
        flat_axes[i + 1].set_title(title)

    y_range = flat_axes[0].get_ylim()
    x_range = flat_axes[0].get_xlim()
    delta_one_percent = ((x_range[1] - x_range[0]) / 100,
                         (y_range[1] - y_range[0]) / 100)

    smart_delta = (2 * delta_one_percent[0], -1.5 * delta_one_percent[0])

    referenced = data.S.referenced_scans

    # idea here is to collect points by those that are close together, then
    # only plot one annotation
    condensed = []
    cutoff = 3  # 3 percent
    for index, row in referenced.iterrows():
        ff = simple_load(index)

        x, y, _ = ff.S.sample_pos
        found = False
        for cx, cy, cl in condensed:
            if abs(cx - x) < cutoff * abs(delta_one_percent[0]) and abs(
                    cy - y) < cutoff * abs(delta_one_percent[1]):
                cl.append(index)
                found = True
                break

        if not found:
            condensed.append((x, y, [index]))

    for fax in flat_axes:
        for cx, cy, cl in condensed:
            annotate_point(fax, (
                cx,
                cy,
            ),
                           ','.join([str(l) for l in cl]),
                           delta=smart_delta,
                           fontsize='large')

    plt.tight_layout()

    if out is not None:
        plt.savefig(path_for_plot(out), dpi=400)
        return path_for_plot(out)

    return fig, ax
Пример #16
0
def spin_polarized_spectrum(spin_dr,
                            title=None,
                            ax=None,
                            out=None,
                            component='y',
                            scatter=False,
                            stats=False,
                            norm=None):
    if ax is None:
        _, ax = plt.subplots(2, 1, sharex=True)

    if stats:
        spin_dr = bootstrap(lambda x: x)(spin_dr, N=100)
        pol = mean_and_deviation(to_intensity_polarization(spin_dr))
        counts = mean_and_deviation(spin_dr)
    else:
        counts = spin_dr
        pol = to_intensity_polarization(counts)

    ax_left = ax[0]
    ax_right = ax[1]

    up = counts.down.data
    down = counts.up.data

    energies = spin_dr.coords['eV'].values
    min_e, max_e = np.min(energies), np.max(energies)

    # Plot the spectra
    if stats:
        if scatter:
            scatter_with_std(counts, 'up', color='red', ax=ax_left)
            scatter_with_std(counts, 'down', color='blue', ax=ax_left)
        else:
            v, s = counts.up.values, counts.up_std.values
            ax_left.plot(energies, v, 'r')
            ax_left.fill_between(energies, v - s, v + s, color='r', alpha=0.25)

            v, s = counts.down.values, counts.down_std.values
            ax_left.plot(energies, v, 'b')
            ax_left.fill_between(energies, v - s, v + s, color='b', alpha=0.25)
    else:
        ax_left.plot(energies, up, 'r')
        ax_left.plot(energies, down, 'b')

    ax_left.set_title(
        title if title is not None else 'Spin spectrum {}'.format(''))
    ax_left.set_ylabel(r'\textbf{Spectrum Intensity}')
    ax_left.set_xlabel(r'\textbf{Kinetic energy} (eV)')
    ax_left.set_xlim(min_e, max_e)

    max_up = np.max(up)
    max_down = np.max(down)
    ax_left.set_ylim(0, max(max_down, max_up) * 1.2)

    # Plot the polarization and associated statistical error bars
    if stats:
        if scatter:
            scatter_with_std(pol, 'polarization', ax=ax_right, color='black')
        else:
            v = pol.polarization.data
            s = pol.polarization_std.data
            ax_right.plot(energies, v, color='black')
            ax_right.fill_between(energies,
                                  v - s,
                                  v + s,
                                  color='black',
                                  alpha=0.25)

    else:
        ax_right.plot(energies, pol.polarization.data, color='black')
    ax_right.fill_between(energies, 0, 1, facecolor='blue', alpha=0.1)
    ax_right.fill_between(energies, -1, 0, facecolor='red', alpha=0.1)

    ax_right.set_title('Spin polarization, $\\text{S}_\\textbf{' + component +
                       '}$')
    ax_right.set_ylabel(r'\textbf{Polarization}')
    ax_right.set_xlabel(r'\textbf{Kinetic Energy} (eV)')
    ax_right.set_xlim(min_e, max_e)
    ax_right.axhline(0, color='white', linestyle=':')

    ax_right.set_ylim(-1, 1)
    ax_right.grid(True, axis='y')

    plt.tight_layout()

    if out is not None:
        savefig(out, dpi=400)
        plt.clf()
        return path_for_plot(out)
    else:
        pass

    return ax