Beispiel #1
0
import sys
from xgbooster import XGBooster


if __name__ == '__main__':
    # parsing command-line options
    options = Options(sys.argv)

    # making output unbuffered
    if sys.version_info.major == 2:
        sys.stdout = os.fdopen(sys.stdout.fileno(), 'w', 0)

    xgb = XGBooster(options, from_model='../temp/compas_data/compas_data_nbestim_50_maxdepth_3_testsplit_0.2.mod.pkl')

    # encode it and save the encoding to another file
    xgb.encode()

    xgb2 = copy.deepcopy(xgb)

    with open('../bench/fairml/compas/compas.samples', 'r') as fp:
        lines = fp.readlines()

    # timers
    ltimes = []
    vtimes = []
    ftimes = []
    etimes = []

    tested = set()
    errors = []
    reduced = 0
Beispiel #2
0
def enumerate_all(options, xtype, xnum, smallest, usecld, usemhs, useumcs,
                  prefix):
    # setting the right preferences
    options.xtype = xtype
    options.xnum = xnum
    options.reduce = 'lin'
    options.smallest = smallest
    options.usecld = usecld
    options.usemhs = usemhs
    options.useumcs = useumcs

    # reading all unique samples
    with open('../bench/nlp/spam/quant/spam10.samples', 'r') as fp:
        lines = fp.readlines()

    # timers and other variables
    times, calls = [], []
    xsize, exlen = [], []

    # doing everything incrementally is expensive;
    # let's restart the solver for every 10% of instances
    tested = set()
    for i, s in enumerate(lines):
        if i % (len(lines) / 10) == 0:
            # creating a new XGBooster
            xgb = XGBooster(
                options,
                from_model=
                '../temp/spam10_data/spam10_data_nbestim_50_maxdepth_3_testsplit_0.2.mod.pkl'
            )

            # encode it and save the encoding to another file
            xgb.encode()

        options.explain = [float(v.strip()) for v in s.split(',')]

        if tuple(options.explain) in tested:
            continue

        tested.add(tuple(options.explain))
        print(prefix, 'sample {0}: {1}'.format(i, ','.join(s.split(','))))

        # calling anchor
        timer = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime

        expls = xgb.explain(options.explain)

        timer = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime - timer
        times.append(timer)

        print(prefix, 'expls:', expls)
        print(prefix, 'nof x:', len(expls))
        print(prefix, 'timex: {0:.2f}'.format(timer))
        print(prefix, 'calls:', xgb.x.calls)
        print(prefix, 'Msz x:', max([len(x) for x in expls]))
        print(prefix, 'msz x:', min([len(x) for x in expls]))
        print(
            prefix,
            'asz x: {0:.2f}'.format(sum([len(x) for x in expls]) / len(expls)))
        print('')

        calls.append(xgb.x.calls)
        xsize.append(sum([len(x) for x in expls]) / float(len(expls)))
        exlen.append(len(expls))

    print('')
    print('all samples:', len(lines))

    # reporting the time spent
    print('{0} total time: {1:.2f}'.format(prefix, sum(times)))
    print('{0} max time per instance: {1:.2f}'.format(prefix, max(times)))
    print('{0} min time per instance: {1:.2f}'.format(prefix, min(times)))
    print('{0} avg time per instance: {1:.2f}'.format(prefix,
                                                      sum(times) / len(times)))
    print('{0} total oracle calls: {1}'.format(prefix, sum(calls)))
    print('{0} max oracle calls per instance: {1}'.format(prefix, max(calls)))
    print('{0} min oracle calls per instance: {1}'.format(prefix, min(calls)))
    print('{0} avg oracle calls per instance: {1:.2f}'.format(
        prefix,
        float(sum(calls)) / len(calls)))
    print('{0} avg number of explanations per instance: {1:.2f}'.format(
        prefix,
        float(sum(exlen)) / len(exlen)))
    print('{0} avg explanation size per instance: {1:.2f}'.format(
        prefix,
        float(sum(xsize)) / len(xsize)))
    print('')
Beispiel #3
0
            xgb = XGBooster(options, from_data=data)
            train_accuracy, test_accuracy, model = xgb.train()

        # read a sample from options.explain
        if options.explain:
            options.explain = [
                float(v.strip()) for v in options.explain.split(',')
            ]

        if options.encode:
            if not xgb:
                xgb = XGBooster(options, from_model=options.files[0])

            # encode it and save the encoding to another file
            xgb.encode(test_on=options.explain)

        if options.explain:
            if not xgb:
                # abduction-based approach requires an encoding
                xgb = XGBooster(options, from_encoding=options.files[0])
            if (options.encode == "ortools"):
                expl = xgb.explain_ortools(options.explain)
            else:
                # exp0lain using anchor or the abduction-based approach
                expl = xgb.explain(options.explain)

            # here we take only first explanation if case enumeration was done
            if options.xnum != 1:
                expl = expl[0]
Beispiel #4
0
            if count > 1:
                nof_insts = min(int(count), len(insts))
            else:
                nof_insts = min(int(len(insts) * count), len(insts))
            print(f'considering {nof_insts} instances')

        base = os.path.splitext(os.path.basename(data))[0]
        mfile = 'temp/{0}/{0}_nbestim_{1}_maxdepth_{2}_testsplit_0.2.mod.pkl'.format(
            base, num, adepth)

        slog = open(f'results/smt/{base}.log', 'w')
        mlog = open(f'results/mx/{base}.log', 'w')

        # creating booster objects
        sxgb = XGBooster(soptions, from_model=mfile)
        sxgb.encode(test_on=None)
        mxgb = XGBooster(moptions, from_model=mfile)
        mxgb.encode(test_on=None)

        stimes = []
        mtimes = []
        mcalls = []
        smem = []
        mxmem = []

        #with open("/tmp/texture.samples", 'r') as fp:
        #    insts = [line.strip() for line in fp.readlines()]

        for i, inst in enumerate(insts):
            if i == nof_insts:
                break