Example #1
0
    def test_slice(self):
        print('TestTrainData: skim')
        a = self.createSimpleArray('int32', 600)
        b = self.createSimpleArray('float32', 600)
        d = self.createSimpleArray('float32', 600)

        a_slice = a.getSlice(2, 3)
        b_slice = b.getSlice(2, 3)
        d_slice = d.getSlice(2, 3)

        td = TrainData()
        td._store([a, b], [d], [])
        td_slice = td.getSlice(2, 3)

        fl = td_slice.transferFeatureListToNumpy(False)
        tl = td_slice.transferTruthListToNumpy(False)
        a_tdslice = SimpleArray(fl[0], fl[1])
        b_tdslice = SimpleArray(fl[2], fl[3])
        d_tdslice = SimpleArray(tl[0], tl[1])

        self.assertEqual(a_slice, a_tdslice)
        self.assertEqual(b_slice, b_tdslice)
        self.assertEqual(d_slice, d_tdslice)

        #test skim
        td.skim(2)
        fl = td.transferFeatureListToNumpy(False)
        tl = td.transferTruthListToNumpy(False)
        a_tdslice = SimpleArray(fl[0], fl[1])
        b_tdslice = SimpleArray(fl[2], fl[3])
        d_tdslice = SimpleArray(tl[0], tl[1])

        self.assertEqual(a_slice, a_tdslice)
        self.assertEqual(b_slice, b_tdslice)
        self.assertEqual(d_slice, d_tdslice)
Example #2
0
class PredictCallback(Callback):
    def __init__(
            self,
            samplefile,
            function_to_apply=None,  #needs to be function(counter,[model_input], [predict_output], [truth])
            after_n_batches=50,
            on_epoch_end=False,
            use_event=0,
            decay_function=None):
        super(PredictCallback, self).__init__()
        self.samplefile = samplefile
        self.function_to_apply = function_to_apply
        self.counter = 0
        self.call_counter = 0
        self.decay_function = decay_function

        self.after_n_batches = after_n_batches
        self.run_on_epoch_end = on_epoch_end

        if self.run_on_epoch_end and self.after_n_batches >= 0:
            print(
                'PredictCallback: can only be used on epoch end OR after n batches, falling back to epoch end'
            )
            self.after_n_batches = 0

        self.td = TrainData()
        self.td.readIn(samplefile)
        if use_event >= 0:
            self.td.skim(event=use_event)

    def on_train_begin(self, logs=None):
        pass

    def reset(self):
        self.call_counter = 0

    def predict_and_call(self, counter):

        predicted = self.model.predict(self.td.x)
        if not isinstance(predicted, list):
            predicted = [predicted]

        self.function_to_apply(self.call_counter, self.td.x, predicted,
                               self.td.y)
        self.call_counter += 1

    def on_epoch_end(self, epoch, logs=None):
        self.counter = 0
        if self.decay_function is not None:
            self.after_n_batches = self.decay_function(self.after_n_batches)
        if not self.run_on_epoch_end: return
        self.predict_and_call(epoch)

    def on_batch_end(self, batch, logs=None):
        if self.after_n_batches <= 0: return
        self.counter += 1
        if self.counter > self.after_n_batches:
            self.counter = 0
            self.predict_and_call(batch)
Example #3
0
    def __init__(
            self,
            samplefile,
            function_to_apply=None,  #needs to be function(counter,[model_input], [predict_output], [truth])
            after_n_batches=50,
            batchsize=10,
            on_epoch_end=False,
            use_event=0,
            decay_function=None,
            offset=0):
        super(PredictCallback, self).__init__()
        self.samplefile = samplefile
        self.function_to_apply = function_to_apply
        self.counter = 0
        self.call_counter = offset
        self.decay_function = decay_function

        self.after_n_batches = after_n_batches
        self.run_on_epoch_end = on_epoch_end

        if self.run_on_epoch_end and self.after_n_batches >= 0:
            print(
                'PredictCallback: can only be used on epoch end OR after n batches, falling back to epoch end'
            )
            self.after_n_batches = 0

        td = TrainData()
        td.readFromFile(samplefile)
        if use_event >= 0:
            td.skim(use_event)

        self.batchsize = 1
        self.td = td
        self.gen = trainDataGenerator()
        self.gen.setBatchSize(batchsize)
        self.gen.setSkipTooLargeBatches(False)
Example #4
0
args = parser.parse_args()

import DeepJetCore
from keras.models import load_model
from DeepJetCore.compiled.c_trainDataGenerator import trainDataGenerator
from DeepJetCore.evaluation import predict_from_TrainData
from DeepJetCore.customObjects import get_custom_objects
from DeepJetCore.TrainData import TrainData
import matplotlib.pyplot as plt
from ragged_plotting_tools import make_cluster_coordinates_plot, make_original_truth_shower_plot
from index_dicts import create_index_dict, create_feature_dict

td = TrainData()
td.readFromFile(args.i)
td.skim(int(args.e))
#td=td.split(int(args.e)+1)#get the first e+1 elements
#if int(args.e)>0:
#    td.split(1) #reduce to the last element (the e'th one)

model = load_model(args.inputModel, custom_objects=get_custom_objects())

predicted = predict_from_TrainData(model, td, batchsize=100000)

pred = predicted[0]
feat = td.transferFeatureListToNumpy()
rs = feat[1]
feat = feat[0]
#weights = td.transferWeightListToNumpy()
truth = td.transferTruthListToNumpy()[0]
td.clear()
    print('setting buffer')
    gen.setBuffer(td)
    print('done setting buffer')

    #print("expect: 5, 4, 4, 5, 4, 4 = 6 ")
    nbatches = gen.getNBatches()
    #print(nbatches)
    print("gen.getNTotal()", gen.getNTotal())  # <- bullshit
    print("gen.getNBatches()", gen.getNBatches())

    #gen.debug=True
    i_e = 0
    for i in range(nbatches):
        print("batch", i, "is last ", gen.lastBatch())
        d = gen.getBatch()
        nelems = d.nElements()
        if expected_here[i_e] is None:
            i_e += 1
        assert expected_here[i_e] == nelems
        i_e += 1

print('checking 1 example generator')

td.skim(0)
gen = trainDataGenerator()
gen.setBuffer(td)
d = gen.getBatch()
nelems = d.nElements()

assert nelems == 1