Exemple #1
0
def load(num_data=10000,
         generator='pythia',
         pad=True,
         with_bc=True,
         cache_dir="/.energyflow/datasets"):

    def_path = os.getenv('HOME')
    # obtain files
    Xs, ys = [], []
    cache_dir = def_path + cache_dir
    try:
        filename = "hgg.root"
        fpath = cache_dir + "/" + filename
        print(fpath)
    except Exception as e:
        print(str(e))

    f = ur.open(fpath)["cmb"]  #open the tree
    if f:
        print(f.show())  # show tuples

    #------- get tuples --------
    energy = f['Energy'].array()
    cosT = f['CosT'].array()
    PHI = f['PHI'].array()
    PDGID = f['PDGID'].array()
    BCL = f['BCL'].array()
    print("------- load tupules successful --------")

    #------- combine 2-d arrays to a 3-d array -------
    comb = np.array(ak.Array([energy, cosT, PHI, PDGID
                              ]))  # transform awkward array to NumPy array

    #------- transpose the axis to y-z-x -------
    outcomb = comb.transpose((1, 2, 0))

    print(f)
    Xs.append(outcomb)
    ys.append(BCL)

    # get X array
    if pad:
        max_len_axis1 = max([X.shape[1] for X in Xs])
        X = np.vstack([_pad_events_axis1(x, max_len_axis1) for x in Xs])
    else:
        X = np.asarray([x[x[:, 0] > 0] for X in Xs for x in X], dtype='O')

    # get y array
    y = np.concatenate(ys)

    # chop down to specified amount of data
    if num_data > -1:
        X, y = X[:num_data], y[:num_data]

    return X, y
Exemple #2
0
def load(num_data=100000,
         generator='pythia',
         pad=True,
         with_bc=False,
         cache_dir='~/.energyflow'):
    """Loads samples from the dataset (which in total is contained in twenty 
    files). Any file that is needed that has not been cached will be 
    automatically downloaded. Downloading a file causes it to be cached for
    later use. Basic checksums are performed.

    **Arguments**

    - **num_data** : _int_
        - The number of events to return. A value of `-1` means read in all
        events.
    - **generator** : _str_
        - Specifies which Monte Carlo generator the events should come from.
        Currently, the options are `'pythia'` and `'herwig'`.
    - **pad** : _bool_
        - Whether to pad the events with zeros to make them the same length.
        Note that if set to `False`, the returned `X` array will be an object
        array and not a 3-d array of floats.
    - **with_bc** : _bool_
        - Whether to include jets coming from bottom or charm quarks. Changing
        this flag does not mask out these jets but rather accesses an entirely
        different dataset. The datasets with and without b and c quarks should
        not be combined.
    - **cache_dir** : _str_
        - The directory where to store/look for the files. Note that 
        `'datasets'` is automatically appended to the end of this path.

    **Returns**

    - _3-d numpy.ndarray_, _1-d numpy.ndarray_
        - The `X` and `y` components of the dataset as specified above. If
        `pad` is `False` then these will be object arrays holding the events,
        each of which is a 2-d ndarray.
    """

    # check for valid options
    if generator not in GENERATORS:
        raise ValueError("'generator' must be in " + str(GENERATORS))

    # get number of files we need
    num_files = int(np.ceil(num_data /
                            NUM_PER_FILE)) if num_data > -1 else MAX_NUM_FILES
    if num_files > MAX_NUM_FILES:
        warnings.warn(
            'More data requested than available. Providing the full dataset.')
        num_files = MAX_NUM_FILES
        num_data = -1

    # index into global variables
    bc = 'bc' if with_bc else 'nobc'
    urls = URLS[generator][bc]
    hashes = HASHES[generator][bc]

    # obtain files
    Xs, ys = [], []
    for i in range(num_files):
        for j, source in enumerate(SOURCES):
            try:
                url = urls[source][i]
                filename = url.split('/')[-1].split('?')[0]

                fpath = _get_filepath(filename,
                                      url,
                                      cache_dir,
                                      file_hash=hashes['sha256'][i])

                # we succeeded, so don't continue trying to download this file
                break

            except Exception as e:
                print(str(e))

                # if this was our last source, raise an error
                if j == len(SOURCES) - 1:
                    m = 'Failed to download {} from any source.'.format(
                        filename)
                    raise RuntimeError(m)

                # otherwise indicate we're trying again
                else:
                    print(
                        "Failed to download {} from source '{}', trying next source..."
                        .format(filename, source))

        # load file and append arrays
        f = np.load(fpath)
        Xs.append(f['X'])
        ys.append(f['y'])
        f.close()

    # get X array
    if pad:
        max_len_axis1 = max([X.shape[1] for X in Xs])
        X = np.vstack([_pad_events_axis1(x, max_len_axis1) for x in Xs])
    else:
        X = np.asarray([x[x[:, 0] > 0] for X in Xs for x in X], dtype='O')

    # get y array
    y = np.concatenate(ys)

    # chop down to specified amount of data
    if num_data > -1:
        X, y = X[:num_data], y[:num_data]

    return X, y
