/
train.py
67 lines (55 loc) · 1.46 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import kalman
import matplotlib.pyplot as plt
from numpy import *
from scipy.stats import norm
dt = 0.1
accel_sigma = 2.0
measurement_sigma = 1.0
process = kalman.process_model(
F=array([
[1, dt],
[0, 1],
]),
Q=array([
[(dt**4)/4, (dt**3)/2],
[(dt**3)/2, dt**2],
]) * (accel_sigma ** 2)
)
measure = kalman.observation_model(
H=[[1.0, 0.0]],
R=[[measurement_sigma ** 2]],
)
filt = kalman.kalman(
x=zeros((2,1)),
P=zeros((2,2)),
)
accel_rv = norm(0, accel_sigma)
measurement_rv = norm(0, measurement_sigma)
accelerations = (accel_rv.rvs() for _ in xrange(1000))
true_positions = []
true_velocities = []
estimated_positions = []
estimated_velocities = []
true_position = 0.0
true_velocity = 0.0
for acceleration in accelerations:
filt.predict(process)
measurement = true_velocity + measurement_rv.rvs()
filt.update(measure, measurement)
true_position += true_velocity * dt + acceleration * (dt ** 2) / 2
true_velocity += acceleration * dt
true_positions.append(true_position)
true_velocities.append(true_velocity)
estimated_positions.append(filt.x[0])
estimated_velocities.append(filt.x[1])
plt.subplot(211)
plt.title('Position')
plt.plot(estimated_positions, label='Estimate')
plt.plot(true_positions, label='True')
plt.legend()
plt.subplot(212)
plt.title('Velocity')
plt.plot(estimated_velocities, label='Estimate')
plt.plot(true_velocities, label='True')
plt.legend()
plt.show()