Exemplo n.º 1
0
	def openRootFile(self):
		'''
		Creates the DmpChain
		'''
		self.chain = DmpChain("CollectionTree")
		for f in self.filelist:
			if not self.alreadyskimmed(f):
				self.chain.Add(f)
		if not self.chain.GetEntries():
			raise IOError("0 events in DmpChain - something went wrong")
		self.chain.SetOutputDir(self.outputdir)
		self.nvts = self.chain.GetEntries()
Exemplo n.º 2
0
def main(args=None):
    # Parsing options
    parser = ArgumentParser(usage="Usage: %(prog)s [options]",
                            description="MonteCarlo Analyzer")

    parser.add_argument("-l",
                        "--list",
                        type=str,
                        dest='list',
                        help='Input file list')
    parser.add_argument("-d",
                        "--dir",
                        type=str,
                        dest='directory',
                        help='Target Directory')
    parser.add_argument("-v",
                        "--verbose",
                        dest='verbose',
                        default=False,
                        action='store_true',
                        help='run in high verbosity mode')

    opts = parser.parse_args(args)

    # Create eKin energy histo
    nBins = 10000
    xLow = 0
    xUp = 200000
    eKinHisto = TH1D("eKinHisto", "MC True Kinetic Energy", nBins, xLow, xUp)

    # Load DAMPE libs
    gSystem.Load("libDmpEvent.so")
    gSystem.Load("libDmpKernel.so")
    gSystem.Load("libDmpService.so")

    from ROOT import DmpChain
    from ROOT import DmpVSvc

    # Load DAMPE modules
    import DMPSW

    # Creating DAMPE chain for input files
    dmpch = DmpChain("CollectionTree")

    # Add ROOT files to the chain and get total events
    nevents = addToChain(opts, dmpch)

    # Analyze events
    readMC(opts, dmpch, nevents, eKinHisto, kStep=1e+4)

    # Create output ROOT TFile
    outFile = createOutFile(opts)

    # Write histo to TFile
    eKinHisto.Write()

    # Write output TFile
    outFile.Write()
Exemplo n.º 3
0
 def isFlight():
     global error_code
     if debug: print 'check if data is flight data'
     ch = DmpChain("CollectionTree")
     ch.Add(infile)
     nevts = int(ch.GetEntries())
     if not nevts:
         error_code = 1004
         raise Exception("zero events")
     evt = ch.GetDmpEvent(0)
     tstart = getTime(evt)
     #for i in tqdm(xrange(nevts)):
     #    evt = ch.GetDmpEvent(i)
     #    if i == 0:
     evt = ch.GetDmpEvent(nevts - 1)
     tstop = getTime(evt)
     flight_data = True if ch.GetDataType() == DmpChain.kFlight else False
     del ch
     del evt
     return flight_data, dict(tstart=tstart, tstop=tstop)
Exemplo n.º 4
0
    def merge(self, filelist, i):
        '''
		Takes a python list "filelist" as argument, merges all the root files from that list
		into an output root file, whose name is decided by "i"
		'''
        # String manipulation to build output name
        outfile = 'merger_out' + self.subdir + '/'
        temp_index = "%05d" % (i, )
        if 'reco' in filelist[0]:
            temp_int = basename(filelist[0]).find('.reco') - 6
            outfile = outfile + basename(
                filelist[0])[0:temp_int] + temp_index + '.reco.root'
        elif 'mc' in filelist[0]:
            temp_int = basename(filelist[0]).find('.mc') - 6
            outfile = outfile + basename(
                filelist[0])[0:temp_int] + temp_index + '.mc.root'
        else:
            raise Exception(
                "Could not identify file type. Looking for '.reco.root' or '.mc.root'"
            )
        del temp_int, temp_index
        if isfile(outfile):
            return

        # Open files, build TChain
        dmpch = DmpChain("CollectionTree")
        metachain = TChain("RunMetadataTree")
        for f in filelist:
            dmpch.Add(f)
        metachain.Add(filelist[0])

        # Get Metadata values from "old" files
        nrofevents = 0
        for f in filelist:
            fo = TFile.Open(f)
            rmt = fo.Get("RunMetadataTree")
            simuheader = DmpRunSimuHeader()
            rmt.SetBranchAddress("DmpRunSimuHeader", simuheader)
            rmt.GetEntry(0)
            nrofevents += simuheader.GetEventNumber()
            fo.Close()

        fob = TFile.Open(filelist[0])
        rmt = fob.Get("RunMetadataTree")
        simuheader = DmpRunSimuHeader()
        rmt.SetBranchAddress("DmpRunSimuHeader", simuheader)
        rmt.GetEntry(0)
        runNumber = simuheader.GetRunNumber()
        spectrumType = simuheader.GetSpectrumType()
        sourceType = simuheader.GetSourceType()
        vertexRadius = simuheader.GetVertexRadius()
        sourceGen = simuheader.GetSourceGen()
        maxEne = simuheader.GetMaxEne()
        minEne = simuheader.GetMinEne()
        version = simuheader.GetVersion()
        fluxFile = simuheader.GetFluxFile()
        orbitFile = simuheader.GetOrbitFile()
        prescale = simuheader.GetPrescale()
        startMJD = simuheader.GetMJDstart()
        stopMJD = simuheader.GetMJDstop()
        seed = simuheader.GetSeed()
        fob.Close()

        # Write new file
        tf = TFile(outfile, "RECREATE")
        ot = dmpch.CloneTree(-1, "fast")
        om = metachain.CloneTree(0)
        simhdr = DmpRunSimuHeader()
        om.SetBranchAddress("DmpRunSimuHeader", simhdr)
        for i in xrange(metachain.GetEntries()):
            metachain.GetEntry(i)
            simhdr.SetEventNumber(nrofevents)
            simhdr.SetRunNumber(runNumber)
            simhdr.SetSpectrumType(spectrumType)
            simhdr.SetSourceType(sourceType)
            simhdr.SetVertexRadius(vertexRadius)
            simhdr.SetSourceGen(sourceGen)
            simhdr.SetMaxEne(maxEne)
            simhdr.SetMinEne(minEne)
            simhdr.SetVersion(version)
            simhdr.SetFluxFile(fluxFile)
            simhdr.SetOrbitFile(orbitFile)
            simhdr.SetPrescale(prescale)
            simhdr.SetMJDstart(startMJD)
            simhdr.SetMJDstop(stopMJD)
            simhdr.SetSeed(seed)
            om.Fill()
        tf.Write()
        tf.Close()

        self.equivalence[basename(outfile)] = filelist

        for x in [
                dmpch, metachain, nrofevents, runNumber, spectrumType,
                sourceType, vertexRadius, sourceGen, maxEne, minEne, version,
                fluxFile, orbitFile, prescale, startMJD, stopMJD, seed,
                simuheader, rmt, simhdr, outfile
        ]:
            try:
                del x
            except:
                continue
