def test_modify_initial_tree(NEXP=10):
    """Add pseudoexepriments into TTree/TChain
    """

    files = prepare_data(1, 100000)

    logger.info('Add %s pseudoexepriments into TTree/TChain' % NEXP)

    logger.info('#files:    %s' % len(files))
    data = Data('S', files)
    logger.info('Initial Tree/Chain:\n%s' % data.chain.table(prefix='# '))

    ## pseudo experiments
    for e in progress_bar(range(NEXP)):
        h2_new = h2.sample()
        func = Ostap.Functions.FuncTH2(h2_new, 'pt', 'eta')
        data.chain.add_new_branch('w%d' % e, func)

    data = Data('S', files)
    logger.info('Tree/Chain after:\n%s' % data.chain.table(prefix='# '))

    counter = SE()
    for e in range(NEXP):
        weight = 'w%d' % e
        accepted = data.chain.sumVar('1', weight * cut)
        rejected = data.chain.sumVar('1', weight * ~cut)
        efficiency = 1 / (1 + rejected / accepted)
        logger.info("Experiment %3d, accepted/rejected %s/%s , eff = %s " %
                    (e, accepted, rejected, efficiency))
        counter += efficiency
    logger.info('Statistics of pseudoexperiments %s' % counter)
    logger.info('Mean/rms: %s[%%]/%.4f]%%]' %
                (counter.mean() * 100, counter.rms() * 100))
Exemple #2
0
def prepare_data(tmpdir, nfiles=100, nentries=100, ppservers=(), silent=True):

    ## Use generic Task from Kisa
    from ostap.parallel.parallel import GenericTask as Task
    task = Task(processor=create_tree)

    ## task  = PrepareTask ()
    wmgr = Parallel.WorkManager(ppservers=ppservers, silent=silent)

    from ostap.utils.cleanup import CleanUp
    tmpfile = CleanUp.tempfile(prefix='test_kisa_', suffix='.root', dir=tmpdir)

    fname = '%s/test_kisa_%d.root'

    files = [
        CleanUp.tempfile(prefix='test_kisa_', suffix='.root', dir=tmpdir)
        for i in range(nfiles)
    ]

    wmgr.process(task, [(f, nentries) for f in files])

    the_files = set()
    for f in task.results():
        if os.path.exists(f):
            the_files.add(f)

    from ostap.trees.data import Data
    the_files = list(the_files)
    the_files.sort()
    return Data('S', list(the_files))
def prepare_data(nfiles=10, nentries=100):
    """Prepare data for the test
    """

    files = []
    for i in progress_bar(range(nfiles)):

        from ostap.utils.cleanup import CleanUp
        tmpfile = CleanUp.tempfile(prefix='ostap-test-selectors-',
                                   suffix='.root')
        files.append(create_tree(tmpfile, nentries))

    files.sort()
    return Data('S', files)
def test_add_to_dataset(NEXP=10):
    """Add pseudoexepriments into RooDataSet
    """

    logger.info('Add %s pseudoexepriments into RooDataSet' % NEXP)

    files = prepare_data(1, 100000)

    logger.info('#files:    %s' % len(files))
    data = Data('S', files)
    logger.info('Initial Tree/Chain:\n%s' % data.chain.table(prefix='# '))

    import ostap.fitting.pyselectors
    dataset, _ = data.chain.fill_dataset(['mass', 'pt', 'eta'])

    logger.info('Initial dataset:\n%s' % dataset.table(prefix='# '))

    ## pseudo experiments
    for e in progress_bar(range(NEXP)):
        h2_new = h2.sample()
        func = Ostap.Functions.FuncRooTH2(h2_new, 'pt', 'eta')
        dataset.add_new_var('w%d' % e, func)

    logger.info('Dataset after:\n%s' % dataset.table(prefix='# '))

    counter = SE()
    for e in range(NEXP):
        weight = 'w%d' % e
        accepted = dataset.sumVar('1', weight * cut)
        rejected = dataset.sumVar('1', weight * ~cut)
        efficiency = 1 / (1 + rejected / accepted)
        logger.info("Experiment %3d, accepted/rejected %s/%s , eff = %s " %
                    (e, accepted, rejected, efficiency))

        counter += efficiency
    logger.info('Statistics of pseudoexperiments %s' % counter)
    logger.info('Mean/rms: %s[%%]/%.4f[%%]' %
                (counter.mean() * 100, counter.rms() * 100))
