コード例 #1
0
 def sub_test_store(self, readWrite):
     td = TrainData()
     x,y,w = self.createSimpleArray('int32'), self.createSimpleArray('float32'), self.createSimpleArray('int32')
     x_orig=x.copy()
     x2,y2,_ = self.createSimpleArray('float32'), self.createSimpleArray('float32'), self.createSimpleArray('int32')
     x2_orig=x2.copy()
     y_orig=y.copy()
     
     td._store([x,x2], [y,y2], [w])
     
     if readWrite:
         td.writeToFile("testfile.tdjctd")
         td = TrainData()
         td.readFromFile("testfile.tdjctd")
         os.system('rm -f testfile.tdjctd')
     
     shapes = td.getNumpyFeatureShapes()
     self.assertEqual([[3, 5, 6], [1], [3, 5, 6], [1]], shapes,"shapes")
     
     self.assertEqual(2, td.nFeatureArrays())
     self.assertEqual(2, td.nTruthArrays())
     self.assertEqual(1, td.nWeightArrays())
     
     f = td.transferFeatureListToNumpy(False)
     t = td.transferTruthListToNumpy(False)
     w = td.transferWeightListToNumpy(False)
     
     xnew = SimpleArray(f[0],np.array(f[1],dtype='int64'))
     self.assertEqual(x_orig, xnew)
     
     xnew = SimpleArray(f[2],np.array(f[3],dtype='int64'))
     self.assertEqual(x2_orig, xnew)
     
     ynew = SimpleArray(t[0],np.array(t[1],dtype='int64'))
     self.assertEqual(y_orig, ynew)
コード例 #2
0
ファイル: play.py プロジェクト: abao1999/DRNOC
import numpy as np
from DeepJetCore.TrainData import TrainData
from argparse import ArgumentParser
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches
import math
from numba import jit
from inference import collect_condensates, make_inference_dict

td = TrainData()
#td.readFromFile("../results_partial/predictions/pred_9.djctd")
td.readFromFile("../data/test_data/9.djctd")
td.x = td.transferFeatureListToNumpy()
td.y = td.transferTruthListToNumpy()
td.z = td.transferWeightListToNumpy()

x = td.x
y = td.y
z = td.z

print(len(x))

print(x[0].shape)
print(x[1].shape)
print(x[2].shape)
#print(y.shape)
#print(z.shape)

data = make_inference_dict(td.x[0], td.x[1], td.x[2])