Exemplo n.º 5
0
def main(infile, debug=False):
    pdgs = dict(Proton=2212,
                Electron=11,
                Muon=13,
                Gamma=22,
                He=2,
                Li=3,
                Be=4,
                B=5,
                C=6,
                N=7,
                O=8)
    chksum = None
    global error_code

    if not infile.startswith("root://"):
        infile = abspath(infile)
        try:
            chksum = TMD5.FileChecksum(infile).AsString()
        except ReferenceError:
            pass
    DmpChain.SetVerbose(-1)
    DmpEvent.SetVerbosity(-1)
    if debug:
        DmpChain.SetVerbose(1)
        DmpEvent.SetVerbosity(1)
    types = ("mc:simu", "mc:reco", "2A")

    branches = {
        "mc:simu": [
            'EventHeader', 'DmpSimuStkHitsCollection', 'DmpSimuPsdHits',
            'DmpSimuNudHits0Collection', 'DmpSimuNudHits1Collection',
            'DmpSimuNudHits2Collection', 'DmpSimuNudHits3Collection',
            'DmpBgoSptStruct', 'DmpSimuBgoHits', 'DmpEvtSimuHeader',
            'DmpEvtSimuPrimaries', 'DmpTruthTrajectoriesCollection',
            'DmpSimuSeondaryVtxCollection', 'DmpEvtOrbit'
        ],
        "mc:reco": [
            'StkKalmanTracks', 'DmpEvtBgoHits', 'DmpEvtBgoRec', 'EventHeader',
            'DmpEvtNudRaw', 'DmpEvtSimuPrimaries', 'StkClusterCollection',
            'DmpStkLadderAdcCollection', 'DmpStkEventMetadata', 'DmpPsdHits',
            'DmpSimuPsdHits', 'DmpEvtPsdRec', 'DmpGlobTracks', 'DmpEvtOrbit',
            'DmpEvtSimuHeader', 'DmpTruthTrajectoriesCollection'
        ],
        "2A": [
            'EventHeader', 'DmpEvtPsdRaw', 'DmpPsdHits', 'DmpEvtBgoRaw',
            'DmpEvtBgoHits', 'StkClusterCollection',
            'DmpStkLadderAdcCollection', 'DmpStkEventMetadata',
            'StkKalmanTracks', 'DmpEvtNudRaw', 'EvtAttitudeContainer',
            'DmpEvtBgoRec', 'DmpEvtPsdRec', 'DmpGlobTracks'
        ]
    }

    # fixme: this wraps the null pointer testing.

    #def md5sum(fname):
    #    from subprocess import Popen, PIPE
    #    cmd = "md5sum {fname}".

    def checkBranches(tree, branches):
        global error_code
        if debug: print 'checking branches.'
        for b in branches:
            res = tree.FindBranch(b)
            if res is None:
                error_code = 1001
                raise Exception("missing branch %s", b)
        return True

    def verifyEnergyBounds(fname, emin, emax):
        if debug: print 'verify energy range from DmpRunSimuHeader'
        ch = TChain("CollectionTree")
        ch.Add(fname)
        nentries = ch.GetEntries()
        h1 = TH1D("h1", "hEnergy", 10, TMath.Log10(emin), TMath.Log10(emax))
        ch.Project("h1", "TMath::Log10(DmpEvtSimuPrimaries.pvpart_ekin)")
        underflow = h1.GetBinContent(0)
        overflow = h1.GetBinContent(11)
        assert (
            underflow == 0), "{n} events found below emin={emin} MeV".format(
                n=int(underflow), emin=emin)
        assert (
            overflow == 0), "{n} events found above emax={emax} MeV".format(
                n=int(overflow), emax=emax)
        del h1
        del ch

    def testPdgId(fname):
        global error_code
        if debug: print 'testing for PDG id'
        bn = basename(fname).split(".")[0].split("-")[0]
        if (not bn.startswith("all")) or (("bkg" or "background" or "back")
                                          in bn.lower()):
            return True
        else:
            try:
                particle = bn.replace("all", "")
                assert particle in pdgs.keys(), "particle type not supported"
                part = pdgs[particle]
                ch = TChain("CollectionTree")
                ch.Add(fname)
                h1 = TH1D("h1", "hPdgId", 10, part - 5, part + 5)
                if part < 10:
                    if debug: print 'Ion mode, subtract'
                    ch.Project(
                        "h1",
                        "TMath::Floor(DmpEvtSimuPrimaries.pvpart_pdg/10000.) - 100000"
                    )
                else:
                    ch.Project("h1", "DmpEvtSimuPrimaries.pvpart_pdg")
                delta = TMath.Abs(part - h1.GetMean())
                width = h1.GetRMS()
                if width > 0.1:
                    raise ValueError(
                        "pdg Id verification failed, distribution too broad, expect delta"
                    )
                if delta > 0.1:
                    raise ValueError(
                        "pdg Id verification failed, distribution not centered on %i but on %1.1f"
                        % (part, h1.GetMean()))
                if debug:
                    print "Pdg Hist: Mean = {mean}, RMS = {rms}".format(
                        mean=h1.GetMean(), rms=h1.GetRMS())
                del h1
                del ch
            except Exception as err:
                error_code = 1003
                raise Exception(err.message)
            return True

    def isNull(ptr):
        if debug: print 'test if pointer is null.'
        try:
            ptr.IsOnHeap()
        except ReferenceError:
            return True
        else:
            return False

    def getTime(evt):
        if debug: print 'extract timestamp'
        if isNull(evt.pEvtHeader()):
            return -1.
        sec = evt.pEvtHeader().GetSecond()
        ms = evt.pEvtHeader().GetMillisecond()
        time = float("{second}.{ms}".format(second=sec, ms=ms))
        return time

    def checkHKD(fname):
        global error_code
        if debug: print 'check HKD trees'
        trees = dict(
            SatStatus=['DmpHKDSatStatus'],
            HighVoltage=['DmpHKDHighVoltage'],
            TempSatellite=['DmpHKDTempSatellite'],
            TempPayloadNegative=['DmpHKDTempNegative'],
            TempPayloadPositive=['DmpHKDTempPositive'],
            CurrentPayloadNegative=['DmpHKDCurrentNegative'],
            CurrentPayloadPositive=['DmpHKDCurrentPositive'],
            StatusPayloadNegative=['DmpHKDStatusNegative'],
            StatusPayloadPositive=['DmpHKDStatusPositive'],
            StatusPowerSupplyPositive=['DmpHKDStatusPowerSupplyPositive'],
            StatusPowerSupplyNegative=['DmpHKDStatusPowerSupplyNegative'],
            PayloadDataProcesser=['DmpHKDPayloadDataProcessor'],
            PayloadManager=['DmpHKDPayloadManager'])
        try:
            for tree, branches in trees.iteritems():
                ch = TChain("HousekeepingData/{tree}".format(tree=tree))
                ch.Add(fname)
                if ch.GetEntries() == 0:
                    raise Exception("HKD tree %s empty", tree)
                checkBranches(ch, branches)
        except Exception as err:
            error_code = 1002
            raise Exception(err.message)
        return True

    def extractVersion(fname):
        if debug: print 'extract version'
        ch = TChain("RunMetadataTree")
        ch.Add(fname)
        ch.SetBranchStatus("*", 1)
        svn_rev = TString()
        tag = TString()
        ch.SetBranchAddress("SvnRev", svn_rev)
        ch.SetBranchAddress("tag", tag)
        ch.GetEntry(0)
        return str(tag), str(svn_rev)

    def extractEnergyBounds(fname):
        if debug: print 'extract energy boundaries'
        ch = TChain("RunMetadataTree")
        ch.Add(fname)
        simuHeader = DmpRunSimuHeader()
        b_sH = ch.GetBranch("DmpRunSimuHeader")
        ch.SetBranchAddress("DmpRunSimuHeader", simuHeader)
        b_sH.GetEntry(0)
        return simuHeader.GetMinEne(), simuHeader.GetMaxEne()

    def isFlight():
        global error_code
        if debug: print 'check if data is flight data'
        ch = DmpChain("CollectionTree")
        ch.Add(infile)
        nevts = int(ch.GetEntries())
        if not nevts:
            error_code = 1004
            raise Exception("zero events")
        evt = ch.GetDmpEvent(0)
        tstart = getTime(evt)
        #for i in tqdm(xrange(nevts)):
        #    evt = ch.GetDmpEvent(i)
        #    if i == 0:
        evt = ch.GetDmpEvent(nevts - 1)
        tstop = getTime(evt)
        flight_data = True if ch.GetDataType() == DmpChain.kFlight else False
        del ch
        del evt
        return flight_data, dict(tstart=tstart, tstop=tstop)

    def getSize(lfn):
        global error_code
        if debug: print 'extracting file size'
        if lfn.startswith("root://"):
            server = "root://{server}".format(server=lfn.split("/")[2])
            xc = client.FileSystem(server)
            is_ok, res = xc.stat(lfn.replace(server, ""))
            if not is_ok.ok:
                error_code = 2000
                raise Exception(is_ok.message)
            return res.size
        else:
            return getsize(lfn)

    def getModDate(lfn):
        from datetime import datetime
        global error_code
        if debug: print 'creation date'
        if lfn.startswith("root://"):
            server = "root://{server}".format(server=lfn.split("/")[2])
            xc = client.FileSystem(server)
            is_ok, res = xc.stat(lfn.replace(server, ""))
            if not is_ok.ok:
                error_code = 2000
                raise Exception(is_ok.message)
            return datetime.strptime(res.modtimestr, "%Y-%m-%d %H:%M:%S")
        else:
            return datetime.fromtimestamp(getmtime(lfn))

    def getTaskName(lfn):
        url = lfn
        if lfn.startswith("root://"):
            server = "root://{server}".format(server=lfn.split("/")[2])
            url = lfn.replace(server, "")
        return url.split("/")[-2]

    def isFile(lfn):
        global error_code
        if debug: print 'checking file access'
        if debug: print "LFN: ", lfn
        if lfn.startswith("root://"):
            server = "root://{server}".format(server=lfn.split("/")[2])
            if debug: print "server: ", server
            xc = client.FileSystem(server)
            fname = lfn.replace(server, "")
            if debug: print "LFN: ", fname
            if debug: print xc.stat(fname)
            is_ok, res = xc.stat(fname)
            if debug:
                print 'is_ok: ', is_ok
                print 'res: ', res
            if not is_ok.ok:
                error_code = 2000
                raise Exception(is_ok.message)
            if debug:
                print "res.flags: ", res.flags
                print "StatInfoFlags.IS_READABLE: ", StatInfoFlags.IS_READABLE
            if res.flags == 0:
                if debug:
                    print '[!] FIXME: XRootD.client.FileSystem.stat() returned StatInfoFlags = 0, this flag is not supported'
                res.flags = StatInfoFlags.IS_READABLE
            return True if res.flags >= StatInfoFlags.IS_READABLE else False
        else:
            return isfile(lfn)

    tstart = -1.
    tstop = -1.
    fsize = 0.
    moddate = "NONE"
    tname = "None"
    good = True
    eMax = -1.
    eMin = -1.
    comment = "NONE"
    f_type = "Other"
    svn_rev = "None"
    tag = "None"
    nevts = 0
    try:
        good = isFile(infile)
        if not good:
            error_code = 2000
            raise Exception("could not access file")

        fsize = getSize(infile)
        moddate = getModDate(infile).strftime("%Y-%m-%d_%H:%M:%S")
        tag, svn_rev = extractVersion(infile)
        tch = TChain("CollectionTree")
        tch.Add(infile)
        nevts = int(tch.GetEntries())
        if nevts == 0:
            error_code = 1004
            good = False
            raise IOError("zero events.")
        flight_data, stat = isFlight()
        if flight_data:
            good = checkHKD(infile)
            tstart = stat.get("tstart", -1.)
            tstop = stat.get("tstop", -1.)
            f_type = "2A"
        else:
            tname = getTaskName(infile)
            #print 'mc data'
            simu_branches = [tch.FindBranch(b) for b in branches['mc:simu']]
            reco_branches = [tch.FindBranch(b) for b in branches['mc:reco']]

            if None in simu_branches:
                f_type = "mc:reco"
                if None in reco_branches:
                    error_code = 1001
                    good = False
                    raise Exception("missing branches in mc:reco")
            else:
                f_type = "mc:simu"
            if (testPdgId(infile)): good = True
            eMin, eMax = extractEnergyBounds(infile)
            if eMin != eMax:
                try:
                    verifyEnergyBounds(infile, eMin, eMax)
                except AssertionError as msg:
                    error_code = 1005
                    raise Exception(msg.message)
        if good:
            assert f_type in types, "found non-supported dataset type!"
            good = checkBranches(tch, branches[f_type])

    except Exception as err:
        comment = str(err.message)
        good = False

    f_out = dict(lfn=infile,
                 nevts=nevts,
                 tstart=tstart,
                 tstop=tstop,
                 good=good,
                 error_code=error_code,
                 comment=comment,
                 size=fsize,
                 type=f_type,
                 version=tag,
                 SvnRev=svn_rev,
                 emax=eMax,
                 emin=eMin,
                 checksum=chksum,
                 last_modified=moddate,
                 task=tname)
    # convert True / False to 0/1
    for key, value in f_out.iteritems():
        if isinstance(value, bool): f_out[key] = int(value)
    return f_out
