def main(args):
    ROOT.ROOT.EnableImplicitMT(args.threads)
    start = time.time()
    selection = THClass(
        'dijet_nano/%s_%s_snapshot.txt' % (args.setname, args.era),
        int(args.era), 1, 1)
    kinOnly = selection.OpenForSelection('None')

    # Kinematic plots
    jetPlots = HistGroup('jetPlots')
    # Taggers after mass selection
    selection.a.Define(
        'TopMassBools',
        'Dijet_msoftdrop_corrT > 105 && Dijet_msoftdrop_corrT < 210')
    selection.a.Define('DAK8TopScoresInMassWindow',
                       'Dijet_deepTag_TvsQCD[TopMassBools]')
    selection.a.Define('PNTopScoresInMassWindow',
                       'Dijet_particleNet_TvsQCD[TopMassBools]')
    jetPlots.Add(
        'DAK8TopScoresInMassWindow',
        selection.a.DataFrame.Histo1D(
            ('DAK8TopScoresInMassWindow',
             'DeepAK8 top score for jets in top mass window', 50, 0, 1),
            'DAK8TopScoresInMassWindow'))
    jetPlots.Add(
        'PNTopScoresInMassWindow',
        selection.a.DataFrame.Histo1D(
            ('PNTopScoresInMassWindow',
             'ParticleNet top score for jets in top mass window', 50, 0, 1),
            'PNTopScoresInMassWindow'))

    selection.a.Define(
        'HiggsMassBools',
        'Dijet_msoftdrop_corrH > 100 && Dijet_msoftdrop_corrH < 140')
    selection.a.Define('DAK8HiggsScoresInMassWindow',
                       'Dijet_deepTagMD_HbbvsQCD[HiggsMassBools]')
    selection.a.Define('PNHiggsScoresInMassWindow',
                       'Dijet_particleNet_HbbvsQCD[HiggsMassBools]')
    jetPlots.Add(
        'DAK8HiggsScoresInMassWindow',
        selection.a.DataFrame.Histo1D(
            ('DAK8HiggsScoresInMassWindow',
             'DeepAK8 Higgs score for jets in Higgs mass window', 50, 0, 1),
            'DAK8HiggsScoresInMassWindow'))
    jetPlots.Add(
        'PNHiggsScoresInMassWindow',
        selection.a.DataFrame.Histo1D(
            ('PNHiggsScoresInMassWindow',
             'ParticleNet Higgs score for jets in Higgs mass window', 50, 0,
             1), 'PNHiggsScoresInMassWindow'))

    # Mass after tagger selection
    selection.a.Define('TopDAK8Bools', 'Dijet_deepTag_TvsQCD > 0.9')
    selection.a.Define('TopPNBools', 'Dijet_particleNet_TvsQCD > 0.9')
    selection.a.Define('TopMassAfterDAK8Tag',
                       'Dijet_msoftdrop_corrT[TopDAK8Bools]')
    selection.a.Define('TopMassAfterPNTag',
                       'Dijet_msoftdrop_corrT[TopPNBools]')
    jetPlots.Add(
        'TopMassAfterDAK8Tag',
        selection.a.DataFrame.Histo1D(
            ('TopMassAfterDAK8Tag', 'Jet mass after DAK8 top score > 0.9', 25,
             50, 300), 'TopMassAfterDAK8Tag'))
    jetPlots.Add(
        'TopMassAfterPNTag',
        selection.a.DataFrame.Histo1D(
            ('TopMassAfterPNTag', 'Jet mass after PN top score > 0.9', 25, 50,
             300), 'TopMassAfterPNTag'))

    selection.a.Define('HiggsDAK8Bools', 'Dijet_deepTagMD_HbbvsQCD > 0.9')
    selection.a.Define('HiggsPNBools', 'Dijet_particleNet_HbbvsQCD > 0.9')
    selection.a.Define('HiggsMassAfterDAK8Tag',
                       'Dijet_msoftdrop_corrH[HiggsDAK8Bools]')
    selection.a.Define('HiggsMassAfterPNTag',
                       'Dijet_msoftdrop_corrH[HiggsPNBools]')
    jetPlots.Add(
        'HiggsMassAfterDAK8Tag',
        selection.a.DataFrame.Histo1D(
            ('HiggsMassAfterDAK8Tag', 'Jet mass after DAK8 Higgs score > 0.9',
             25, 50, 300), 'HiggsMassAfterDAK8Tag'))
    jetPlots.Add(
        'HiggsMassAfterPNTag',
        selection.a.DataFrame.Histo1D(
            ('HiggsMassAfterPNTag', 'Jet mass after PN Higgs score > 0.9', 25,
             50, 300), 'HiggsMassAfterPNTag'))

    selection.a.Define(
        'GenPart_vect',
        'hardware::TLvector(GenPart_pt, GenPart_eta, GenPart_phi, GenPart_mass)'
    )

    out = ROOT.TFile.Open(
        'rootfiles/THjetstudy_%s_%s.root' % (args.setname, args.era),
        'RECREATE')
    out.cd()
    presel = selection.a.GetActiveNode()
    # Assign jets on truth in parallel
    selection.a.SetActiveNode(presel)
    selection.ApplyTopPickViaMatch()
    truthtag = selection.a.Define(
        'MassDiff', 'Top_msoftdrop_corrT - Higgs_msoftdrop_corrH')
    nicenames = {"deepTag": "DAK8^{top}", "particleNet": "PN^{top}"}
    for t in ['deepTag', 'particleNet']:
        selection.a.SetActiveNode(presel)
        top_tagger = '%s_TvsQCD' % t
        # higgs_tagger = '%s_HbbvsQCD'%t
        # Signal region
        selection.ApplyTopPick(tagger=top_tagger, invert=False)

        selection.a.Define('MassDiff',
                           'Top_msoftdrop_corrT - Higgs_msoftdrop_corrH')
        selection.a.Define('NNDiff', 'Top_{0} - Higgs_{0}'.format(top_tagger))
        jetPlots.Add(
            'MassDiffvsNNDiff_%s' % t,
            selection.a.DataFrame.Histo2D(
                ('MassDiffvsNNDiff_%s' % t,
                 '(m_{{t}} - m_{{H}}) vs ({0}_{{t}} - {0}_{{H}})'.format(
                     nicenames[t]), 25, -100, 150, 40, -1, 1), 'MassDiff',
                'NNDiff'))
        # Look at unmatched pieces
        checkpoint = selection.a.GetActiveNode()
        selection.a.Cut(
            'NotGenMatchTop',
            '!MatchToGen(6, Top_vect, GenPart_vect, GenPart_pdgId)')
        selection.a.Cut(
            'NotGenMatchH',
            '!MatchToGen(25, Higgs_vect, GenPart_vect, GenPart_pdgId)')
        jetPlots.Add(
            'MassDiffvsNNDiff_%s_BadMatch' % t,
            selection.a.DataFrame.Histo2D(
                ('MassDiffvsNNDiff_%s_BadMatch' % t,
                 '(m_{{t}} - m_{{H}}) vs ({0}_{{t}} - {0}_{{H}}) - Bad matches'
                 .format(nicenames[t]), 25, -100, 150, 40, -1, 1), 'MassDiff',
                'NNDiff'))
        # Look at matched pieces
        selection.a.SetActiveNode(checkpoint)
        selection.a.Cut(
            'GenMatchTop',
            'MatchToGen(6, Top_vect, GenPart_vect, GenPart_pdgId)')
        selection.a.Cut(
            'GenMatchH',
            'MatchToGen(25, Higgs_vect, GenPart_vect, GenPart_pdgId)')
        jetPlots.Add(
            'MassDiffvsNNDiff_%s_GoodMatch' % t,
            selection.a.DataFrame.Histo2D((
                'MassDiffvsNNDiff_%s_GoodMatch' % t,
                '(m_{{t}} - m_{{H}}) vs ({0}_{{t}} - {0}_{{H}}) - Good matches'
                .format(nicenames[t]), 25, -100, 150, 40, -1, 1), 'MassDiff',
                                          'NNDiff'))
        # Assign jets on truth
        selection.a.SetActiveNode(truthtag)
        selection.a.Define('NNDiff_%s' % t,
                           'Top_{0} - Higgs_{0}'.format(top_tagger))
        jetPlots.Add(
            'MassDiffvsNNDiff_%s_TruthMatch' % t,
            selection.a.DataFrame.Histo2D((
                'MassDiffvsNNDiff_%s_TruthMatch' % t,
                '(m_{{t}} - m_{{H}}) vs ({0}_{{t}} - {0}_{{H}}) - Truth matches'
                .format(nicenames[t]), 25, -100, 150, 40, -1, 1), 'MassDiff',
                                          'NNDiff_%s' % t))

    jetPlots.Do('Write')
    selection.a.PrintNodeTree('NodeTree.pdf')
    print('%s sec' % (time.time() - start))
