Пример #1
0
    def test_slice(self):
        print('TestTrainData: skim')
        a = self.createSimpleArray('int32', 600)
        b = self.createSimpleArray('float32', 600)
        d = self.createSimpleArray('float32', 600)

        a_slice = a.getSlice(2, 3)
        b_slice = b.getSlice(2, 3)
        d_slice = d.getSlice(2, 3)

        td = TrainData()
        td._store([a, b], [d], [])
        td_slice = td.getSlice(2, 3)

        fl = td_slice.transferFeatureListToNumpy(False)
        tl = td_slice.transferTruthListToNumpy(False)
        a_tdslice = SimpleArray(fl[0], fl[1])
        b_tdslice = SimpleArray(fl[2], fl[3])
        d_tdslice = SimpleArray(tl[0], tl[1])

        self.assertEqual(a_slice, a_tdslice)
        self.assertEqual(b_slice, b_tdslice)
        self.assertEqual(d_slice, d_tdslice)

        #test skim
        td.skim(2)
        fl = td.transferFeatureListToNumpy(False)
        tl = td.transferTruthListToNumpy(False)
        a_tdslice = SimpleArray(fl[0], fl[1])
        b_tdslice = SimpleArray(fl[2], fl[3])
        d_tdslice = SimpleArray(tl[0], tl[1])

        self.assertEqual(a_slice, a_tdslice)
        self.assertEqual(b_slice, b_tdslice)
        self.assertEqual(d_slice, d_tdslice)
Пример #2
0
 def sub_test_store(self, readWrite):
     td = TrainData()
     x,y,w = self.createSimpleArray('int32'), self.createSimpleArray('float32'), self.createSimpleArray('int32')
     x_orig=x.copy()
     x2,y2,_ = self.createSimpleArray('float32'), self.createSimpleArray('float32'), self.createSimpleArray('int32')
     x2_orig=x2.copy()
     y_orig=y.copy()
     
     td._store([x,x2], [y,y2], [w])
     
     if readWrite:
         td.writeToFile("testfile.tdjctd")
         td = TrainData()
         td.readFromFile("testfile.tdjctd")
         os.system('rm -f testfile.tdjctd')
     
     shapes = td.getNumpyFeatureShapes()
     self.assertEqual([[3, 5, 6], [1], [3, 5, 6], [1]], shapes,"shapes")
     
     self.assertEqual(2, td.nFeatureArrays())
     self.assertEqual(2, td.nTruthArrays())
     self.assertEqual(1, td.nWeightArrays())
     
     f = td.transferFeatureListToNumpy(False)
     t = td.transferTruthListToNumpy(False)
     w = td.transferWeightListToNumpy(False)
     
     xnew = SimpleArray(f[0],np.array(f[1],dtype='int64'))
     self.assertEqual(x_orig, xnew)
     
     xnew = SimpleArray(f[2],np.array(f[3],dtype='int64'))
     self.assertEqual(x2_orig, xnew)
     
     ynew = SimpleArray(t[0],np.array(t[1],dtype='int64'))
     self.assertEqual(y_orig, ynew)
Пример #3
0
td.readFromFile(args.i)
td.skim(int(args.e))
#td=td.split(int(args.e)+1)#get the first e+1 elements
#if int(args.e)>0:
#    td.split(1) #reduce to the last element (the e'th one)

model = load_model(args.inputModel, custom_objects=get_custom_objects())

predicted = predict_from_TrainData(model, td, batchsize=100000)

pred = predicted[0]
feat = td.transferFeatureListToNumpy()
rs = feat[1]
feat = feat[0]
#weights = td.transferWeightListToNumpy()
truth = td.transferTruthListToNumpy()[0]
td.clear()

print(feat.shape)
print(truth.shape)

fig = plt.figure(figsize=(10, 4))
ax = [fig.add_subplot(1, 2, 1, projection='3d'), fig.add_subplot(1, 2, 2)]

data = create_index_dict(truth, pred, usetf=False)
feats = create_feature_dict(feat)

make_cluster_coordinates_plot(
    plt,
    ax[1],
    data['truthHitAssignementIdx'],  #[ V ]
Пример #4
0
import numpy as np
from DeepJetCore.TrainData import TrainData
from argparse import ArgumentParser
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches
import math
from numba import jit
from inference import collect_condensates, make_inference_dict

td = TrainData()
#td.readFromFile("../results_partial/predictions/pred_9.djctd")
td.readFromFile("../data/test_data/9.djctd")
td.x = td.transferFeatureListToNumpy()
td.y = td.transferTruthListToNumpy()
td.z = td.transferWeightListToNumpy()

x = td.x
y = td.y
z = td.z

print(len(x))

print(x[0].shape)
print(x[1].shape)
print(x[2].shape)
#print(y.shape)
#print(z.shape)

data = make_inference_dict(td.x[0], td.x[1], td.x[2])