def make_plot(items, figfile='', xlabel='', ylabel='', x_log=False, y_log=False, labels=[], title='', ps=qu.PlotStyle('dark')): plt.cla() plt.clf() fig, ax = plt.subplots(1, 1) colors = ps.colors #fig.patch.set_facecolor('white') for i, item in enumerate(items): label = None if len(labels) >= i: label = labels[i] color = colors[i % len(colors)] ax.plot(item, label=label, color=color) if x_log: ax.set_xscale('log') if y_log: ax.set_yscale('log') ax.set_title(title) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.legend() if figfile != '': plt.savefig(figfile, transparent=True, facecolor=ps.canv_plt) plt.show()
def MetricPlot(model_history, acc_range=(0.5,1.), loss_range=(0.,0.7), acc_log=False, loss_log=False, plotpath='/', model_keys=[], plotstyle=qu.PlotStyle('dark')): if(model_keys == []): model_keys = list(model_history.keys()) for model_key in model_keys: fig, ax = plt.subplots(1,2,figsize=(15,5)) keys = ['acc','val_acc'] lines = [model_history[model_key][key] for key in keys] epochs = np.arange(len(lines[0])) + 1 pu.multiplot_common( ax[0], epochs, lines, keys, y_min = acc_range[0], y_max = acc_range[1], y_log = acc_log, xlabel = 'Epoch', ylabel = 'Accuracy', title='Model accuracy for {}'.format(model_key), ps=plotstyle ) keys = ['loss','val_loss'] lines = [model_history[model_key][key] for key in keys] pu.multiplot_common( ax[1], epochs, lines, keys, y_min = loss_range[0], y_max = loss_range[1], y_log = loss_log, xlabel = 'Epoch', ylabel = 'Loss', title='Model loss for {}'.format(model_key), ps=plotstyle ) # add grids for axis in ax.flatten(): axis.grid(True,color=plotstyle.grid_plt) qu.SaveSubplots(fig, ax, ['accuracy_{}'.format(model_key), 'loss_{}'.format(model_key)], savedir=plotpath, ps=plotstyle) plt.show() return
def roc_plot(ax, xlist, ylist, xlabel='False positive rate', ylabel='True positive rate', x_min=0, x_max=1.1, x_log=False, y_min=0, y_max=1.1, y_log=False, linestyles=[], colorgrouping=-1, extra_lines=[[[0, 1], [0, 1]]], labels=[], atlas_x=-1, atlas_y=-1, simulation=False, textlist=[], title='', ps=qu.PlotStyle('dark'), colors=None): ''' Shortcut for making a ROC curve. ''' multiplot(ax, xlist, ylist, xlabel=xlabel, ylabel=ylabel, x_min=x_min, x_max=x_max, x_log=x_log, y_min=y_min, y_max=y_max, y_log=y_log, linestyles=linestyles, colorgrouping=colorgrouping, extra_lines=extra_lines, labels=labels, atlas_x=atlas_x, atlas_y=atlas_y, simulation=simulation, textlist=textlist, title=title, ps=ps, colors=colors) return
def histogramOverlay(ax, data, labels, xlabel, ylabel, x_min=0, x_max=2200, xbins=22, normed=True, y_log=False, atlas_x=-1, atlas_y=-1, simulation=False, textlist=[], ps=qu.PlotStyle('dark')): xbin = np.arange(x_min, x_max, (x_max - x_min) / xbins) zorder_start = -1 * len(data) # hack to get axes on top colors = ps.colors for i, vals in enumerate(data): ax.hist(vals, bins=xbin, density=normed, alpha=0.5, label=labels[i], color=colors[i % len(colors)], zorder=zorder_start + i) ax.set_xlim(x_min, x_max) if y_log: ax.set_yscale('log') ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ps.SetStylePlt(ax) # TODO: find a way to replace this #if atlas_x >= 0 and atlas_y >= 0: #ampl.draw_atlas_label(atlas_x, atlas_y, simulation = simulation, fontsize = 18) #drawLabels(fig, atlas_x, atlas_y, simulation, textlist) #TODO: fix for fig,ax implementation ax.set_zorder = len(data) + 1 #hack to keep the tick marks up legend = ax.legend(facecolor=ps.canv_plt) for leg_text in legend.get_texts(): leg_text.set_color(ps.text_plt) return
def MapStabilityTest(mapping_func, b_vals=[0., .1, .5, 1., 1.0e14], m_vals=[1.], x=np.linspace(0.001, 4., 1000), ps=qu.PlotStyle('dark'), savedir='', legend_size=-1): mb_combos = list(itertools.product(b_vals, m_vals)) if (m_vals == [1.]): forward_labels = ['f(x), b={:.1e}'.format(mb[0]) for mb in mb_combos] reverse_labels = [ 'g(f(x)), b={:.1e}'.format(mb[0]) for mb in mb_combos ] else: forward_labels = [ 'f(x), b={:.1e}, m={:.1e}'.format(mb[0], mb[1]) for mb in mb_combos ] reverse_labels = [ 'g(f(x)), b={:.1e}, m={:.1e}'.format(mb[0], mb[1]) for mb in mb_combos ] forward = [mapping_func(b, m).Forward(x) for (b, m) in mb_combos] reverse = [ mapping_func(mb_combos[i][0], mb_combos[i][1]).Inverse(forward[i]) for i in range(len(forward)) ] #reverse = [mapping_func(b,m).Inverse(forward) for (b,m) in mb_combos] fig, ax = plt.subplots(1, 2, figsize=(16, 6)) y_min, y_max = (np.min(x), np.max(x)) pu.multiplot_common(ax[0], x, forward, forward_labels, y_min=y_min, y_max=y_max, xlabel='x', ylabel='y', title='Forward Mapping', ps=ps) pu.multiplot_common(ax[1], x, reverse, reverse_labels, y_min=y_min, y_max=y_max, xlabel='x', ylabel='y', title='Reverse Mapping', ps=ps) if (legend_size > 0): plt.rc('legend', fontsize=legend_size) plt.show() savename = 'mapping_test.png' if (savedir != ''): savename = savedir + '/' + savename fig.savefig(savename, transparent=True) return
def multiplot_common(ax, xcenter, lines, labels, xlabel, ylabel, x_min=None, x_max=None, y_min=None, y_max=None, x_log=False, y_log=False, x_ticks=None, linestyles=[], colorgrouping=-1, extra_lines=[], atlas_x=-1, atlas_y=-1, simulation=False, textlist=[], title='', ps=qu.PlotStyle('dark'), colors=None): ''' Creates a set of plots, on a common carrier "xcenter". Draws the plots on a provided axis. ''' if (x_min == None): x_min = np.min(xcenter) if (x_max == None): x_max = np.max(xcenter) if (y_min == None): y_min = np.minimum(0., np.min(np.column_stack(lines))) if (y_max == None): y_max = 1.25 * np.max(np.column_stack(lines)) if (x_ticks != None): ax.xaxis.set_major_locator(plt.MaxNLocator(x_ticks)) for extra_line in extra_lines: ax.plot(extra_line[0], extra_line[1], linestyle='--', color='black') if (colors is None): colors = ps.colors for i, line in enumerate(lines): if len(linestyles) > 0: linestyle = linestyles[i] else: linestyle = 'solid' if colorgrouping > 0: color = colors[int(np.floor(i / colorgrouping))] else: color = colors[i % len(colors)] ax.plot(xcenter, line, label=labels[i], linestyle=linestyle, color=color) if x_log: ax.set_xscale('log') if y_log: ax.set_yscale('log') ax.set_xlim(x_min, x_max) ax.set_ylim(y_min, y_max) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) ps.SetStylePlt(ax) #drawLabels(fig, atlas_x, atlas_y, simulation, textlist) legend = ax.legend(facecolor=ps.canv_plt) for leg_text in legend.get_texts(): leg_text.set_color(ps.text_plt) return
def multiplot(ax, xlist, ylist, xlabel='False positive rate', ylabel='True positive rate', x_min=0, x_max=1.1, x_log=False, y_min=0, y_max=1.1, y_log=False, linestyles=[], colorgrouping=-1, extra_lines=[], labels=[], atlas_x=-1, atlas_y=-1, simulation=False, textlist=[], title='', ps=qu.PlotStyle('dark'), colors=None): ''' Creates a set of plots, from series of x and y values (does not use a common carrier). Draws the plots on a provided axis. ''' for extra_line in extra_lines: ax.plot(extra_line[0], extra_line[1], linestyle='--', color=ps.main_plt) if (colors is None): colors = ps.colors for i, (x, y) in enumerate(zip(xlist, ylist)): if len(linestyles) > 0: linestyle = linestyles[i] else: linestyle = 'solid' if colorgrouping > 0: color = colors[int(np.floor(i / colorgrouping))] else: color = colors[i % (len(colors))] label = None if len(labels) > 0: label = labels[i] ax.plot(x, y, label=label, linestyle=linestyle, color=color) if x_log: ax.set_xscale('log') if y_log: ax.set_yscale('log') ax.set_title(title) ax.set_xlim(x_min, x_max) ax.set_ylim(y_min, y_max) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ps.SetStylePlt(ax) legend = ax.legend(facecolor=ps.canv_plt) for leg_text in legend.get_texts(): leg_text.set_color(ps.text_plt) #drawLabels(fig, atlas_x, atlas_y, simulation, textlist) return
def EnergySummary(train_dfs, valid_dfs, data_dfs, energy_name, model_name, plotpath, extensions=['png'], plot_size=750, full=True, ps=qu.PlotStyle('dark'), **kwargs): ps.SetStyle() max_energy = 2000. # GeV max_energy_2d = max_energy bin_energy = 300 ratio_range_2d = [0.3, 1.7] bins_2d = [200, 70] if ('max_energy' in kwargs.keys()): max_energy = kwargs['max_energy'] max_energy_2d = max_energy if ('max_energy_2d' in kwargs.keys()): max_energy = kwargs['max_energy_2d'] if ('bin_energy' in kwargs.keys()): bin_energy = kwargs['bin_energy'] if ('ratio_range_2d' in kwargs.keys()): ratio_range_2d = kwargs['ratio_range_2d'] if ('bins_2d' in kwargs.keys()): bins_2d = kwargs['bins_2d'] # Dictionaries to keep track of all our histogram objects. # Each entry will be a dictionary of hists. # Outer key is data type (charged pion, neutral pion), # inner key is data set (train, valid, all). clusterE = {} # reco energy clusterE_calib = { } # cluster_ENG_CALIB_TOT (true energy, as far as we're concerned) clusterE_pred = {} # predicted energy #clusterE_true = {} # "truth" energy from the parton level (I think), not what we're after clusterE_ratio1 = {} # ratio1: E_pred / ENG_CALIB_TOT clusterE_ratio2 = {} # ratio2: E_reco / ENG_CALIB_TOT clusterE_ratio2D = {} # ratio1 vs. ENG_CALIB_TOT clusterE_ratio2D_zoomed = {} # ratio1 vs. ENG_CALIB_TOT (Zoomed on left) ratio1_iqr = {} # IQR, from ratio1 ratio2_iqr = {} # IQR, from ratio2 ratio1_iqr_zoomed = {} # IQR, from ratio1 ratio2_iqr_zoomed = {} # IQR, from ratio2 # histogram stacks energy_stacks = {} iqr_stacks = {} iqr_stacks_zoomed = {} # keep track of mean/median curves from the 2D plots (one set for each). mean_curves = {} mean_curves_zoomed = {} # keep track of our canvases, legends and histogram stacks canvs = {} legends = {} key_conversions = { 'pp': '#pi^{#pm}', 'p0': '#pi^{0}', } dsets = {'train': train_dfs, 'valid': valid_dfs, 'all data': data_dfs} for key in train_dfs.keys( ): # assuming all DataFrame dicts have the same keys # Initialize the inner dictionaries. clusterE[key] = {} clusterE_calib[key] = {} clusterE_pred[key] = {} #clusterE_true[key] = {} clusterE_ratio1[key] = {} clusterE_ratio2[key] = {} clusterE_ratio2D[key] = {} clusterE_ratio2D_zoomed[key] = {} ratio1_iqr[key] = {} ratio2_iqr[key] = {} ratio1_iqr_zoomed[key] = {} ratio2_iqr_zoomed[key] = {} energy_stacks[key] = {} iqr_stacks[key] = {} iqr_stacks_zoomed[key] = {} mean_curves[key] = {} mean_curves_zoomed[key] = {} canvs[key] = {} legends[key] = {} iqr_stacks[key] = {} for dkey, frame in dsets.items(): key2 = '(' + key_conversions[key] + ', ' + dkey + ')' clusterE[key][dkey] = rt.TH1F( qu.RN(), 'E_{reco} ' + key2 + '; E_{reco} [GeV];Count', bin_energy, 0., max_energy) clusterE_calib[key][dkey] = rt.TH1F( qu.RN(), 'E_{calib}^{tot} ' + key2 + ';E_{calib}^{tot} [GeV];Count', bin_energy, 0., max_energy) clusterE_pred[key][dkey] = rt.TH1F( qu.RN(), 'E_{pred} ' + key2 + ';E_{pred} [GeV];Count', bin_energy, 0., max_energy) #clusterE_true[key][dkey] = rt.TH1F(qu.RN(), 'E_{true} ' + key2 + ';E_{true} [GeV];Count', bin_energy,0.,max_energy) clusterE_ratio1[key][dkey] = rt.TH1F( qu.RN(), 'E_{pred} / E_{calib}^{tot} ' + key2 + ';E_{pred}/E_{calib}^{tot};Count', 250, 0., 10.) clusterE_ratio2[key][dkey] = rt.TH1F( qu.RN(), 'E / E_{calib}^{tot} ' + key2 + ';E_{reco}/E_{calib}^{tot]};Count', 250, 0., 10.) qu.SetColor(clusterE[key][dkey], ps.main, alpha=0.4) qu.SetColor(clusterE_calib[key][dkey], rt.kPink + 9, alpha=0.4) qu.SetColor(clusterE_pred[key][dkey], ps.curve, alpha=0.4) # qu.SetColor(clusterE_true[key][dkey], rt.kRed, alpha = 0.4) qu.SetColor(clusterE_ratio1[key][dkey], ps.main, alpha=0.4) qu.SetColor(clusterE_ratio2[key][dkey], ps.curve, alpha=0.4) meas = frame[key]['clusterE'].to_numpy() calib = frame[key]['cluster_ENG_CALIB_TOT'].to_numpy() pred = frame[key][energy_name].to_numpy() #true = frame[key]['truthE'].to_numpy() ratio1 = pred / calib ratio2 = meas / calib for i in range(len(meas)): clusterE[key][dkey].Fill(meas[i]) clusterE_calib[key][dkey].Fill(calib[i]) clusterE_pred[key][dkey].Fill(pred[i]) #clusterE_true[key][dkey].Fill(true[i]) clusterE_ratio1[key][dkey].Fill(ratio1[i]) clusterE_ratio2[key][dkey].Fill(ratio2[i]) # Fill the histogram stack for the energy ratios. energy_stacks[key][dkey] = rt.THStack( qu.RN(), clusterE_ratio1[key][dkey].GetTitle()) energy_stacks[key][dkey].Add(clusterE_ratio1[key][dkey]) energy_stacks[key][dkey].Add(clusterE_ratio2[key][dkey]) # Make the 2D energy ratio plots. title = 'E_{pred}/E_{calib}^{tot} vs. E_{calib}^{tot} ' + key2 + ';E_{calib}^{tot} [GeV];E_{pred}/E_{calib}^{tot};Count' x_range = [0., max_energy_2d] nbins = bins_2d mean_curves[key][dkey], clusterE_ratio2D[key][dkey] = EnergyPlot2D( pred, calib, nbins=nbins, x_range=x_range, y_range=ratio_range_2d, title=title, offset=True) title = 'E_{pred}/E_{calib}^{tot} vs. E_{calib}^{tot} ' + key2 + ';(E_{calib}^{tot} + 1) [GeV];E_{pred}/E_{calib}^{tot};Count' x_range = [1., 1. + 0.01 * max_energy_2d] nbins = bins_2d nbins[0] = nbins[0] - 1 mean_curves_zoomed[key][dkey], clusterE_ratio2D_zoomed[key][ dkey] = EnergyPlot2D(pred, calib, nbins=nbins, x_range=x_range, y_range=ratio_range_2d, title=title, offset=False) # Make the energy ratio IQR plots. title = 'IQR(E_{x}/E_{calib}^{tot}) ' + key2 + ';E_{calib}^{tot} [GeV];IQR' x_range = [0., max_energy_2d] nbins = int(bins_2d[0] / 2) ratio1_iqr[key][dkey] = IqrPlot(pred, calib, title=title, nbins=nbins, x_range=x_range) ratio1_iqr[key][dkey].SetLineColor(ps.main) ratio2_iqr[key][dkey] = IqrPlot(meas, calib, title=title, nbins=nbins, x_range=x_range) ratio2_iqr[key][dkey].SetLineColor(ps.curve) title = 'IQR(E_{x}/E_{calib}^{tot}) ' + key2 + ';(E_{calib}^{tot} + 1) [GeV];IQR' x_range = [1., 1. + 0.01 * max_energy_2d] nbins = int((bins_2d[0] - 1) / 2) ratio1_iqr_zoomed[key][dkey] = IqrPlot(pred, calib, title=title, nbins=nbins, x_range=x_range, offset=True) ratio1_iqr_zoomed[key][dkey].SetLineColor(ps.main) ratio2_iqr_zoomed[key][dkey] = IqrPlot(meas, calib, title=title, nbins=nbins, x_range=x_range, offset=True) ratio2_iqr_zoomed[key][dkey].SetLineColor(ps.curve) # Fill the histogram stack for the IQR plots. iqr_stacks[key][dkey] = rt.THStack(qu.RN(), title) iqr_stacks[key][dkey].Add(ratio1_iqr[key][dkey]) iqr_stacks[key][dkey].Add(ratio2_iqr[key][dkey]) iqr_stacks_zoomed[key][dkey] = rt.THStack(qu.RN(), title) iqr_stacks_zoomed[key][dkey].Add(ratio1_iqr_zoomed[key][dkey]) iqr_stacks_zoomed[key][dkey].Add(ratio2_iqr_zoomed[key][dkey]) # Prepare the list of plots we'll show (we might exclude some). plots = [ clusterE, clusterE_calib, clusterE_pred, energy_stacks, clusterE_ratio2D, clusterE_ratio2D_zoomed, iqr_stacks, iqr_stacks_zoomed ] if (not full): plots = [ energy_stacks, clusterE_ratio2D, clusterE_ratio2D_zoomed, iqr_stacks, iqr_stacks_zoomed ] dkeys = list(dsets.keys()) # Make legend for the overlapping plots (1D energy ratios, and IQR plots) legends[key] = rt.TLegend(0.7, 0.7, 0.85, 0.85) legends[key].SetBorderSize(0) legends[key].AddEntry(clusterE_ratio1[key][dkeys[0]], 'x = pred', 'f') legends[key].AddEntry(clusterE_ratio2[key][dkeys[0]], 'x = reco', 'f') nx = len(dkeys) ny = len(plots) canvs[key] = rt.TCanvas(qu.RN(), 'c_' + str(key), nx * plot_size, ny * plot_size) canvs[key].Divide(nx, ny) for i, plot in enumerate(plots): x = nx * i + 1 if (plot == energy_stacks or plot == iqr_stacks or plot == iqr_stacks_zoomed): for j, dkey in enumerate(dkeys): canvs[key].cd(x + j) draw_option = 'NOSTACK HIST' if (plot != energy_stacks): draw_option = 'NOSTACK C' plot[key][dkey].Draw(draw_option) rt.gPad.SetGrid() rt.gPad.SetLogy() plot[key][dkey].GetHistogram().GetXaxis().SetTitle( 'E_{x}/E_{calib}^{tot}') if (plot == energy_stacks): plot[key][dkey].GetHistogram().GetYaxis().SetTitle( clusterE_ratio1[key][dkey].GetYaxis().GetTitle()) # if(strat == 'jet'): # plot[key][dkey].SetMinimum(5.0e-1) # plot[key][dkey].SetMaximum(1.0e3) plot[key][dkey].SetMinimum(5.0e-1) plot[key][dkey].SetMaximum(2.0e5) else: plot[key][dkey].GetHistogram().GetYaxis().SetTitle( ratio1_iqr[key][dkey].GetYaxis().GetTitle()) plot[key][dkey].SetMinimum(1.0e-2) plot[key][dkey].SetMaximum(1.) if (plot == iqr_stacks_zoomed): rt.gPad.SetLogx() rt.gPad.SetBottomMargin(0.15) plot[key][dkey].GetXaxis().SetTitleOffset(1.5) if (plot == iqr_stacks or plot == iqr_stacks_zoomed): plot[key][dkey].SetMinimum(1.0e-3) legends[key].SetTextColor(ps.text) legends[key].Draw() elif (plot == clusterE_ratio2D or plot == clusterE_ratio2D_zoomed): for j, dkey in enumerate(dkeys): canvs[key].cd(x + j) plot[key][dkey].Draw('COLZ') if (plot == clusterE_ratio2D): mean_curves[key][dkey].Draw('SAME') else: mean_curves_zoomed[key][dkey].Draw('SAME') rt.gPad.SetLogx() rt.gPad.SetBottomMargin(0.15) plot[key][dkey].GetXaxis().SetTitleOffset(1.5) rt.gPad.SetLogz() rt.gPad.SetRightMargin(0.2) plot[key][dkey].GetXaxis().SetMaxDigits(4) else: for j, dkey in enumerate(dkeys): canvs[key].cd(x + j) plot[key][dkey].Draw('HIST') plot[key][dkey].SetMinimum(5.0e-1) rt.gPad.SetLogy() # Draw the canvas canvs[key].Draw() # Save the canvas as a PDF & PNG image. image_name = '_'.join([model_name, key, 'plots']) for ext in extensions: canvs[key].SaveAs(plotpath + image_name + '.' + ext) results = {} results['canv'] = canvs results['plots'] = plots results['curves'] = [mean_curves, mean_curves_zoomed] results['legend'] = legends return results
def RocCurves(model_scores, data_labels, roc_fpr, roc_tpr, roc_thresh, roc_auc, indices=[], sample_weight=None, plotpath = '/', plotname = 'ROC', model_keys = [], model_labels = [], colors = [], drawPlots=True, figsize=(15,5), plotstyle=qu.PlotStyle('dark')): if(model_keys == []): model_keys = list(model_scores.keys()) if(model_labels == []): model_labels = model_keys if(colors == []): colors = plotstyle.colors if(len(indices) != len(data_labels)): indices = np.full(len(data_labels), True, dtype=np.dtype('bool')) for model_key in model_keys: roc_fpr[model_key], roc_tpr[model_key], roc_thresh[model_key] = roc_curve( data_labels[indices], model_scores[model_key][indices], drop_intermediate=False, sample_weight=sample_weight ) roc_auc[model_key] = auc(roc_fpr[model_key], roc_tpr[model_key]) print('Area under curve for {}: {}'.format(model_key, roc_auc[model_key])) if(not drawPlots): return # TODO: Sort model_keys by AUC # Make a plot of the ROC curves fig, ax = plt.subplots(1,2,figsize=figsize) xlist = [roc_fpr[x] for x in model_keys] ylist = [roc_tpr[x] for x in model_keys] labels = [] for i,x in enumerate(model_keys): labels.append('{} (area = {:.3f})'.format(model_labels[i], roc_auc[x])) # labels = ['{} (area = {:.3f})'.format(x, roc_auc[x]) for x in model_keys] title = 'ROC curve: classification of $\pi^+$ vs. $\pi^0$' pu.roc_plot(ax[0], xlist=xlist, ylist=ylist, labels=labels, title=title, ps=plotstyle, colors=colors ) #title = 'ROC curve (zoomed in at top left)' pu.roc_plot(ax[1], xlist=xlist, ylist=ylist, x_min=0. , x_max=0.25, y_min=0.6, y_max=1., labels=labels, title=title, ps=plotstyle, colors=colors ) qu.SaveSubplots(fig, ax, [plotname, plotname + '_zoom'], savedir=plotpath, ps=plotstyle) plt.show() return
def ImagePlot(pcells, cluster, log=True, dynamic_range=False, layers=[], cell_shapes={}, scaled_shape = [], latex_mpl = {}, plotpath = '', filename = '', plotstyle=qu.PlotStyle('dark')): # Set some default values. if(layers == []): layers = list(mu.cell_meta.keys()) if(cell_shapes == {}): cell_shapes = {key: (val['len_eta'],val['len_phi']) for key,val in mu.cell_meta.items()} if(latex_mpl == {}): latex_mpl = { 'p0': '$\pi^{0}$', 'pp': '$\pi^{+}$' } scaling = False if(scaled_shape != []): scaling = True fig, ax = plt.subplots(len(pcells.keys()),len(layers),figsize=(60,20)) fig.patch.set_facecolor(plotstyle.canv_plt) i = 0 for ptype, pcell in pcells.items(): for layer in layers: axis = ax.flatten()[i] # default behaviour: plot a single cluster if(cluster >= 0): image = pcell[layer][cluster].reshape(cell_shapes[layer]) # if cluster index is negative, provide an average image else: image = np.mean(pcell[layer],axis=0).reshape(cell_shapes[layer]) if(dynamic_range): vmin, vmax = np.min(image), np.max(image) vmax = np.maximum(np.abs(vmin),np.abs(vmax)) vmin = -vmax if(vmax == 0. and vmin == 0.): vmax = 0.1 vmin = -vmax else: vmin, vmax = (-1.,1.) norm = TwoSlopeNorm(vmin=vmin,vcenter=0.,vmax=vmax) if(log): norm = SymLogNorm(linthresh = 0.001, linscale=0.001, vmin=vmin, vmax=vmax, base=10.) cmap = plt.get_cmap('BrBG') image = pcell[layer][cluster].reshape(cell_shapes[layer]) if(scaling): # Use our ImageScaleBlock. It requires a 4d tensor, format is [batch,eta,phi,channel]. image = np.expand_dims(image, axis=(0,-1)) image = np.squeeze(ImageScaleBlock(new_shape=tuple(scaled_shape), normalization=True)([image]).numpy()) im = axis.imshow( image, extent=[-0.2, 0.2, -0.2, 0.2], cmap=cmap, origin='lower', interpolation='nearest', norm=norm ) #axis.colorbar() axis.set_title('{a} in {b}'.format(a=latex_mpl[ptype],b=layer)) axis.set_xlabel("$\Delta\phi$") axis.set_ylabel("$\Delta\eta$") plotstyle.SetStylePlt(axis) divider = make_axes_locatable(axis) cax = divider.append_axes('right', size='5%', pad=0.2) cb = fig.colorbar(im, cax=cax, orientation='vertical') plt.setp(plt.getp(cb.ax.axes, 'yticklabels'), color=plotstyle.text_plt) i += 1 # show the plots if(filename != ''): plt.savefig('{}/{}'.format(plotpath,filename),transparent=True,facecolor=plotstyle.canv_plt) plt.show() return