/
trackRecorder.py
285 lines (245 loc) · 12.1 KB
/
trackRecorder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
#
# For recording tracks
#
import ast
import sqlite3
from numpy import *
import minimizeAlgorithm
from sqlUtils import queryDBGivenParams
MAYAVI_OFF=True
if not MAYAVI_OFF:
import mayavi.mlab as mlab
columnsString=''' (pointNum int, point text, gradDict text, projGrad text, normProj real,
eos text, rollMid real, rollScale real, a real, T real,
omega_c real, J real, gravMass real, edMax real, baryMass real,
ToverW real, arealR real, VoverC real, omg_c_over_Omg_c real,
rpoe real, Z_p real, Z_b real, Z_f real, h_direct real,
h_retro real, e_over_m real, shed real,
RedMax real, propRe real, runType int, runID text, lineNum int) '''
# Omega_c cJ/GMs^2 M/Ms Eps_c Mo/Ms T/W R_c v/c omg_c/Omg_c rp Z_p Z_b Z_f h-direct h-retro e/m Shed RedMax
metadataColumns='''(independentVars, minimizedFunc, fixedFuncs, derivName, deltas,
stationaryParamsDict, maxSteps, changeBasis)'''
class trackRecorder(object):
dbFile=None
trackTableName=None
dbConnection=None
pointNumber=0
doRecording=True
independentVars=None
trackMetaTable = None
def __init__(self,trackTableName,dbConnection,independentVars):
assert isinstance(trackTableName,str)
assert isinstance(dbConnection, sqlite3.Connection)
assert isinstance(independentVars,tuple)
assert isinstance(independentVars[0],str)
print "Initializing track recorder."
self.trackTableName=trackTableName
self.dbConnection=dbConnection
self.independentVars=independentVars
self.trackMetaTable = trackTableName + "Metadata"
#Check if table already exists
curs = self.dbConnection.cursor()
curs.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='%s'" % trackTableName)
result=curs.fetchall()
if not result:
print "Track table '%s' not found in db; creating now" % trackTableName
curs.execute("CREATE TABLE %s %s" %(trackTableName,columnsString) )
else:
print "Track table already exists, turning off recording!"
#self.doRecording = False
self.dbConnection.commit()
return
def recordTrackMetadata(self,independentVars,funcName,fixedNames,firstDeriv,deltas,
stationaryParamsDict, maxSteps, changeBasis):
curs = self.dbConnection.cursor()
curs.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='%s'" % self.trackMetaTable)
result=curs.fetchall()
if not result:
print "Track metadata table '%s' not found in db; creating now" % self.trackMetaTable
curs.execute("CREATE TABLE %s %s" %(self.trackMetaTable,metadataColumns) )
else:
print "Track metadata table already exists, returning without recording metadata!"
return 1
entry=[str(independentVars).replace(" ",""),
funcName, str(fixedNames).replace(" ",""),
firstDeriv, str(tuple(deltas)).replace(" ",""),
str(stationaryParamsDict).replace(" ",""),
maxSteps,str(changeBasis)]
print entry
print tuple(entry)
print str(tuple(entry))
curs.execute("INSERT INTO "+self.trackMetaTable+" VALUES "+ str(tuple(entry)) )
self.dbConnection.commit()
return 0
def record(self,point,gradientDict,projectedGradFunc,normAfterProjection, modelsTable):
if not self.doRecording:
return 1
assert len(point) == len(self.independentVars), "Point length doesnt match # independent vars!"
roundArray=frompyfunc(round,2,1)
point = tuple(roundArray(point,10))
pointAsStr= str(point).replace(" ","")
gradDictAsStr =str(gradientDict).replace(" ","")
projGradFuncAsStr=str(tuple(projectedGradFunc)).replace(" ","")
queryDict = {self.independentVars[i]:point[i] for i in range(len(point))}
cursor=self.dbConnection.cursor()
print "Fetching point entry data from database '%s' for recording" % modelsTable
entryData = list(queryDBGivenParams("*",queryDict,cursor,modelsTable)[0])
#Stupid table strings are recorded in type unicode and mess up upon re-insertion...
for i,ent in enumerate(entryData):
if isinstance(ent,unicode):
entryData[i]=str(ent)
entry = [self.pointNumber, pointAsStr,gradDictAsStr,projGradFuncAsStr,normAfterProjection]
entry = entry + entryData
#print str(tuple(entry))
existingPoint = queryDBGivenParams("pointNum",{'pointNum':self.pointNumber},
cursor,self.trackTableName)
if existingPoint :
print "POINT NUM %s EXISTS IN RECORD TABLE %s" % (existingPoint[0],self.trackTableName)
else:
print "Inserting new point into record table '%s'" % self.trackTableName
cursor.execute("INSERT INTO "+self.trackTableName+" VALUES "+ str(tuple(entry)) )
self.pointNumber+=1
print
self.dbConnection.commit()
return
def returnGradientXsYsZsForPlot(gradDictList,grad,normalize=True):
xs,ys,zs = [],[],[]
for i,thisPointsGrads in enumerate(gradDictList):
desiredGrad=array(thisPointsGrads[grad])
if normalize:
desiredGrad = desiredGrad/minimizeAlgorithm.norm(desiredGrad)
xs.append(desiredGrad[0])
ys.append(desiredGrad[1])
zs.append(desiredGrad[2])
return xs,ys,zs
class trackPlotter(object):
dbFilenames=[]
trackTableName="track"
trackVars=()
trackData=[]
trackMetaTable="trackMetadata"
maxLabelLength= 9
def __init__(self,dbFilenames,trackTableName,trackVars):
assert isinstance(trackVars,tuple)
self.dbFilenames=dbFilenames
#TODO: what if track tablenames differ between dbNames!
self.trackTableName=trackTableName
self.trackMetaTable=trackTableName+"Metadata"
self.trackVars=trackVars
for file in self.dbFilenames:
thisFilesData={key:[] for key in trackVars}
connection=sqlite3.connect(file)
metadataDict=self.readTrackMetadata(connection)
thisFilesData.update(metadataDict)
c=connection.cursor()
rawData=queryDBGivenParams(["pointNum","point","gradDict","projGrad","normProj"]+list(trackVars),
{},c,self.trackTableName,(), " ORDER BY pointNum" )
pointList=[]
gradDictList=[]
projGradList=[]
normProjList=[]
for entry in rawData:
#print entry
point = ast.literal_eval( entry[1] )
gradDict = ast.literal_eval( entry[2])
projGrad = ast.literal_eval( entry[3])
normProj = entry[4]
gradDictList.append(gradDict)
projGradList.append(projGrad)
normProjList.append(normProj)
for i,key in enumerate(trackVars):
thisFilesData[key].append( entry[ i + 5]) #since there are 5 variables before the trackVars
pointList.append(point)
thisFilesData.update({'points':pointList,'gradDicts':gradDictList,
'projGrads':projGradList, 'normProjs':normProjList})
print thisFilesData
self.trackData.append(thisFilesData)
def readTrackMetadata(self,dbConnection):
c = dbConnection.cursor()
c.execute("SELECT * FROM " +self.trackMetaTable )
rawMetadata=c.fetchall()
assert len(rawMetadata) < 2, "Metadata table shouldn't have multiple entries!"
rawMetadata=list(rawMetadata[0])
#print rawMetadata
keys = metadataColumns.strip('()').replace(" ","").replace('\n',"").split(',')
metadataDict={}
for i,key in enumerate(keys):
if isinstance(rawMetadata[i],unicode):
rawMetadata[i]=str(rawMetadata[i])
value = None
try:
value = ast.literal_eval(rawMetadata[i])
except ValueError:
value = rawMetadata[i]
#Weird syntax error in ast occurs if the literal is a string that begins with a number
except SyntaxError:
value = ast.literal_eval("'" + rawMetadata[i] + "'")
metadataDict[key] = value
dbConnection.commit()
return metadataDict
if not MAYAVI_OFF:
def trackPlotter(self,plotVars,plotGradients=None,plotProjGrad=False):
'''Gradients must be one of one of minimizedFunc or fixedFuncs as these are only ones
Stored in gradient dict
'''
assert len(plotVars) == 3, "Track plotter requires 3 plotVars for 3D track! You gave: %s" % len(plotVars)
for i,track in enumerate(self.trackData):
stepScale = minimizeAlgorithm.norm(array(track['deltas']))
pointList=zip(*track['points'])
xs_ys_zs=[]
#Here we search through the available data and add the correct data to plot to xs_ys_zs
for plotVar in plotVars:
gotItFlag=False
for pointIndex,pointVar in enumerate(track['independentVars']):
if pointVar==plotVar:
print 'got ind var %s' % plotVar
xs_ys_zs.append(pointList[pointIndex])
gotItFlag=True
break
if gotItFlag:
continue
for availData in track.keys():
if availData == plotVar:
print 'got other track var %s' % plotVar
gotItFlag=True
xs_ys_zs.append(track[plotVar])
break
assert gotItFlag, "uh oh didn't find our variable to plot '%s'" % plotVar
print 'redmax: ', track['RedMax']
print 'baryMass: ', track['baryMass']
mlab.plot3d(*xs_ys_zs,
color=(1-(1./(i%3+1)),1,1./(i%2+1.)),
reset_zoom=False,
tube_radius=None)
if plotGradients:
if plotVars == track['independentVars']:
print "Mkay good, you're plotting gradients in same space as your tracks"
pass
else:
assert False, "Bad! WARNING your gradients are not in the same space as your tracks!!"
if not isinstance(plotGradients,tuple):
plotGradients = [plotGradients]
for j,grad in enumerate(plotGradients):
vxs,vys,vzs= returnGradientXsYsZsForPlot(track['gradDicts'],grad)
print vxs
print vys
print vzs
thisColor=( (j%2)/2.+0.5,1./(j+1.3),((j+1)%2.)/1.7 )
mlab.quiver3d(xs_ys_zs[0],xs_ys_zs[1],xs_ys_zs[2],
vxs,vys,vzs,
color=thisColor,
scale_factor=stepScale )
label = grad + " " * ( self.maxLabelLength - len(grad) )
mlab.text(0.01,0.1 + j*0.2, label, color=thisColor, width=0.1)
if plotProjGrad:
assert plotVars== track['independentVars']
vxs,vys,vzs=zip(*track['projGrads'])
mlab.quiver3d(xs_ys_zs[0],xs_ys_zs[1],xs_ys_zs[2],
vxs,vys,vzs,
color=(1,1,1),
scale_factor=stepScale )
mlab.text(0.5,0.1 , "projGrad", color=(1,1,1), width=0.1)
# Below move doesnt work
#print mlab.move(xs_ys_zs[0][0],xs_ys_zs[1][0],xs_ys_zs[2][0]) # move camera to first point
mlab.show()
return 0