예제 #1
0
def slice_plot(particle_group, stat_key='sigma_x', n_slice=40, slice_key='z'):
    """
    Complete slice plotting routine. Will plot the density of the slice key on the right axis. 
    """

    x_key = 'mean_' + slice_key
    y_key = stat_key
    slice_dat = slice_statistics(
        particle_group,
        n_slice=n_slice,
        slice_key=slice_key,
        keys=[x_key, y_key, 'ptp_' + slice_key, 'charge'])

    slice_dat['density'] = slice_dat['charge'] / slice_dat['ptp_' + slice_key]
    y2_key = 'density'
    fig, ax = plt.subplots()

    # Get nice arrays
    x, _, prex = nice_array(slice_dat[x_key])
    y, _, prey = nice_array(slice_dat[y_key])
    y2, _, prey2 = nice_array(slice_dat[y2_key])

    # Add prefix to units
    x_units = prex + particle_group.units(x_key).unitSymbol
    y_units = prey + particle_group.units(y_key).unitSymbol

    # Convert to Amps if possible
    y2_units = f'C/{particle_group.units(x_key)}'
    if y2_units == 'C/s':
        y2_units = 'A'
    y2_units = prey2 + y2_units

    # Labels
    ax.set_xlabel(f'{x_key} ({x_units})')
    ax.set_ylabel(f'{y_key} ({y_units})')

    # Main plot
    ax.plot(x, y, color='black')

    #ax.set_ylim(0, 1.1*ymax )

    ax2 = ax.twinx()
    ax2.set_ylabel(f'{y2_key} ({y2_units})')
    ax2.fill_between(x, 0, y2, color='black', alpha=0.2)
예제 #2
0
def plot_stats(astra_object, keys=['norm_emit_x', 'sigma_z'], sections=['cavity', 'solenoid'], fieldmaps = {}, verbose=False):
    """
    Plots stats, with fieldmaps plotted from seections. 
    
    TODO: quadrupoles
    
    """
    
    astra_input = astra_object.input
    
    fmaps = load_fieldmaps(astra_input, sections=sections, verbose=verbose)
    
    assert len(sections) == 2, 'TODO: more general'
    
    nplots = len(keys) + 1
    
    fig, axs = plt.subplots(nplots)
    
    # Make RHS axis for the solenoid field. 
    
    
    xdat = astra_object.stat('mean_z')
    xmin = min(xdat)
    xmax = max(xdat)
    for i, key in enumerate(keys):
        ax = axs[i]
        unit = astra_object.units(key)
        ydat = astra_object.stat(key)
        
        ndat, factor, prefix = nice_array(ydat)
        label = f'{key} ({prefix}{unit})'
        ax.set_ylabel(label)
        ax.set_xlim(xmin, xmax)
        ax.plot(xdat, ndat)
    

    ax1 = axs[-1]
    
    ax1rhs = ax1.twinx()  
    ax = [ax1, ax1rhs]
    
    ylabel = {'cavity': '$E_z$ (MV/m)', 'solenoid':'$B_z$ (T)'}
    color = {'cavity': 'green', 'solenoid':'blue'}
    
    for i, section in enumerate(sections):
        a = ax[i]
        ixlist = find_fieldmap_ixlist(astra_input, section)
        for ix in ixlist:
            dat = fieldmap_data(astra_input, section=section, index=ix, fieldmaps=fmaps, verbose=verbose)
            label = f'{section}_{ix}'
            c = color[section]
            a.plot(*dat.T, label=label, color=c)
        a.set_ylabel(ylabel[section])
    ax1.set_xlabel('$z$ (m)')
    ax1.set_xlim(xmin, xmax)
