예제 #1
0
    def read_mc_primary(self,energy_var=MC_P_EN,\
                       type_var=MC_P_TY,\
                       zenith_var=MC_P_ZE,\
                       weight_var=MC_P_WE):
        """
        Trigger the readout of MC Primary information
        Rename variables to magic keywords if necessary

        Keyword Args:
            energy_var (str): simulated primary energy
            type_var (str): simulated primary type
            zenith_var (str): simulated primary zenith
            weight_var (str): a weight, e.g. interaction propability
        """

        self.read_variables([energy_var,type_var,zenith_var,weight_var])
        for varname,defaultname in [(energy_var,MC_P_EN),\
                                    (type_var,MC_P_TY),\
                                    (zenith_var,MC_P_ZE),
                                    (weight_var,MC_P_WE)]:
            if varname != defaultname:
                Logger.warning("..renaming {} to {}..".format(varname,defaultname))
                self.vardict[varname].name = defaultname

        self._mc_p_readout = True
예제 #2
0
    def get_weights(self,model,model_kwargs = None):
        """
        Calculate weights for the variables in this category

        Args:
            model (callable): A model to be evaluated

        Keyword Args:
            model_kwargs (dict): Will be passed to model
        """
        if not self.mc_p_readout:
            self.read_mc_primary()

        if model_kwargs is None:
            model_kwargs = dict()
        func_kwargs = {MC_P_EN : self.get(MC_P_EN),\
                       MC_P_TY : self.get(MC_P_TY),\
                       MC_P_WE : self.get(MC_P_WE)}

        for key in MC_P_ZE,MC_P_GW,MC_P_TS,DATASETS:
            reg = key
            if key == DATASETS:
                reg = 'mc_datasets'
            try:
                func_kwargs[reg] = self.get(key)
            except KeyError:
                Logger.warning("No MCPrimary {0} informatiion! Trying to omit..".format(key))

        func_kwargs.update(model_kwargs)
        Logger.info("Getting weights for datasets {}".format(self.datasets.__repr__()))
        self._weights = pd.Series(self._weightfunction(model,self.datasets,\
                                 **func_kwargs))
예제 #3
0
파일: canvases.py 프로젝트: achim1/pyevsel
    def create_top_stacked_axes(self,heights=(1.)):
        """
        Create several axes for subplot on top of each other

        Args:
            heights (iterable):  relative height e.g.
                                 heights = [.2,.1,.6] will give axes using this amount of
                                 space
        """
        assert sum(heights) <= 1., "heights must be in relative heights"
    
        cfg = get_config_item("canvas")
        left = cfg["leftpadding"]
        right = cfg["rightpadding"]
        bot  = cfg["bottompadding"]
        top  = cfg["toppadding"]
        width = 1. - left - right
        height = 1. - top - bot
        
        heights = [height*h for h in heights]
        heights.reverse()
        Logger.debug("Using heights {0}".format(heights.__repr__()))
        abs_bot = 0 +bot     
        axes = [p.axes([left,abs_bot,width,heights[0]])]
        restheights = heights[1:]
        abs_bot = bot + heights[0]
        for h in restheights:
            theaxes = p.axes([left,abs_bot,width,h])
            p.setp(theaxes.get_xticklabels(), visible=False)
            axes.append(theaxes)
            abs_bot += h
    
        self.axes = axes
예제 #4
0
파일: plotting.py 프로젝트: achim1/pyevsel
    def indicate_cut(self,ax,arrow=True):
        """
        If cuts are given, indicate them by lines

        Args:
            ax (pylab.axes): axes to draw on

        """
        vmin,vmax = ax.get_ylim()
        hmin,hmax = ax.get_xlim()
        for cut in self.cuts:
            for name,(operator,value) in cut:
                if name != self.dataname:
                    continue
                Logger.debug('Found cut! {0} on {1}'.format(name,value))
                width = vmax/50.
                ax.vlines(value,ymin=vmin,ymax=vmax,linestyle=':')
                length = (hmax - hmin)*0.1
                shape = 'left'
                inversed = False
                if operator in ('>','>='):
                    shape = 'right'
                    inversed = True

                if arrow:
                    ax = create_arrow(ax,value,vmax*0.1, -1., 0, length, width= width,shape=shape,log=True)
                #ax.add_patch(arr)
                if not inversed:
                    ax.axvspan(value, hmax, facecolor=self.color_palette["prohibited"], alpha=0.5)
                else:
                    ax.axvspan(hmin, value, facecolor=self.color_palette["prohibited"], alpha=0.5)
