def flatten_bkg_templates(fnames_to_run):
    """
    Function that writes linearized mtt vs costheta distributions to root file.
    """
    if "3Jets" in njets_to_run:
        histo_dict_3j = processor.dict_accumulator({
            "Muon": {},
            "Electron": {}
        })
    if "4PJets" in njets_to_run:
        histo_dict_4pj = processor.dict_accumulator({
            "Muon": {},
            "Electron": {}
        })

    #set_trace()
    for bkg_file in fnames_to_run:
        hdict = load(bkg_file)
        jmult = "3Jets" if "3Jets" in os.path.basename(bkg_file) else "4PJets"
        for lep in hdict.keys():
            for tname, orig_template in hdict[lep].items():
                #set_trace()

                proc = tname.split(
                    "_")[0] if not "data_obs" in tname else "data_obs"
                sys = sorted(filter(None, tname.split(f"{proc}_")))[0]
                #if sys == "nosys": continue
                print(lep, jmult, sys, proc)

                # perform flattening
                flattened_histo = hdict[lep][f"{proc}_nosys"].copy(
                ) if sys == "nosys" else Plotter.flatten(
                    nosys=hdict[lep][f"{proc}_nosys"].copy(),
                    systematic=orig_template.copy())

                ## save template histos to coffea dict
                if jmult == "3Jets":
                    histo_dict_3j[lep][tname] = flattened_histo.copy()
                if jmult == "4PJets":
                    histo_dict_4pj[lep][tname] = flattened_histo.copy()

    #set_trace()
    if "3Jets" in njets_to_run:
        coffea_out_3j = os.path.join(
            input_dir,
            f"test_flattened_templates_lj_3Jets_bkg_{args.year}_{jobid}.coffea"
        )
        save(histo_dict_3j, coffea_out_3j)
        print(f"{coffea_out_3j} written")
    if "4PJets" in njets_to_run:
        coffea_out_4pj = os.path.join(
            input_dir,
            f"test_flattened_templates_lj_4PJets_bkg_{args.year}_{jobid}.coffea"
        )
        save(histo_dict_4pj, coffea_out_4pj)
        print(f"{coffea_out_4pj} written")
def get_bkg_templates(fnames_to_run):
    """
    Function that writes linearized mtt vs costheta distributions to root file.
    """

    #set_trace()
    for bkg_file in fnames_to_run:
        hdict = load(bkg_file)
        jmult = "3Jets" if "3Jets" in os.path.basename(bkg_file) else "4PJets"
        for lep in hdict.keys():
            for tname, orig_template in hdict[lep].items():

                proc = tname.split(
                    "_")[0] if not "data_obs" in tname else "data_obs"
                sys = sorted(filter(None, tname.split(f"{proc}_")))[0]

                #if not ((sys == "ueDOWN") and (proc == "ttJets")): continue
                if sys == "nosys": continue
                print(lep, jmult, sys, proc)

                nominal_hist = hdict[lep][f"{proc}_nosys"].copy()

                x_lims = (0, nominal_hist.dense_axes()[0].centers().size)

                # perform smoothing
                smoothed_histos_list = [(Plotter.smoothing_mttbins(
                    nosys=nominal_hist,
                    systematic=orig_template,
                    mtt_centers=mtt_centers,
                    nbinsx=len(linearize_binning[0]) - 1,
                    nbinsy=len(linearize_binning[1]) - 1,
                    **{"frac": frac_val / 10.}), frac_val / 10.)
                                        for frac_val in np.arange(2, 7, 2)]
                #smoothed_histos_chi2 = {frac_val :  find_chi2(h_fitted=smooth_histo, h_unc=orig_template) for smooth_histo, frac_val in smoothed_histos_list}
                # perform flattening
                flattened_histo = Plotter.flatten(nosys=nominal_hist,
                                                  systematic=orig_template)
                #flat_chi2 = find_chi2(h_fitted=flattened_histo, h_unc=orig_template)

                # plot relative deviation
                fig, ax = plt.subplots()
                fig.subplots_adjust(hspace=.07)

                # plot original dist
                orig_masked_vals, orig_masked_bins = Plotter.get_ratio_arrays(
                    num_vals=orig_template.values()[()] -
                    nominal_hist.values()[()],
                    denom_vals=nominal_hist.values()[()],
                    input_bins=nominal_hist.dense_axes()[0].edges())
                ax.fill_between(orig_masked_bins,
                                orig_masked_vals,
                                facecolor="k",
                                step="post",
                                alpha=0.5,
                                label="Unsmoothed")

                # plot smoothed versions
                for smooth_histo, frac_val in smoothed_histos_list:
                    smooth_masked_vals, smooth_masked_bins = Plotter.get_ratio_arrays(
                        num_vals=smooth_histo.values()[()] -
                        nominal_hist.values()[()],
                        denom_vals=nominal_hist.values()[()],
                        input_bins=nominal_hist.dense_axes()[0].edges())
                    ax.step(smooth_masked_bins,
                            smooth_masked_vals,
                            where="post",
                            **{
                                "linestyle": "-",
                                "label": f"Frac={frac_val}",
                                "linewidth": 2
                            })

                # plot flattened val
                flat_masked_vals, flat_masked_bins = Plotter.get_ratio_arrays(
                    num_vals=flattened_histo.values()[()] -
                    nominal_hist.values()[()],
                    denom_vals=nominal_hist.values()[()],
                    input_bins=nominal_hist.dense_axes()[0].edges())
                ax.step(flat_masked_bins,
                        flat_masked_vals,
                        where="post",
                        **{
                            "linestyle": "-",
                            "label": "Flat",
                            "linewidth": 2
                        })

                ax.legend(loc="upper right", title=f"{sys}, {proc}")
                ax.axhline(
                    0, **{
                        "linestyle": "--",
                        "color": (0, 0, 0, 0.5),
                        "linewidth": 1
                    })
                ax.autoscale()
                ax.set_xlim(x_lims)
                ax.set_xlabel(
                    "$m_{t\\bar{t}}$ $\otimes$ |cos($\\theta^{*}_{t_{l}}$)|")
                ax.set_ylabel("Rel. Deviaton from Nominal")

                # add lepton/jet multiplicity label
                ax.text(0.02,
                        0.94,
                        f"{leptypes[lep]}, {jet_mults[jmult]}",
                        fontsize=rcParams["font.size"] * 0.9,
                        horizontalalignment="left",
                        verticalalignment="bottom",
                        transform=ax.transAxes)
                ## draw vertical lines for distinguishing different ctstar bins
                vlines = [x_lims[1] * ybin / 5 for ybin in range(1, 5)]
                for vline in vlines:
                    ax.axvline(vline, color="k", linestyle="--")
                hep.cms.label(ax=ax,
                              data=False,
                              paper=False,
                              year=args.year,
                              lumi=round(data_lumi_year[f"{lep}s"] / 1000., 1))

                #set_trace()
                pltdir = os.path.join(outdir, lep, jmult, sys)
                if not os.path.isdir(pltdir):
                    os.makedirs(pltdir)

                #figname = os.path.join(pltdir, "_".join([jmult, lep, sys, proc, "BinWidths_Comp"]))
                #figname = os.path.join(pltdir, "_".join([jmult, lep, sys, proc, "SmoothValues_Comp"]))
                #figname = os.path.join(pltdir, "_".join([jmult, lep, sys, proc, "MttBinWidths_SmoothValues_Comp"]))
                figname = os.path.join(
                    pltdir,
                    "_".join([jmult, lep, sys, proc, "SmoothedFlatVals_Comp"]))
                fig.savefig(figname)
                print(f"{figname} written")
                plt.close()