def test_addbranch():
    """Four ways to add branch into TTree/Tchain
    - using string formula (TTreeFormula-based)
    - using pure python function
    - using histogram/function
    - using histogram sampling
    """

    files = prepare_data(100, 1000)
    ## files = prepare_data ( 2 , 10 )

    logger.info('#files:    %s' % len(files))
    data = Data('S', files)
    logger.info('Initial Tree/Chain:\n%s' % data.chain.table(prefix='# '))

    # =========================================================================
    ## 1) add new branch as TTree-formula:
    # =========================================================================
    data.chain.add_new_branch('et', 'sqrt(pt*pt+mass*mass)')

    ## reload the chain and check:
    data.reload()
    logger.info('With formula:\n%s' % data.chain.table(prefix='# '))
    assert 'et' in data.chain, "Branch ``et'' is  not here!"

    # =========================================================================
    ## 2) add several new branches as TTree-formula:
    # =========================================================================
    data.chain.add_new_branch(
        {
            'Et1': 'sqrt(pt*pt+mass*mass)',
            'Et2': 'sqrt(pt*pt+mass*mass)*2',
            'Et3': 'sqrt(pt*pt+mass*mass)*3'
        }, None)

    ## reload the chain and check:
    data.reload()
    logger.info('With formula:\n%s' % data.chain.table(prefix='# '))
    assert 'Et1' in data.chain, "Branch ``Et1'' is  not here!"
    assert 'Et2' in data.chain, "Branch ``Et2'' is  not here!"
    assert 'Et3' in data.chain, "Branch ``Et3'' is  not here!"

    # =========================================================================
    ## 2) add new branch as pure python function
    # =========================================================================
    et2 = lambda tree: tree.pt**2 + tree.mass**2

    data.chain.add_new_branch('et2', et2)

    ## reload the chain and check:
    data.reload()
    logger.info('With python:\n%s' % data.chain.table(prefix='# '))
    assert 'et2' in data.chain, "Branch ``et2'' is  not here!"

    # =========================================================================
    ## 3) add new branch as histogram-function
    # =========================================================================
    h1 = ROOT.TH1D('h1', 'some pt-correction', 100, 0, 10)
    h1 += lambda x: 1.0 + math.tanh(0.2 * (x - 5))

    from ostap.trees.funcs import FuncTH1
    ptw = FuncTH1(h1, 'pt')
    data.chain.add_new_branch('ptw', ptw)

    ## reload the chain and check:
    data.reload()
    logger.info('With histogram:\n%s' % data.chain.table(prefix='# '))
    assert 'ptw' in data.chain, "Branch ``ptw'' is  not here!"

    # =========================================================================
    ## 4) add the variable sampled from the histogram
    # =========================================================================
    h2 = ROOT.TH1D('2', 'Gauss', 120, -6, 6)
    for i in range(100000):
        h2.Fill(random.gauss(0, 1))

    data.chain.add_new_branch('hg', h2)

    ## reload the chain and check:
    data.reload()
    logger.info('With sampled:\n%s' % data.chain.table(prefix='# '))
    assert 'hg' in data.chain, "Branch ``g'' is  not here!"