예제 #5
0
파일: variables.py 프로젝트: achim1/pyevsel
    def harvest(self,*filenames):
        #FIXME: filenames is not used, just
        #there for compatibility

        if self.is_harvested:
            return
        harvested = filter(lambda var : var.is_harvested, self.variables)
        if not len(harvested) == len(self.variables):
            Logger.error("Variables have to be harvested for compound variable {} first!".format(self.name))
            return
        self.declare_harvested()
예제 #6
0
    def load_vardefs(self,module):
        """
        Load the variable definitions from a module

        Args:
            module (python module): Needs to contain variable definitions
        """

        all_vars = inspect.getmembers(module)
        all_vars = [x[1] for x in all_vars if isinstance(x[1],variables.AbstractBaseVariable)]
        for v in all_vars:
            if v.name in self.vardict:
                Logger.debug("Variable {} already defined,skipping!".format(v.name))
                continue
            self.add_variable(v)
예제 #7
0
파일: dataset.py 프로젝트: achim1/pyevsel
    def read_variables(self,variable_defs,names=None):
        """
        Read out the variable for all categories

        Args:
            variable_defs: A python module containing variable definitions
        
        Keyword Args:
            names (str): Readout only these variables if given
        Returns:

        """
        for cat in self.categories:
            Logger.debug("Reading variables for {}".format(cat))
            cat.load_vardefs(variable_defs)
            cat.read_variables(names=names)
예제 #8
0
    def estimate_livetime(self,force=False):
        """
        Calculate the livetime from run start/stop times, account for gaps
        
        Keyword Args:
            force (bool): overide existing livetime
        """
        if self.livetime and (not self.livetime=="guess"):
            Logger.warning("There is already a livetime of {:4.2f} ".format(self.livetime))
            if force:
                Logger.warning("Applying force...")
            else:
                Logger.warning("If you really want to do this, use force = True")
                return
        
        if not self._runstartstop_set:
            if (RUN_STOP in self.vardict.keys()) and (RUN_START in self.vardict.keys()):
                self._runstartstop_set = True
            else:
                Logger.warning("Need to set run start and stop times first! use object.set_run_start_stop")
                return

        Logger.warning("This is a crude estimate! Rather use a good run list or something!")
        lengths = self.get(RUN_STOP) - self.get(RUN_START)
        gaps    = self.get(RUN_START)[1:] - self.get(RUN_STOP)[:-1] #trust me!
        #h = self.nodes["header"].read()
        #h0 = h[:-1]
        #h1 = h[1:]
        ##FIXME
        #lengths = ((h["time_end_mjd_day"] - h["time_start_mjd_day"]) * 24. * 3600. +
        #           (h["time_end_mjd_sec"] - h["time_start_mjd_sec"]) +
        #           (h["time_end_mjd_ns"] - h["time_start_mjd_ns"])*1e-9 )
 
        #gaps = ((h1["time_start_mjd_day"] - h0["time_end_mjd_day"]) * 24.  * 3600. +
        #        (h1["time_start_mjd_sec"] - h0["time_end_mjd_sec"]) +
        #        (h1["time_start_mjd_ns"] - h0["time_end_mjd_ns"])*1e-9)
 

        # detector livetime is the duration of all events + the length of      all
        # gaps between events that are short enough to be not downtime. (     guess: 30s)
        est_ltime =  ( lengths.sum() + gaps[(0<gaps) & (gaps<30)].sum() )
        self.set_livetime(est_ltime)
        return 