Exemplo n.º 6
0
from os.path import getsize, abspath, isfile, splitext
#from tqdm import tqdm
gROOT.SetBatch(True)
gROOT.ProcessLine("gErrorIgnoreLevel = 3002;")
res = gSystem.Load("libDmpEvent")
if res != 0:
    raise ImportError("could not import libDmpEvent, mission failed.")
from ROOT import DmpChain

infile = argv[1]
ev_id = argv[2]
event_id = []
if "," in ev_id: event_id = tuple([int(e) for e in ev_id.split(",")])
else:
    event_id = tuple([int(ev_id)])

ch = DmpChain("CollectionTree")
ch.SetOutputDir(abspath("."))
ch.Add(infile)

nevts = ch.GetEntries()
if not nevts: raise Exception("no events found in file.")
if max(event_id) > nevts: raise Exception("Event ID > total number of events.")

for i in event_id:
    print 'getting event %i' % i
    pev = ch.GetDmpEvent(i)
    ev_no = pev.pEvtSimuHeader().GetEventNuber()
    print i, ev_no
    ch.SaveCurrentEvent()
ch.Terminate()
Exemplo n.º 7
0
class Skim(object):
	
	def __init__(self,filename,particle=None,outputdir='skim_output'):
		
		self.filename = filename
		with open(filename,'r') as fi:
			self.filelist = []
			for line in fi:
				self.filelist.append(line.replace('\n',''))
		
		self.addPrefix()
		
		self.outputdir = outputdir
		if not os.path.isdir(outputdir): os.mkdir(outputdir)
		
		self.particle = self.identifyParticle(particle)
		
		self.openRootFile()
		
		self.t0 = time.time()
		self.selected = 0
		self.skipped = 0
		
	
	def addPrefix(self):
		'''
		adds the XrootD prefix to the filelist, in case it is missing
		'''
		if not 'root://' in self.filelist[0]:
			if not os.path.isfile(self.filelist[0]):
				self.filelist = ['root://xrootd-dampe.cloud.ba.infn.it/' + x for x in self.filelist]
		
	def identifyParticle(self,part):
		'''
		Particle identification based on either the argument or the file name
		'''
		e = ['e','elec','electron','11','E','Elec','Electron']
		p = ['p','prot','proton','2212','P','Prot','Proton']
		gamma = ['g','gamma','photon','22','Gamma','Photon']
		
		if part is None:
			for cat in [e,p,gamma]:
				for x in cat[1:]:
					if x in self.filename:
						return int(cat[3])
			return None
		else:
			for cat in [e,p,gamma]:
				if part in cat:
					return int(cat[3])
			return None
		
	def openRootFile(self):
		'''
		Creates the DmpChain
		'''
		self.chain = DmpChain("CollectionTree")
		for f in self.filelist:
			if not self.alreadyskimmed(f):
				self.chain.Add(f)
		if not self.chain.GetEntries():
			raise IOError("0 events in DmpChain - something went wrong")
		self.chain.SetOutputDir(self.outputdir)
		self.nvts = self.chain.GetEntries()
		
	def getRunTime(self):
		return (time.time() - self.t0)
	def getSelected(self):
		return self.selected
	def getSkipped(self):
		return self.skipped 
	
	def alreadyskimmed(self,f):
		'''
		Checks if file f has already been skimmed
		'''
		temp_f = os.path.basename(f).replace('.root','_UserSel.root')
		temp_f = self.outputdir + '/' + temp_f
		
		if os.path.isfile(temp_f):
			print os.path.basename(f), " already skimmed"
			del temp_f
			return True
		del temp_f
		return False
	
	def run(self):
		self.analysis()
		self.end()
		
	def selection(self,event):
		
		if self.particle is None:
			if event.pEvtSimuPrimaries().pvpart_pdg not in [11,2212,22]:
				return False
		else:
			if event.pEvtSimuPrimaries().pvpart_pdg != self.particle :
				return False
			
		# High energy trigger
		if not event.pEvtHeader().GeneratedTrigger(3):
			return False
		
		# Here: Russlan skimmer
		BGO_TopZ = 46
		BGO_BottomZ = 448
		
		# "BGO tack containment cut" -Russlan
		bgoRec_slope = [  event.pEvtBgoRec().GetSlopeYZ() , event.pEvtBgoRec().GetSlopeXZ() ]
		bgoRec_intercept = [ event.pEvtBgoRec().GetInterceptXZ() , event.pEvtBgoRec().GetInterceptYZ() ]
		if (bgoRec_slope[1]==0 amd bgoRec_intercept[1]==0) or (bgoRec_slope[0]==0 and bgoRec_intercept[0]==0): 
			return False
		
		# "BGO containment cut"
		topX = bgoRec_slope[1]*BGO_TopZ + bgoRec_intercept[1]
		topY = bgoRec_slope[0]*BGO_TopZ + bgoRec_intercept[0]
		bottomX = bgoRec_slope[1]*BGO_BottomZ + bgoRec_intercept[1]
		bottomY = bgoRec_slope[0]*BGO_BottomZ + bgoRec_intercept[0]
		if not all( [ abs(x) < 280 for x in [topX,topY,bottomX,bottomY] ] ):
			return False
		
		# "cut maxElayer"
		ELayer_max = 0
		for i in range(14):
			e = event.pEvtBgoRec().GetELayer(i)
			if e > ELayer_max: ELayer_max = e
			
		rMaxELayerTotalE = ELayer_max / event.pEvtBgoRec().GetElectronEcor()
		if rMaxELayerTotalE > 0.35: 
			return False
		
		# "cut maxBarLayer"
		barNumberMaxEBarLay1_2_3 = [-1 for i in [1,2,3]]
		MaxEBarLay1_2_3 = [0 for i in [1,2,3]]
		for ihit in range(0, event.pEvtBgoHits().GetHittedBarNumber()):
			hitE = (event.pEvtBgoHits().fEnergy)[ihit]
			lay = (event.pEvtBgoHits().GetLayerID)(ihit)
			if lay in [1,2,3]:
				if hitE > MaxEBarLay1_2_3[lay-1]:
					iBar =  ((event.pEvtBgoHits().fGlobalBarId)[ihit]>>6) & 0x1f		# What the f**k?
					MaxEBarLay1_2_3[lay-1] = hitE
					barNumberMaxEBarLay1_2_3[lay-1] = iBar
		for j in range(3):
			if barNumberMaxEBarLay1_2_3[j] <=0 or barNumberMaxEBarLay1_2_3[j] == 21:
				return False
						
		return True