예제 #3
0
def density_plot(particle_group, key='x', bins=None, **kwargs):
    """
    1D density plot. Also see: marginal_plot
    
    Example:
    
        density_plot(P, 'x', bins=100)   
    
    """

    if not bins:
        n = len(particle_group)
        bins = int(n / 100)

    # Scale to nice units and get the factor, unit prefix
    x, f1, p1 = nice_array(particle_group[key])
    w = particle_group['weight']
    u1 = particle_group.units(key).unitSymbol
    ux = p1 + u1

    labelx = f'{key} ({ux})'

    fig, ax = plt.subplots(**kwargs)

    hist, bin_edges = np.histogram(x, bins=bins, weights=w)
    hist_x = bin_edges[:-1] + np.diff(bin_edges) / 2
    hist_width = np.diff(bin_edges)
    hist_y, hist_f, hist_prefix = nice_array(hist / hist_width)
    ax.bar(hist_x, hist_y, hist_width, color='grey')
    # Special label for C/s = A
    if u1 == 's':
        _, hist_prefix = nice_scale_prefix(hist_f / f1)
        ax.set_ylabel(f'{hist_prefix}A')
    else:
        ax.set_ylabel(f'{hist_prefix}C/{ux}')

    ax.set_xlabel(labelx)

    return fig
예제 #4
0
def plot_stat(impact_object, y='sigma_x', x='mean_z', nice=True):
    """
    Plots stat output of key y vs key
    
    If particles have the same stat key, these will also be plotted.
    
    If nice, a nice SI prefix and scaling will be used to make the numbers reasonably sized.
    
    """
    I = impact_object  # convenience
    fig, ax = plt.subplots()

    units1 = str(I.units(x))
    units2 = str(I.units(y))

    X = I.stat(x)
    Y = I.stat(y)

    if nice:
        X, f1, prefix1 = nice_array(X)
        Y, f2, prefix2 = nice_array(Y)
        units1 = prefix1 + units1
        units2 = prefix2 + units2
    else:
        f1 = 1
        f2 = 1
    ax.set_xlabel(x + f' ({units1})')
    ax.set_ylabel(y + f' ({units2})')

    # line plot
    plt.plot(X, Y)

    try:
        ax.scatter([I.particles[name][x] / f1 for name in I.particles],
                   [I.particles[name][y] / f2 for name in I.particles],
                   color='red')
    except:
        pass
예제 #5
0
def plot_stats(astra_object, keys=['norm_emit_x', 'sigma_z'], sections=['cavity', 'solenoid'], fieldmaps = {}, verbose=False):
    """
    Plots stats, with fieldmaps plotted from seections.

    TODO: quadrupoles

    """

    astra_input = astra_object.input

    fmaps = load_fieldmaps(astra_input, sections=sections, verbose=verbose)

    assert len(sections) == 2, 'TODO: more general'

    nplots = len(keys) + 1

    fig, axs = plt.subplots(nplots)

    # Make RHS axis for the solenoid field.


    xdat = astra_object.stat('mean_z')
    xmin = min(xdat)
    xmax = max(xdat)
    for i, key in enumerate(keys):
        ax = axs[i]
        unit = astra_object.units(key)
        ydat = astra_object.stat(key)

        ndat, factor, prefix = nice_array(ydat)
        label = f'{key} ({prefix}{unit})'
        ax.set_ylabel(label)
        ax.set_xlim(xmin, xmax)
        ax.plot(xdat, ndat)

    add_fieldmaps_to_axes(astra_object, axs[-1], bounds=(xmin, xmax),
                           sections=['cavity', 'solenoid'],
                          include_labels=True)
예제 #6
0
def marginal_plot(particle_group, key1='t', key2='p', bins=None):
    """
    Density plot and projections
    
    Example:
    
        marginal_plot(P, 't', 'energy', bins=200)   
    
    """

    if not bins:
        n = len(particle_group)
        bins = int(np.sqrt(n / 4))

    # Scale to nice units and get the factor, unit prefix
    x, f1, p1 = nice_array(particle_group[key1])
    y, f2, p2 = nice_array(particle_group[key2])

    w = particle_group['weight']

    u1 = particle_group.units(key1).unitSymbol
    u2 = particle_group.units(key2).unitSymbol
    ux = p1 + u1
    uy = p2 + u2

    labelx = f'{key1} ({ux})'
    labely = f'{key2} ({uy})'

    fig = plt.figure()

    gs = GridSpec(4, 4)

    ax_joint = fig.add_subplot(gs[1:4, 0:3])
    ax_marg_x = fig.add_subplot(gs[0, 0:3])
    ax_marg_y = fig.add_subplot(gs[1:4, 3])
    #ax_info = fig.add_subplot(gs[0, 3:4])
    #ax_info.table(cellText=['a'])

    # Proper weighting
    ax_joint.hexbin(x,
                    y,
                    C=w,
                    reduce_C_function=np.sum,
                    gridsize=bins,
                    cmap=cmap,
                    vmin=1e-15)

    # Manual histogramming version
    #H, xedges, yedges = np.histogram2d(x, y, weights=w, bins=bins)
    #extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
    #ax_joint.imshow(H.T, cmap=cmap, vmin=1e-16, origin='lower', extent=extent, aspect='auto')

    dx = x.ptp() / bins
    dy = y.ptp() / bins
    ax_marg_x.hist(x, weights=w / dx / f1, bins=bins, color='gray')
    ax_marg_y.hist(y,
                   orientation="horizontal",
                   weights=w / dy,
                   bins=bins,
                   color='gray')

    # Turn off tick labels on marginals
    plt.setp(ax_marg_x.get_xticklabels(), visible=False)
    plt.setp(ax_marg_y.get_yticklabels(), visible=False)

    # Set labels on joint
    ax_joint.set_xlabel(labelx)
    ax_joint.set_ylabel(labely)

    # Set labels on marginals
    ax_marg_x.set_ylabel(f'C/{u1}')
    ax_marg_y.set_xlabel(f'C/{uy}')
    plt.show()