예제 #9
0
파일: weighting.py 프로젝트: achim1/pyevsel
    def __call__(self,energy,ptype,\
                 zenith=None,mapping=False,
                 weight=None):
        """


        Args:
            energy: primary MC energy
            ptype: primary MC particle type
            zenith: cos (?) zenith

        Keyword Args:
            mapping: do a mapping to pdg
            weight: e.g. interactionprobabilityweights for nue

        Returns:
            numpy.ndarray: weights
        """

        # FIXME: mapping argument should go away
        if mapping:
            pmap = {14:ParticleType.PPlus, 402:ParticleType.He4Nucleus, 1407:ParticleType.N14Nucleus, 2713:ParticleType.Al27Nucleus, 5626:ParticleType.Fe56Nucleus}
            ptype = map(lambda x : pmap[x], ptype )

        # FIXME: This is too ugly and not general
        can_use_zenith = False
        if hasattr(self.flux,"__call__"):
            if hasattr(self.flux.__call__,"im_func"):
                args = inspect.getargs(self.flux.__call__.im_func.func_code)
                if len(args.args) == 4: # account for self
                    can_use_zenith = True
            else:
                can_use_zenith = True # method wrapper created by NewNuflux 
        else:
            args = inspect.getargs(self.flux.func_code) 
            if len(args.args) == 3:
                can_use_zenith = True
        if (zenith is not None) and can_use_zenith:
            Logger.debug("Using zenith!")
            return self.flux(energy,ptype,zenith)/self.gen(energy,particle_type=ptype,cos_theta=zenith)
        else:
            Logger.debug("Not using zenith!")
            return self.flux(energy,ptype)/self.gen(energy,particle_type=ptype,cos_theta=zenith)
예제 #10
0
    def get_files(self,*args,**kwargs):
        """
        Load files for this category
        uses pyevsel.utils.files.harvest_files

        Args:
            *args (list of strings): Path to possible files

        Keyword Args:
            datasets (dict(dataset_id : nfiles)): i given, load only files from dataset dataset_id  set nfiles parameter to amount of L2 files the loaded files will represent
            force (bool): forcibly reload filelist (pre-readout vars will be lost)
            all other kwargs will be passed to
            utils.files.harvest_files
        """
        force = False
        if "force" in kwargs:
            force = kwargs.pop("force")
        if self.is_harvested:
            Logger.info("Variables have already been harvested!\
                         if you really want to reload the filelist,\
                         use 'force=True'.\
                         If you do so, all your harvested variables will be deleted!")
            if not force:
                return
            else:
                Logger.warning("..using force..")

        if "datasets" in kwargs:
            filtered_files = []
            self.datasets = kwargs.pop("datasets")
            files = harvest_files(*args,**kwargs)
            datasets = [self._ds_regexp(x) for x in files]
            assert len(datasets) == len(files)

            ds_files = zip(datasets,files)
            for k in self.datasets.keys():
                filtered_files.extend([x[1] for x in ds_files if x[0] == k])
            files = filtered_files
        else:
            files = harvest_files(*args,**kwargs)

        self.files = files
예제 #11
0
파일: __init__.py 프로젝트: achim1/pyevsel
def GetCategoryConfig(name):
    """
    Get the relevant config section from the actual
    config for a category

    Args:
        name (string): Name of a category to search for
    """

    configs = yaml.load(open(CONFIGFILE, "r"))
    for cfg in configs["categories"]:
        if cfg["name"] == name:
            # FIXME little hack for bad latex parsing
            # by yaml
            # cleanlabel = cfg["label"]
            cleanlabel = SLASHES.sub(r"\\", cfg["label"])
            cfg["label"] = cleanlabel
            return cfg
    Logger.warning("No config for {0} found!".format(name))
    return cfg