Exemplo n.º 8
0
def findElectrons(opts):
    
    ### Load Python modules

    import os
    import math
    import numpy as np
    from array import array
    from os.path import isdir, abspath

    ### Load ROOT modules
    from ROOT import TClonesArray, TFile, TTree, gSystem, gROOT, AddressOf
    from ROOT import TH2F, TH1F, TMath, TGraphAsymmErrors

    ###Load DAMPE libs

    gSystem.Load("libDmpEvent.so")
    gSystem.Load("libDmpEventFilter.so")
    
    gSystem.Load("libDmpKernel.so")
    gSystem.Load("libDmpService.so")

    ###Load DAMPE modules

    from ROOT import DmpChain, DmpEvent, DmpFilterOrbit, DmpPsdBase, DmpCore
    from ROOT import DmpSvcPsdEposCor, DmpVSvc   #DmpRecPsdManager
    import DMPSW

    gROOT.SetBatch(True)

    ############################# Searching for electrons

    ####### Reading input files

    #Creating DAMPE chain for input files
    dmpch = DmpChain("CollectionTree")
    
    #Reading input files
    if not opts.input:
        files = [f.replace("\n","") for f in open(opts.list,'r').readlines()]
        for ifile, f in enumerate(files):
            DMPSW.IOSvc.Set("InData/Read" if ifile == 0 else "InData/ReadMore",f)
            if os.path.isfile(f):
                dmpch.Add(f)
                if opts.verbose:
                    print('\nInput file read: {} -> {}'.format(ifile,f))
    else:
        DMPSW.IOSvc.Set("InData/Read",opts.input)
        if os.path.isfile(opts.input):
            dmpch.Add(opts.input)
            if opts.verbose:
                print('\nInput file read: {}'.format(opts.input))
    
    #Defining the total number of events
    nevents = dmpch.GetEntries()

    if opts.verbose:
        print('\nTotal number of events: {}'.format(nevents))
        print("\nPrinting the chain...\n")
        dmpch.Print()
    
    ####### Setting the output directory to the chain
    dmpch.SetOutputDir(abspath(opts.outputDir),"electrons")

    ####### Processing input files

    ###Histos

    #Defining log binning

    #np.logspace binning
    nBins=1000
    eMax=6
    eMin=0
    eBinning = np.logspace(eMin, eMax, num=(nBins+1))
    
    #custom binning
    ''' 
    nBins = 1000
    eMin=0.1
    eMax=1000000
    EDmax = []
    EDEdge = [] 
    EDstepX=np.log10(eMax/eMin)/nBins
    for iedge in range(0, nBins):
        EDEdge.append(eMin*pow(10,iedge*EDstepX))
        EDmax.append(eMin*pow(10,(iedge+1)*EDstepX))
    EDEdge.append(EDmax[-1])
    Edges= array('d',EDEdge) # this makes a bound array for TH1F
    '''

    #Pointing
    h_terrestrial_lat_vs_long =  TH2F("h_terrestrial_lat_vs_long","latitude vs longitude",360,0,360,180,-90,90)

    ## Energy
    h_energy_all = TH1F("h_energy_all","all particle energy",nBins,eBinning)
    h_energyCut = TH1F("h_energyCut","all particle energy - 20 GeV cut",nBins,eBinning)
    h_energyCut_SAAcut = TH1F("h_energyCut_SAAcut","all particle energy - 20 GeV cut (no SAA)",nBins,eBinning)
    h_energyCut_noTrack = TH1F("h_energyCut_noTrack","all particle energy - 20 GeV cut (NO TRACK)",nBins,eBinning)
    h_energyCut_Track = TH1F("h_energyCut_Track","all particle energy - 20 GeV cut (TRACK)",nBins,eBinning)
    h_energyCut_TrackMatch = TH1F("h_energyCut_TrackMatch","all particle energy - 20 GeV cut (TRACK match)",nBins,eBinning)
    
    ##BGO
    h_energyBGOl=[]  #energy of BGO vertical layer (single vertical plane)
    for BGO_idxl in range(14):
        histoName = "h_energyBGOl_" + str(BGO_idxl)
        histoTitle = "BGO energy deposit layer " + str(BGO_idxl)
        tmpHisto = TH1F(histoName,histoTitle,1000,0,1e+6)
        h_energyBGOl.append(tmpHisto)

    h_energyBGOb = [] #energy of BGO lateral layer (single bars of a plane)
    h_BGOb_maxEnergyFraction = [] #fraction of the maximum released energy for each bar on each layer of the BGO calorimeter

    for BGO_idxl in range(14):
        tmp_eLayer = []
        for BGO_idxb in range(23):
            histoName = "h_energyBGOl_" + str(BGO_idxl) + "_BGOb_" + str(BGO_idxb)
            histoTitle = "BGO energy deposit layer " + str(BGO_idxl) + " bar " + str(BGO_idxb)
            tmpHisto = TH1F(histoName,histoTitle,1000,0,1e+6)
            tmp_eLayer.append(tmpHisto)
            
        maxhistoName = "h_BGO_maxEnergyFraction_l_" + str(BGO_idxl)
        maxhistoTitle = "fraction of the maximum released energy layer " + str(BGO_idxl)
        tmpMaxHisto = TH1F(maxhistoName,maxhistoTitle,100,0,1)
        h_BGOb_maxEnergyFraction.append(tmpMaxHisto)
        h_energyBGOb.append(tmp_eLayer)

    h_BGOl_maxEnergyFraction = TH1F("h_BGOl_maxEnergyFraction","Fraction of the maximum released energy",100,0,1)

    h_thetaBGO = TH1F("h_thetaBGO","theta BGO",100,0,90)

    ##STK

    h_STK_nTracks = TH1F("h_STK_nTracks","number of tracks",1000,0,1000)
    h_STK_trackChi2norm = TH1F("h_STK_trackChi2norm","\chi^2/n track",100,0,200)
    h_STK_nTracksChi2Cut = TH1F("h_STK_nTracksChi2Cut","number of tracks (\chi^2 cut)",1000,0,1000)
        
    h_stk_cluster_XvsY = []
    for iLayer in range(6):
        hName = 'h_stkCluster_XvsY_l_'+str(iLayer)
        hTitle = 'cluster X vs Y - plane '+str(iLayer)
        tmpHisto = TH2F(hName,hTitle,1000,-500,500,1000,-500,500)
        h_stk_cluster_XvsY.append(tmpHisto)

    h_ThetaSTK = TH1F("h_ThetaSTK","theta STK",100,0,90)
    h_deltaTheta = TH1F("h_deltaTheta","\Delta theta",500,-100,100)
    
    h_resX_STK_BGO = TH1F("h_resX_STK_BGO","BGO/STK residue layer X",200,-1000,1000)
    h_resY_STK_BGO = TH1F("h_resY_STK_BGO","BGO/STK residue layer Y",200,-1000,1000)

    h_imapctPointSTK = TH2F("h_imapctPointSTK","STK impact point",1000,-500,500,1000,-500,500)

    h_stk_chargeClusterX = TH1F("h_stk_chargeClusterX","STK charge on cluster X",10000,0,10000)
    h_stk_chargeClusterY = TH1F("h_stk_chargeClusterY","STK charge on cluster Y",10000,0,10000)

    ##PSD

    h_psd_ChargeX = []
    for lidx in range (2):
        histoName = "h_psd_ChargeX_l" + str(lidx)
        histoTitle = "PSD X charge layer " + str(lidx)
        tmpHisto = TH1F(histoName,histoTitle,10000,0,10000)
        h_psd_ChargeX.append(tmpHisto)

    h_psd_ChargeY = []
    for lidx in range (2):
        histoName = "h_psd_ChargeY_l" + str(lidx)
        histoTitle = "PSD Y charge layer " + str(lidx)
        tmpHisto = TH1F(histoName,histoTitle,10000,0,10000)
        h_psd_ChargeY.append(tmpHisto)

    ###

    ### Analysis cuts

    eCut = 50       #Energy cut in GeV

    ### DAMPE geometry

    BGOzTop = 46.
    BGOzBot = 448.

    #Filtering for SAA
    if not opts.mc:
        DMPSW.IOSvc.Set("OutData/NoOutput", "True")
        DMPSW.IOSvc.Initialize()
        pFilter = DmpFilterOrbit("EventHeader")
        pFilter.ActiveMe()
    
    #Starting loop on files

    if opts.debug:
        if opts.verbose:
            print('\nDebug mode activated... the number of chain events is limited to 1000')
        nevents = 1000
    
    for iev in xrange(nevents):

        if opts.mc:
            DmpVSvc.gPsdECor.SetMCflag(1)
        pev=dmpch.GetDmpEvent(iev)

        #Get latitude and longitude
        longitude = pev.pEvtAttitude().lon_geo
        latitude = pev.pEvtAttitude().lat_geo

        #Get particle total energy
        etot=pev.pEvtBgoRec().GetTotalEnergy()/1000.
        h_energy_all.Fill(etot)
        if etot < eCut:
            continue
        h_energyCut.Fill(etot)

        #Get BGO energy deposit for each layer (vertical BGO shower profile)
        v_bgolayer  = np.array([pev.pEvtBgoRec().GetELayer(ibgo) for ibgo in range(14)])
        
        for BGO_idxl in range(14):
            h_energyBGOl[BGO_idxl].Fill(v_bgolayer[BGO_idxl])  

        #Get BGO energy deposit for each bar (lateral BGO shower profile) of each layer

        for ilay in xrange(0,14):
            v_bgolayer_bars  = np.array([pev.pEvtBgoRec().GetEdepPos(ilay,ibar) for ibar in xrange(0,23)])
            #Fraction of the maximum energy deposit of the particle crossing the BGO on a certain layer (single bars)
            h_BGOb_maxEnergyFraction[ilay].Fill(np.max(v_bgolayer_bars)/1000./etot)
            for idx_BGOb in range (23):
                h_energyBGOb[ilay][idx_BGOb].Fill(v_bgolayer_bars[idx_BGOb])
            

        #Fraction of the maximum energy deposit of the particle crossing the BGO
        h_BGOl_maxEnergyFraction.Fill(np.max(v_bgolayer)/1000./etot)

        #BGO acceptance projection
        
        projectionX_BGO_BGOTop =  pev.pEvtBgoRec().GetInterceptXZ() +BGOzTop  * pev.pEvtBgoRec().GetSlopeXZ()
        projectionY_BGO_BGOTop =  pev.pEvtBgoRec().GetInterceptYZ() +BGOzTop  * pev.pEvtBgoRec().GetSlopeYZ()


        #SAA filter
        if not opts.mc:
            inSAA = pFilter.IsInSAA(pev.pEvtHeader().GetSecond())
            #inSAA = False
            if (inSAA): 
                continue
            h_energyCut_SAAcut.Fill(etot)
            h_terrestrial_lat_vs_long.Fill(longitude,latitude)

        tgZ = math.atan(np.sqrt( (pev.pEvtBgoRec().GetSlopeXZ()*pev.pEvtBgoRec().GetSlopeXZ()) + (pev.pEvtBgoRec().GetSlopeYZ()*pev.pEvtBgoRec().GetSlopeYZ()) ) );
        theta_bgo = tgZ*180./math.pi

        h_thetaBGO.Fill(theta_bgo)

        #Tracks
        ntracks = pev.NStkKalmanTrack()

        if ntracks < 0:
            print "\nTRACK ERROR: number of tracks < 0 - ABORTING\n"
            break
        if ntracks == 0:
            h_energyCut_noTrack.Fill(etot)
        
        h_STK_nTracks.Fill(ntracks)
        h_energyCut_Track.Fill(etot)

        res_X_min = 1000
        res_Y_min = 1000
        trackID_X = -9
        trackID_Y = -9

        lTrackIDX = []
        lTrackIDY = []

        residueXmin = []
        residueYmin = []

        #Loop on STK tracks to get the STK charge measurement

        for iTrack in range(ntracks):
            tmpTrack = pev.pStkKalmanTrack(iTrack)
            chi2_norm = tmpTrack.getChi2()/(tmpTrack.getNhitX()+tmpTrack.getNhitY()-4)
            h_STK_trackChi2norm.Fill(chi2_norm)

            if chi2_norm > 25: 
                continue
        
            h_STK_nTracksChi2Cut.Fill(ntracks)

            l0ClusterX = l0ClusterY = False

            for iCluster in range(tmpTrack.GetNPoints()):
                clux = tmpTrack.pClusterX(iCluster)
                cluy = tmpTrack.pClusterY(iCluster)
                if clux and clux.getPlane() == 0:
                    l0ClusterX = True
                if cluy and cluy.getPlane() == 0:
                    l0ClusterY = True

                # check plot for the dead region of STK
                if(clux and cluy):
                    h_stk_cluster_XvsY[clux.getPlane()].Fill(clux.GetX(),cluy.GetY())


            if l0ClusterX == False and l0ClusterY == False:
                continue

            #### Tracks characteristics

            theta_stk =math.acos(tmpTrack.getDirection().CosTheta())*180./math.pi;

            delta_theta_STK_BGO = theta_stk - theta_bgo

            #STK impact point
            trackImpactPointX = tmpTrack.getImpactPoint().x()
            trackImpactPointY = tmpTrack.getImpactPoint().y()

            #Track projections
            trackProjX = tmpTrack.getDirection().x()*(BGOzTop - tmpTrack.getImpactPoint().z()) + tmpTrack.getImpactPoint().x()
            trackProjY = tmpTrack.getDirection().y()*(BGOzTop - tmpTrack.getImpactPoint().z()) + tmpTrack.getImpactPoint().y()

            #Track residues
            resX_STK_BGO = projectionX_BGO_BGOTop - trackProjX
            resY_STK_BGO = projectionY_BGO_BGOTop - trackProjY

            resX_STK_BGO_top = trackImpactPointX - (pev.pEvtBgoRec().GetInterceptXZ() + tmpTrack.getImpactPoint().z() * pev.pEvtBgoRec().GetSlopeXZ())
            resY_STK_BGO_top = trackImpactPointY - (pev.pEvtBgoRec().GetInterceptYZ() + tmpTrack.getImpactPoint().z() * pev.pEvtBgoRec().GetSlopeYZ())

            ####

            h_ThetaSTK.Fill(theta_stk)
            h_deltaTheta.Fill(delta_theta_STK_BGO)

            h_imapctPointSTK.Fill(trackImpactPointX,trackImpactPointY)
                
            h_resX_STK_BGO.Fill(tmpTrack.getImpactPoint().x() - (pev.pEvtBgoRec().GetInterceptXZ() + tmpTrack.getImpactPoint().z() * pev.pEvtBgoRec().GetSlopeXZ()))
            h_resY_STK_BGO.Fill(tmpTrack.getImpactPoint().y() - (pev.pEvtBgoRec().GetInterceptYZ() + tmpTrack.getImpactPoint().z() * pev.pEvtBgoRec().GetSlopeYZ()))
    
            if abs(theta_stk - theta_bgo) > 25:
                continue
                    
            #Selecting good tracks for charge measurement

            if abs(resX_STK_BGO_top) < 200 and abs(resX_STK_BGO) < 60:
                lTrackIDX.append(tmpTrack)
                residueXmin.append(res_X_min)
                if res_X_min > abs(resX_STK_BGO_top):
                    res_X_min = abs(resX_STK_BGO_top)
                    trackID_X = iTrack
                    

            if abs(resY_STK_BGO_top) < 200 and abs(resY_STK_BGO) < 60:
                lTrackIDY.append(tmpTrack)
                residueYmin.append(res_Y_min)
                if res_Y_min > abs(resY_STK_BGO_top):
                    res_Y_min = abs(resY_STK_BGO_top)
                    trackID_Y = iTrack

        if(trackID_X == -9): 
            continue
        if(trackID_Y == -9): 
            continue

        track_ID = -9
        #print trackID_X
        
        if(trackID_X == trackID_Y):
            track_ID = trackID_X
        else:
            trackX = pev.pStkKalmanTrack(trackID_X)
            trackY = pev.pStkKalmanTrack(trackID_Y)
            chi2X = trackX.getChi2() /(trackX.getNhitX()+trackX.getNhitY()-4);
            chi2Y = trackY.getChi2() /(trackY.getNhitX()+trackY.getNhitY()-4);
            npointX = trackX.GetNPoints()
            npointY = trackY.GetNPoints()

            if(npointX == npointY or abs(npointX - npointY) == 1):
                if(chi2X < chi2Y):
                    if trackID_X in lTrackIDY:
                        track_ID = trackID_X
                    elif trackID_Y in lTrackIDX:
                            track_ID = trackID_Y
                    else:
                        common_id = list(set(lTrackIDX).intersection(lTrackIDY))
                        searchForTrack(
                                        common_id,
                                        lTrackIDX,
                                        lTrackIDY,
                                        residueXmin,
                                        residueYmin,
                                        track_ID
                                    )
                else:
                    if trackID_Y in lTrackIDX:
                        track_ID = trackID_Y
                    elif trackID_X in lTrackIDY:
                            track_ID = trackID_X
                    else:
                        common_id = list(set(lTrackIDX).intersection(lTrackIDY))
                        searchForTrack(
                                        common_id,
                                        lTrackIDX,
                                        lTrackIDY,
                                        residueXmin,
                                        residueYmin,
                                        track_ID
                                    )
            else:
                if(npointX > npointY):
                    if trackID_X in lTrackIDY:
                        track_ID = trackID_X
                    elif trackID_Y in lTrackIDX:
                            track_ID = trackID_Y
                    else:
                        common_id = list(set(lTrackIDX).intersection(lTrackIDY))
                        searchForTrack(
                                        common_id,
                                        lTrackIDX,
                                        lTrackIDY,
                                        residueXmin,
                                        residueYmin,
                                        track_ID
                                    )
                else:
                    if trackID_Y in lTrackIDX:
                        track_ID = trackID_Y
                    elif trackID_X in lTrackIDY:
                            track_ID = trackID_X
                    else:
                        common_id = list(set(lTrackIDX).intersection(lTrackIDY))
                        searchForTrack(
                                        common_id,
                                        lTrackIDX,
                                        lTrackIDY,
                                        residueXmin,
                                        residueYmin,
                                        track_ID
                                    )
        if(track_ID == -9): 
            continue

        h_energyCut_TrackMatch.Fill(etot)

        #Select the matched track
        track_sel = pev.pStkKalmanTrack(track_ID)
        theta_track_sel =math.acos(track_sel.getDirection().CosTheta())*180./math.pi;
        deltaTheta_rec_sel = theta_bgo - theta_track_sel
        track_correction = track_sel.getDirection().CosTheta();

        cluChargeX = -1000
        cluChargeY = -1000

        for iclu in xrange(0,track_sel.GetNPoints()):
            clux = track_sel.pClusterX(iclu)
            cluy = track_sel.pClusterY(iclu)
            if (clux and clux.getPlane() == 0):
                cluChargeX = clux.getEnergy()*track_correction
            if (cluy and cluy.getPlane() == 0):
                cluChargeY = cluy.getEnergy()*track_correction
        
        h_stk_chargeClusterX.Fill(cluChargeX)
        h_stk_chargeClusterY.Fill(cluChargeY)


        #Loop on PSD hits to get PSD charge measurement
        
        '''

        #PSD fiducial volume cut

        psd_YZ_top = -324.7
        psd_XZ_top = -298.5
        stk_to_psd_topY = (track_sel.getDirection().y()*(psd_YZ_top - track_sel.getImpactPoint().z()) + track_sel.getImpactPoint().y())
        stk_to_psd_topX = (track_sel.getDirection().x()*(psd_XZ_top - track_sel.getImpactPoint().z()) + track_sel.getImpactPoint().x())

        if(abs(stk_to_psd_topX) > 400.): 
            continue
        if(abs(stk_to_psd_topY) > 400.): 
            continue

        '''
       

        PSDXlayer0 = -298.5
        PSDXlayer1 = -284.5

        PSDYlayer0 = -324.7
        PSDYlayer1 = -310.7
        
        psdChargeX     = [[]for _ in range(2)]
        psdGIDX        = [[]for _ in range(2)]
        psdPathlengthX = [[]for _ in range(2)]
        psdPositionX   = [[]for _ in range(2)]

        psdChargeY     = [[]for _ in range(2)]
        psdGIDY        = [[]for _ in range(2)]
        psdPathlengthY = [[]for _ in range(2)]
        psdPositionY   = [[]for _ in range(2)]

        for lPSD in xrange(0,pev.NEvtPsdHits()):
            
            if pev.pEvtPsdHits().IsHitMeasuringX(lPSD):
                crossingX = False
                lenghtX = [-99999.,-99999.]
                array_lenghtX = array('d',lenghtX)

                if(pev.pEvtPsdHits().GetHitZ(lPSD) == PSDXlayer0):
                    npsdX = 0
                if(pev.pEvtPsdHits().GetHitZ(lPSD)== PSDXlayer1):
                    npsdX = 1
                
                if not opts.mc:
                    crossingX = DmpVSvc.gPsdECor.GetPathLengthPosition(pev.pEvtPsdHits().fGlobalBarID[lPSD],track_sel.getDirection(),track_sel.getImpactPoint(), array_lenghtX)

                if crossingX:
                    psdChargeX[npsdX].append(pev.pEvtPsdHits().fEnergy[lPSD]) 
                    psdGIDX[npsdX].append(pev.pEvtPsdHits().fGlobalBarID[lPSD]) 
                    psdPathlengthX[npsdX].append(array_lenghtX[1])
                    psdPositionX[npsdX].append(pev.pEvtPsdHits().GetHitX(lPSD))
            
            elif pev.pEvtPsdHits().IsHitMeasuringY(lPSD):
                crossingY = False
                lenghtY = [-99999.,-99999.]
                array_lenghtY = array('d',lenghtY)

                if(pev.pEvtPsdHits().GetHitZ(lPSD) == PSDYlayer0):
                    npsdY = 0
                if(pev.pEvtPsdHits().GetHitZ(lPSD)== PSDYlayer1):
                    npsdY = 1
                
                if not opts.mc:
                    crossingY = DmpVSvc.gPsdECor.GetPathLengthPosition(pev.pEvtPsdHits().fGlobalBarID[lPSD],track_sel.getDirection(),track_sel.getImpactPoint(), array_lenghtY)

                if crossingY:
                    psdChargeY[npsdY].append(pev.pEvtPsdHits().fEnergy[lPSD]) 
                    psdGIDY[npsdY].append(pev.pEvtPsdHits().fGlobalBarID[lPSD]) 
                    psdPathlengthY[npsdY].append(array_lenghtY[1])
                    psdPositionY[npsdY].append(pev.pEvtPsdHits().GetHitY(lPSD))
        
        '''
        print psdChargeX
        print psdGIDX
        print psdPathlengthX
        print psdPositionX

        print psdChargeY
        print psdGIDY
        print psdPathlengthY
        print psdPositionY
        
        '''

        psdFinalChargeX = [-999,-999]
        psdFinalChargeY = [-999,-999]

        #psdFinalChargeX_corr = [-999,-999]
        #psdFinalChargeY_corr = [-999,-999]

        psdFinalChargeX_proj = [-999,-999]
        psdFinalChargeY_proj = [-999,-999]

        psdX_pathlength = [-999,-999]
        psdY_pathlength = [-999,-999]

        psdX_position = [-999,-999]
        psdY_position = [-999,-999]
         
        PsdEC_tmpX = 0.
        PsdEC_tmpY = 0.

        for ipsd in xrange(0,2):
            
            if(len(psdChargeY[ipsd]) > 0):
                pos_max_len = np.argmax(psdPathlengthY[ipsd])
                lenghtY = [-99999.,-99999.]
                array_lenghtY = array('d',lenghtY)
                test_pos = False 
                if not opts.mc:
                    test_pos = DmpVSvc.gPsdECor.GetPathLengthPosition(psdGIDY[ipsd][pos_max_len],track_sel.getDirection(),track_sel.getImpactPoint(), array_lenghtY)
                 
                '''   
                PsdEC_tmpY = -1.
                if test_pos:
                    PsdEC_tmpY = DmpVSvc.gPsdECor.GetPsdECorSp3(psdGIDY[ipsd][pos_max_len], array_lenghtY[0])
                '''

                psdFinalChargeY[ipsd] = psdChargeY[ipsd][pos_max_len]
                h_psd_ChargeY[ipsd].Fill(psdFinalChargeY[ipsd])
                #psdFinalChargeY_corr[ipsd] = psdChargeY[ipsd][pos_max_len]*PsdEC_tmpY
                psdFinalChargeY_proj[ipsd] = array_lenghtY[0]
                psdY_pathlength[ipsd] = array_lenghtY[1]
                psdY_position[ipsd] =  psdPositionY[ipsd][pos_max_len]   
                


            if(len(psdChargeX[ipsd]) > 0):  
                pos_max_len = np.argmax(psdPathlengthX[ipsd])
                lenghtX = [-99999.,-99999.]
                array_lenghtX = array('d',lenghtX)
                test_pos = False 
                
                if not opts.mc:
                    test_pos = DmpVSvc.gPsdECor.GetPathLengthPosition(psdGIDX[ipsd][pos_max_len],track_sel.getDirection(),track_sel.getImpactPoint(), array_lenghtY)
                '''    
                PsdEC_tmpY = -1.
                if test_pos:
                    PsdEC_tmpX = DmpVSvc.gPsdECor.GetPsdECorSp3(psdGIDX[ipsd][pos_max_len], array_lenghtX[0])
                '''
                psdFinalChargeX[ipsd] = psdChargeX[ipsd][pos_max_len]
                h_psd_ChargeX[ipsd].Fill(psdFinalChargeX[ipsd])
                #psdFinalChargeX_corr[ipsd] = psdChargeX[ipsd][pos_max_len]*PsdEC_tmpX
                psdFinalChargeX_proj[ipsd] = array_lenghtX[0]
                psdX_pathlength[ipsd] = array_lenghtX[1]
                psdX_position[ipsd] =  psdPositionX[ipsd][pos_max_len] 










    ### Writing output files to file

    if opts.data:

        tf_skim = TFile(opts.outputFile,"RECREATE")

        h_energy_all.Write()
        h_energyCut.Write()
        h_energyCut_SAAcut.Write()
        h_energyCut_noTrack.Write()
        h_energyCut_Track.Write()
        h_energyCut_TrackMatch.Write()

        for BGO_idxl in range(14):
            h_energyBGOl[BGO_idxl].Write()
            h_BGOb_maxEnergyFraction[BGO_idxl].Write()
            for BGO_idxb in range(23):
                h_energyBGOb[BGO_idxl][BGO_idxb].Write()
        
        h_thetaBGO.Write()
        h_BGOl_maxEnergyFraction.Write()
        h_terrestrial_lat_vs_long.Write()

        h_STK_nTracks.Write()
        h_STK_trackChi2norm.Write()
        h_STK_nTracksChi2Cut.Write()

        for iLayer in range(6):
            h_stk_cluster_XvsY[iLayer].Write()

        h_ThetaSTK.Write()
        h_deltaTheta.Write()

        h_imapctPointSTK.Write()
        h_resX_STK_BGO.Write()
        h_resY_STK_BGO.Write()

        h_stk_chargeClusterX.Write()
        h_stk_chargeClusterY.Write()

        h_psd_ChargeX[0].Write()
        h_psd_ChargeX[1].Write()

        h_psd_ChargeY[0].Write()
        h_psd_ChargeY[1].Write()

        tf_skim.Close()

    
        