예제 #7
0
def density_and_slice_plot(particle_group,
                           key1='t',
                           key2='p',
                           stat_keys=['norm_emit_x', 'norm_emit_y'],
                           bins=100,
                           n_slice=30):
    """
    Density plot and projections
    
    Example:
    
        marginal_plot(P, 't', 'energy', bins=200)   
    
    """

    # Scale to nice units and get the factor, unit prefix
    x, f1, p1 = nice_array(particle_group[key1])
    y, f2, p2 = nice_array(particle_group[key2])
    w = particle_group['weight']

    u1 = particle_group.units(key1).unitSymbol
    u2 = particle_group.units(key2).unitSymbol
    ux = p1 + u1
    uy = p2 + u2

    labelx = f'{key1} ({ux})'
    labely = f'{key2} ({uy})'

    fig, ax = plt.subplots()

    ax.set_xlabel(labelx)
    ax.set_ylabel(labely)

    # Proper weighting
    #ax_joint.hexbin(x, y, C=w, reduce_C_function=np.sum, gridsize=bins, cmap=cmap, vmin=1e-15)

    # Manual histogramming version
    H, xedges, yedges = np.histogram2d(x, y, weights=w, bins=bins)
    extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
    ax.imshow(H.T,
              cmap=cmap,
              vmin=1e-16,
              origin='lower',
              extent=extent,
              aspect='auto')

    # Slice data
    slice_dat = slice_statistics(particle_group,
                                 n_slice=n_slice,
                                 slice_key=key1,
                                 keys=stat_keys +
                                 ['ptp_' + key1, 'mean_' + key1, 'charge'])

    slice_dat['density'] = slice_dat['charge'] / slice_dat['ptp_' + key1]

    #
    ax2 = ax.twinx()
    #ax2.set_ylim(0, 1e-6)
    x2 = slice_dat['mean_' + key1] / f1
    ulist = [particle_group.units(k).unitSymbol for k in stat_keys]

    max2 = max([slice_dat[k].ptp() for k in stat_keys])

    f3, p3 = nice_scale_prefix(max2)

    u2 = ulist[0]
    assert all([u == u2 for u in ulist])
    u2 = p3 + u2
    for k in stat_keys:
        ax2.plot(x2, slice_dat[k] / f3, label=k)
    ax2.legend()
    ax2.set_ylabel(f'({u2})')
    ax2.set_ylim(bottom=0)

    # Add density
    y2 = slice_dat['density']
    y2 = y2 * max2 / y2.max() / f3 / 2
    ax2.fill_between(x2, 0, y2, color='black', alpha=0.1)
