Example #1
0
    def test_slice(self):
        print('TestTrainData: skim')
        a = self.createSimpleArray('int32', 600)
        b = self.createSimpleArray('float32', 600)
        d = self.createSimpleArray('float32', 600)

        a_slice = a.getSlice(2, 3)
        b_slice = b.getSlice(2, 3)
        d_slice = d.getSlice(2, 3)

        td = TrainData()
        td._store([a, b], [d], [])
        td_slice = td.getSlice(2, 3)

        fl = td_slice.transferFeatureListToNumpy(False)
        tl = td_slice.transferTruthListToNumpy(False)
        a_tdslice = SimpleArray(fl[0], fl[1])
        b_tdslice = SimpleArray(fl[2], fl[3])
        d_tdslice = SimpleArray(tl[0], tl[1])

        self.assertEqual(a_slice, a_tdslice)
        self.assertEqual(b_slice, b_tdslice)
        self.assertEqual(d_slice, d_tdslice)

        #test skim
        td.skim(2)
        fl = td.transferFeatureListToNumpy(False)
        tl = td.transferTruthListToNumpy(False)
        a_tdslice = SimpleArray(fl[0], fl[1])
        b_tdslice = SimpleArray(fl[2], fl[3])
        d_tdslice = SimpleArray(tl[0], tl[1])

        self.assertEqual(a_slice, a_tdslice)
        self.assertEqual(b_slice, b_tdslice)
        self.assertEqual(d_slice, d_tdslice)
Example #2
0
 def sub_test_store(self, readWrite):
     td = TrainData()
     x,y,w = self.createSimpleArray('int32'), self.createSimpleArray('float32'), self.createSimpleArray('int32')
     x_orig=x.copy()
     x2,y2,_ = self.createSimpleArray('float32'), self.createSimpleArray('float32'), self.createSimpleArray('int32')
     x2_orig=x2.copy()
     y_orig=y.copy()
     
     td._store([x,x2], [y,y2], [w])
     
     if readWrite:
         td.writeToFile("testfile.tdjctd")
         td = TrainData()
         td.readFromFile("testfile.tdjctd")
         os.system('rm -f testfile.tdjctd')
     
     shapes = td.getNumpyFeatureShapes()
     self.assertEqual([[3, 5, 6], [1], [3, 5, 6], [1]], shapes,"shapes")
     
     self.assertEqual(2, td.nFeatureArrays())
     self.assertEqual(2, td.nTruthArrays())
     self.assertEqual(1, td.nWeightArrays())
     
     f = td.transferFeatureListToNumpy(False)
     t = td.transferTruthListToNumpy(False)
     w = td.transferWeightListToNumpy(False)
     
     xnew = SimpleArray(f[0],np.array(f[1],dtype='int64'))
     self.assertEqual(x_orig, xnew)
     
     xnew = SimpleArray(f[2],np.array(f[3],dtype='int64'))
     self.assertEqual(x2_orig, xnew)
     
     ynew = SimpleArray(t[0],np.array(t[1],dtype='int64'))
     self.assertEqual(y_orig, ynew)
Example #3
0
td.readFromFile(args.i)
td.skim(int(args.e))
#td=td.split(int(args.e)+1)#get the first e+1 elements
#if int(args.e)>0:
#    td.split(1) #reduce to the last element (the e'th one)

model = load_model(args.inputModel, custom_objects=get_custom_objects())

predicted = predict_from_TrainData(model, td, batchsize=100000)

pred = predicted[0]
feat = td.transferFeatureListToNumpy()
rs = feat[1]
feat = feat[0]
#weights = td.transferWeightListToNumpy()
truth = td.transferTruthListToNumpy()[0]
td.clear()

print(feat.shape)
print(truth.shape)

fig = plt.figure(figsize=(10, 4))
ax = [fig.add_subplot(1, 2, 1, projection='3d'), fig.add_subplot(1, 2, 2)]

data = create_index_dict(truth, pred, usetf=False)
feats = create_feature_dict(feat)

make_cluster_coordinates_plot(
    plt,
    ax[1],
    data['truthHitAssignementIdx'],  #[ V ]
Example #4
0
import numpy as np
from DeepJetCore.TrainData import TrainData
from argparse import ArgumentParser
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches
import math
from numba import jit
from inference import collect_condensates, make_inference_dict

td = TrainData()
#td.readFromFile("../results_partial/predictions/pred_9.djctd")
td.readFromFile("../data/test_data/9.djctd")
td.x = td.transferFeatureListToNumpy()
td.y = td.transferTruthListToNumpy()
td.z = td.transferWeightListToNumpy()

x = td.x
y = td.y
z = td.z

print(len(x))

print(x[0].shape)
print(x[1].shape)
print(x[2].shape)
#print(y.shape)
#print(z.shape)

data = make_inference_dict(td.x[0], td.x[1], td.x[2])