forked from tmramalho/smallParticleFilter
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
72 lines (66 loc) · 2.02 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
'''
Created on Jan 8, 2013
@author: tiago
'''
from polContainer import PolContainer
from yxContainer import XYContainer
from linearContainer import LinearContainer
from generateTestData import GenerateTestData
from particleFilter import ParticleFilter
from kalmanFilter import KalmanFilter
from extPlotter import extPlot
import numpy as np
if __name__ == '__main__':
test = 3
if test == 0:
x0 = np.array([0, 0.8])
oscillator = PolContainer()
params = {}
params['mu'] = 0.5
params['w'] = np.array([2.0])
params['A'] = np.array([[0]])
oscillator.setParams(params)
elif test == 1:
x0 = np.array([0, 0.8])#, 0.7, 0.5, 0, 0
oscillator = PolContainer()
elif test == 2:
x0 = np.array([6, 5.4, 6.8])
oscillator = LinearContainer()
elif test == 3:
x0 = np.array([0.3,0.5])
oscillator = XYContainer()
dataGen = GenerateTestData(oscillator, x0.size)
pltWorker = extPlot()
numSamples = 10
samplingTime = 10.0/float(numSamples)
dt = 0.001
dtf = 0.001
procs = 0.01
obs = 0.01
(x,y) = dataGen.generateSamplePointsGG(samplingTime, numSamples, dt, x0, procs, obs)
print 'samplePoints generated'
finalTime = samplingTime*(numSamples-1)
pf = ParticleFilter(oscillator, x0.size)
pf.runFilter(y, samplingTime, dtf, procs, obs)
pltWorker.plotPath(x, finalTime, 0.5)
pltWorker.plotMarkers(y, finalTime, mk='*', ms=10, a=0.4)
pltWorker.plotPFMarkers(pf.x, pf.w, finalTime)
pltWorker.save("dists.pdf")
pltWorker.clear()
z = pf.getAveragePath()
pltWorker.plotPath(x, finalTime, 0.5)
pltWorker.plotMarkers(y, finalTime)
pltWorker.plotMarkers(z, finalTime, mk='x', ms=8)
pltWorker.save('sir.pdf')
pltWorker.clear()
ac = dataGen.calculateAccuracy(x, z)
print "Particle filter error:", ac
kf = KalmanFilter(oscillator, x0.size)
kf.runFilter(y, samplingTime, dtf, procs, obs)
k = kf.getAveragePath()
pltWorker.plotPath(x, finalTime, 0.5)
pltWorker.plotMarkers(y, finalTime)
pltWorker.plotMarkers(k, finalTime, mk='x', ms=8)
pltWorker.save('kalman.pdf')
ac = dataGen.calculateAccuracy(x, k)
print "Kalman filter error:", ac