Exemple #6
0
def test_fitting_fill_1():
    ## if 1 < 2 :
    logger = getLogger('test_fitting_fill_1')

    ## prepare data
    with timing("Prepare test data", logger=logger):
        files = prepare_data(4, 5000)
        data = Data('S', files)

    chain = data.chain

    mJPsi = ROOT.RooRealVar('mJPsi', 'mass(J/Psi) [GeV]', 3.0 * GeV, 3.2 * GeV)

    # =========================================================================
    logger.info(attention('All trivial variables'))
    # =========================================================================

    variables = [
        Variable(mJPsi, accessor='mass'),
        Variable('massMeV', 'mass in MeV', 3000, 3200, 'mass*1000.0'),
        Variable('vv102', 'vv10[2]', -1, 100, '1.0*vv10[2]'),
        Variable('fevt', accessor='1.0*evt'), ('pt', ), ('eta', ),
        ('x', 'some variable', 0, 5000, '(mass+pt+eta)/eta')
    ]

    config = {'variables': variables, 'selection': "pt>7 && eta<3"}

    with timing("No SHORTCUT, no FRAME", logger=None) as t1:
        logger.info(attention(t1.name))
        selector = SelectorWithVars(**config)
        chain.fill_dataset(selector, shortcut=False, use_frame=False)
        ds1_1 = selector.data
    with timing("   SHORTCUT, no FRAME", logger=None) as t2:
        logger.info(attention(t2.name))
        selector = SelectorWithVars(**config)
        chain.fill_dataset(selector, shortcut=True, use_frame=False)
        ds1_2 = selector.data
    with timing("No SHORTCUT,    FRAME", logger=None) as t3:
        logger.info(attention(t3.name))
        selector = SelectorWithVars(**config)
        chain.fill_dataset(selector, shortcut=False, use_frame=True)
        ds1_3 = selector.data
    with timing("   SHORTCUT,    FRAME", logger=None) as t4:
        logger.info(attention(t4.name))
        selector = SelectorWithVars(**config)
        chain.fill_dataset(selector, shortcut=True, use_frame=True)
        ds1_4 = selector.data

    if DataSet_NEW_FILL:
        with timing(" pure-FRAME (new) ", logger=None) as t5:
            logger.info(attention(t5.name))
            ds1_5, _ = chain.make_dataset(silent=False, **config)

    table = [('Configuration', 'CPU')]

    table.append((t1.name, '%.3fs' % t1.delta))
    table.append((t2.name, '%.3fs' % t2.delta))
    table.append((t3.name, '%.3fs' % t3.delta))
    table.append((t4.name, '%.3fs' % t4.delta))

    if DataSet_NEW_FILL:
        table.append((t5.name, '%.3fs' % t5.delta))

    title1 = "All trivial variables"
    table1 = T.table(table, title=title1, prefix='# ', alignment='rr')
    logger.info('%s\n%s' % (title1, table1))

    if ds1_1 != ds1_2:
        logger.error('Datasets ds1_1  and ds1_2   are different!')
    if ds1_1 != ds1_3:
        logger.error('Datasets ds1_1  and ds1_3   are different!')
    if ds1_1 != ds1_4:
        logger.error('Datasets ds1_1  and ds1_4   are different!')

    if DataSet_NEW_FILL:
        if ds1_1 != ds1_5:
            logger.error('Datasets ds1_1  and ds1_5   are different!')

    with timing("No SHORTCUT, no FRAME", logger=None) as t1:
        logger.info(attention(t1.name))
        selector = SelectorWithVars(**config)
        chain.parallel_fill(selector,
                            shortcut=False,
                            use_frame=False,
                            max_files=1)
        ds1p_1 = selector.data
    with timing("   SHORTCUT, no FRAME", logger=None) as t2:
        logger.info(attention(t2.name))
        selector = SelectorWithVars(**config)
        chain.parallel_fill(selector,
                            shortcut=True,
                            use_frame=False,
                            max_files=1)
        ds1p_2 = selector.data
    with timing("No SHORTCUT,    FRAME", logger=None) as t3:
        logger.info(attention(t3.name))
        selector = SelectorWithVars(**config)
        chain.parallel_fill(selector,
                            shortcut=False,
                            use_frame=True,
                            max_files=1)
        ds1p_3 = selector.data
    with timing("   SHORTCUT,    FRAME", logger=None) as t4:
        logger.info(attention(t4.name))
        selector = SelectorWithVars(**config)
        chain.parallel_fill(selector,
                            shortcut=True,
                            use_frame=True,
                            max_files=1)
        ds1p_4 = selector.data

    table = [('Configuration', 'CPU')]

    table.append((t1.name, '%.3fs' % t1.delta))
    table.append((t2.name, '%.3fs' % t2.delta))
    table.append((t3.name, '%.3fs' % t3.delta))
    table.append((t4.name, '%.3fs' % t4.delta))

    title1p = "All trivial variables (parallel)"
    table1p = T.table(table, title=title1p, prefix='# ', alignment='rr')
    logger.info('%s\n%s' % (title1p, table1p))

    if ds1_1 != ds1p_1:
        logger.error('Datasets ds1_1  and ds1p_1  are different!')
    if ds1_2 != ds1p_2:
        logger.error('Datasets ds1_2  and ds1p_2  are different!')
    if ds1_3 != ds1p_3:
        logger.error('Datasets ds1_3  and ds1p_3  are different!')
    if ds1_4 != ds1p_4:
        logger.error('Datasets ds1_4  and ds1p_4  are different!')

    # =========================================================================
    logger.info(attention('Trivial variables + CUT'))
    # =========================================================================

    variables = [
        Variable(mJPsi, accessor='mass'),
        Variable('massMeV', 'mass in MeV', 3000, 3200, 'mass*1000'),
        Variable('vv102', 'vv10[2]', -1, 100, '1.0*vv10[2]'),
        Variable('fevt', accessor='1.0*evt'), ('pt', ), ('eta', ),
        ('x', 'some variable', 0, 5000, '(mass+pt+eta)/eta')
    ]

    if not DILL_PY3_issue:

        config = {
            'variables': variables,
            'selection': "pt>7 && eta<3",
            'cuts': lambda s: s.pt > 3
        }  ## ATTENTION: no trivial cuts!

    else:

        logger.warning('There is an issue with dill+python3: avoid lambda!')
        config = {
            'variables': variables,
            'selection': "pt>7 && eta<3",
            'cuts': ptcut
        }  ## ATTENTION: no trivial cuts!

    with timing("No SHORTCUT, no FRAME", logger=None) as t1:
        logger.info(attention(t1.name))
        selector = SelectorWithVars(**config)
        chain.fill_dataset(selector, shortcut=False, use_frame=False)
        ds2_1 = selector.data
    with timing("   SHORTCUT, no FRAME", logger=None) as t2:
        logger.info(attention(t2.name))
        selector = SelectorWithVars(**config)
        chain.fill_dataset(selector, shortcut=True, use_frame=False)
        ds2_2 = selector.data
    with timing("No SHORTCUT,    FRAME", logger=None) as t3:
        logger.info(attention(t3.name))
        selector = SelectorWithVars(**config)
        chain.fill_dataset(selector, shortcut=False, use_frame=True)
        ds2_3 = selector.data
    with timing("   SHORTCUT,    FRAME", logger=None) as t4:
        logger.info(attention(t4.name))
        selector = SelectorWithVars(**config)
        chain.fill_dataset(selector, shortcut=True, use_frame=True)
        ds2_4 = selector.data

    table = [('Configuration', 'CPU')]

    table.append((t1.name, '%.3fs' % t1.delta))
    table.append((t2.name, '%.3fs' % t2.delta))
    table.append((t3.name, '%.3fs' % t3.delta))
    table.append((t4.name, '%.3fs' % t4.delta))

    title2 = "Trivial variables + CUT"
    table2 = T.table(table, title=title2, prefix='# ', alignment='rr')
    logger.info('%s\n%s' % (title2, table2))

    if ds2_1 != ds2_2:
        logger.error('Datasets ds2_1  and ds2_2   are different!')
    if ds2_1 != ds2_3:
        logger.error('Datasets ds2_1  and ds2_3   are different!')
    if ds2_1 != ds2_4:
        logger.error('Datasets ds2_1  and ds2_4   are different!')

    with timing("No SHORTCUT, no FRAME", logger=None) as t1:
        logger.info(attention(t1.name))
        selector = SelectorWithVars(**config)
        chain.parallel_fill(selector,
                            shortcut=False,
                            use_frame=False,
                            maX_files=1)
        ds2p_1 = selector.data
    with timing("   SHORTCUT, no FRAME", logger=None) as t2:
        logger.info(attention(t2.name))
        selector = SelectorWithVars(**config)
        chain.parallel_fill(selector,
                            shortcut=True,
                            use_frame=False,
                            maX_files=1)
        ds2p_2 = selector.data
    with timing("No SHORTCUT,    FRAME", logger=None) as t3:
        logger.info(attention(t3.name))
        selector = SelectorWithVars(**config)
        chain.parallel_fill(selector,
                            shortcut=False,
                            use_frame=True,
                            max_files=1)
        ds2p_3 = selector.data
    with timing("   SHORTCUT,    FRAME", logger=None) as t4:
        logger.info(attention(t4.name))
        selector = SelectorWithVars(**config)
        chain.parallel_fill(selector,
                            shortcut=True,
                            use_frame=True,
                            max_files=1)
        ds2p_4 = selector.data

    table = [('Configuration', 'CPU')]

    table.append((t1.name, '%.3fs' % t1.delta))
    table.append((t2.name, '%.3fs' % t2.delta))
    table.append((t3.name, '%.3fs' % t3.delta))
    table.append((t4.name, '%.3fs' % t4.delta))

    title2p = "Trivial variables + CUT (parallel)"
    table2p = T.table(table, title=title2p, prefix='# ', alignment='rr')
    logger.info('%s\n%s' % (title2p, table2p))

    if ds1_1 != ds2_1:
        logger.error('Datasets ds1_1  and ds2_1   are different!')

    if ds2_1 != ds2p_1:
        logger.error('Datasets ds2_1  and ds2p_1  are different!')
    if ds2_2 != ds2p_2:
        logger.error('Datasets ds2_2  and ds2p_2  are different!')
    if ds2_3 != ds2p_3:
        logger.error('Datasets ds2_3  and ds2p_3  are different!')
    if ds2_4 != ds2p_4:
        logger.error('Datasets ds2_4  and ds2p_4  are different!')

    # =========================================================================
    logger.info(attention('Non-trivial variables'))
    # =========================================================================

    if not DILL_PY3_issue:

        variables = [
            Variable(mJPsi, accessor='mass'),
            Variable('massMeV', 'mass in MeV', 3000, 3200, 'mass*1000'),
            Variable('vv102', 'vv10[2]', -1, 100, '1.0*vv10[2]'),
            Variable('fevt', accessor='1.0*evt'), ('pt', ), ('eta', ),
            ('x', 'some variable', 0, 5000, lambda s:
             (s.mass + s.pt + s.eta) / s.eta)
        ]

    else:

        logger.warning('There is an issue with dill+python3: avoid lambda!')
        variables = [
            Variable(mJPsi, accessor='mass'),
            Variable('massMeV', 'mass in MeV', 3000, 3200, 'mass*1000'),
            Variable('vv102', 'vv10[2]', -1, 100, '1.0*vv10[2]'),
            Variable('fevt', accessor='1.0*evt'), ('pt', ), ('eta', ),
            ('x', 'some variable', 0, 5000, xvar)
        ]

    config = {'variables': variables, 'selection': "pt>7 && eta<3"}

    with timing("No SHORTCUT, no FRAME", logger=None) as t1:
        logger.info(attention(t1.name))
        selector = SelectorWithVars(**config)
        chain.fill_dataset(selector, shortcut=False, use_frame=False)
        ds3_1 = selector.data
    with timing("   SHORTCUT, no FRAME", logger=None) as t2:
        logger.info(attention(t2.name))
        selector = SelectorWithVars(**config)
        chain.fill_dataset(selector, shortcut=True, use_frame=False)
        ds3_2 = selector.data
    with timing("No SHORTCUT,    FRAME", logger=None) as t3:
        logger.info(attention(t3.name))
        selector = SelectorWithVars(**config)
        chain.fill_dataset(selector, shortcut=False, use_frame=True)
        ds3_3 = selector.data
    with timing("   SHORTCUT,    FRAME", logger=None) as t4:
        logger.info(attention(t4.name))
        selector = SelectorWithVars(**config)
        chain.fill_dataset(selector, shortcut=True, use_frame=True)
        ds3_4 = selector.data

    table = [('Configuration', 'CPU')]

    table.append((t1.name, '%.3fs' % t1.delta))
    table.append((t2.name, '%.3fs' % t2.delta))
    table.append((t3.name, '%.3fs' % t3.delta))
    table.append((t4.name, '%.3fs' % t4.delta))

    title3 = "Non-trivial variables"
    table3 = T.table(table, title=title3, prefix='# ', alignment='rr')
    logger.info('%s\n%s' % (title3, table3))

    if ds1_1 != ds3_1:
        logger.error('Datasets ds1_1  and ds3_1   are different!')

    if ds3_1 != ds3_2:
        logger.error('Datasets ds3_1  and ds2_2   are different!')
    if ds3_1 != ds3_3:
        logger.error('Datasets ds3_1  and ds2_3   are different!')
    if ds3_1 != ds3_4:
        logger.error('Datasets ds3_1  and ds2_4   are different!')

    with timing("No SHORTCUT, no FRAME", logger=None) as t1:
        logger.info(attention(t1.name))
        selector = SelectorWithVars(**config)
        chain.parallel_fill(selector,
                            shortcut=False,
                            use_frame=False,
                            max_files=1)
        ds3p_1 = selector.data
    with timing("   SHORTCUT, no FRAME", logger=None) as t2:
        logger.info(attention(t2.name))
        selector = SelectorWithVars(**config)
        chain.parallel_fill(selector,
                            shortcut=True,
                            use_frame=False,
                            max_files=1)
        ds3p_2 = selector.data
    with timing("No SHORTCUT,    FRAME", logger=None) as t3:
        logger.info(attention(t3.name))
        selector = SelectorWithVars(**config)
        chain.parallel_fill(selector,
                            shortcut=False,
                            use_frame=True,
                            max_files=1)
        ds3p_3 = selector.data
    with timing("   SHORTCUT,    FRAME", logger=None) as t4:
        logger.info(attention(t4.name))
        selector = SelectorWithVars(**config)
        chain.parallel_fill(selector,
                            shortcut=True,
                            use_frame=True,
                            max_files=1)
        ds3p_4 = selector.data

    table = [('Configuration', 'CPU')]

    table.append((t1.name, '%.3fs' % t1.delta))
    table.append((t2.name, '%.3fs' % t2.delta))
    table.append((t3.name, '%.3fs' % t3.delta))
    table.append((t4.name, '%.3fs' % t4.delta))

    title3p = "Non-trivial variables (parallel)"
    table3p = T.table(table, title=title3p, prefix='# ', alignment='rr')
    logger.info('%s\n%s' % (title3p, table3p))

    if ds3_1 != ds3p_1:
        logger.error('Datasets ds3_1  and ds3p_1  are different!')
    if ds3_2 != ds3p_2:
        logger.error('Datasets ds3_2  and ds3p_2  are different!')
    if ds3_3 != ds3p_3:
        logger.error('Datasets ds3_3  and ds3p_3  are different!')
    if ds3_4 != ds3p_4:
        logger.error('Datasets ds3_4  and ds3p_4  are different!')

    # =========================================================================
    logger.info(attention('Non-trivial variables + CUT'))
    # =========================================================================

    if not DILL_PY3_issue:

        variables = [
            Variable(mJPsi, accessor='mass'),
            Variable('massMeV', 'mass in MeV', 3000, 3200, 'mass*1000'),
            Variable('vv102', 'vv10[2]', -1, 100, '1.0*vv10[2]'),
            Variable('fevt', accessor='1.0*evt'), ('pt', ), ('eta', ),
            ('x', 'some variable', 0, 5000, lambda s:
             (s.mass + s.pt + s.eta) / s.eta)
        ]

    else:

        logger.warning('There is an issue with dill+python3: avoid lambda!')
        variables = [
            Variable(mJPsi, accessor='mass'),
            Variable('massMeV', 'mass in MeV', 3000, 3200, 'mass*1000'),
            Variable('vv102', 'vv10[2]', -1, 100, '1.0*vv10[2]'),
            Variable('fevt', accessor='1.0*evt'), ('pt', ), ('eta', ),
            ('x', 'some variable', 0, 5000, xvar)
        ]

    if not DILL_PY3_issue:

        config = {
            'variables': variables,
            'selection': "pt>7 && eta<3",
            'cuts': lambda s: s.pt > 3
        }  ## ATTENTION: no trivial cuts!
    else:

        logger.warning('There is an issue with dill+python3: avoid lambda!')
        config = {
            'variables': variables,
            'selection': "pt>7 && eta<3",
            'cuts': ptcut
        }  ## ATTENTION: no trivial cuts!

    with timing("No SHORTCUT, no FRAME", logger=None) as t1:
        logger.info(attention(t1.name))
        selector = SelectorWithVars(**config)
        chain.fill_dataset(selector, shortcut=False, use_frame=False)
        ds4_1 = selector.data
    with timing("   SHORTCUT, no FRAME", logger=None) as t2:
        logger.info(attention(t2.name))
        selector = SelectorWithVars(**config)
        chain.fill_dataset(selector, shortcut=True, use_frame=False)
        ds4_2 = selector.data
    with timing("No SHORTCUT,    FRAME", logger=None) as t3:
        logger.info(attention(t3.name))
        selector = SelectorWithVars(**config)
        chain.fill_dataset(selector, shortcut=False, use_frame=True)
        ds4_3 = selector.data
    with timing("   SHORTCUT,    FRAME", logger=None) as t4:
        logger.info(attention(t4.name))
        selector = SelectorWithVars(**config)
        chain.fill_dataset(selector, shortcut=True, use_frame=True)
        ds4_4 = selector.data

    table = [('Configuration', 'CPU')]

    table.append((t1.name, '%.3fs' % t1.delta))
    table.append((t2.name, '%.3fs' % t2.delta))
    table.append((t3.name, '%.3fs' % t3.delta))
    table.append((t4.name, '%.3fs' % t4.delta))

    title4 = "Non-trivial variables + CUT"
    table4 = T.table(table, title=title4, prefix='# ', alignment='rr')
    logger.info('%s\n%s' % (title4, table4))

    if ds1_1 != ds4_1:
        logger.error('Datasets ds1_1  and ds4_1   are different!')

    if ds4_1 != ds4_2:
        logger.error('Datasets ds4_1  and ds4_2   are different!')
    if ds4_1 != ds4_3:
        logger.error('Datasets ds4_1  and ds4_3   are different!')
    if ds4_1 != ds4_4:
        logger.error('Datasets ds4_1  and ds4_4   are different!')

    with timing("No SHORTCUT, no FRAME", logger=None) as t1:
        logger.info(attention(t1.name))
        selector = SelectorWithVars(**config)
        chain.parallel_fill(selector,
                            shortcut=False,
                            use_frame=False,
                            max_files=1)
        ds4p_1 = selector.data
    with timing("   SHORTCUT, no FRAME", logger=None) as t2:
        logger.info(attention(t2.name))
        selector = SelectorWithVars(**config)
        chain.parallel_fill(selector,
                            shortcut=True,
                            use_frame=False,
                            max_files=1)
        ds4p_2 = selector.data
    with timing("No SHORTCUT,    FRAME", logger=None) as t3:
        logger.info(attention(t3.name))
        selector = SelectorWithVars(**config)
        chain.parallel_fill(selector,
                            shortcut=False,
                            use_frame=True,
                            max_files=1)
        ds4p_3 = selector.data
    with timing("   SHORTCUT,    FRAME", logger=None) as t4:
        logger.info(attention(t4.name))
        selector = SelectorWithVars(**config)
        chain.parallel_fill(selector,
                            shortcut=True,
                            use_frame=True,
                            max_files=1)
        ds4p_4 = selector.data

    table = [('Configuration', 'CPU')]

    table.append((t1.name, '%.3fs' % t1.delta))
    table.append((t2.name, '%.3fs' % t2.delta))
    table.append((t3.name, '%.3fs' % t3.delta))
    table.append((t4.name, '%.3fs' % t4.delta))

    title4p = "Non-trivial variables + CUT (parallel)"
    table4p = T.table(table, title=title4p, prefix='# ', alignment='rr')
    logger.info('%s\n%s' % (title4p, table4p))

    if ds4_1 != ds4p_1:
        logger.error('Datasets ds4_1  and ds4p_1  are different!')
    if ds4_2 != ds4p_2:
        logger.error('Datasets ds4_2  and ds4p_2  are different!')
    if ds4_3 != ds4p_3:
        logger.error('Datasets ds4_3  and ds4p_3  are different!')
    if ds4_4 != ds4p_4:
        logger.error('Datasets ds4_4  and ds4p_4  are different!')

    logger.info('%s\n%s' % (title1, table1))
    logger.info('%s\n%s' % (title1p, table1p))

    logger.info('%s\n%s' % (title2, table2))
    logger.info('%s\n%s' % (title2p, table2p))

    logger.info('%s\n%s' % (title3, table3))
    logger.info('%s\n%s' % (title3p, table3p))

    logger.info('%s\n%s' % (title4, table4))
    logger.info('%s\n%s' % (title4p, table4p))