예제 #8
0
def plot_stats_with_layout(astra_object, ykeys=['sigma_x', 'sigma_y'], ykeys2=['sigma_z'],
                           xkey='mean_z', xlim=None,
                           nice=True,
                           include_layout=False,
                           include_labels=True,
                           include_particles=True,
                           include_legend=True, **kwargs):
    """
    Plots stat output multiple keys.

    If a list of ykeys2 is given, these will be put on the right hand axis. This can also be given as a single key.

    Logical switches, all default to True:
        nice: a nice SI prefix and scaling will be used to make the numbers reasonably sized.

        include_legend: The plot will include the legend

        include_layout: the layout plot will be displayed at the bottom

        include_labels: the layout will include element labels.

    Copied almost verbatim from lume-impact's Impact.plot.plot_stats_with_layout

    """
    I = astra_object # convenience

    if include_layout:
        fig, all_axis = plt.subplots(2, gridspec_kw={'height_ratios': [4, 1]}, **kwargs)
        ax_layout = all_axis[-1]
        ax_plot = [all_axis[0]]
    else:
        fig, all_axis = plt.subplots( **kwargs)
        ax_plot = [all_axis]

    # collect axes
    if isinstance(ykeys, str):
        ykeys = [ykeys]

    if ykeys2:
        if isinstance(ykeys2, str):
            ykeys2 = [ykeys2]
        ax_plot.append(ax_plot[0].twinx())

    # No need for a legend if there is only one plot
    if len(ykeys)==1 and not ykeys2:
        include_legend=False

    #assert xkey == 'mean_z', 'TODO: other x keys'

    X = I.stat(xkey)

    # Only get the data we need
    if xlim:
        good = np.logical_and(X >= xlim[0], X <= xlim[1])
        X = X[good]
    else:
        xlim = X.min(), X.max()
        good = slice(None,None,None) # everything

    # Try particles within these bounds
    Pnames = []
    X_particles = []

    if include_particles:
        try:
            for pname in range(len(I.particles)): # Modified from Impact
                xp = I.particles[pname][xkey]
                if xp >= xlim[0] and xp <= xlim[1]:
                    Pnames.append(pname)
                    X_particles.append(xp)
            X_particles = np.array(X_particles)
        except:
            Pnames = []
    else:
        Pnames = []

    # X axis scaling
    units_x = str(I.units(xkey))
    if nice:
        X, factor_x, prefix_x = nice_array(X)
        units_x  = prefix_x+units_x
    else:
        factor_x = 1

    # set all but the layout
    for ax in ax_plot:
        ax.set_xlim(xlim[0]/factor_x, xlim[1]/factor_x)
        ax.set_xlabel(f'{xkey} ({units_x})')


    # Draw for Y1 and Y2

    linestyles = ['solid','dashed']

    ii = -1 # counter for colors
    for ix, keys in enumerate([ykeys, ykeys2]):
        if not keys:
            continue
        ax = ax_plot[ix]
        linestyle = linestyles[ix]

        # Check that units are compatible
        ulist = [I.units(key) for key in keys]
        if len(ulist) > 1:
            for u2 in ulist[1:]:
                assert ulist[0] == u2, f'Incompatible units: {ulist[0]} and {u2}'
        # String representation
        unit = str(ulist[0])

        # Data
        data = [I.stat(key)[good] for key in keys]



        if nice:
            factor, prefix = nice_scale_prefix(np.ptp(data))
            unit = prefix+unit
        else:
            factor = 1

        # Make a line and point
        for key, dat in zip(keys, data):
            #
            ii += 1
            color = 'C'+str(ii)
            ax.plot(X, dat/factor, label=f'{key} ({unit})', color=color, linestyle=linestyle)

            # Particles
            if Pnames:
                try:
                    Y_particles = np.array([I.particles[name][key] for name in Pnames])
                    ax.scatter(X_particles/factor_x, Y_particles/factor, color=color)
                except:
                    pass
        ax.set_ylabel(', '.join(keys)+f' ({unit})')
        #if len(keys) > 1:

    # Collect legend
    if include_legend:
        lines = []
        labels = []
        for ax in ax_plot:
            a, b = ax.get_legend_handles_labels()
            lines += a
            labels += b
        ax_plot[0].legend(lines, labels, loc='best')

    # Layout
    if include_layout:

        # Gives some space to the top plot
        #ax_layout.set_ylim(-1, 1.5)

        if xkey == 'mean_z':
            #ax_layout.set_axis_off()
            ax_layout.set_xlim(xlim[0], xlim[1])
        else:
            ax_layout.set_xlabel('mean_z')
            xlim = (0, I.stop)
        add_fieldmaps_to_axes(I,  ax_layout, bounds=xlim, include_labels=include_labels)