def THstudies(args):
    print('PROCESSING: %s %s' % (args.setname, args.era))
    ROOT.ROOT.EnableImplicitMT(args.threads)
    start = time.time()
    # Base setup
    selection = THClass(
        'dijet_nano/%s_%s_snapshot.txt' % (args.setname, args.era),
        int(args.era), 1, 1)
    selection.OpenForSelection('None')
    selection.a.Define(
        'Dijet_vect',
        'hardware::TLvector(Dijet_pt_corr, Dijet_eta, Dijet_phi, Dijet_msoftdrop_corrT)'
    )
    selection.a.Define('mth', 'hardware::InvariantMass(Dijet_vect)')
    selection.a.Define('m_avg',
                       '(Dijet_msoftdrop_corrT[0]+Dijet_msoftdrop_corrT[1])/2'
                       )  # Use the top version of the corrected mass
    # since it still has JES/JER which both would get anyway
    selection.ApplyTrigs(args.trigEff)
    selection.a.MakeWeightCols(
        extraNominal='' if selection.a.isData else 'genWeight*%s' %
        selection.GetXsecScale())

    # Kinematic definitions
    selection.a.Define('pt0', 'Dijet_pt_corr[0]')
    selection.a.Define('pt1', 'Dijet_pt_corr[1]')
    selection.a.Define('HT', 'pt0+pt1')
    selection.a.Define('deltaEta', 'abs(Dijet_eta[0] - Dijet_eta[1])')
    selection.a.Define('deltaPhi',
                       'hardware::DeltaPhi(Dijet_phi[0],Dijet_phi[1])')
    kinOnly = selection.a.Define(
        'deltaY', 'abs(Dijet_vect[0].Rapidity() - Dijet_vect[1].Rapidity())')

    # Kinematic plots
    kinPlots = HistGroup('kinPlots')
    kinPlots.Add(
        'pt0',
        selection.a.DataFrame.Histo1D(('pt0', 'Lead jet pt', 100, 350, 2350),
                                      'pt0', 'weight__nominal'))
    kinPlots.Add(
        'pt1',
        selection.a.DataFrame.Histo1D(
            ('pt1', 'Sublead jet pt', 100, 350, 2350), 'pt1',
            'weight__nominal'))
    kinPlots.Add(
        'HT',
        selection.a.DataFrame.Histo1D(
            ('HT', 'Sum of pt of two leading jets', 150, 700, 3700), 'HT',
            'weight__nominal'))
    kinPlots.Add(
        'deltaEta',
        selection.a.DataFrame.Histo1D(
            ('deltaEta', '| #Delta #eta |', 48, 0, 4.8), 'deltaEta',
            'weight__nominal'))
    kinPlots.Add(
        'deltaPhi',
        selection.a.DataFrame.Histo1D(
            ('deltaPhi', '| #Delta #phi |', 32, 1, 3.14), 'deltaPhi',
            'weight__nominal'))
    kinPlots.Add(
        'deltaY',
        selection.a.DataFrame.Histo1D(('deltaY', '| #Delta y |', 60, 0, 3),
                                      'deltaY', 'weight__nominal'))

    # Check MC truth to get jet idx assignment
    selection.ApplyTopPickViaMatch()
    kinPlots.Add(
        'tIdx_true',
        selection.a.DataFrame.Histo1D(
            ('tIdx_true', 'Top jet idx based on MC truth', 2, 0, 2), 'tIdx'))
    kinPlots.Add(
        'hIdx_true',
        selection.a.DataFrame.Histo1D(
            ('hIdx_true', 'Higgs jet idx based on MC truth', 2, 0, 2), 'hIdx'))

    # Do N-1 setup before splitting into DAK8 and PN - assume leading top
    #    This is a 50/50 assumption that kills the stats by 50% but
    #    it allows us to make the plots with real world possibility that
    #    there's Higgs and top cross contamination. Also helps to do this without
    #    too much hastle.
    selection.a.SetActiveNode(kinOnly)
    selection.a.ObjectFromCollection('LeadTop', 'Dijet', 0)
    nminus1Node = selection.a.ObjectFromCollection('SubleadHiggs', 'Dijet', 1)

    out = ROOT.TFile.Open(
        'rootfiles/THstudies_%s_%s%s.root' %
        (args.setname, args.era,
         '_' + args.variation if args.variation != 'None' else ''), 'RECREATE')
    out.cd()
    for t in ['deepTag', 'particleNet']:
        top_tagger = '%s_TvsQCD' % t
        higgs_tagger = '%sMD_HbbvsQCD' % t

        # N-1
        selection.a.SetActiveNode(nminus1Node)
        nminusGroup = selection.GetNminus1Group(t)
        nminusNodes = selection.a.Nminus1(nminusGroup)
        for n in nminusNodes.keys():
            if n.startswith('m'):
                bins = [25, 50, 300]
                if n.startswith('mH'): var = 'SubleadHiggs_msoftdrop_corrH'
                else: var = 'LeadTop_msoftdrop_corrT'
            elif n == 'full': continue
            else:
                bins = [50, 0, 1]
                if n.endswith('H_cut'): var = 'SubleadHiggs_%s' % higgs_tagger
                else: var = 'LeadTop_%s' % top_tagger
            print('N-1: Plotting %s for node %s' % (var, n))
            kinPlots.Add(
                n + '_nminus1', nminusNodes[n].DataFrame.Histo1D(
                    (n + '_nminus1', n + '_nminus1', bins[0], bins[1],
                     bins[2]), var, 'weight__nominal'))

    kinPlots.Do('Write')
    selection.a.PrintNodeTree('NodeTree.pdf', verbose=True)
    print('%s sec' % (time.time() - start))