def test_addbranch():
    """Four ways to add branch into TTree/Tchain
    - using string formula (TTreeFormula-based)
    - using pure python function
    - using histogram/function
    - using histogram sampling
    """

    ## files = prepare_data ( 100 , 1000 )
    files = prepare_data(2, 100)

    logger.info('#files:    %s' % len(files))
    data = Data('S', files)
    logger.info('Initial Tree/Chain:\n%s' % data.chain.table(prefix='# '))

    # =========================================================================
    ## 1) add new branch as TTree-formula:
    # =========================================================================
    with timing('expression', logger=logger):
        chain = data.chain
        chain.add_new_branch('et', 'sqrt(pt*pt+mass*mass)')
    ## reload the chain and check:
    logger.info('With formula:\n%s' % data.chain.table(prefix='# '))
    assert 'et' in data.chain, "Branch ``et'' is  not here!"

    # =========================================================================
    ## 2) add several new branches as TTree-formula:
    # =========================================================================
    with timing('simultaneous', logger=logger):
        chain = data.chain
        chain.add_new_branch(
            {
                'Et1': 'sqrt(pt*pt+mass*mass)',
                'Et2': 'sqrt(pt*pt+mass*mass)*2',
                'Et3': 'sqrt(pt*pt+mass*mass)*3'
            }, None)
    ## reload the chain and check:
    logger.info('With formula:\n%s' % data.chain.table(prefix='# '))
    assert 'Et1' in data.chain, "Branch ``Et1'' is  not here!"
    assert 'Et2' in data.chain, "Branch ``Et2'' is  not here!"
    assert 'Et3' in data.chain, "Branch ``Et3'' is  not here!"

    # =========================================================================
    ## 2) add new branch as pure python function
    # =========================================================================
    with timing('pyfunc', logger=logger):
        et2 = lambda tree: tree.pt**2 + tree.mass**2
        chain = data.chain
        chain.add_new_branch('et2', et2)
    ## reload the chain and check:
    logger.info('With python:\n%s' % data.chain.table(prefix='# '))
    assert 'et2' in data.chain, "Branch ``et2'' is  not here!"

    # =========================================================================
    ## 3) add new branch as histogram-function
    # =========================================================================
    with timing('histo-1', logger=logger):
        h1 = ROOT.TH1D('h1', 'some pt-correction', 100, 0, 10)
        h1 += lambda x: 1.0 + math.tanh(0.2 * (x - 5))
        from ostap.trees.funcs import FuncTH1
        ptw = FuncTH1(h1, 'pt')
        chain = data.chain
        chain.add_new_branch('ptw', ptw)
    ## reload the chain and check:
    logger.info('With histogram:\n%s' % data.chain.table(prefix='# '))
    assert 'ptw' in data.chain, "Branch ``ptw'' is  not here!"

    # =========================================================================
    ## 4) add the variable sampled from the histogram
    # =========================================================================
    with timing('histo-2', logger=logger):
        h2 = ROOT.TH1D('h2', 'Gauss', 120, -6, 6)
        for i in range(100000):
            h2.Fill(random.gauss(0, 1))
        chain = data.chain
        chain.add_new_branch('hg', h2)
    ## reload the chain and check:
    logger.info('With sampled:\n%s' % data.chain.table(prefix='# '))
    assert 'hg' in data.chain, "Branch ``hg'' is  not here!"

    # =========================================================================
    ## 5) python function again
    # =========================================================================
    with timing('gauss', logger=logger):

        def gauss(*_):
            return random.gauss(0, 1)

        chain = data.chain
        chain.add_new_branch('gauss', gauss)
    ## reload the chain and check:
    logger.info('With gauss:\n%s' % data.chain.table(prefix='# '))
    assert 'gauss' in data.chain, "Branch ``gauss'' is  not here!"

    # =========================================================================
    ## 6) add numpy array
    # =========================================================================
    try:
        import numpy
    except ImportError:
        numpy = None

    if numpy:

        with timing('numpy float16', logger=logger):
            adata = numpy.full(10000, +0.1, dtype=numpy.float16)
            chain = data.chain
            chain.add_new_branch('np_f16', adata)
        ## reload the chain and check:
        logger.info('With numpy.float16:\n%s' % data.chain.table(prefix='# '))
        assert 'np_f16' in data.chain, "Branch ``np_f16'' is  not here!"

        with timing('numpy float32', logger=logger):
            adata = numpy.full(10000, -0.2, dtype=numpy.float32)
            chain = data.chain
            chain.add_new_branch('np_f32', adata)
        ## reload the chain and check:
        logger.info('With numpy.float32:\n%s' % data.chain.table(prefix='# '))
        assert 'np_f32' in data.chain, "Branch ``np_f32'' is  not here!"

        with timing('numpy float64', logger=logger):
            adata = numpy.full(10000, +0.3, dtype=numpy.float64)
            chain = data.chain
            chain.add_new_branch('np_f64', adata)
        ## reload the chain and check:
        logger.info('With numpy.float64:\n%s' % data.chain.table(prefix='# '))
        assert 'np_f64' in data.chain, "Branch ``np_f64'' is  not here!"

        with timing('numpy int8 ', logger=logger):
            adata = numpy.full(10000, -1, dtype=numpy.int8)
            chain = data.chain
            chain.add_new_branch('np_i8', adata)
        ## reload the chain and check:
        logger.info('With numpy.int8:\n%s' % data.chain.table(prefix='# '))
        assert 'np_i8' in data.chain, "Branch ``np_i8'' is  not here!"

        with timing('numpy uint8 ', logger=logger):
            adata = numpy.full(10000, +2, dtype=numpy.uint8)
            chain = data.chain
            chain.add_new_branch('np_ui8', adata)
        ## reload the chain and check:
        logger.info('With numpy.uint8:\n%s' % data.chain.table(prefix='# '))
        assert 'np_ui8' in data.chain, "Branch ``np_ui8'' is  not here!"

        with timing('numpy int16 ', logger=logger):
            adata = numpy.full(10000, -3, dtype=numpy.int16)
            chain = data.chain
            chain.add_new_branch('np_i16', adata)
        ## reload the chain and check:
        logger.info('With numpy.int16:\n%s' % data.chain.table(prefix='# '))
        assert 'np_i16' in data.chain, "Branch ``np_i16'' is  not here!"

        with timing('numpy uint16 ', logger=logger):
            adata = numpy.full(10000, +4, dtype=numpy.uint16)
            chain = data.chain
            chain.add_new_branch('np_ui16', adata)
        ## reload the chain and check:
        logger.info('With numpy.uint16:\n%s' % data.chain.table(prefix='# '))
        assert 'np_ui16' in data.chain, "Branch ``np_ui16'' is  not here!"

        with timing('numpy int32 ', logger=logger):
            adata = numpy.full(10000, -5, dtype=numpy.int32)
            chain = data.chain
            chain.add_new_branch('np_i32', adata)
        ## reload the chain and check:
        logger.info('With numpy.int32:\n%s' % data.chain.table(prefix='# '))
        assert 'np_i32' in data.chain, "Branch ``np_i32'' is  not here!"

        with timing('numpy uint32 ', logger=logger):
            adata = numpy.full(10000, +6, dtype=numpy.uint32)
            chain = data.chain
            chain.add_new_branch('np_ui32', adata)
        ## reload the chain and check:
        logger.info('With numpy.uint32:\n%s' % data.chain.table(prefix='# '))
        assert 'np_ui32' in data.chain, "Branch ``np_ui32'' is  not here!"

        with timing('numpy int64 ', logger=logger):
            adata = numpy.full(10000, -7, dtype=numpy.int64)
            chain = data.chain
            chain.add_new_branch('np_i64', adata)
        ## reload the chain and check:
        logger.info('With numpy.int64:\n%s' % data.chain.table(prefix='# '))
        assert 'np_i64' in data.chain, "Branch ``np_i64'' is  not here!"

        with timing('numpy uint64 ', logger=logger):
            adata = numpy.full(10000, +8, dtype=numpy.uint64)
            chain = data.chain
            chain.add_new_branch('np_ui64', adata)
        ## reload the chain and check:
        logger.info('With numpy.uint64:\n%s' % data.chain.table(prefix='# '))
        assert 'np_ui64' in data.chain, "Branch ``np_ui64'' is  not here!"

    for l, v in (('f', +100.1), ('d', -200.2), ('i', -3), ('l', -4), ('I', 5),
                 ('L', 6), ('h', 7), ('H', 8)):

        with timing('array %s' % l, logger=logger):
            adata = array.array(l, 10000 * [v])
            chain = data.chain
            vname = 'arr_%s' % l
            chain.add_new_branch(vname, adata)
            ## reload the chain and check:
        logger.info("With array '%s':\n%s" %
                    (l, data.chain.table(prefix='# ')))
        assert vname in data.chain, "Branch ``%s'' is  not here!" % vname