예제 #9
0
def plot_stats_with_layout(impact_object,
                           ykeys=['sigma_x', 'sigma_y'],
                           ykeys2=['mean_kinetic_energy'],
                           xkey='mean_z',
                           xlim=None,
                           ylim=None,
                           ylim2=None,
                           nice=True,
                           tex=True,
                           include_layout=True,
                           include_labels=True,
                           include_markers=True,
                           include_particles=True,
                           include_legend=True,
                           return_figure=False,
                           **kwargs):
    """
    Plots stat output multiple keys.
    
    If a list of ykeys2 is given, these will be put on the right hand axis. This can also be given as a single key. 
    
    Logical switches:
        nice: a nice SI prefix and scaling will be used to make the numbers reasonably sized. Default: True
        
        tex: use mathtext (TeX) for plot labels. Default: True
        
        include_legend: The plot will include the legend.  Default: True
        
        include_layout: the layout plot will be displayed at the bottom.  Default: True
        
        include_labels: the layout will include element labels.  Default: True
        
        return_figure: return the figure object for further manipulation. Default: False

    """
    I = impact_object  # convenience

    if include_layout:
        fig, all_axis = plt.subplots(2,
                                     gridspec_kw={'height_ratios': [4, 1]},
                                     **kwargs)
        ax_layout = all_axis[-1]
        ax_plot = [all_axis[0]]
    else:
        fig, all_axis = plt.subplots(**kwargs)
        ax_plot = [all_axis]

    # collect axes
    if isinstance(ykeys, str):
        ykeys = [ykeys]

    if ykeys2:
        if isinstance(ykeys2, str):
            ykeys2 = [ykeys2]
        ax_twinx = ax_plot[0].twinx()
        ax_plot.append(ax_twinx)

    # No need for a legend if there is only one plot
    if len(ykeys) == 1 and not ykeys2:
        include_legend = False

    #assert xkey == 'mean_z', 'TODO: other x keys'

    X = I.stat(xkey)

    # Only get the data we need
    if xlim:
        good = np.logical_and(X >= xlim[0], X <= xlim[1])
        X = X[good]
    else:
        xlim = X.min(), X.max()
        good = slice(None, None, None)  # everything

    # Try particles within these bounds
    Pnames = []
    X_particles = []

    if include_particles:
        try:
            for pname in I.particles:
                xp = I.particles[pname][xkey]
                if xp >= xlim[0] and xp <= xlim[1]:
                    Pnames.append(pname)
                    X_particles.append(xp)
            X_particles = np.array(X_particles)
        except:
            Pnames = []
    else:
        Pnames = []

    # X axis scaling
    units_x = str(I.units(xkey))
    if nice:
        X, factor_x, prefix_x = nice_array(X)
        units_x = prefix_x + units_x
    else:
        factor_x = 1

    # set all but the layout

    # Handle tex labels

    xlabel = mathlabel(xkey, units=units_x, tex=tex)

    for ax in ax_plot:
        ax.set_xlim(xlim[0] / factor_x, xlim[1] / factor_x)
        ax.set_xlabel(xlabel)

    # Draw for Y1 and Y2

    linestyles = ['solid', 'dashed']

    ii = -1  # counter for colors
    for ix, keys in enumerate([ykeys, ykeys2]):
        if not keys:
            continue
        ax = ax_plot[ix]
        linestyle = linestyles[ix]

        # Check that units are compatible
        ulist = [I.units(key) for key in keys]
        if len(ulist) > 1:
            for u2 in ulist[1:]:
                assert ulist[
                    0] == u2, f'Incompatible units: {ulist[0]} and {u2}'
        # String representation
        unit = str(ulist[0])

        # Data
        data = [I.stat(key)[good] for key in keys]

        if nice:
            factor, prefix = nice_scale_prefix(np.ptp(data))
            unit = prefix + unit
        else:
            factor = 1

        # Make a line and point
        for key, dat in zip(keys, data):
            #
            ii += 1
            color = 'C' + str(ii)

            # Handle tex labels
            label = mathlabel(key, units=unit, tex=tex)
            ax.plot(X,
                    dat / factor,
                    label=label,
                    color=color,
                    linestyle=linestyle)

            # Particles
            if Pnames:
                try:
                    Y_particles = np.array(
                        [I.particles[name][key] for name in Pnames])
                    ax.scatter(X_particles / factor_x,
                               Y_particles / factor,
                               color=color)
                except:
                    pass

        # Handle tex labels
        ylabel = mathlabel(*keys, units=unit, tex=tex)
        ax.set_ylabel(ylabel)

        # Set limits, considering the scaling.
        if ix == 0 and ylim:
            new_ylim = np.array(ylim) / factor
            ax.set_ylim(new_ylim)
        # Set limits, considering the scaling.
        if ix == 1 and ylim2:
            pass
            # TODO
            if ylim2:
                new_ylim2 = np.array(ylim2) / factor
                ax_twinx.set_ylim(new_ylim2)
            else:
                pass

    # Collect legend
    if include_legend:
        lines = []
        labels = []
        for ax in ax_plot:
            a, b = ax.get_legend_handles_labels()
            lines += a
            labels += b
        ax_plot[0].legend(lines, labels, loc='best')

    # Layout
    if include_layout:

        # Gives some space to the top plot
        ax_layout.set_ylim(-1, 1.5)

        if xkey == 'mean_z':
            ax_layout.set_axis_off()
            ax_layout.set_xlim(xlim[0], xlim[1])
        else:
            ax_layout.set_xlabel('mean_z')
            xlim = (0, I.stop)
        add_layout_to_axes(I,
                           ax_layout,
                           bounds=xlim,
                           include_labels=include_labels,
                           include_markers=include_markers)

    if return_figure:
        return fig