예제 #12
0
    def set_run_start_stop(self,runstart_var=variables.Variable(None),runstop_var=variables.Variable(None)):
        """
        Let the simulation category know which 
        are the paramters describing the primary

        Keyword Args:
            runstart_var (pyevself.variables.variables.Variable): beginning of a run
            runstop_var (pyevself.variables.variables.Variable): beginning of a run

        """
        #FIXME
        for var,name in [(runstart_var,RUN_START),(runstop_var,RUN_STOP)]:
            if var.name is None:
                Logger.warning("No {0} available".format(name))
            elif name in self.vardict:
                Logger.info("..{0} already defined, skipping...".format(name))
                continue
            
            else:
                if var.name != name:
                    Logger.info("..renaming {0} to {1}..".format(var.name,name))
                    var.name = name
                newvar = deepcopy(var)
                self.vardict[name] = newvar

        self._runstartstop_set = True
예제 #13
0
파일: variables.py 프로젝트: achim1/pyevsel
def freedman_diaconis_bins(data,leftedge,\
                         rightedge,minbins=20,\
                         maxbins=70,fallbackbins=DEFAULT_BINS):
    """
    Get a number of bins for a histogram
    following Freedman/Diaconis

    Args:
        leftedge (float): left bin edge
        rightedge (float): right bin edge
        minbins (int): the minimum number of bins
        maxbins (int): the maximum number of bins
        fallbackbins (int): a number of bins which is returned
                            if calculation failse

    Returns:
        nbins (int): number of bins, minbins < bins < maxbins
    """

    try:
        finite_data = n.isfinite(data)
        q3          = n.percentile(data[finite_data],75)
        q1          = n.percentile(data[finite_data],25)
        n_data      = len(data)
        h           = (2*(q3-q1))/(n_data**1./3)
        bins = (rightedge - leftedge)/h
    except Exception as e:
        Logger.warn("Calculate Freedman-Draconis bins failed {0}".format( e.__repr__()))
        bins = fallbackbins

    if not n.isfinite(bins):
        Logger.warn("Calculate Freedman-Draconis bins failed, calculated nan bins, returning fallback")
        bins = fallbackbins

    if bins < minbins:
        bins = minbins
    if bins > maxbins:
        bins = maxbins

    return bins
예제 #14
0
파일: variables.py 프로젝트: achim1/pyevsel
def harvest(filenames,definitions,**kwargs):
    """
    Extract the variable data from the provided files

    Args:
        filenames (list): the files to extract from
                          currently supported: {0}

    Keyword Args:
        transformation (func): will be applied to the read out data

    Returns:
        pd.Series or pd.DataFrame
    """.format(REGISTERED_FILEEXTENSIONS.__repr__())

    data = pd.Series()
    for filename in filenames:
        filetype = f.strip_all_endings(filename)[1]
        assert filetype in REGISTERED_FILEEXTENSIONS, "Filetype {} not known!".format(filetype)
        assert os.path.exists(filename), "File {} does not exist!".format(filetype)
        Logger.debug("Attempting to harvest {1} file {0}".format(filename,filetype))
        
        if filetype == ".h5" and not isinstance(filename, tables.table.Table):
            # store = pd.HDFStore(filename)
            hdftable = tables.openFile(filename)

        else:
            hdftable = filename

        tmpdata = pd.Series()
        for definition in definitions:
            if filetype == ".h5":
                try:
                    # data = store.select_column(*definition)
                    tmpdata = hdftable.getNode("/" + definition[0]).col(definition[1])
                    tmpdata = pd.Series(tmpdata, dtype=n.float64)
                    Logger.debug("Found {} entries in table for {}{}".format(len(tmpdata),definition[0],definition[1]))
                    break
                except tables.NoSuchNodeError:
                    Logger.debug("Can not find definition {0} in {1}! ".format(definition, filename))
                    continue

            elif filetype == ".root":
                tmpdata = rn.root2rec(filename, *definition)
                tmpdata = pd.Series(data)
        if filetype == ".h5":
            hdftable.close()

        #tmpdata = harvest_single_file(filename, filetype,definitions)
        # self.data = self.data.append(data.map(self.transform))
        # concat should be much faster
        if "transformation" in kwargs:
            transform = kwargs['transformation']
            data = pd.concat([data, tmpdata.map(transform)])
        else:
            data = pd.concat([data, tmpdata])
        del tmpdata
    return data
