def slice_plot(particle_group, stat_key='sigma_x', n_slice=40, slice_key='z'): """ Complete slice plotting routine. Will plot the density of the slice key on the right axis. """ x_key = 'mean_' + slice_key y_key = stat_key slice_dat = slice_statistics( particle_group, n_slice=n_slice, slice_key=slice_key, keys=[x_key, y_key, 'ptp_' + slice_key, 'charge']) slice_dat['density'] = slice_dat['charge'] / slice_dat['ptp_' + slice_key] y2_key = 'density' fig, ax = plt.subplots() # Get nice arrays x, _, prex = nice_array(slice_dat[x_key]) y, _, prey = nice_array(slice_dat[y_key]) y2, _, prey2 = nice_array(slice_dat[y2_key]) # Add prefix to units x_units = prex + particle_group.units(x_key).unitSymbol y_units = prey + particle_group.units(y_key).unitSymbol # Convert to Amps if possible y2_units = f'C/{particle_group.units(x_key)}' if y2_units == 'C/s': y2_units = 'A' y2_units = prey2 + y2_units # Labels ax.set_xlabel(f'{x_key} ({x_units})') ax.set_ylabel(f'{y_key} ({y_units})') # Main plot ax.plot(x, y, color='black') #ax.set_ylim(0, 1.1*ymax ) ax2 = ax.twinx() ax2.set_ylabel(f'{y2_key} ({y2_units})') ax2.fill_between(x, 0, y2, color='black', alpha=0.2)
def plot_stats(astra_object, keys=['norm_emit_x', 'sigma_z'], sections=['cavity', 'solenoid'], fieldmaps = {}, verbose=False): """ Plots stats, with fieldmaps plotted from seections. TODO: quadrupoles """ astra_input = astra_object.input fmaps = load_fieldmaps(astra_input, sections=sections, verbose=verbose) assert len(sections) == 2, 'TODO: more general' nplots = len(keys) + 1 fig, axs = plt.subplots(nplots) # Make RHS axis for the solenoid field. xdat = astra_object.stat('mean_z') xmin = min(xdat) xmax = max(xdat) for i, key in enumerate(keys): ax = axs[i] unit = astra_object.units(key) ydat = astra_object.stat(key) ndat, factor, prefix = nice_array(ydat) label = f'{key} ({prefix}{unit})' ax.set_ylabel(label) ax.set_xlim(xmin, xmax) ax.plot(xdat, ndat) ax1 = axs[-1] ax1rhs = ax1.twinx() ax = [ax1, ax1rhs] ylabel = {'cavity': '$E_z$ (MV/m)', 'solenoid':'$B_z$ (T)'} color = {'cavity': 'green', 'solenoid':'blue'} for i, section in enumerate(sections): a = ax[i] ixlist = find_fieldmap_ixlist(astra_input, section) for ix in ixlist: dat = fieldmap_data(astra_input, section=section, index=ix, fieldmaps=fmaps, verbose=verbose) label = f'{section}_{ix}' c = color[section] a.plot(*dat.T, label=label, color=c) a.set_ylabel(ylabel[section]) ax1.set_xlabel('$z$ (m)') ax1.set_xlim(xmin, xmax)
def density_plot(particle_group, key='x', bins=None, **kwargs): """ 1D density plot. Also see: marginal_plot Example: density_plot(P, 'x', bins=100) """ if not bins: n = len(particle_group) bins = int(n / 100) # Scale to nice units and get the factor, unit prefix x, f1, p1 = nice_array(particle_group[key]) w = particle_group['weight'] u1 = particle_group.units(key).unitSymbol ux = p1 + u1 labelx = f'{key} ({ux})' fig, ax = plt.subplots(**kwargs) hist, bin_edges = np.histogram(x, bins=bins, weights=w) hist_x = bin_edges[:-1] + np.diff(bin_edges) / 2 hist_width = np.diff(bin_edges) hist_y, hist_f, hist_prefix = nice_array(hist / hist_width) ax.bar(hist_x, hist_y, hist_width, color='grey') # Special label for C/s = A if u1 == 's': _, hist_prefix = nice_scale_prefix(hist_f / f1) ax.set_ylabel(f'{hist_prefix}A') else: ax.set_ylabel(f'{hist_prefix}C/{ux}') ax.set_xlabel(labelx) return fig
def plot_stat(impact_object, y='sigma_x', x='mean_z', nice=True): """ Plots stat output of key y vs key If particles have the same stat key, these will also be plotted. If nice, a nice SI prefix and scaling will be used to make the numbers reasonably sized. """ I = impact_object # convenience fig, ax = plt.subplots() units1 = str(I.units(x)) units2 = str(I.units(y)) X = I.stat(x) Y = I.stat(y) if nice: X, f1, prefix1 = nice_array(X) Y, f2, prefix2 = nice_array(Y) units1 = prefix1 + units1 units2 = prefix2 + units2 else: f1 = 1 f2 = 1 ax.set_xlabel(x + f' ({units1})') ax.set_ylabel(y + f' ({units2})') # line plot plt.plot(X, Y) try: ax.scatter([I.particles[name][x] / f1 for name in I.particles], [I.particles[name][y] / f2 for name in I.particles], color='red') except: pass
def plot_stats(astra_object, keys=['norm_emit_x', 'sigma_z'], sections=['cavity', 'solenoid'], fieldmaps = {}, verbose=False): """ Plots stats, with fieldmaps plotted from seections. TODO: quadrupoles """ astra_input = astra_object.input fmaps = load_fieldmaps(astra_input, sections=sections, verbose=verbose) assert len(sections) == 2, 'TODO: more general' nplots = len(keys) + 1 fig, axs = plt.subplots(nplots) # Make RHS axis for the solenoid field. xdat = astra_object.stat('mean_z') xmin = min(xdat) xmax = max(xdat) for i, key in enumerate(keys): ax = axs[i] unit = astra_object.units(key) ydat = astra_object.stat(key) ndat, factor, prefix = nice_array(ydat) label = f'{key} ({prefix}{unit})' ax.set_ylabel(label) ax.set_xlim(xmin, xmax) ax.plot(xdat, ndat) add_fieldmaps_to_axes(astra_object, axs[-1], bounds=(xmin, xmax), sections=['cavity', 'solenoid'], include_labels=True)
def marginal_plot(particle_group, key1='t', key2='p', bins=None): """ Density plot and projections Example: marginal_plot(P, 't', 'energy', bins=200) """ if not bins: n = len(particle_group) bins = int(np.sqrt(n / 4)) # Scale to nice units and get the factor, unit prefix x, f1, p1 = nice_array(particle_group[key1]) y, f2, p2 = nice_array(particle_group[key2]) w = particle_group['weight'] u1 = particle_group.units(key1).unitSymbol u2 = particle_group.units(key2).unitSymbol ux = p1 + u1 uy = p2 + u2 labelx = f'{key1} ({ux})' labely = f'{key2} ({uy})' fig = plt.figure() gs = GridSpec(4, 4) ax_joint = fig.add_subplot(gs[1:4, 0:3]) ax_marg_x = fig.add_subplot(gs[0, 0:3]) ax_marg_y = fig.add_subplot(gs[1:4, 3]) #ax_info = fig.add_subplot(gs[0, 3:4]) #ax_info.table(cellText=['a']) # Proper weighting ax_joint.hexbin(x, y, C=w, reduce_C_function=np.sum, gridsize=bins, cmap=cmap, vmin=1e-15) # Manual histogramming version #H, xedges, yedges = np.histogram2d(x, y, weights=w, bins=bins) #extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]] #ax_joint.imshow(H.T, cmap=cmap, vmin=1e-16, origin='lower', extent=extent, aspect='auto') dx = x.ptp() / bins dy = y.ptp() / bins ax_marg_x.hist(x, weights=w / dx / f1, bins=bins, color='gray') ax_marg_y.hist(y, orientation="horizontal", weights=w / dy, bins=bins, color='gray') # Turn off tick labels on marginals plt.setp(ax_marg_x.get_xticklabels(), visible=False) plt.setp(ax_marg_y.get_yticklabels(), visible=False) # Set labels on joint ax_joint.set_xlabel(labelx) ax_joint.set_ylabel(labely) # Set labels on marginals ax_marg_x.set_ylabel(f'C/{u1}') ax_marg_y.set_xlabel(f'C/{uy}') plt.show()
def density_and_slice_plot(particle_group, key1='t', key2='p', stat_keys=['norm_emit_x', 'norm_emit_y'], bins=100, n_slice=30): """ Density plot and projections Example: marginal_plot(P, 't', 'energy', bins=200) """ # Scale to nice units and get the factor, unit prefix x, f1, p1 = nice_array(particle_group[key1]) y, f2, p2 = nice_array(particle_group[key2]) w = particle_group['weight'] u1 = particle_group.units(key1).unitSymbol u2 = particle_group.units(key2).unitSymbol ux = p1 + u1 uy = p2 + u2 labelx = f'{key1} ({ux})' labely = f'{key2} ({uy})' fig, ax = plt.subplots() ax.set_xlabel(labelx) ax.set_ylabel(labely) # Proper weighting #ax_joint.hexbin(x, y, C=w, reduce_C_function=np.sum, gridsize=bins, cmap=cmap, vmin=1e-15) # Manual histogramming version H, xedges, yedges = np.histogram2d(x, y, weights=w, bins=bins) extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]] ax.imshow(H.T, cmap=cmap, vmin=1e-16, origin='lower', extent=extent, aspect='auto') # Slice data slice_dat = slice_statistics(particle_group, n_slice=n_slice, slice_key=key1, keys=stat_keys + ['ptp_' + key1, 'mean_' + key1, 'charge']) slice_dat['density'] = slice_dat['charge'] / slice_dat['ptp_' + key1] # ax2 = ax.twinx() #ax2.set_ylim(0, 1e-6) x2 = slice_dat['mean_' + key1] / f1 ulist = [particle_group.units(k).unitSymbol for k in stat_keys] max2 = max([slice_dat[k].ptp() for k in stat_keys]) f3, p3 = nice_scale_prefix(max2) u2 = ulist[0] assert all([u == u2 for u in ulist]) u2 = p3 + u2 for k in stat_keys: ax2.plot(x2, slice_dat[k] / f3, label=k) ax2.legend() ax2.set_ylabel(f'({u2})') ax2.set_ylim(bottom=0) # Add density y2 = slice_dat['density'] y2 = y2 * max2 / y2.max() / f3 / 2 ax2.fill_between(x2, 0, y2, color='black', alpha=0.1)
def plot_stats_with_layout(astra_object, ykeys=['sigma_x', 'sigma_y'], ykeys2=['sigma_z'], xkey='mean_z', xlim=None, nice=True, include_layout=False, include_labels=True, include_particles=True, include_legend=True, **kwargs): """ Plots stat output multiple keys. If a list of ykeys2 is given, these will be put on the right hand axis. This can also be given as a single key. Logical switches, all default to True: nice: a nice SI prefix and scaling will be used to make the numbers reasonably sized. include_legend: The plot will include the legend include_layout: the layout plot will be displayed at the bottom include_labels: the layout will include element labels. Copied almost verbatim from lume-impact's Impact.plot.plot_stats_with_layout """ I = astra_object # convenience if include_layout: fig, all_axis = plt.subplots(2, gridspec_kw={'height_ratios': [4, 1]}, **kwargs) ax_layout = all_axis[-1] ax_plot = [all_axis[0]] else: fig, all_axis = plt.subplots( **kwargs) ax_plot = [all_axis] # collect axes if isinstance(ykeys, str): ykeys = [ykeys] if ykeys2: if isinstance(ykeys2, str): ykeys2 = [ykeys2] ax_plot.append(ax_plot[0].twinx()) # No need for a legend if there is only one plot if len(ykeys)==1 and not ykeys2: include_legend=False #assert xkey == 'mean_z', 'TODO: other x keys' X = I.stat(xkey) # Only get the data we need if xlim: good = np.logical_and(X >= xlim[0], X <= xlim[1]) X = X[good] else: xlim = X.min(), X.max() good = slice(None,None,None) # everything # Try particles within these bounds Pnames = [] X_particles = [] if include_particles: try: for pname in range(len(I.particles)): # Modified from Impact xp = I.particles[pname][xkey] if xp >= xlim[0] and xp <= xlim[1]: Pnames.append(pname) X_particles.append(xp) X_particles = np.array(X_particles) except: Pnames = [] else: Pnames = [] # X axis scaling units_x = str(I.units(xkey)) if nice: X, factor_x, prefix_x = nice_array(X) units_x = prefix_x+units_x else: factor_x = 1 # set all but the layout for ax in ax_plot: ax.set_xlim(xlim[0]/factor_x, xlim[1]/factor_x) ax.set_xlabel(f'{xkey} ({units_x})') # Draw for Y1 and Y2 linestyles = ['solid','dashed'] ii = -1 # counter for colors for ix, keys in enumerate([ykeys, ykeys2]): if not keys: continue ax = ax_plot[ix] linestyle = linestyles[ix] # Check that units are compatible ulist = [I.units(key) for key in keys] if len(ulist) > 1: for u2 in ulist[1:]: assert ulist[0] == u2, f'Incompatible units: {ulist[0]} and {u2}' # String representation unit = str(ulist[0]) # Data data = [I.stat(key)[good] for key in keys] if nice: factor, prefix = nice_scale_prefix(np.ptp(data)) unit = prefix+unit else: factor = 1 # Make a line and point for key, dat in zip(keys, data): # ii += 1 color = 'C'+str(ii) ax.plot(X, dat/factor, label=f'{key} ({unit})', color=color, linestyle=linestyle) # Particles if Pnames: try: Y_particles = np.array([I.particles[name][key] for name in Pnames]) ax.scatter(X_particles/factor_x, Y_particles/factor, color=color) except: pass ax.set_ylabel(', '.join(keys)+f' ({unit})') #if len(keys) > 1: # Collect legend if include_legend: lines = [] labels = [] for ax in ax_plot: a, b = ax.get_legend_handles_labels() lines += a labels += b ax_plot[0].legend(lines, labels, loc='best') # Layout if include_layout: # Gives some space to the top plot #ax_layout.set_ylim(-1, 1.5) if xkey == 'mean_z': #ax_layout.set_axis_off() ax_layout.set_xlim(xlim[0], xlim[1]) else: ax_layout.set_xlabel('mean_z') xlim = (0, I.stop) add_fieldmaps_to_axes(I, ax_layout, bounds=xlim, include_labels=include_labels)
def plot_stats_with_layout(impact_object, ykeys=['sigma_x', 'sigma_y'], ykeys2=['mean_kinetic_energy'], xkey='mean_z', xlim=None, ylim=None, ylim2=None, nice=True, tex=True, include_layout=True, include_labels=True, include_markers=True, include_particles=True, include_legend=True, return_figure=False, **kwargs): """ Plots stat output multiple keys. If a list of ykeys2 is given, these will be put on the right hand axis. This can also be given as a single key. Logical switches: nice: a nice SI prefix and scaling will be used to make the numbers reasonably sized. Default: True tex: use mathtext (TeX) for plot labels. Default: True include_legend: The plot will include the legend. Default: True include_layout: the layout plot will be displayed at the bottom. Default: True include_labels: the layout will include element labels. Default: True return_figure: return the figure object for further manipulation. Default: False """ I = impact_object # convenience if include_layout: fig, all_axis = plt.subplots(2, gridspec_kw={'height_ratios': [4, 1]}, **kwargs) ax_layout = all_axis[-1] ax_plot = [all_axis[0]] else: fig, all_axis = plt.subplots(**kwargs) ax_plot = [all_axis] # collect axes if isinstance(ykeys, str): ykeys = [ykeys] if ykeys2: if isinstance(ykeys2, str): ykeys2 = [ykeys2] ax_twinx = ax_plot[0].twinx() ax_plot.append(ax_twinx) # No need for a legend if there is only one plot if len(ykeys) == 1 and not ykeys2: include_legend = False #assert xkey == 'mean_z', 'TODO: other x keys' X = I.stat(xkey) # Only get the data we need if xlim: good = np.logical_and(X >= xlim[0], X <= xlim[1]) X = X[good] else: xlim = X.min(), X.max() good = slice(None, None, None) # everything # Try particles within these bounds Pnames = [] X_particles = [] if include_particles: try: for pname in I.particles: xp = I.particles[pname][xkey] if xp >= xlim[0] and xp <= xlim[1]: Pnames.append(pname) X_particles.append(xp) X_particles = np.array(X_particles) except: Pnames = [] else: Pnames = [] # X axis scaling units_x = str(I.units(xkey)) if nice: X, factor_x, prefix_x = nice_array(X) units_x = prefix_x + units_x else: factor_x = 1 # set all but the layout # Handle tex labels xlabel = mathlabel(xkey, units=units_x, tex=tex) for ax in ax_plot: ax.set_xlim(xlim[0] / factor_x, xlim[1] / factor_x) ax.set_xlabel(xlabel) # Draw for Y1 and Y2 linestyles = ['solid', 'dashed'] ii = -1 # counter for colors for ix, keys in enumerate([ykeys, ykeys2]): if not keys: continue ax = ax_plot[ix] linestyle = linestyles[ix] # Check that units are compatible ulist = [I.units(key) for key in keys] if len(ulist) > 1: for u2 in ulist[1:]: assert ulist[ 0] == u2, f'Incompatible units: {ulist[0]} and {u2}' # String representation unit = str(ulist[0]) # Data data = [I.stat(key)[good] for key in keys] if nice: factor, prefix = nice_scale_prefix(np.ptp(data)) unit = prefix + unit else: factor = 1 # Make a line and point for key, dat in zip(keys, data): # ii += 1 color = 'C' + str(ii) # Handle tex labels label = mathlabel(key, units=unit, tex=tex) ax.plot(X, dat / factor, label=label, color=color, linestyle=linestyle) # Particles if Pnames: try: Y_particles = np.array( [I.particles[name][key] for name in Pnames]) ax.scatter(X_particles / factor_x, Y_particles / factor, color=color) except: pass # Handle tex labels ylabel = mathlabel(*keys, units=unit, tex=tex) ax.set_ylabel(ylabel) # Set limits, considering the scaling. if ix == 0 and ylim: new_ylim = np.array(ylim) / factor ax.set_ylim(new_ylim) # Set limits, considering the scaling. if ix == 1 and ylim2: pass # TODO if ylim2: new_ylim2 = np.array(ylim2) / factor ax_twinx.set_ylim(new_ylim2) else: pass # Collect legend if include_legend: lines = [] labels = [] for ax in ax_plot: a, b = ax.get_legend_handles_labels() lines += a labels += b ax_plot[0].legend(lines, labels, loc='best') # Layout if include_layout: # Gives some space to the top plot ax_layout.set_ylim(-1, 1.5) if xkey == 'mean_z': ax_layout.set_axis_off() ax_layout.set_xlim(xlim[0], xlim[1]) else: ax_layout.set_xlabel('mean_z') xlim = (0, I.stop) add_layout_to_axes(I, ax_layout, bounds=xlim, include_labels=include_labels, include_markers=include_markers) if return_figure: return fig
def marginal_plot(particle_group, key1='t', key2='p', bins=None, **kwargs): """ Density plot and projections Example: marginal_plot(P, 't', 'energy', bins=200) """ if not bins: n = len(particle_group) bins = int(np.sqrt(n / 4)) # Scale to nice units and get the factor, unit prefix x, f1, p1 = nice_array(particle_group[key1]) y, f2, p2 = nice_array(particle_group[key2]) w = particle_group['weight'] u1 = particle_group.units(key1).unitSymbol u2 = particle_group.units(key2).unitSymbol ux = p1 + u1 uy = p2 + u2 labelx = f'{key1} ({ux})' labely = f'{key2} ({uy})' fig = plt.figure(**kwargs) gs = GridSpec(4, 4) ax_joint = fig.add_subplot(gs[1:4, 0:3]) ax_marg_x = fig.add_subplot(gs[0, 0:3]) ax_marg_y = fig.add_subplot(gs[1:4, 3]) #ax_info = fig.add_subplot(gs[0, 3:4]) #ax_info.table(cellText=['a']) # Proper weighting ax_joint.hexbin(x, y, C=w, reduce_C_function=np.sum, gridsize=bins, cmap=CMAP0, vmin=1e-20) # Manual histogramming version #H, xedges, yedges = np.histogram2d(x, y, weights=w, bins=bins) #extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]] #ax_joint.imshow(H.T, cmap=cmap, vmin=1e-16, origin='lower', extent=extent, aspect='auto') # Top histogram # Old method: #dx = x.ptp()/bins #ax_marg_x.hist(x, weights=w/dx/f1, bins=bins, color='gray') hist, bin_edges = np.histogram(x, bins=bins, weights=w) hist_x = bin_edges[:-1] + np.diff(bin_edges) / 2 hist_width = np.diff(bin_edges) hist_y, hist_f, hist_prefix = nice_array(hist / hist_width) ax_marg_x.bar(hist_x, hist_y, hist_width, color='gray') # Special label for C/s = A if u1 == 's': _, hist_prefix = nice_scale_prefix(hist_f / f1) ax_marg_x.set_ylabel(f'{hist_prefix}A') else: ax_marg_x.set_ylabel(f'{hist_prefix}C/{ux}') # Side histogram # Old method: #dy = y.ptp()/bins #ax_marg_y.hist(y, orientation="horizontal", weights=w/dy, bins=bins, color='gray') hist, bin_edges = np.histogram(y, bins=bins, weights=w) hist_x = bin_edges[:-1] + np.diff(bin_edges) / 2 hist_width = np.diff(bin_edges) hist_y, hist_f, hist_prefix = nice_array(hist / hist_width) ax_marg_y.barh(hist_x, hist_y, hist_width, color='gray') ax_marg_y.set_xlabel(f'{hist_prefix}C/{uy}') # Turn off tick labels on marginals plt.setp(ax_marg_x.get_xticklabels(), visible=False) plt.setp(ax_marg_y.get_yticklabels(), visible=False) # Set labels on joint ax_joint.set_xlabel(labelx) ax_joint.set_ylabel(labely) return fig
def plot_stats_with_layout(gpt_object, ykeys=['sigma_x', 'sigma_y'], ykeys2=['mean_kinetic_energy'], xkey='mean_z', xlim=None, nice=True, include_layout=False, include_labels=True, include_legend=True, **kwargs): """ Plots stat output multiple keys. If a list of ykeys2 is given, these will be put on the right hand axis. This can also be given as a single key. Logical switches, all default to True: nice: a nice SI prefix and scaling will be used to make the numbers reasonably sized. include_legend: The plot will include the legend include_layout: the layout plot will be displayed at the bottom include_labels: the layout will include element labels. """ I = gpt_object # convenience if include_layout: fig, all_axis = plt.subplots(2, gridspec_kw={'height_ratios': [4, 1]}, **kwargs) ax_layout = all_axis[-1] ax_plot = [all_axis[0]] else: fig, all_axis = plt.subplots(**kwargs) ax_plot = [all_axis] # collect axes if isinstance(ykeys, str): ykeys = [ykeys] if ykeys2: if isinstance(ykeys2, str): ykeys2 = [ykeys2] ax_plot.append(ax_plot[0].twinx()) # No need for a legend if there is only one plot if len(ykeys) == 1 and not ykeys2: include_legend = False #assert xkey == 'mean_z', 'TODO: other x keys' X = I.stat(xkey) # Only get the data we need if xlim: good = np.logical_and(X >= xlim[0], X <= xlim[1]) X = X[good] else: xlim = X.min(), X.max() good = slice(None, None, None) # everything # X axis scaling units_x = str(I.units(xkey)) if nice: X, factor_x, prefix_x = nice_array(X) units_x = prefix_x + units_x else: factor_x = 1 # set all but the layout for ax in ax_plot: ax.set_xlim(xlim[0] / factor_x, xlim[1] / factor_x) ax.set_xlabel(f'{xkey} ({units_x})') # Draw for Y1 and Y2 linestyles = ['solid', 'dashed'] ii = -1 # counter for colors for ix, keys in enumerate([ykeys, ykeys2]): if not keys: continue ax = ax_plot[ix] linestyle = linestyles[ix] # Check that units are compatible ulist = [I.units(key) for key in keys] if len(ulist) > 1: for u2 in ulist[1:]: assert ulist[ 0] == u2, f'Incompatible units: {ulist[0]} and {u2}' # String representation unit = str(ulist[0]) # Data data = [I.stat(key)[good] for key in keys] if nice: factor, prefix = nice_scale_prefix(np.ptp(data)) unit = prefix + unit else: factor = 1 # Make a line and point for key, dat in zip(keys, data): # ii += 1 color = 'C' + str(ii) ax.plot(X, dat / factor, label=f'{key} ({unit})', color=color, linestyle=linestyle) ax.set_ylabel(', '.join(keys) + f' ({unit})') #if len(keys) > 1: # Collect legend if include_legend: lines = [] labels = [] for ax in ax_plot: a, b = ax.get_legend_handles_labels() lines += a labels += b ax_plot[0].legend(lines, labels, loc='best') # Layout if include_layout: print('TODO include_layout')