예제 #10
0
def marginal_plot(particle_group, key1='t', key2='p', bins=None, **kwargs):
    """
    Density plot and projections
    
    Example:
    
        marginal_plot(P, 't', 'energy', bins=200)   
    
    """

    if not bins:
        n = len(particle_group)
        bins = int(np.sqrt(n / 4))

    # Scale to nice units and get the factor, unit prefix
    x, f1, p1 = nice_array(particle_group[key1])
    y, f2, p2 = nice_array(particle_group[key2])

    w = particle_group['weight']

    u1 = particle_group.units(key1).unitSymbol
    u2 = particle_group.units(key2).unitSymbol
    ux = p1 + u1
    uy = p2 + u2

    labelx = f'{key1} ({ux})'
    labely = f'{key2} ({uy})'

    fig = plt.figure(**kwargs)

    gs = GridSpec(4, 4)

    ax_joint = fig.add_subplot(gs[1:4, 0:3])
    ax_marg_x = fig.add_subplot(gs[0, 0:3])
    ax_marg_y = fig.add_subplot(gs[1:4, 3])
    #ax_info = fig.add_subplot(gs[0, 3:4])
    #ax_info.table(cellText=['a'])

    # Proper weighting
    ax_joint.hexbin(x,
                    y,
                    C=w,
                    reduce_C_function=np.sum,
                    gridsize=bins,
                    cmap=CMAP0,
                    vmin=1e-20)

    # Manual histogramming version
    #H, xedges, yedges = np.histogram2d(x, y, weights=w, bins=bins)
    #extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
    #ax_joint.imshow(H.T, cmap=cmap, vmin=1e-16, origin='lower', extent=extent, aspect='auto')

    # Top histogram
    # Old method:
    #dx = x.ptp()/bins
    #ax_marg_x.hist(x, weights=w/dx/f1, bins=bins, color='gray')
    hist, bin_edges = np.histogram(x, bins=bins, weights=w)
    hist_x = bin_edges[:-1] + np.diff(bin_edges) / 2
    hist_width = np.diff(bin_edges)
    hist_y, hist_f, hist_prefix = nice_array(hist / hist_width)
    ax_marg_x.bar(hist_x, hist_y, hist_width, color='gray')
    # Special label for C/s = A
    if u1 == 's':
        _, hist_prefix = nice_scale_prefix(hist_f / f1)
        ax_marg_x.set_ylabel(f'{hist_prefix}A')
    else:
        ax_marg_x.set_ylabel(f'{hist_prefix}C/{ux}')

    # Side histogram
    # Old method:
    #dy = y.ptp()/bins
    #ax_marg_y.hist(y, orientation="horizontal", weights=w/dy, bins=bins, color='gray')
    hist, bin_edges = np.histogram(y, bins=bins, weights=w)
    hist_x = bin_edges[:-1] + np.diff(bin_edges) / 2
    hist_width = np.diff(bin_edges)
    hist_y, hist_f, hist_prefix = nice_array(hist / hist_width)
    ax_marg_y.barh(hist_x, hist_y, hist_width, color='gray')
    ax_marg_y.set_xlabel(f'{hist_prefix}C/{uy}')

    # Turn off tick labels on marginals
    plt.setp(ax_marg_x.get_xticklabels(), visible=False)
    plt.setp(ax_marg_y.get_yticklabels(), visible=False)

    # Set labels on joint
    ax_joint.set_xlabel(labelx)
    ax_joint.set_ylabel(labely)

    return fig