예제 #15
0
파일: plotting.py 프로젝트: achim1/pyevsel
    def add_ratio(self,names_upper,names_under,\
                  total_ratio=None,total_ratio_errors=None,\
                  log=False,label="data/$\Sigma$ bg"):
        """
        Add a ratio plot to the canvas

        """
        if not isinstance(names_upper,list):
            names_upper = [names_upper]
        if not isinstance(names_under,list):
            names_under = [names_under]

        name = "".join(names_upper) + "_" + "".join(names_under)
        first_upper = names_upper.pop()
        upper_hist = self.histograms[first_upper]
        upper_ws   = self.histograms[first_upper].stats.weightsum
        for name in names_upper:
            upper_hist += self.histograms[name] 
            upper_ws   += self.histograms[name].stats.weightsum
        first_under = names_under.pop()
        under_hist = self.histograms[first_under]
        under_ws = self.histograms[first_under].stats.weightsum
        for name in names_under:
            under_hist += self.histograms[name]
            under_ws   += self.histograms[name].stats.weightsum
    
        upper_hist.normalized()
        under_hist.normalized()
        ratio = d.histfuncs.histratio(upper_hist,under_hist,\
                                      log=False,ylabel=label)
        if total_ratio is None:
            total_ratio = upper_ws/under_ws
            Logger.info("Calculated scalar ratio of {:4.2f} from histos".format(total_ratio))

        #ratio.y[ratio.y > 0] = ratio.y[ratio.y > 0] + total_ratio -1
        self.histratios[name] = (ratio,total_ratio,total_ratio_errors,label)
        return name
예제 #16
0
파일: weighting.py 프로젝트: achim1/pyevsel
def GetModelWeight(model,datasets,\
                   mc_datasets=None,\
                   mc_p_en=None,\
                   mc_p_ty=None,\
                   mc_p_ze=None,\
                   mc_p_we=1.,\
                   mc_p_ts=1.,\
                   mc_p_gw=1.,\
                   **model_kwargs):
    """
    Compute weights using a predefined model

    Args:
        model (func): Used to calculate the target flux
        datasets (dict): Get the generation pdf for these datasets from the db
                         dict needs to be dataset_id -> nfiles
    Keyword Args:
        mc_p_en (array-like): primary energy
        mc_p_ty (array-like): primary particle type
        mc_p_ze (array-like): primary particle cos(zenith)
        mc_p_we (array-like): weight for mc primary, e.g. some interaction probability

    Returns (array-like): Weights
    """
    if model_kwargs:
        flux = model(**model_kwargs)
    else:
        flux = model()
    # FIXME: There is a factor of 5000 not accounted
    # for -> 1e4 is for the conversion of
    factor = 1.
    gen  = GetGenerator(datasets)
    if map(int,gen.spectra.keys())[0] in NUTYPES:
        Logger.debug('Patching weights')
        factor = 5000
    weight = Weight(gen,flux)
    return factor*mc_p_we*weight(mc_p_en,mc_p_ty,zenith=mc_p_ze)
