Example #1
0
class PredictCallback(Callback):
    def __init__(
            self,
            samplefile,
            function_to_apply=None,  #needs to be function(counter,[model_input], [predict_output], [truth])
            after_n_batches=50,
            on_epoch_end=False,
            use_event=0,
            decay_function=None):
        super(PredictCallback, self).__init__()
        self.samplefile = samplefile
        self.function_to_apply = function_to_apply
        self.counter = 0
        self.call_counter = 0
        self.decay_function = decay_function

        self.after_n_batches = after_n_batches
        self.run_on_epoch_end = on_epoch_end

        if self.run_on_epoch_end and self.after_n_batches >= 0:
            print(
                'PredictCallback: can only be used on epoch end OR after n batches, falling back to epoch end'
            )
            self.after_n_batches = 0

        self.td = TrainData()
        self.td.readIn(samplefile)
        if use_event >= 0:
            self.td.skim(event=use_event)

    def on_train_begin(self, logs=None):
        pass

    def reset(self):
        self.call_counter = 0

    def predict_and_call(self, counter):

        predicted = self.model.predict(self.td.x)
        if not isinstance(predicted, list):
            predicted = [predicted]

        self.function_to_apply(self.call_counter, self.td.x, predicted,
                               self.td.y)
        self.call_counter += 1

    def on_epoch_end(self, epoch, logs=None):
        self.counter = 0
        if self.decay_function is not None:
            self.after_n_batches = self.decay_function(self.after_n_batches)
        if not self.run_on_epoch_end: return
        self.predict_and_call(epoch)

    def on_batch_end(self, batch, logs=None):
        if self.after_n_batches <= 0: return
        self.counter += 1
        if self.counter > self.after_n_batches:
            self.counter = 0
            self.predict_and_call(batch)
Example #2
0
        print('reading truth')

        truth, _ = readListArray(
            filename=infile,
            treename="Delphes",
            branchname="layercluster_simcluster_fractions",
            nevents=nentries,
            list_size=3500,
            n_feat_per_element=20,
            zeropad=True,
            list_size_cut=True)

elif 'meta' == infile[-4:]:
    from DeepJetCore.TrainData import TrainData
    td = TrainData()
    td.readIn(infile)
    features = td.x[0]
    truth = td.y[0][:, :, 0:-1]  #cut off energy
    nentries = len(td.x[0])

#B x V x F

name = "brg"

print('plotting')


def remove_zero_energy(a, e):
    return a[e > 0.01]

Example #3
0
#!/usr/bin/env python


from argparse import ArgumentParser
parser = ArgumentParser('Browse entries in a DeepHGCal TrainData file. Assumes standard ordering (3D image as second entry)')
parser.add_argument('inputFile')
args = parser.parse_args()

import matplotlib.pyplot as plt
from plotting import plot4d

from DeepJetCore.TrainData import TrainData
import mpl_toolkits.mplot3d.art3d as a3d

td=TrainData()
td.readIn(args.inputFile)
x_chmap=td.x[1]
del td
nentries=x_chmap.shape[0]
ncolors=x_chmap[0].shape[3]

xcenter=x_chmap[0].shape[0]/2
xmax=x_chmap[0].shape[0]
ycenter=x_chmap[0].shape[1]/2
ymax=x_chmap[0].shape[1]
zcenter=x_chmap[0].shape[2]/2

print(ncolors)

for i in range(nentries):
    print(x_chmap[i].shape)
Example #4
0
#monitored during the training) or with callbacks that make plots (still under development for GANs)
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


def plotgrid(in_array, nplotsx, nplotsy, outname):
    fig, ax = plt.subplots(nplotsy, nplotsx)
    counter = 0
    for i in range(nplotsy):
        for j in range(nplotsx):
            ax[i][j].imshow(in_array[counter, :, :, 0])
            counter += 1
    fig.savefig(outname)
    plt.close()


import numpy as np
from DeepJetCore.TrainData import TrainData
td = TrainData()
td.readIn("/eos/home-j/jkiesele/DeepNtuples/GraphGAN_test/test/out_800.meta")
x_gen = train.generator.predict(td.x)
forplots = np.concatenate([x_gen[0][:4], td.x[0][:4]], axis=0)
plotgrid(forplots,
         nplotsx=4,
         nplotsy=2,
         outname=train.outputDir + "comparison.pdf")

exit()