예제 #11
0
def plot_stats_with_layout(gpt_object,
                           ykeys=['sigma_x', 'sigma_y'],
                           ykeys2=['mean_kinetic_energy'],
                           xkey='mean_z',
                           xlim=None,
                           nice=True,
                           include_layout=False,
                           include_labels=True,
                           include_legend=True,
                           **kwargs):
    """
    Plots stat output multiple keys.
    
    If a list of ykeys2 is given, these will be put on the right hand axis. This can also be given as a single key. 
    
    Logical switches, all default to True:
        nice: a nice SI prefix and scaling will be used to make the numbers reasonably sized.
        
        include_legend: The plot will include the legend
        
        include_layout: the layout plot will be displayed at the bottom
        
        include_labels: the layout will include element labels. 

    """
    I = gpt_object  # convenience

    if include_layout:
        fig, all_axis = plt.subplots(2,
                                     gridspec_kw={'height_ratios': [4, 1]},
                                     **kwargs)
        ax_layout = all_axis[-1]
        ax_plot = [all_axis[0]]
    else:
        fig, all_axis = plt.subplots(**kwargs)
        ax_plot = [all_axis]

    # collect axes
    if isinstance(ykeys, str):
        ykeys = [ykeys]

    if ykeys2:
        if isinstance(ykeys2, str):
            ykeys2 = [ykeys2]
        ax_plot.append(ax_plot[0].twinx())

    # No need for a legend if there is only one plot
    if len(ykeys) == 1 and not ykeys2:
        include_legend = False

    #assert xkey == 'mean_z', 'TODO: other x keys'

    X = I.stat(xkey)

    # Only get the data we need
    if xlim:
        good = np.logical_and(X >= xlim[0], X <= xlim[1])
        X = X[good]
    else:
        xlim = X.min(), X.max()
        good = slice(None, None, None)  # everything

    # X axis scaling
    units_x = str(I.units(xkey))
    if nice:
        X, factor_x, prefix_x = nice_array(X)
        units_x = prefix_x + units_x
    else:
        factor_x = 1

    # set all but the layout
    for ax in ax_plot:
        ax.set_xlim(xlim[0] / factor_x, xlim[1] / factor_x)
        ax.set_xlabel(f'{xkey} ({units_x})')

    # Draw for Y1 and Y2

    linestyles = ['solid', 'dashed']

    ii = -1  # counter for colors
    for ix, keys in enumerate([ykeys, ykeys2]):
        if not keys:
            continue
        ax = ax_plot[ix]
        linestyle = linestyles[ix]

        # Check that units are compatible
        ulist = [I.units(key) for key in keys]
        if len(ulist) > 1:
            for u2 in ulist[1:]:
                assert ulist[
                    0] == u2, f'Incompatible units: {ulist[0]} and {u2}'
        # String representation
        unit = str(ulist[0])

        # Data
        data = [I.stat(key)[good] for key in keys]

        if nice:
            factor, prefix = nice_scale_prefix(np.ptp(data))
            unit = prefix + unit
        else:
            factor = 1

        # Make a line and point
        for key, dat in zip(keys, data):
            #
            ii += 1
            color = 'C' + str(ii)
            ax.plot(X,
                    dat / factor,
                    label=f'{key} ({unit})',
                    color=color,
                    linestyle=linestyle)

        ax.set_ylabel(', '.join(keys) + f' ({unit})')
        #if len(keys) > 1:

    # Collect legend
    if include_legend:
        lines = []
        labels = []
        for ax in ax_plot:
            a, b = ax.get_legend_handles_labels()
            lines += a
            labels += b
        ax_plot[0].legend(lines, labels, loc='best')

    # Layout
    if include_layout:
        print('TODO include_layout')