예제 #17
0
    like string like 'k' for matplotlib
    """

    def __getitem__(self,item):
        if item in self:
            return self.get(item)
        else:
            return item


seaborn_loaded = False
try:
    import seaborn.apionly as sb

    seaborn_loaded = True
    Logger.debug("Seaborn found!")
except ImportError:
    Logger.warning("Seaborn not found! Using predefined color palette")

    
def get_color_palette(name="dark"):
    """
    Load a color pallete, use seaborn if available
    """
    if not seaborn_loaded:
        color_palette = ColorDict()   # stolen from seaborn color-palette
        color_palette[0]           = (0.2980392156862745, 0.4470588235294118, 0.6901960784313725)
        color_palette[5]     = (0.8, 0.7254901960784313, 0.4549019607843137)#(0.3921568627450    9803, 0.7098039215686275, 0.803921568627451)
        color_palette["k"]           = "k"
        color_palette[1]       = (0.3333333333333333, 0.6588235294117647, 0.40784313725490196)
        color_palette[2] = (0.7686274509803922, 0.3058823529411765, 0.3215686274509804)
예제 #18
0
파일: plotting.py 프로젝트: achim1/pyevsel
    def plot(self,heights=(.5,.2,.2),\
             axes_locator=((0,"c"),(1,"r"),(2,"h")),\
             combined_distro=True,\
             combined_ratio=True,\
             combined_cumul=True,
             log=True):
        '''
        Create the plot

        Args:
            heights:
            axes_locator:
            combined_distro:
            combined_ratio:
            combined_cumul:
            log:

        Returns:

        '''

        Logger.info("Found {} distributions".format(len(self.histograms)))
        Logger.info("Found {} ratios".format(len(self.histratios)))
        Logger.info("Found {} cumulative distributions".format(len(self.cumuls)))
        if not axes_locator:
            axes_locator = self._locate_axes(combined_cumul,combined_ratio,combined_distro)

        # calculate the amount of needed axes
        assert len(axes_locator) == len(heights), "Need to specify exactly as many heights as plots you want to have"

        self.canvas = c.YStackedCanvas(axeslayout=heights)
        
        cu_axes = filter(lambda x : x[1] == "c",axes_locator)
        h_axes = filter(lambda x : x[1] == "h",axes_locator)
        r_axes = filter(lambda x : x[1] == "r",axes_locator)
        maxheights = []
        minheights = []
        for ax in cu_axes:
            cur_ax = self.canvas.select_axes(ax[0])
            if combined_cumul:
                for k in self.cumuls.keys():
                    self._draw_distribution(cur_ax,k,cumulative=True,log=log)
                break
            else:
                k = self.cumuls[self.cumuls.keys()[ax[0]]]
                self._draw_distribution(cur_ax,cumulative=True,log=log)
        for ax in r_axes:
            cur_ax = self.canvas.select_axes(ax[0])
            if combined_ratio:
                for k in self.histratios.keys():
                    self._draw_histratio(k,cur_ax)
                break
            else:
                k = self.histratios[self.histratios.keys()[ax[0]]]
                self._draw_histratio(k,cur_ax)    

        for ax in h_axes:
            cur_ax = self.canvas.select_axes(ax[0])
            if combined_distro:
                for k in self.histograms.keys():
                    print "drawing..",k
                    self._draw_distribution(cur_ax,k,log=log)
                break
            else:
                k = self.histograms[self.histograms.keys()[ax[0]]]
                ymax, ymin = self._draw_distribution(cur_ax,k,log=log)
            cur_ax.set_ylim(ymin=ymin - 0.1*ymin,ymax=1.1*ymax)
        lgax = self.canvas.select_axes(-1)#most upper one
        lg = lgax.legend(**LoadConfig()['legend'])
        legendwidth = LoadConfig()
        legendwidth = legendwidth['legendwidth']
        lg.get_frame().set_linewidth(legendwidth)
        # plot the cuts
        if self.cuts:
            for ax in h_axes:
                self.indicate_cut(ax,arrow=True)
            for ax in r_axes + cu_axes:
                self.indicate_cut(ax,arrow=False)
        # cleanup
        leftplotedge = n.inf
        rightplotedge = -n.inf
        minplotrange = n.inf
        maxplotrange = -n.inf
        for h in self.histograms.values():
            if not h.bincenters[h.bincontent > 0].sum():
                continue
            if h.bincenters[h.bincontent > 0][0] < leftplotedge:
                leftplotedge = h.bincenters[h.bincontent > 0][0]
            if h.bincenters[h.bincontent > 0][-1] > rightplotedge:
                rightplotedge = h.bincenters[h.bincontent > 0][-1]
            if min(h.bincontent[h.bincontent > 0]) < minplotrange:
                minplotrange = min(h.bincontent[h.bincontent > 0])
            if max(h.bincontent[h.bincontent > 0]) > maxplotrange:
                maxplotrange = max(h.bincontent[h.bincontent > 0])

        if log:
            maxplotrange *= 8
        else:
            maxplotrange *= 1.2
        if n.isfinite(leftplotedge):
            self.canvas.limit_xrange(xmin=leftplotedge)
        if n.isfinite(rightplotedge):
            self.canvas.limit_xrange(xmax=rightplotedge)
        for ax in h_axes:
            self.canvas.select_axes(ax[0]).set_ylim(ymax=maxplotrange,ymin=minplotrange)

        #if n.isfinite(minplotrange):
        #    self.canvas.limit_yrange(ymin=minplotrange - 0.1*minplotrange)
        #if n.isfinite(maxplotrange):
        #    self.canvas.limit_yrange(ymax=maxplotrange)
        self.canvas.eliminate_lower_yticks()
        # set the label on the lowest axes
        self.canvas.axes[0].set_xlabel(self.label)
예제 #19
0
파일: canvases.py 프로젝트: achim1/pyevsel
"""
Provides canvases for multi axes plots
"""

import os.path
import pylab as p

from pyevsel.plotting import get_config_item
from pyevsel.utils.logger import Logger
try:
    from IPython.core.display import Image
except ImportError:
    Logger.debug("Can not import IPython!")
    Image = lambda x : x

# golden cut values
# CW = current width of my thesis (adjust
CW = 5.78851
S  = 1.681

##########################################

class YStackedCanvas(object):
    """
    A canvas for plotting multiple axes
    """

    def __init__(self,axeslayout=(.2,.2,.5),figsize=(CW,CW*S)):
        """
        Axes indices go from bottom to top
        """
예제 #20
0
파일: __init__.py 프로젝트: achim1/pyevsel
def load_dataset(config, variables=None):
    """
    Loads a dataset according to a 
    configuration file
    
    Args:
        config (str): json style config file
    """

    # FIXME: os.path exits tests
    cfg = commentjson.load(open(config))
    categories = dict()
    weightfunctions = dict()
    models = dict()
    files_basepath = cfg["files_basepath"]
    for cat in cfg["categories"].keys():
        thiscat = cfg["categories"][cat]
        if thiscat["datatype"] == "simulation":
            categories[cat] = c.Simulation(cat)
            # remember that json keys are strings, so
            # convert to int
            datasets = {int(x): int(thiscat["datasets"][x]) for x in thiscat["datasets"]}
            categories[cat].get_files(
                os.path.join(files_basepath, thiscat["subpath"]),
                prefix=thiscat["file_prefix"],
                datasets=datasets,
                ending=thiscat["file_type"],
            )
            try:
                fluxclass, flux = thiscat["model"].split(".")
                models[cat] = getattr(dict(inspect.getmembers(fluxes))[fluxclass], flux)
            except ValueError:
                Logger.warning(
                    "{} does not seem to be a valid model for {}. This might cause troubles. If not, it is probably fine!".format(
                        thiscat["model"], cat
                    )
                )
                models[cat] = None
            weightfunctions[cat] = dict(inspect.getmembers(wgt))[thiscat["model_method"]]
        elif thiscat["datatype"] == "data":
            categories[cat] = c.Data(cat)
            categories[cat].get_files(
                os.path.join(files_basepath, thiscat["subpath"]),
                prefix=thiscat["file_prefix"],
                ending=thiscat["file_type"],
            )
            models[cat] = float(thiscat["livetime"])
            weightfunctions[cat] = dict(inspect.getmembers(wgt))[thiscat["model_method"]]

        elif thiscat["datatype"] == "reweighted":
            pass
        else:
            raise TypeError("Data type not understood. Has to be either 'simulation', 'reweighted' or 'data'!!")
    # at last we can take care of reweighted categories
    for cat in cfg["categories"].keys():
        thiscat = cfg["categories"][cat]
        if thiscat["datatype"] == "reweighted":
            categories[cat] = c.ReweightedSimulation(cat, categories[thiscat["parent"]])
            if thiscat["model"]:
                fluxclass, flux = thiscat["model"].split(".")
                models[cat] = getattr(dict(inspect.getmembers(fluxes))[fluxclass], flux)
                weightfunctions[cat] = dict(inspect.getmembers(wgt))[thiscat["model_method"]]
        elif thiscat["datatype"] in ["data", "simulation"]:
            pass
        else:
            raise TypeError("Data type not understood. Has to be either 'simulation', 'reweighted' or 'data'!!")

    # combined categories
    combined_categories = dict()
    for k in combined_categories.keys():
        combined_categories[k] = [categories[l] for l in cfg["combined_categories"]]

    # import variable defs
    vardefs = __import__(cfg["variable_definitions"])

    dataset = ds.Dataset(*categories.values(), combined_categories=combined_categories)
    dataset.read_variables(vardefs, names=variables)
    dataset.set_weightfunction(weightfunctions)
    dataset.get_weights(models=models)
    return dataset
예제 #21
0
파일: variables.py 프로젝트: achim1/pyevsel
import os
import pandas as pd
import tables
import abc

from pyevsel.utils import files as f
from pyevsel.utils.logger import Logger


DEFAULT_BINS = 70
REGISTERED_FILEEXTENSIONS = [".h5",".root"]

try:
    import root_numpy as rn
except ImportError:
    Logger.warning("No root_numpy found, root support is limited!")
    REGISTERED_FILEEXTENSIONS.remove(".root")

################################################################
# define a non-member function so that it can be used in a
# multiprocessing approach

#def harvest_single_file(filename, filetype, definitions):
#    """
#    Get the variable data from a fileobject
#    Optimized for hdf files
#
#    Args:
#        filename (str):
#        filetype (str): the extension of the filename, eg "h5"
#
예제 #22
0
    def read_variables(self,names=None):
        """
        Harvest the variables in self.vardict

        Keyword Args:
            names (list): havest only these variables
        """

        if names is None:
            names = self.vardict.keys()
        compound_variables = [] #harvest them later

        executor = fut.ProcessPoolExecutor(max_workers=MAX_CORES)
        future_to_varname = {}

        # first read out variables,
        # then compound variables
        # so make sure they are in the 
        # right order
        simple_vars = []
        for varname in names:
            try:
                if isinstance(self.vardict[varname],variables.CompoundVariable):
                    compound_variables.append(varname)
                    continue

                elif isinstance(self.vardict[varname],variables.VariableList):
                    compound_variables.append(varname)
                    continue
                else:
                    simple_vars.append(varname)
            except KeyError:
                Logger.warning("Cannot find {} in variables!".format(varname))
                continue
        for varname in simple_vars:
            # FIXME: Make it an option to not use
            # multi cpu readout!
            #self.vardict[varname].data = variables.harvest(self.files,self.vardict[varname].definitions)
            future_to_varname[executor.submit(variables.harvest,self.files,self.vardict[varname].definitions)] = varname
        #for future in tqdm.tqdm(fut.as_completed(future_to_varname),desc="Reading {0} variables".format(self.name), leave=True):
        progbar = False
        try:
            import pyprind
            n_it = len(future_to_varname.keys())
            bar = pyprind.ProgBar(n_it,monitor=False,bar_char='#',title=self.name)
            progbar = True
        except ImportError:
            pass

        exc_caught = """"""
        for future in fut.as_completed(future_to_varname):
            varname = future_to_varname[future]
            Logger.debug("Reading {} finished".format(varname))
            try:
                data = future.result()
                Logger.debug("Found {} entries ...".format(len(data)))
                data = self.vardict[varname].transform(data)
            except Exception as exc:
                exc_caught += "Reading {} for {} generated an exception: {}\n".format(varname,self.name, exc)
                data = pd.Series([])

            self.vardict[varname].data = data
            self.vardict[varname].declare_harvested()
            if progbar: bar.update()
        for varname in compound_variables:
            #FIXME check if this causes a memory leak
            self.vardict[varname].rewire_variables(self.vardict)
            self.vardict[varname].harvest()
        if exc_caught:
            Logger.warning("During the variable readout some exceptions occured!\n" + exc_caught)
        self._is_harvested = True