Exemple #3
0
def load(dataset,
         num_data=100000,
         pad=False,
         cache_dir='~/.energyflow',
         source='zenodo',
         which='all',
         include_keys=None,
         exclude_keys=None):
    """Loads in the Z+jet Pythia/Herwig + Delphes datasets. Any file that is
    needed that has not been cached will be automatically downloaded.
    Downloaded files are cached for later use. Checksum verification is
    performed.

    **Arguments**

    - **datasets**: {`'Herwig'`, `'Pythia21'`, `'Pythia25'`, `'Pythia26'`}
        - The dataset (specified by which generator/tune was used to produce
        it) to load. Note that this argument is not sensitive to
        capitalization.
    - **num_data**: _int_
        - The number of events to read in. A value of `-1` means to load all
        available events.
    - **pad**: _bool_
        - Whether to pad the particles with zeros in order to form contiguous
        arrays.
    - **cache_dir**: _str_
        - Path to the directory where the dataset files should be stored.
    - **source**: {`'dropbox'`, `'zenodo'`}
        - Which location to obtain the files from.
    - **which**: {`'gen'`, `'sim'`, `'all'`}
        - Which type(s) of events to read in. Each dataset has corresponding
        generated events at particle-level and simulated events at
        detector-level.
    - **include_keys**: _list_ or_tuple_ of _str_, or `None`
        - If not `None`, load these keys from the dataset files. A value of
        `None` uses all available keys (the `KEYS` global variable of this
        module contains the available keys as keys of the dictionary and
        brief descriptions as values). Note that keys do not have 'sim' or
        'gen' prepended to the front yet.
    - **exclude_keys**: _list_ or _tuple_ or _str_, or `None`
        - Any keys to exclude from loading. Most useful when a small number of
        keys should be excluded from the default set of keys.

    **Returns**

    - _dict_ of _numpy.ndarray_
        - A dictionary of the jet, particle, and observable arrays for the
        specified dataset.
    """

    # load info from JSON file
    if 'INFO' not in globals():
        global INFO
        with open(os.path.join(EF_DATA_DIR, 'ZjetsDelphes.json'), 'r') as f:
            INFO = json.load(f)

    # check that options are valid
    dataset_low = dataset.lower()
    if dataset_low not in DATASETS:
        raise ValueError("Invalid dataset '{}'".format(dataset))
    if source not in SOURCES:
        raise ValueError("Invalud source '{}'".format(source))

    # handle selecting keys
    keys = set(KEYS.keys()) if include_keys is None else set(include_keys)
    keys -= set() if exclude_keys is None else set(exclude_keys)
    for key in keys:
        if key not in KEYS:
            raise ValueError("Unrecognized key '{}'".format(key))

    # create dictionray to store values to be returned
    levels = ['gen', 'sim'] if which == 'all' else [which.lower()]
    for level in levels:
        if level != 'gen' and level != 'sim':
            raise ValueError(
                "Unrecognized specification '{}' ".format(level) +
                "for argument 'which', allowed options are 'all', 'gen', 'sim'"
            )
    vals = {'{}_{}'.format(level, key): [] for level in levels for key in keys}
    if 'sim_Zs' in vals:
        del vals['sim_Zs']

    # get filenames
    filenames = [
        FILENAME_PATTERNS[dataset_low].format(i) for i in range(NUM_FILES)
    ]

    # get urls
    if source == 'dropbox':
        db_link_hashes = INFO['dropbox_link_hashes'][dataset_low]
        urls = [
            DROPBOX_URL_PATTERN.format(dl, fn)
            for dl, fn in zip(db_link_hashes, filenames)
        ]
    elif source == 'zenodo':
        urls = [
            ZENODO_URL_PATTERN.format(ZENODO_RECORD, fn) for fn in filenames
        ]

    # get hashes
    hashes = INFO['hashes'][dataset_low]['sha256']

    n = 0
    subdir = os.path.join('datasets', 'ZjetsDelphes')
    for filename, url, h in zip(filenames, urls, hashes):

        # check if we have enough events
        if n >= num_data and num_data != -1:
            break

        # load file
        f = np.load(
            _get_filepath(filename,
                          url,
                          cache_dir,
                          cache_subdir=subdir,
                          file_hash=h))

        # add relevant arrays to vals
        for i, (key, val) in enumerate(vals.items()):

            if 'particles' not in key or pad:
                val.append(f[key])
            else:
                val.append([np.array(ps[ps[:, 0] > 0]) for ps in f[key]])

            # increment number of events
            if i == 0:
                n += len(val[-1])

        f.close()

    # warn if we don't have enough events
    if num_data > n:
        warnings.warn('Only have {} events when {} were requested'.format(
            n, num_data))

    # concatenate arrays
    s = slice(0, num_data if num_data > -1 else None)
    for key, val in vals.items():

        if 'particles' not in key or not pad:
            vals[key] = np.concatenate(val, axis=0)[s]
        else:
            max_len_axis1 = max([X.shape[1] for X in val])
            vals[key] = np.vstack(
                [_pad_events_axis1(x, max_len_axis1) for x in val])[s]

    return vals