Exemplo n.º 9
0
def main(args=None):
    usage = "Usage: %(prog)s [options]"
    description = "adding MVAtree to existing 2A file, will create a copy"
    parser = ArgumentParser(usage=usage, description=description)
    parser.add_argument("-i",
                        "--infile",
                        dest='infile',
                        type=str,
                        default=None,
                        help='name of input file',
                        required=True)
    parser.add_argument("-o",
                        "--outfile",
                        dest='outfile',
                        type=str,
                        default=None,
                        help='name of output file',
                        required=True)
    parser.add_argument("-m",
                        "--model",
                        dest="model",
                        type=str,
                        default="model.model",
                        help="name of Keras model file")
    parser.add_argument("-b",
                        "--bdt",
                        dest="bdt",
                        type=str,
                        default="bdt.pick",
                        help="name of sklearn BDT model file")
    opts = parser.parse_args(args)

    from ROOT import TFile, TTree, gSystem, TObject, gROOT
    gSystem.Load("libDmpEvent.so")
    gROOT.SetBatch(True)
    from ROOT import DmpChain
    # first, make copy of outfile
    copy(opts.infile, opts.outfile)
    # next, load CollectionTree
    dpch = DmpChain("CollectionTree")
    dpch.Add(opts.infile)
    fout = TFile(opts.outfile, "update")
    fTree = TTree("MVAtree", "MVA scores")
    # modify, add new variables here
    # basically, you need to add an 'array' of doubles "d" with length 1.
    DNN_score = zeros(1, dtype=float)
    BDT_score = zeros(1, dtype=float)
    # register branch in tree
    fTree.Branch("DNN_score", DNN_score, "DNN_score/D")
    fTree.Branch("BDT_score", BDT_score, "BDT_score/D")

    # next is the usual event loop
    nevts = dpch.GetEntries()
    print 'found {i} events in {ifile}'.format(i=nevts, ifile=opts.infile)
    # space to declare your variables
    BgoTotalE = zeros(nevts, dtype=float)
    # event loop to read out variables
    print 'read out events'
    for i in xrange(nevts):
        pev = dpch.GetDmpEvent(i)
        # here you can add the usual logic, just *never* use continue
        BgoTotalE[i] = pev.pEvtBgoRec().GetTotalEnergy()
    dpch.Terminate()
    # here comes some keras 'magic'
    # ...
    # i'm assuming you compute scores as DNN_sk & BDT_sk

    DNN = load_model(args.model)
    DNN_sk = DNN.predict(X_norm)

    BDT = joblib.load(args.bdt)
    BDT_sk = BDT.predict_proba(X)[:, 1]

    # now loop again, creating scoring variables
    print 'store scores'
    for i in xrange(nevts):
        DNN_score[0] = DNN_sk[i]
        BDT_score[0] = BDT_sk[i]
        fTree.Fill()
    fTree.Write()
    fout.Write("", TObject.kOverwrite)
    fout.Close()