コード例 #1
0
def test_linear_2d_simplex():
    """ should work like a linear KF if problem is linear """
    def fx(x, dt):
        F = np.array(
            [[1, dt, 0, 0], [0, 1, 0, 0], [0, 0, 1, dt], [0, 0, 0, 1]],
            dtype=float)

        return np.dot(F, x)

    def hx(x):
        return np.array([x[0], x[2]])

    dt = 0.1
    points = SimplexSigmaPoints(n=4)
    kf = UKF(dim_x=4, dim_z=2, dt=dt, fx=fx, hx=hx, points=points)

    kf.x = np.array([-1., 1., -1., 1])
    kf.P *= 0.0001

    zs = []
    for i in range(20):
        z = np.array([i + randn() * 0.1, i + randn() * 0.1])
        zs.append(z)

    Ms, Ps = kf.batch_filter(zs)
    smooth_x, _, _ = kf.rts_smoother(Ms, Ps, dt=dt)

    if DO_PLOT:
        zs = np.asarray(zs)

        #plt.plot(zs[:,0])
        plt.plot(Ms[:, 0])
        plt.plot(smooth_x[:, 0], smooth_x[:, 2])

        print(smooth_x)
コード例 #2
0
def test_batch_missing_data():
    """ batch filter should accept missing data with None in the measurements """
    def fx(x, dt):
        F = np.array(
            [[1, dt, 0, 0], [0, 1, 0, 0], [0, 0, 1, dt], [0, 0, 0, 1]],
            dtype=float)

        return np.dot(F, x)

    def hx(x):
        return np.array([x[0], x[2]])

    dt = 0.1
    points = MerweScaledSigmaPoints(4, .1, 2., -1)
    kf = UKF(dim_x=4, dim_z=2, dt=dt, fx=fx, hx=hx, points=points)

    kf.x = np.array([-1., 1., -1., 1])
    kf.P *= 0.0001

    zs = []
    for i in range(20):
        z = np.array([i + randn() * 0.1, i + randn() * 0.1])
        zs.append(z)

    zs[2] = None
    Rs = [1] * len(zs)
    Rs[2] = None
    Ms, Ps = kf.batch_filter(zs)
コード例 #3
0
ファイル: test_ukf.py プロジェクト: Censio/filterpy
def test_linear_2d():
    """ should work like a linear KF if problem is linear """


    def fx(x, dt):
        F = np.array([[1, dt, 0, 0],
                      [0,  1, 0, 0],
                      [0, 0,  1, dt],
                      [0, 0, 0,  1]], dtype=float)

        return np.dot(F, x)

    def hx(x):
        return np.array([x[0], x[2]])


    dt = 0.1
    points = MerweScaledSigmaPoints(4, .1, 2., -1)
    kf = UKF(dim_x=4, dim_z=2, dt=dt, fx=fx, hx=hx, points=points)


    kf.x = np.array([-1., 1., -1., 1])
    kf.P*=0.0001
    #kf.R *=0
    #kf.Q

    zs = []
    for i in range(20):
        z = np.array([i+randn()*0.1, i+randn()*0.1])
        zs.append(z)



    Ms, Ps = kf.batch_filter(zs)
    smooth_x, _, _ = kf.rts_smoother(Ms, Ps, dt=dt)
コード例 #4
0
ファイル: test_ukf.py プロジェクト: SiChiTong/filterpy
def test_batch_missing_data():
    """ batch filter should accept missing data with None in the measurements """


    def fx(x, dt):
        F = np.array([[1, dt, 0, 0],
                      [0,  1, 0, 0],
                      [0, 0,  1, dt],
                      [0, 0, 0,  1]], dtype=float)

        return np.dot(F, x)

    def hx(x):
        return np.array([x[0], x[2]])


    dt = 0.1
    points = MerweScaledSigmaPoints(4, .1, 2., -1)
    kf = UKF(dim_x=4, dim_z=2, dt=dt, fx=fx, hx=hx, points=points)


    kf.x = np.array([-1., 1., -1., 1])
    kf.P*=0.0001

    zs = []
    for i in range(20):
        z = np.array([i+randn()*0.1, i+randn()*0.1])
        zs.append(z)

    zs[2] = None
    Rs = [1]*len(zs)
    Rs[2] = None
    Ms, Ps = kf.batch_filter(zs)
コード例 #5
0
def test_linear_2d_merwe():
    """ should work like a linear KF if problem is linear """
    def fx(x, dt):
        F = np.array(
            [[1, dt, 0, 0], [0, 1, 0, 0], [0, 0, 1, dt], [0, 0, 0, 1]],
            dtype=float)

        return np.dot(F, x)

    def hx(x):
        return np.array([x[0], x[2]])

    dt = 0.1
    points = MerweScaledSigmaPoints(4, .1, 2., -1)
    kf = UKF(dim_x=4, dim_z=2, dt=dt, fx=fx, hx=hx, points=points)

    kf.x = np.array([-1., 1., -1., 1])
    kf.P *= 1.1

    # test __repr__ doesn't crash
    str(kf)

    zs = [[i + randn() * 0.1, i + randn() * 0.1] for i in range(20)]

    Ms, Ps = kf.batch_filter(zs)
    smooth_x, _, _ = kf.rts_smoother(Ms, Ps, dt=dt)

    if DO_PLOT:
        plt.figure()
        zs = np.asarray(zs)
        plt.plot(zs[:, 0], marker='+')
        plt.plot(Ms[:, 0], c='b')
        plt.plot(smooth_x[:, 0], smooth_x[:, 2], c='r')
        print(smooth_x)
コード例 #6
0
def test_rts():
    def fx(x, dt):
        A = np.eye(3) + dt * np.array([[0, 1, 0], [0, 0, 0], [0, 0, 0]])
        f = np.dot(A, x)
        return f

    def hx(x):
        return np.sqrt(x[0]**2 + x[2]**2)

    dt = 0.05

    sp = JulierSigmaPoints(n=3, kappa=1.)
    kf = UKF(3, 1, dt, fx=fx, hx=hx, points=sp)

    kf.Q *= 0.01
    kf.R = 10
    kf.x = np.array([0., 90., 1100.])
    kf.P *= 100.
    radar = RadarSim(dt)

    t = np.arange(0, 20 + dt, dt)

    n = len(t)

    xs = np.zeros((n, 3))

    random.seed(200)
    rs = []
    #xs = []
    for i in range(len(t)):
        r = radar.get_range()
        #r = GetRadar(dt)
        kf.predict()
        kf.update(z=[r])

        xs[i, :] = kf.x
        rs.append(r)

    kf.x = np.array([0., 90., 1100.])
    kf.P = np.eye(3) * 100
    M, P = kf.batch_filter(rs)
    assert np.array_equal(M, xs), "Batch filter generated different output"

    Qs = [kf.Q] * len(t)
    M2, P2, K = kf.rts_smoother(Xs=M, Ps=P, Qs=Qs)

    if DO_PLOT:
        print(xs[:, 0].shape)

        plt.figure()
        plt.subplot(311)
        plt.plot(t, xs[:, 0])
        plt.plot(t, M2[:, 0], c='g')
        plt.subplot(312)
        plt.plot(t, xs[:, 1])
        plt.plot(t, M2[:, 1], c='g')
        plt.subplot(313)

        plt.plot(t, xs[:, 2])
        plt.plot(t, M2[:, 2], c='g')
コード例 #7
0
ファイル: test_ukf.py プロジェクト: Censio/filterpy
def test_rts():
    def fx(x, dt):
        A = np.eye(3) + dt * np.array ([[0, 1, 0],
                                        [0, 0, 0],
                                        [0, 0, 0]])
        f = np.dot(A, x)
        return f

    def hx(x):
        return np.sqrt (x[0]**2 + x[2]**2)

    dt = 0.05

    sp = JulierSigmaPoints(n=3, kappa=1.)
    kf = UKF(3, 1, dt, fx=fx, hx=hx, points=sp)

    kf.Q *= 0.01
    kf.R = 10
    kf.x = np.array([0., 90., 1100.])
    kf.P *= 100.
    radar = RadarSim(dt)

    t = np.arange(0,20+dt, dt)

    n = len(t)

    xs = np.zeros((n,3))

    random.seed(200)
    rs = []
    #xs = []
    for i in range(len(t)):
        r = radar.get_range()
        #r = GetRadar(dt)
        kf.predict()
        kf.update(z=[r])

        xs[i,:] = kf.x
        rs.append(r)


    kf.x = np.array([0., 90., 1100.])
    kf.P = np.eye(3)*100
    M, P = kf.batch_filter(rs)
    assert np.array_equal(M, xs), "Batch filter generated different output"

    Qs = [kf.Q]*len(t)
    M2, P2, K = kf.rts_smoother(Xs=M, Ps=P, Qs=Qs)
コード例 #8
0
def Unscentedfilter(zs):  # Filter function
    points = MerweScaledSigmaPoints(2, alpha=.1, beta=2., kappa=1)
    ukf = UnscentedKalmanFilter(dim_x=2,
                                dim_z=1,
                                fx=fx,
                                hx=hx,
                                points=points,
                                dt=dt)
    ukf.Q = array(([50, 0], [0, 50]))
    ukf.R = 100
    ukf.P = eye(2) * 2
    mu, cov = ukf.batch_filter(zs)

    x, _, _ = ukf.rts_smoother(mu, cov)

    return x[:, 0]
コード例 #9
0
ファイル: test_ukf.py プロジェクト: PepSalehi/filterpy
def test_linear_2d():
    """ should work like a linear KF if problem is linear """
    
    
    def fx(x, dt):
        F = np.array([[1, dt, 0, 0],
                      [0,  1, 0, 0],
                      [0, 0,  1, dt],
                      [0, 0, 0,  1]], dtype=float)
                    
        return np.dot(F, x)
        
    def hx(x):
        return np.array([x[0], x[2]])
        
        
    dt = 0.1
    kf = UKF(dim_x=4, dim_z=2, dt=dt, fx=fx, hx=hx, kappa=0)
    
    
    kf.x = np.array([-1., 1., -1., 1])
    kf.P*=0.0001
    #kf.R *=0
    #kf.Q 
    
    zs = []
    for i in range(20):
        z = np.array([i+randn()*0.1, i+randn()*0.1])
        zs.append(z)

        
    
    Ms, Ps = kf.batch_filter(zs)
    smooth_x, _, _ = kf.rts_smoother(Ms, Ps, dt=dt)
    
    if DO_PLOT:
        zs = np.asarray(zs)
        
        #plt.plot(zs[:,0])
        plt.plot(Ms[:,0])
        plt.plot(smooth_x[:,0], smooth_x[:,2])
        
        print(smooth_x)
コード例 #10
0
def smooth(data: ndarray, dt: float):
    points = MerweScaledSigmaPoints(3, alpha=1e-3, beta=2.0, kappa=0)
    noisy_kalman = UnscentedKalmanFilter(
        dim_x=3,
        dim_z=1,
        dt=dt,
        hx=state_to_measurement,
        fx=state_transition,
        points=points,
    )

    noisy_kalman.x = array([0, data[1], data[1] - data[0]], dtype="float32")
    noisy_kalman.R *= 20**2  # sensor variance
    noisy_kalman.P = diag([5**2, 5**2, 1**2])  # variance of the system
    noisy_kalman.Q = Q_discrete_white_noise(3, dt=dt, var=0.05)

    means, covariances = noisy_kalman.batch_filter(data)
    means[:, 1][means[:, 1] < 0] = 0  # clip velocity
    return means[:, 1]
コード例 #11
0
ファイル: test_ukf.py プロジェクト: PepSalehi/filterpy
def test_fixed_lag():
    def fx(x, dt):
        A = np.eye(3) + dt * np.array ([[0, 1, 0],
                                        [0, 0, 0],
                                        [0, 0, 0]])
        f = np.dot(A, x)
        return f

    def hx(x):
        return np.sqrt (x[0]**2 + x[2]**2)

    dt = 0.05

    kf = UKF(3, 1, dt, fx=fx, hx=hx, kappa=0.)

    kf.Q *= 0.01
    kf.R = 10
    kf.x = np.array([0., 90., 1100.])
    kf.P *= 1.
    radar = RadarSim(dt)

    t = np.arange(0,20+dt, dt)

    n = len(t)

    xs = np.zeros((n,3))

    random.seed(200)
    rs = []
    #xs = []

    M = []
    P = []
    N =10
    flxs = []
    for i in range(len(t)):
        r = radar.get_range()
        #r = GetRadar(dt)
        kf.predict()
        kf.update(z=[r])

        xs[i,:] = kf.x
        flxs.append(kf.x)
        rs.append(r)
        M.append(kf.x)
        P.append(kf.P)
        print(i)
        #print(i, np.asarray(flxs)[:,0])
        if i == 20 and len(M) >= N:
            try:
                M2, P2, K = kf.rts_smoother(Xs=np.asarray(M)[-N:], Ps=np.asarray(P)[-N:])
                flxs[-N:] = M2
                #flxs[-N:] = [20]*N
            except:
                print('except', i)
            #P[-N:] = P2


    kf.x = np.array([0., 90., 1100.])
    kf.P = np.eye(3)*100
    M, P = kf.batch_filter(rs)

    Qs = [kf.Q]*len(t)
    M2, P2, K = kf.rts_smoother(Xs=M, Ps=P, Qs=Qs)


    flxs = np.asarray(flxs)
    print(xs[:,0].shape)

    plt.figure()
    plt.subplot(311)
    plt.plot(t, xs[:,0])
    plt.plot(t, flxs[:,0], c='r')
    plt.plot(t, M2[:,0], c='g')
    plt.subplot(312)
    plt.plot(t, xs[:,1])
    plt.plot(t, flxs[:,1], c='r')
    plt.plot(t, M2[:,1], c='g')

    plt.subplot(313)
    plt.plot(t, xs[:,2])
    plt.plot(t, flxs[:,2], c='r')
    plt.plot(t, M2[:,2], c='g')
コード例 #12
0
ファイル: test_ukf.py プロジェクト: PepSalehi/filterpy
def test_rts():
    def fx(x, dt):
        A = np.eye(3) + dt * np.array ([[0, 1, 0],
                                        [0, 0, 0],
                                        [0, 0, 0]])
        f = np.dot(A, x)
        return f

    def hx(x):
        return np.sqrt (x[0]**2 + x[2]**2)

    dt = 0.05

    kf = UKF(3, 1, dt, fx=fx, hx=hx, kappa=1.)

    kf.Q *= 0.01
    kf.R = 10
    kf.x = np.array([0., 90., 1100.])
    kf.P *= 100.
    radar = RadarSim(dt)

    t = np.arange(0,20+dt, dt)

    n = len(t)

    xs = np.zeros((n,3))

    random.seed(200)
    rs = []
    #xs = []
    for i in range(len(t)):
        r = radar.get_range()
        #r = GetRadar(dt)
        kf.predict()
        kf.update(z=[r])

        xs[i,:] = kf.x
        rs.append(r)


    kf.x = np.array([0., 90., 1100.])
    kf.P = np.eye(3)*100
    M, P = kf.batch_filter(rs)
    assert np.array_equal(M, xs), "Batch filter generated different output"

    Qs = [kf.Q]*len(t)
    M2, P2, K = kf.rts_smoother(Xs=M, Ps=P, Qs=Qs)


    if DO_PLOT:
        print(xs[:,0].shape)

        plt.figure()
        plt.subplot(311)
        plt.plot(t, xs[:,0])
        plt.plot(t, M2[:,0], c='g')
        plt.subplot(312)
        plt.plot(t, xs[:,1])
        plt.plot(t, M2[:,1], c='g')
        plt.subplot(313)

        plt.plot(t, xs[:,2])
        plt.plot(t, M2[:,2], c='g')
コード例 #13
0
def two_radar():

    # code is not complete - I was using to test RTS smoother. very similar
    # to two_radary.py in book.

    import numpy as np
    import matplotlib.pyplot as plt

    from numpy import array
    from numpy.linalg import norm
    from numpy.random import randn
    from math import atan2

    from filterpy.common import Q_discrete_white_noise

    class RadarStation(object):
        def __init__(self, pos, range_std, bearing_std):
            self.pos = asarray(pos)

            self.range_std = range_std
            self.bearing_std = bearing_std

        def reading_of(self, ac_pos):
            """ Returns range and bearing to aircraft as tuple. bearing is in
            radians.
            """

            diff = np.subtract(self.pos, ac_pos)
            rng = norm(diff)
            brg = atan2(diff[1], diff[0])
            return rng, brg

        def noisy_reading(self, ac_pos):
            rng, brg = self.reading_of(ac_pos)
            rng += randn() * self.range_std
            brg += randn() * self.bearing_std
            return rng, brg

    class ACSim(object):
        def __init__(self, pos, vel, vel_std):
            self.pos = asarray(pos, dtype=float)
            self.vel = asarray(vel, dtype=float)
            self.vel_std = vel_std

        def update(self):
            vel = self.vel + (randn() * self.vel_std)
            self.pos += vel

            return self.pos

    dt = 1.

    def hx(x):
        r1, b1 = hx.R1.reading_of((x[0], x[2]))
        r2, b2 = hx.R2.reading_of((x[0], x[2]))

        return array([r1, b1, r2, b2])

    def fx(x, dt):
        x_est = x.copy()
        x_est[0] += x[1] * dt
        x_est[2] += x[3] * dt
        return x_est

    vx, vy = 0.1, 0.1

    f = UnscentedKalmanFilter(dim_x=4, dim_z=4, dt=dt, hx=hx, fx=fx, kappa=0)
    aircraft = ACSim((100, 100), (vx * dt, vy * dt), 0.00000002)

    range_std = 0.001  # 1 meter
    bearing_std = 1. / 1000  # 1mrad

    R1 = RadarStation((0, 0), range_std, bearing_std)
    R2 = RadarStation((200, 0), range_std, bearing_std)

    hx.R1 = R1
    hx.R2 = R2

    f.x = array([100, vx, 100, vy])

    f.R = np.diag([range_std**2, bearing_std**2, range_std**2, bearing_std**2])
    q = Q_discrete_white_noise(2, var=0.0002, dt=dt)
    f.Q[0:2, 0:2] = q
    f.Q[2:4, 2:4] = q
    f.P = np.diag([.1, 0.01, .1, 0.01])

    track = []
    zs = []

    for i in range(int(300 / dt)):
        pos = aircraft.update()

        r1, b1 = R1.noisy_reading(pos)
        r2, b2 = R2.noisy_reading(pos)

        z = np.array([r1, b1, r2, b2])
        zs.append(z)
        track.append(pos.copy())

    zs = asarray(zs)

    xs, Ps, Pxz, pM, pP = f.batch_filter(zs)
    ms, _, _ = f.rts_smoother(xs, Ps)

    track = asarray(track)
    time = np.arange(0, len(xs) * dt, dt)

    if DO_PLOT:
        plt.figure()
        plt.subplot(411)
        plt.plot(time, track[:, 0])
        plt.plot(time, xs[:, 0])
        plt.legend(loc=4)
        plt.xlabel('time (sec)')
        plt.ylabel('x position (m)')
        plt.tight_layout()

        plt.subplot(412)
        plt.plot(time, track[:, 1])
        plt.plot(time, xs[:, 2])
        plt.legend(loc=4)
        plt.xlabel('time (sec)')
        plt.ylabel('y position (m)')
        plt.tight_layout()

        plt.subplot(413)
        plt.plot(time, xs[:, 1])
        plt.plot(time, ms[:, 1])
        plt.legend(loc=4)
        plt.ylim([0, 0.2])
        plt.xlabel('time (sec)')
        plt.ylabel('x velocity (m/s)')
        plt.tight_layout()

        plt.subplot(414)
        plt.plot(time, xs[:, 3])
        plt.plot(time, ms[:, 3])
        plt.ylabel('y velocity (m/s)')
        plt.legend(loc=4)
        plt.xlabel('time (sec)')
        plt.tight_layout()
        plt.show()
コード例 #14
0
def two_radar_constalt():
    dt = 0.05

    def hx(x):
        r1, b1 = hx.R1.reading_of((x[0], x[2]))
        r2, b2 = hx.R2.reading_of((x[0], x[2]))

        return array([r1, b1, r2, b2])
        pass

    def fx(x, dt):
        x_est = x.copy()
        x_est[0] += x[1] * dt
        return x_est

    vx = 100 / 1000  # meters/sec
    vz = 0.0

    f = UKF(dim_x=3, dim_z=4, dt=dt, hx=hx, fx=fx, kappa=0)
    aircraft = ACSim((0, 1), (vx * dt, vz * dt), 0.00)

    range_std = 1 / 1000.0
    bearing_std = 1 / 1000000.0

    R1 = RadarStation((0, 0), range_std, bearing_std)
    R2 = RadarStation((60, 0), range_std, bearing_std)

    hx.R1 = R1
    hx.R2 = R2

    f.x = array([aircraft.pos[0], vx, aircraft.pos[1]])
    f.R = np.diag([range_std ** 2, bearing_std ** 2, range_std ** 2, bearing_std ** 2])
    q = Q_discrete_white_noise(2, var=0.0002, dt=dt)
    # q = np.array([[0,0],[0,0.0002]])
    f.Q[0:2, 0:2] = q
    f.Q[2, 2] = 0.0002
    f.P = np.diag([0.1, 0.01, 0.1]) * 0.1

    track = []
    zs = []

    for i in range(int(500 / dt)):
        pos = aircraft.update()

        r1, b1 = R1.noisy_reading(pos)
        r2, b2 = R2.noisy_reading(pos)

        z = np.array([r1, b1, r2, b2])
        zs.append(z)
        track.append(pos.copy())

    zs = asarray(zs)

    xs, Ps = f.batch_filter(zs)
    ms, _, _ = f.rts_smoother(xs, Ps)

    track = asarray(track)
    time = np.arange(0, len(xs) * dt, dt)

    plt.figure()
    plt.subplot(311)
    plt.plot(time, track[:, 0])
    plt.plot(time, xs[:, 0])
    plt.legend(loc=4)
    plt.xlabel("time (sec)")
    plt.ylabel("x position (m)")

    plt.subplot(312)
    plt.plot(time, xs[:, 1] * 1000, label="UKF")
    plt.plot(time, ms[:, 1] * 1000, label="RTS")
    plt.legend(loc=4)
    plt.xlabel("time (sec)")
    plt.ylabel("velocity (m/s)")

    plt.subplot(313)
    plt.plot(time, xs[:, 2] * 1000, label="UKF")
    plt.plot(time, ms[:, 2] * 1000, label="RTS")
    plt.legend(loc=4)
    plt.xlabel("time (sec)")
    plt.ylabel("altitude (m)")
    plt.ylim([900, 1100])

    for z in zs[:10]:
        p = R1.z_to_x(z[0], z[1])
        # plt.scatter(p[0], p[1], marker='+', c='k')

        p = R2.z_to_x(z[2], z[3])
        # plt.scatter(p[0], p[1], marker='+', c='b')

    plt.show()
コード例 #15
0
        # calculate weight update
        synapse_1_update = np.dot(np.atleast_2d(layer_1).T, (layer_2_delta))
        synapse_h_update = np.dot(
            np.atleast_2d(context_layer).T, (layer_1_delta))
        synapse_0_update = np.dot(X.T, (layer_1_delta))

        # concatenate weight
        synapse_0_c = np.reshape(synapse_0, (-1, 1))
        synapse_h_c = np.reshape(synapse_h, (-1, 1))
        synapse_1_c = np.reshape(synapse_1, (-1, 1))
        w_concat = np.concatenate((synapse_0_c, synapse_h_c, synapse_1_c),
                                  axis=0)

        zs = data
        Xs, Ps = UKF.batch_filter(zs=w_concat)
# =============================================================================
#         w_concat_eye = np.eye(w_concat.size)
#         # sigma points dari mean dan kovarian pada synapse_1
#
#         # ukf points (mean dan kovarian) setiap synapse layer
#         synapse_0_sig = np.zeros((len(synapse_0_c),dim_x,1))
#         synapse_h_sig = np.zeros((len(synapse_h_c),dim_x,1))
#         synapse_1_sig = np.zeros((len(synapse_1_c),dim_x,1))
#         w_concat_sig = np.concatenate((synapse_0_sig,synapse_h_sig,synapse_1_sig), axis=0)
#
#         # hitung bobot tiap points tersebut
#         # Unscented transform untuk trasformasi mean dan kovarian ke tuple (measurement space)
#         # new mean dan sigmas
#         # update pengukuran
#         # kalman gain
コード例 #16
0
def test_linear_rts():
    """ for a linear model the Kalman filter and UKF should produce nearly
    identical results.

    Test code mostly due to user gboehl as reported in GitHub issue #97, though
    I converted it from an AR(1) process to constant velocity kinematic
    model.
    """
    dt = 1.0
    F = np.array([[1., dt], [.0, 1]])
    H = np.array([[1., .0]])

    def t_func(x, dt):
        F = np.array([[1., dt], [.0, 1]])
        return np.dot(F, x)

    def o_func(x):
        return np.dot(H, x)

    sig_t = .1  # peocess
    sig_o = .00000001  # measurement

    N = 50
    X_true, X_obs = [], []

    for i in range(N):
        X_true.append([i + 1, 1.])
        X_obs.append(i + 1 + np.random.normal(scale=sig_o))

    X_true = np.array(X_true)
    X_obs = np.array(X_obs)

    oc = np.ones((1, 1)) * sig_o**2
    tc = np.zeros((2, 2))
    tc[1, 1] = sig_t**2

    tc = Q_discrete_white_noise(dim=2, dt=dt, var=sig_t**2)
    points = MerweScaledSigmaPoints(n=2, alpha=.1, beta=2., kappa=1)

    ukf = UKF(dim_x=2, dim_z=1, dt=dt, hx=o_func, fx=t_func, points=points)
    ukf.x = np.array([0., 1.])
    ukf.R = oc[:]
    ukf.Q = tc[:]
    s = Saver(ukf)
    s.save()
    s.to_array()

    kf = KalmanFilter(dim_x=2, dim_z=1)
    kf.x = np.array([[0., 1]]).T
    kf.R = oc[:]
    kf.Q = tc[:]
    kf.H = H[:]
    kf.F = F[:]

    mu_ukf, cov_ukf = ukf.batch_filter(X_obs)
    x_ukf, _, _ = ukf.rts_smoother(mu_ukf, cov_ukf)

    mu_kf, cov_kf, _, _ = kf.batch_filter(X_obs)
    x_kf, _, _, _ = kf.rts_smoother(mu_kf, cov_kf)

    # check results of filtering are correct
    kfx = mu_kf[:, 0, 0]
    ukfx = mu_ukf[:, 0]
    kfxx = mu_kf[:, 1, 0]
    ukfxx = mu_ukf[:, 1]

    dx = kfx - ukfx
    dxx = kfxx - ukfxx

    # error in position should be smaller then error in velocity, hence
    # atol is different for the two tests.
    assert np.allclose(dx, 0, atol=1e-7)
    assert np.allclose(dxx, 0, atol=1e-6)

    # now ensure the RTS smoothers gave nearly identical results
    kfx = x_kf[:, 0, 0]
    ukfx = x_ukf[:, 0]
    kfxx = x_kf[:, 1, 0]
    ukfxx = x_ukf[:, 1]

    dx = kfx - ukfx
    dxx = kfxx - ukfxx

    assert np.allclose(dx, 0, atol=1e-7)
    assert np.allclose(dxx, 0, atol=1e-6)
    return ukf
コード例 #17
0
ファイル: ProcessSEIR1R2D.py プロジェクト: SDerrode/divoc
def fit(sysargv):
    """
		Program to process Covid Data.
 
		:Example:

		For countries (European database)
		>> python ProcessSEIR1R2D.py 
		>> python ProcessSEIR1R2D.py France 0 1 0 0 1 1
		>> python ProcessSEIR1R2D.py France 2 3 8 0 1 1          # 3 périodes pour les femmes en France avec un décalage de 8 jours
		>> python ProcessSEIR1R2D.py France,Germany 1 1 0 0 1 1 # 1 période pour les hommes francais et les hommes allemands 

		For French Region (French database)
		>> python ProcessSEIR1R2D.py FRANCE,D69         0 -1 13 0 1 1 # Code Insee Dpt 69 (Rhône)
		>> python ProcessSEIR1R2D.py FRANCE,R84         0 -1 13 0 1 1 # Tous les dpts de la Région dont le code Insee est 
		>> python ProcessSEIR1R2D.py FRANCE,R32+        0 -1 13 0 1 1 # Somme de tous les dpts de la Région 32 (Hauts-de-France)
		>> python ProcessSEIR1R2D.py FRANCE,MetropoleD  0 -1 13 0 1 1 # Tous les départements de la France métropolitaine
		>> python ProcessSEIR1R2D.py FRANCE,MetropoleD+ 0 -1 13 0 1 1 # Toute la France métropolitaine (en sommant les dpts)
		>> python ProcessSEIR1R2D.py FRANCE,MetropoleR+ 0 -1 13 0 1 1 # Somme des dpts de toutes les régions françaises
		Toute combinaison est possible de lieu : exemple FRANCE,R32+,D05,R84
		
		argv[1] : List of countries (ex. France,Germany,Italy), or see above.          Default: France 
		argv[2] : Sex (male:1, female:2, male+female:0). Only for french database      Default: 0 
		argv[3] : Periods ('1' -> 1 period ('all-in-on'), '!=1' -> severall periods).  Default: -1
		argv[4] : Delay (in days).                                                     Default: 13
		argv[5] : UKF filtering of data (0/1).                                         Default: 0
		argv[6] : Verbose level (debug: 3, ..., almost mute: 0).                       Default: 1
		argv[7] : Plot graphique (0/1).                                                Default: 1
	"""

    #Austria,Belgium,Croatia,Czechia,Finland,France,Germany,Greece,Hungary,Ireland,Italy,Lithuania,Poland,Portugal,Romania,Serbia,Spain,Switzerland,Ukraine
    #Austria,Belgium,Croatia,Czechia,Finland,France,Germany,Greece,Hungary,Ireland,Italy,Poland,Portugal,Romania,Serbia,Spain,Switzerland,Ukraine
    # Il y a 18 pays

    if len(sysargv) > 7:
        print('  CAUTION : bad number of arguments - see help')
        exit(1)

    # Constantes
    ######################################################@
    fileLocalCopy = True  # if we upload the file from the url (to get latest data) or from a local copy file
    readStartDateStr = "2020-03-01"  # "2020-03-01" Le 8 mars, pour inclure un grand nombre de pays européens dont la date de premier était postérieur au 1er mars
    readStopDateStr = None
    recouvrement = -1
    dt = 1
    France = 'France'
    thresholdSignif = 1.5E-6

    # Interpetation of arguments - reparation
    ######################################################@

    # Default value for parameters
    listplaces = ['France']
    sexe, sexestr = 0, 'male+female'
    nbperiodes = -1
    decalage = 13
    UKF_filt = False
    verbose = 1
    plot = True

    # Parameters from argv
    if len(sysargv) > 0: liste = list(sysargv[0].split(','))
    if len(sysargv) > 1: sexe = int(sysargv[1])
    if len(sysargv) > 2: nbperiodes = int(sysargv[2])
    if len(sysargv) > 3: decalage = int(sysargv[3])
    if len(sysargv) > 4 and int(sysargv[4]) == 1: UKF_filt = True
    if len(sysargv) > 5: verbose = int(sysargv[5])
    if len(sysargv) > 6 and int(sysargv[6]) == 0: plot = False
    if nbperiodes == 1:
        decalage = 0  # nécessairement pas de décalage (on compense le recouvrement)
    if sexe not in [0, 1, 2]:
        sexe, sexestr = 0, 'male+female'  # sexe indiférencié
    if sexe == 1: sexestr = 'male'
    if sexe == 2: sexestr = 'female'

    listplaces = []
    listnames = []
    if liste[0] == 'FRANCE':
        FrDatabase = True
        liste = liste[1:]
        for el in liste:
            l, n = getPlace(el)
            if el == 'MetropoleR+':
                for l1, n1 in zip(l, n):
                    listplaces.extend(l1)
                    listnames.extend([n1])
            else:
                listplaces.extend(l)
                listnames.extend(n)
    else:
        listplaces = liste[:]
        FrDatabase = False

    if verbose > 0:
        print('  Full command line : ' + sysargv[0] + ' ' + str(nbperiodes) +
              ' ' + str(decalage) + ' ' + str(UKF_filt) + ' ' + str(verbose) +
              ' ' + str(plot),
              flush=True)

    # Data reading to get first and last date available in the data set
    ######################################################@
    if FrDatabase == True:
        pd_exerpt, HeadData, N, readStartDateStr, readStopDateStr, _ = readDataFrance(
            ['D69'],
            readStartDateStr,
            readStopDateStr,
            fileLocalCopy,
            sexe,
            verbose=0)
    else:
        pd_exerpt, HeadData, N, readStartDateStr, readStopDateStr, _ = readDataEurope(
            France,
            readStartDateStr,
            readStopDateStr,
            fileLocalCopy,
            verbose=0)
    dataLength = pd_exerpt.shape[0]

    readStartDate = datetime.strptime(readStartDateStr, strDate)
    if readStartDate < pd_exerpt.index[0]:
        readStartDate = pd_exerpt.index[0]
        readStartDateStr = pd_exerpt.index[0].strftime(strDate)
    readStopDate = datetime.strptime(readStopDateStr, strDate)
    if readStopDate < pd_exerpt.index[-1]:
        readStopDate = pd_exerpt.index[-1]
        readStopDateStr = pd_exerpt.index[-1].strftime(strDate)

    dataLength = pd_exerpt.shape[0]
    if verbose > 1:
        print('readStartDateStr=', readStartDateStr, ', readStopDateStr=',
              readStopDateStr)
        print('readStartDate   =', readStartDate, ', readStopDate   =',
              readStopDate)
        print('dataLength      =', dataLength)
        #input('pause')

    # Collections of data return by this function
    modelSEIR1R2D = np.zeros(shape=(len(listplaces), dataLength, 6))
    data_deriv = np.zeros(shape=(len(listplaces), dataLength, 2))
    modelR1_deriv = np.zeros(shape=(len(listplaces), dataLength, 2))
    data_all = np.zeros(shape=(len(listplaces), dataLength, 2))
    modelR1_all = np.zeros(shape=(len(listplaces), dataLength, 2))
    Listepd = []
    ListetabParamModel = []

    # data observed
    data = np.zeros(shape=(dataLength, 2))

    # Paramètres sous forme de chaines de caractères
    ListeTextParam = []
    ListeDateI0 = []

    # Loop on the places to process
    for indexplace, place in enumerate(listplaces):

        # Get the full name of the place to process, and the special dates corresponding to the place
        if FrDatabase == True:
            placefull = 'France-' + listnames[indexplace][0]
            DatesString = readDates(France, verbose)
        else:
            placefull = place
            DatesString = readDates(place, verbose)

        print('PROCESSING of', placefull, 'in', listnames)

        # data reading of the observations
        #############################################################################
        if FrDatabase == True:
            pd_exerpt, HeadData, N, readStartDateStr, readStopDateStr, dateFirstNonZeroStr = readDataFrance(
                place,
                readStartDateStr,
                readStopDateStr,
                fileLocalCopy,
                sexe,
                verbose=0)
        else:
            pd_exerpt, HeadData, N, readStartDateStr, readStopDateStr, dateFirstNonZeroStr = readDataEurope(
                place,
                readStartDateStr,
                readStopDateStr,
                fileLocalCopy,
                verbose=0)

        shift_0value = getNbDaysBetweenDateFromString(readStartDateStr,
                                                      dateFirstNonZeroStr)

        # UKF Filtering ?
        if UKF_filt == True:
            data2Filt = pd_exerpt[[HeadData[0],
                                   HeadData[1]]].to_numpy(copy=True)
            sigmas = MerweScaledSigmaPoints(n=2, alpha=.5, beta=2.,
                                            kappa=1.)  #1-3.)
            ukf = UKF(dim_x=2, dim_z=2, fx=fR1D, hx=hR1D, dt=dt, points=sigmas)
            # Filter init
            ukf.x[:] = data2Filt[0, :]
            ukf.Q = np.diag([30., 15.])
            ukf.R = np.diag([170., 100.])
            if verbose > 1:
                print('ukf.x[:]=', ukf.x[:])
                print('ukf.R   =', ukf.R)
                print('ukf.Q   =', ukf.Q)

            # UKF filtering and smoothing, batch mode
            R1Ffilt, _ = ukf.batch_filter(data2Filt)
            HeadData[0] += ' filt'
            HeadData[1] += ' filt'
            pd_exerpt[HeadData[0]] = R1Ffilt[:, 0]
            pd_exerpt[HeadData[1]] = R1Ffilt[:, 1]

        # Get the list of dates to process
        ListDates, ListDatesStr = GetPairListDates(readStartDate, readStopDate,
                                                   DatesString, decalage,
                                                   nbperiodes, recouvrement)
        if verbose > 1:
            #print('ListDates   =', ListDates)
            print('ListDatesStr=', ListDatesStr)
            #input('pause')

        # Solveur edo
        solveur = SolveEDO_SEIR1R2D(N, dt, verbose)
        indexdata = solveur.indexdata
        E0, I0, R10, R20, D0 = 0, 1, 0, 0, 0

        # Repertoire des figures
        if plot == True:
            repertoire = getRepertoire(
                UKF_filt, './figures/SEIR1R2D_UKFilt/' + placefull + '/sexe_' +
                str(sexe) + '_delay_' + str(decalage), './figures/SEIR1R2D/' +
                placefull + '/sexe_' + str(sexe) + '_delay_' + str(decalage))
            prefFig = repertoire + '/Process_'

        # Remise à 0 des données
        data.fill(0.)

        # Boucle pour traiter successivement les différentes fenêtres
        ###############################################################

        ListeTextParamPlace = []
        ListetabParamModelPlace = []
        ListeEQM = []

        DEGENERATE_CASE = False

        for i in range(len(ListDatesStr)):

            # dates of the current period
            fitStartDate, fitStopDate = ListDates[i]
            fitStartDateStr, fitStopDateStr = ListDatesStr[i]

            # Est-on dans un CAS degénéré?
            # print(getNbDaysBetweenDateFromString(dateFirstNonZeroStr, fitStopDateStr))
            if getNbDaysBetweenDateFromString(
                    dateFirstNonZeroStr, fitStopDateStr
            ) < 5:  # Il faut au moins 5 données pour fitter
                DEGENERATE_CASE = True

            if i > 0:
                DatesString.addOtherDates(fitStartDateStr)

            # Récupérations des données observées
            dataLengthPeriod = 0
            indMinPeriod = (fitStartDate - readStartDate).days

            for j, z in enumerate(pd_exerpt.loc[
                    fitStartDateStr:addDaystoStrDate(fitStopDateStr, -1),
                    HeadData[0]]):
                data[indMinPeriod + j, 0] = z
                dataLengthPeriod += 1
            for j, z in enumerate(pd_exerpt.loc[
                    fitStartDateStr:addDaystoStrDate(fitStopDateStr, -1),
                    HeadData[1]]):
                data[indMinPeriod + j, 1] = z
            slicedata = slice(indMinPeriod, indMinPeriod + dataLengthPeriod)
            slicedataderiv = slice(slicedata.start + 1, slicedata.stop)
            if verbose > 0:
                print('  dataLength      =', dataLength)
                print('  indMinPeriod    =', indMinPeriod)
                print('  dataLengthPeriod=', dataLengthPeriod)
                print('  fitStartDateStr =', fitStartDateStr)
                print('  fitStopDateStr  =', fitStopDateStr)
                #input('attente')

            # Set initialisation data for the solveur
            ############################################################################

            # paramètres initiaux à optimiser
            if i == 0:
                datelegend = fitStartDateStr
                # ts=getNbDaysBetweenDateFromString(DatesString.listFirstCaseDates[0], readStartDateStr)
                # En premiere approximation, on prend la date du premier cas estimé pour le pays (même si c'est faux pour les régions et dpts)
                ts = getNbDaysBetweenDateFromString(
                    DatesString.listFirstCaseDates[0], dateFirstNonZeroStr)
                if ts < 0:
                    continue  # On passe au pays suivant
                if nbperiodes != 1:  # pour plusieurs périodes
                    #l, b0, c0, f0 = 0.255, 1./5.2, 1./12, 0.08
                    #a0 = (l+c0)*(1.+l/b0)
                    #a0, b0, c0, f0, mu0, xi0 = 0.55, 0.34, 0.12, 0.25, 0.0005, 0.0001
                    a0, b0, c0, f0, mu0, xi0 = 0.60, 0.55, 0.30, 0.50, 0.0005, 0.0001
                    T = 150
                else:  # pour une période
                    #a0, b0, c0, f0, mu0, xi0  = 0.10, 0.29, 0.10, 0.0022, 0.00004, 0.
                    a0, b0, c0, f0, mu0, xi0 = 0.70, 0.25, 0.05, 0.003, 0.0005, 0.0001
                    T = 350

            if i == 1 or i == 2:
                datelegend = None

                _, a0, b0, c0, f0, mu0, xi0 = solveur.modele.getParam()
                R10 = int(data[indMinPeriod,
                               0])  # on corrige R1 à la valeur numérique
                F0 = int(data[indMinPeriod,
                              1])  # on corrige F à la valeur numérique
                if i == 1:
                    a0 /= 4.  # le confinement réduit drastiquement (pour aider l'optimisation)
                T = 120
                ts = 0

            time = np.linspace(0, T - 1, T)

            solveur.modele.setParam(N=N,
                                    a=a0,
                                    b=b0,
                                    c=c0,
                                    f=f0,
                                    mu=mu0,
                                    xi=xi0)
            solveur.setParamInit(N=N, E0=E0, I0=I0, R10=R10, R20=R20, D0=D0)

            # Before optimization
            ###############################

            # Solve ode avant optimization
            sol_ode = solveur.solveEDO(time)
            # calcul time shift initial (ts) with respect to data
            if i == 0:
                ts = solveur.compute_tsfromEQM(data[slicedata, :], T,
                                               indexdata)
            else:
                solveur.TS = ts = 0
            sliceedo = slice(ts, min(ts + dataLengthPeriod, T))
            if verbose > 0:
                print(solveur)
                print('  ts=' + str(ts))

            # plot
            if plot == True and DEGENERATE_CASE == False:
                commontitre = placefull + '- Period ' + str(i) + '\\' + str(
                    len(ListDatesStr) - 1
                ) + ' - [' + fitStartDateStr + '\u2192' + addDaystoStrDate(
                    fitStopDateStr, -1)
                if sewe == 0:
                    titre = commontitre + '] (Delay (delta)=' + str(
                        decalage) + ')'
                else:
                    titre = commontitre + '] (Sex=', +sexestr + ', Delay (delta)=' + str(
                        decalage) + ')'

                listePlot = indexdata
                filename = prefFig + str(decalage) + '_Period' + str(
                    i) + '_' + ''.join(map(str, listePlot)) + 'Init.png'
                solveur.plotEDO(filename,
                                titre,
                                sliceedo,
                                slicedata,
                                plot=listePlot,
                                data=data,
                                text=solveur.getTextParam(datelegend,
                                                          Period=i))

            # Parameters optimization
            ############################################################################

            solveur.paramOptimization(
                data[slicedata, :],
                time)  # version lorsque ts est calculé automatiquement
            #solveur.paramOptimization(data[slicedata, :], time, ts) # version lorsque l'on veut fixer ts
            _, a1, b1, c1, f1, mu1, xi1 = solveur.modele.getParam()
            R0 = solveur.modele.getR0()
            if verbose > 0:
                print('Solver' 's state after optimization=', solveur)
                print('  Reproductivité après: ', R0)

            # After optimization
            ###############################

            # Solve ode avant optimization
            sol_ode = solveur.solveEDO(time)
            # calcul time shift with respect to data
            if i == 0:
                ts = solveur.compute_tsfromEQM(data[slicedata, :], T,
                                               indexdata)
            else:
                solveur.TS = ts = 0
            sliceedo = slice(ts, min(ts + dataLengthPeriod, T))
            sliceedoderiv = slice(sliceedo.start + 1, sliceedo.stop)
            if verbose > 0:
                print(solveur)
                print('  ts=' + str(ts))
            if i == 0:  # on se souvient de la date du premier infesté
                dateI0 = addDaystoStrDate(fitStartDateStr, -ts + shift_0value)
                if verbose > 2:
                    print('dateI0=', dateI0)
                    input('attente')

            # sauvegarde des param (tableau et texte)
            seuil = (data[slicedata.stop - 1, 0] - data[slicedata.start, 0]
                     ) / getNbDaysBetweenDateFromString(
                         fitStartDateStr, fitStopDateStr) / N
            #print('seuil=', seuil)
            #print('DEGENERATE_CASE=', DEGENERATE_CASE)
            if DEGENERATE_CASE == True:
                ROsignificatif = False
                ListetabParamModelPlace.append(
                    [-1., -1., -1., -1., -1., -1., -1.])
            else:
                if seuil < thresholdSignif:
                    ROsignificatif = False
                    ListetabParamModelPlace.append(
                        [a1, b1, c1, f1, mu1, xi1, -1.])
                else:
                    ROsignificatif = True
                    ListetabParamModelPlace.append(
                        [a1, b1, c1, f1, mu1, xi1, R0])
                # print('seuil=', seuil)
                # print('ROsignificatif=', ROsignificatif)
                # print('R0=', R0)
                # input('pause')

            ListeTextParamPlace.append(
                solveur.getTextParamWeak(datelegend, ROsignificatif, Period=i))

            data_deriv_period = (
                data[slicedataderiv, :] -
                data[slicedataderiv.start - 1:slicedataderiv.stop - 1, :]) / dt
            modelR1_deriv_period = (
                sol_ode[sliceedoderiv, indexdata] -
                sol_ode[sliceedoderiv.start - 1:sliceedoderiv.stop - 1,
                        indexdata]) / dt
            data_all_period = data[slicedataderiv, :]
            modelR1_all_period = sol_ode[sliceedoderiv, indexdata]

            if plot == True and DEGENERATE_CASE == False:
                commontitre = placefull + '- Period ' + str(i) + '\\' + str(
                    len(ListDatesStr) - 1
                ) + ' - [' + fitStartDateStr + '\u2192' + addDaystoStrDate(
                    fitStopDateStr, -1)
                if sexe == 0:
                    titre = commontitre + '] (Delay (delta)=' + str(
                        decalage) + ')'
                else:
                    titre = commontitre + '] (Sex=', +sexestr + ', Delay (delta)=' + str(
                        decalage) + ')'

                # listePlot = [0,1,2,3,4,5]
                # filename  = prefFig + str(decalage) + '_Period' + str(i) + '_' + ''.join(map(str, listePlot)) + '.png'
                # solveur.plotEDO(filename, titre, sliceedo, slicedata, plot=listePlot, data=data, text=solveur.getTextParam(datelegend, ROsignificatif, Period=i))
                listePlot = [1, 2, 3, 5]
                filename = prefFig + str(decalage) + '_Period' + str(
                    i) + '_' + ''.join(map(str, listePlot)) + 'Final.png'
                solveur.plotEDO(filename,
                                titre,
                                sliceedo,
                                slicedata,
                                plot=listePlot,
                                data=data,
                                text=solveur.getTextParam(datelegend,
                                                          ROsignificatif,
                                                          Period=i))
                listePlot = indexdata
                filename = prefFig + str(decalage) + '_Period' + str(
                    i) + '_' + ''.join(map(str, listePlot)) + 'Final.png'
                solveur.plotEDO(filename,
                                titre,
                                sliceedo,
                                slicedata,
                                plot=listePlot,
                                data=data,
                                text=solveur.getTextParam(datelegend,
                                                          ROsignificatif,
                                                          Period=i))

                # dérivée  numérique de R1 et F
                filename = prefFig + str(decalage) + '_Period' + str(
                    i) + '_' + ''.join(map(str, listePlot)) + 'Deriv.png'
                solveur.plotEDO_deriv(filename,
                                      titre,
                                      sliceedoderiv,
                                      slicedataderiv,
                                      data_deriv_period,
                                      indexdata,
                                      text=solveur.getTextParam(datelegend,
                                                                ROsignificatif,
                                                                Period=i))

            # sol_ode_withSwitch = solveur.solveEDO_withSwitch(T, timeswitch=ts+dataLengthPeriod)

            # ajout des données dérivées
            data_all[indexplace, slicedataderiv, :] = data_all_period
            modelR1_all[indexplace, slicedataderiv, :] = modelR1_all_period
            data_deriv[indexplace, slicedataderiv, :] = data_deriv_period
            modelR1_deriv[indexplace, slicedataderiv, :] = modelR1_deriv_period

            # ajout des SEIR1R2D
            modelSEIR1R2D[indexplace,
                          slicedata.start:slicedata.stop, :] = sol_ode[
                              ts:ts + sliceedo.stop - sliceedo.start, :]

            # preparation for next iteration
            _, E0, I0, R10, R20, D0 = map(
                int, sol_ode[ts + dataLengthPeriod + recouvrement, :])
            #print('A LA FIN : E0, I0, R10, R20, D0=', E0, I0, R10, R20, D0)

            if verbose > 1:
                input('next step')

        Listepd.append(pd_exerpt)
        ListeDateI0.append(dateI0)

        # calcul de l'EQM sur les données (et non sur les dérivées des données)
        #EQM = mean_squared_error(data_deriv[indexplace, :], modelR1_deriv[indexplace, :])
        EQM = mean_squared_error(data_all[indexplace, :],
                                 modelR1_all[indexplace, :])
        ListeEQM.append(EQM)

        # udpate des listes pour transmission
        ListeTextParam.append(ListeTextParamPlace)
        ListetabParamModel.append(ListetabParamModelPlace)

    return modelSEIR1R2D, ListeTextParam, Listepd, data_deriv, modelR1_deriv, ListetabParamModel, ListeEQM, ListeDateI0
コード例 #18
0
def main(sysargv):
    """
        Program to plot data and generate figures on Covid Data.
 
        :Example:

        For country all around the world (European database)
        >> python PlotDataCovid.py 
        >> python PlotDataCovid.py United_Kingdom
        >> python PlotDataCovid.py Italy 2 1        # Only Italian women
        >> python PlotDataCovid.py France,Germany 1 # Shortcut for processing the two countries successively
        >> python PlotDataCovid.py France,Spain,Italy,United_Kingdom,Germany,Belgium 0

        For French geographical areas (French database)
        >> python PlotDataCovid.py FRANCE,D69         # Department 69 (Rhône), INSEE numbering
        >> python PlotDataCovid.py FRANCE,R84         # Process successively all the dpts of region #84 ('Auvergne-Rhone-Alpes')
        >> python PlotDataCovid.py FRANCE,R32+        # Sum od all dpts of Région 32 ('Hauts de France')
        >> python PlotDataCovid.py FRANCE,MetropoleD  # All the dpts of the metropolitan France
        >> python PlotDataCovid.py FRANCE,MetropoleD+ # France (by summing dpts)
        >> python PlotDataCovid.py FRANCE,MetropoleR+ # All the regions by summing their dpts
        Every combination is possible, e.g.: FRANCE,R32+,D05,R84

        argv[1] : List of countries (ex. France,Germany,Italy), or see above for France. Default: France 
        argv[2] : Sex (male:1, female:2, male+female:0). Only for french database        Default: 0
        argv[3] : Verbose level (debug: 3, ..., almost mute: 0).                         Default: 1
    """

    #Austria,Belgium,Croatia,Czechia,Finland,France,Germany,Greece,Hungary,Ireland,Italy,Lithuania,Poland,Portugal,Romania,Serbia,Spain,Switzerland,Ukraine

    print('Command line : ', sysargv, flush=True)
    if len(sysargv) > 4:
        print('  CAUTION : bad number of arguments - see help')
        exit(1)

    # Constants
    ######################################################@
    dt = 1
    readStartDateStr = "2020-03-01"
    readStopDateStr = None
    France = 'France'

    # Interpetation of arguments - reparation
    ######################################################@

    # Default value for parameters
    listplaces = ['France']
    sexe, sexestr = 0, 'male+female'
    verbose = 1

    # Parameters from argv
    if len(sysargv) > 1: liste = list(sysargv[1].split(','))
    if len(sysargv) > 2: sexe = int(sysargv[2])
    if len(sysargv) > 3: verbose = int(sysargv[3])
    if sexe not in [0, 1, 2]: sexe, sexestr = 0, 'male+female'
    if sexe == 1: sexestr = 'male'
    if sexe == 2: sexestr = 'female'

    # List iof places to process
    listplaces = []
    listnames = []
    if liste[0] == 'FRANCE':
        FrDatabase = True
        liste = liste[1:]
        for el in liste:
            l, n = getPlace(el)
            if el == 'MetropoleR+':
                for l1, n1 in zip(l, n):
                    listplaces.extend(l1)
                    listnames.extend([n1])
            else:
                listplaces.extend(l)
                listnames.extend(n)
    else:
        listplaces = liste[:]
        FrDatabase = False

    # Loop for all places
    ############################################################@
    for indexplace, place in enumerate(listplaces):

        # Get the full name of the place to process, and the special dates corresponding to the place
        if FrDatabase == True:
            placefull = 'France-' + listnames[indexplace][0]
            DatesString = readDates(France, verbose)
        else:
            placefull = place
            DatesString = readDates(place, verbose)

        # Figures repository
        repertoire = getRepertoire(
            True, './figures/data/' + placefull + '/sexe_' + str(sexe))
        prefFig = repertoire + '/'

        # Data reading and plot
        ##########################################################
        if FrDatabase == True:
            pd_exerpt, HeadData, N, readStartDateStr, readStopDateStr, dateFirstNonZeroStr = readDataFrance(
                place,
                readStartDateStr,
                readStopDateStr,
                fileLocalCopy,
                sexe,
                verbose=0)
        else:
            pd_exerpt, HeadData, N, readStartDateStr, readStopDateStr, dateFirstNonZeroStr = readDataEurope(
                place,
                readStartDateStr,
                readStopDateStr,
                fileLocalCopy,
                verbose=0)

        readStartDate = datetime.strptime(readStartDateStr, "%Y-%m-%d")
        readStopDate = datetime.strptime(readStopDateStr, "%Y-%m-%d")
        dataLength = pd_exerpt.shape[0]
        if verbose > 0:
            print('readStartDateStr=', readStartDateStr, ', readStopDateStr=',
                  readStopDateStr)
            print('readStartDate   =', readStartDate, ', readStopDate   =',
                  readStopDate)
            print('dateFirstNonZeroStr=', dateFirstNonZeroStr)
            #input('pause')

        # Adding the gradient
        pd_exerpt['Diff ' + HeadData[0]] = pd_exerpt[HeadData[0]].diff()
        pd_exerpt['Diff ' + HeadData[1]] = pd_exerpt[HeadData[1]].diff()
        pd_exerpt['Diff ' + HeadData[2]] = pd_exerpt[HeadData[2]].diff()

        # Plot and store the figures in the directory
        if sexe == 0:
            titre = placefull
        else:
            titre = placefull + ' - Sex=' + sexestr
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + HeadData[0] + '.png',
                 y=HeadData[0],
                 color='red',
                 Dates=DatesString)
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + HeadData[1] + '.png',
                 y=HeadData[1],
                 color='black',
                 Dates=DatesString)
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + HeadData[2] + '.png',
                 y=HeadData[2],
                 color='black',
                 Dates=DatesString)
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + HeadData[0] + HeadData[1] + '.png',
                 y=[HeadData[0], HeadData[1]],
                 color=['red', 'black'],
                 Dates=DatesString)
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + 'Diff' + HeadData[0] + '.png',
                 y=['Diff ' + HeadData[0]],
                 color='red',
                 Dates=DatesString)
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + 'Diff' + HeadData[1] + '.png',
                 y=['Diff ' + HeadData[1]],
                 color='black',
                 Dates=DatesString)
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + 'Diff' + HeadData[0] + HeadData[1] +
                 '.png',
                 y=['Diff ' + HeadData[0], 'Diff ' + HeadData[1]],
                 color=['red', 'black'],
                 Dates=DatesString)
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + 'Diff' + HeadData[0] + HeadData[1] +
                 HeadData[2] + '.png',
                 y=[
                     'Diff ' + HeadData[0], 'Diff ' + HeadData[1],
                     'Diff ' + HeadData[2]
                 ],
                 color=['red', 'black', 'blue'],
                 Dates=DatesString)

        # Data filtering and plot
        ######################################################

        # R1+D filtering by UKF
        data = pd_exerpt[HeadData[2]].tolist()
        dt = 1
        sigmas = MerweScaledSigmaPoints(n=1, alpha=.5, beta=2.,
                                        kappa=0.)  #1-3.)
        ukf = UKF(dim_x=1, dim_z=1, fx=fR1, hx=hR1, dt=dt, points=sigmas)
        # Filter init
        ukf.x[0] = data[0]
        ukf.Q = np.diag([30.])
        ukf.R = np.diag([170.])
        if verbose > 1:
            print('ukf.x[0]=', ukf.x[0])
            print('ukf.R   =', ukf.R)
            print('ukf.Q   =', ukf.Q)

        # UKF filtering and smoothing, batch mode
        R1filt, _ = ukf.batch_filter(data)

        # plotting
        pd_exerpt[HeadData[2] + ' filt'] = R1filt
        if sexe == 0:
            titre = placefull
        else:
            titre = placefull + ' - Sex=' + sexestr
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + 'filt' + HeadData[2] + '.png',
                 y=[HeadData[2], HeadData[2] + ' filt'],
                 color=['red', 'darkred'],
                 Dates=DatesString)
        pd_exerpt['Diff ' + HeadData[2] + ' filt'] = pd_exerpt[HeadData[2] +
                                                               ' filt'].diff()
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + 'diff_filt' + HeadData[2] + '.png',
                 y=['Diff ' + HeadData[2], 'Diff ' + HeadData[2] + ' filt'],
                 color=['red', 'darkred'],
                 Dates=DatesString)

        # Diff R1 filtering by UKF
        # It works but identicial to previous plot
        #############################################################################
        # data    = pd_exerpt['Diff cases'].tolist()
        # data[0] = data[1]
        # print('data=', data)
        # dt     = 1
        # sigmas = MerweScaledSigmaPoints(n=1, alpha=.5, beta=2., kappa=1.) #1-3.)
        # ukf    = UKF(dim_x=1, dim_z=1, fx=fR1, hx=hR1, dt=dt, points=sigmas)
        # # Filter init
        # ukf.x[0] = data[0]
        # ukf.Q    = np.diag([30.])
        # ukf.R    = np.diag([170.])
        # if verbose>1:
        #     print('ukf.x[0]=', ukf.x[0])
        #     print('ukf.R   =', ukf.R)
        #     print('ukf.Q   =', ukf.Q)

        # # UKF filtering and smoothing, batch mode
        # diffR1filt, _ = ukf.batch_filter(data)
        # pd_exerpt['diffR1 filt'] = diffR1filt
        # PlotData(pd_exerpt, titre=titre, filenameFig=prefFig+'diffcases_filt'+HeadData[0]+'.png', y=['Diff cases', 'diffR1 filt'], color=['red', 'darkred'], Dates=DatesString)

        # F filtering by UKF
        #############################################################################
        data = pd_exerpt[HeadData[1]].tolist()
        dt = 1
        sigmas = MerweScaledSigmaPoints(n=1, alpha=.5, beta=2.,
                                        kappa=0.)  #1-3.)
        ukf = UKF(dim_x=1, dim_z=1, fx=fF, hx=hF, dt=dt, points=sigmas)
        # Filter init
        ukf.x[0] = data[0]
        ukf.Q = np.diag([15.])
        ukf.R = np.diag([100.])
        if verbose > 1:
            print('ukf.x[0]=', ukf.x[0])
            print('ukf.R   =', ukf.R)
            print('ukf.Q   =', ukf.Q)

        # UKF filtering and smoothing, batch mode
        Ffilt, _ = ukf.batch_filter(data)

        # plotting
        pd_exerpt[HeadData[1] + ' filt'] = Ffilt
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + 'filt' + HeadData[1] + '.png',
                 y=[HeadData[1], HeadData[1] + ' filt'],
                 color=['gray', 'black'],
                 Dates=DatesString)
        pd_exerpt['Diff ' + HeadData[1] + ' filt'] = pd_exerpt[HeadData[1] +
                                                               ' filt'].diff()
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + 'diff_filt' + HeadData[1] + '.png',
                 y=['Diff ' + HeadData[1], 'Diff ' + HeadData[1] + ' filt'],
                 color=['gray', 'black'],
                 Dates=DatesString)

        # R1 and F simultaneous filtering by UKF
        #############################################################################
        data = pd_exerpt[[HeadData[0], HeadData[1]]].to_numpy(copy=True)
        dt = 1
        sigmas = MerweScaledSigmaPoints(n=2, alpha=.5, beta=2.,
                                        kappa=1.)  #1-3.)
        ukf = UKF(dim_x=2, dim_z=2, fx=fR1F, hx=hR1F, dt=dt, points=sigmas)
        # Filter init
        ukf.x[:] = data[0, :]
        ukf.Q = np.diag([30., 15.])
        ukf.R = np.diag([170., 100.])
        if verbose > 1:
            print('ukf.x[:]=', ukf.x[:])
            print('ukf.R   =', ukf.R)
            print('ukf.Q   =', ukf.Q)

        # UKF filtering and smoothing, batch mode
        R1Ffilt, _ = ukf.batch_filter(data)

        # plotting
        pd_exerpt[HeadData[0] + ' filtboth'] = R1Ffilt[:, 0]
        pd_exerpt[HeadData[1] + ' filtboth'] = R1Ffilt[:, 1]
        PlotData(pd_exerpt, titre=titre, filenameFig=prefFig+'filtboth'+HeadData[0]+HeadData[1]+'.png', \
                y=[HeadData[0], HeadData[0]+' filtboth', HeadData[1], HeadData[1]+' filtboth'], color=['red', 'darkred', 'gray', 'black'], Dates=DatesString)

        pd_exerpt['Diff ' + HeadData[0] +
                  ' filtboth'] = pd_exerpt[HeadData[0] + ' filtboth'].diff()
        pd_exerpt['Diff ' + HeadData[1] +
                  ' filtboth'] = pd_exerpt[HeadData[1] + ' filtboth'].diff()
        PlotData(pd_exerpt, titre=titre, filenameFig=prefFig+'diff_filt'+HeadData[0]+HeadData[1]+'.png', \
                y=['Diff '+HeadData[0], 'Diff '+HeadData[0]+' filtboth', 'Diff '+HeadData[1], 'Diff '+HeadData[1]+' filtboth', ], color=['red', 'darkred', 'gray', 'black'], Dates=DatesString)
コード例 #19
0
def two_radar_constalt():
    dt = .05
    
    
    def hx(x):
        r1, b1 = hx.R1.reading_of((x[0], x[2]))
        r2, b2 = hx.R2.reading_of((x[0], x[2]))
    
        return array([r1, b1, r2, b2])
        pass
    
    
    def fx(x, dt):
        x_est = x.copy()
        x_est[0] += x[1]*dt
        return x_est
    
   
    
    vx = 100/1000 # meters/sec
    vz = 0.
    
    f = UKF(dim_x=3, dim_z=4, dt=dt, hx=hx, fx=fx, kappa=0)
    aircraft = ACSim ((0, 1), (vx*dt, vz*dt), 0.00)
    
    
    range_std = 1/1000.
    bearing_std =1/1000000.
    
    R1 = RadarStation ((  0, 0), range_std, bearing_std)
    R2 = RadarStation ((60, 0), range_std, bearing_std)
    
    hx.R1 = R1
    hx.R2 = R2
    
    f.x = array([aircraft.pos[0], vx, aircraft.pos[1]])
    f.R = np.diag([range_std**2, bearing_std**2, range_std**2, bearing_std**2])
    q = Q_discrete_white_noise(2, var=0.0002, dt=dt)
    #q = np.array([[0,0],[0,0.0002]])
    f.Q[0:2, 0:2] = q
    f.Q[2,2] = 0.0002
    f.P = np.diag([.1, 0.01, .1])*0.1
    
    
    track = []
    zs = []
    
    
    for i in range(int(500/dt)):   
        pos = aircraft.update()
    
        r1, b1 = R1.noisy_reading(pos)
        r2, b2 = R2.noisy_reading(pos)
    
        z = np.array([r1, b1, r2, b2])
        zs.append(z)
        track.append(pos.copy())
    
    zs = asarray(zs)
    
    
    xs, Ps = f.batch_filter(zs)
    ms, _, _ = f.rts_smoother(xs, Ps)
    
    track = asarray(track)
    time = np.arange(0,len(xs)*dt, dt)
    
    plt.figure()
    plt.subplot(311)
    plt.plot(time, track[:,0])
    plt.plot(time, xs[:,0])
    plt.legend(loc=4)
    plt.xlabel('time (sec)')
    plt.ylabel('x position (m)')  

    plt.subplot(312)
    plt.plot(time, xs[:,1]*1000, label="UKF")
    plt.plot(time, ms[:,1]*1000, label='RTS')
    plt.legend(loc=4)
    plt.xlabel('time (sec)')
    plt.ylabel('velocity (m/s)')
    
    plt.subplot(313)
    plt.plot(time, xs[:,2]*1000, label="UKF")
    plt.plot(time, ms[:,2]*1000, label='RTS')
    plt.legend(loc=4)
    plt.xlabel('time (sec)')
    plt.ylabel('altitude (m)')
    plt.ylim([900,1100])
    
    for z in zs[:10]:
        p = R1.z_to_x(z[0], z[1])
        #plt.scatter(p[0], p[1], marker='+', c='k')
        
        p = R2.z_to_x(z[2], z[3])
        #plt.scatter(p[0], p[1], marker='+', c='b')
  
    plt.show()
コード例 #20
0
def two_radar_constvel():
    dt = 5

    def hx(x):
        r1, b1 = hx.R1.reading_of((x[0], x[2]))
        r2, b2 = hx.R2.reading_of((x[0], x[2]))

        return array([r1, b1, r2, b2])
        pass

    def fx(x, dt):
        x_est = x.copy()
        x_est[0] += x[1] * dt
        x_est[2] += x[3] * dt
        return x_est

    f = UKF(dim_x=4, dim_z=4, dt=dt, hx=hx, fx=fx, kappa=0)
    aircraft = ACSim((100, 100), (0.1 * dt, 0.02 * dt), 0.002)

    range_std = 0.2
    bearing_std = radians(0.5)

    R1 = RadarStation((0, 0), range_std, bearing_std)
    R2 = RadarStation((200, 0), range_std, bearing_std)

    hx.R1 = R1
    hx.R2 = R2

    f.x = array([100, 0.1, 100, 0.02])
    f.R = np.diag([range_std ** 2, bearing_std ** 2, range_std ** 2, bearing_std ** 2])
    q = Q_discrete_white_noise(2, var=0.002, dt=dt)
    # q = np.array([[0,0],[0,0.0002]])
    f.Q[0:2, 0:2] = q
    f.Q[2:4, 2:4] = q
    f.P = np.diag([0.1, 0.01, 0.1, 0.01])

    track = []
    zs = []

    for i in range(int(300 / dt)):

        pos = aircraft.update()

        r1, b1 = R1.noisy_reading(pos)
        r2, b2 = R2.noisy_reading(pos)

        z = np.array([r1, b1, r2, b2])
        zs.append(z)
        track.append(pos.copy())

    zs = asarray(zs)

    xs, Ps, Pxz = f.batch_filter(zs)
    ms, _, _ = f.rts_smoother2(xs, Ps, Pxz)

    track = asarray(track)
    time = np.arange(0, len(xs) * dt, dt)

    plt.figure()
    plt.subplot(411)
    plt.plot(time, track[:, 0])
    plt.plot(time, xs[:, 0])
    plt.legend(loc=4)
    plt.xlabel("time (sec)")
    plt.ylabel("x position (m)")

    plt.subplot(412)
    plt.plot(time, track[:, 1])
    plt.plot(time, xs[:, 2])
    plt.legend(loc=4)
    plt.xlabel("time (sec)")
    plt.ylabel("y position (m)")

    plt.subplot(413)
    plt.plot(time, xs[:, 1])
    plt.plot(time, ms[:, 1])
    plt.legend(loc=4)
    plt.ylim([0, 0.2])
    plt.xlabel("time (sec)")
    plt.ylabel("x velocity (m/s)")

    plt.subplot(414)
    plt.plot(time, xs[:, 3])
    plt.plot(time, ms[:, 3])
    plt.ylabel("y velocity (m/s)")
    plt.legend(loc=4)
    plt.xlabel("time (sec)")

    plt.show()
コード例 #21
0
def two_radar_constvel():
    dt = 5
    
    
    def hx(x):
        r1, b1 = hx.R1.reading_of((x[0], x[2]))
        r2, b2 = hx.R2.reading_of((x[0], x[2]))
    
        return array([r1, b1, r2, b2])
        pass
    
    
    
    def fx(x, dt):
        x_est = x.copy()
        x_est[0] += x[1]*dt
        x_est[2] += x[3]*dt
        return x_est
    
    
    
    
    f = UKF(dim_x=4, dim_z=4, dt=dt, hx=hx, fx=fx, kappa=0)
    aircraft = ACSim ((100,100), (0.1*dt,0.02*dt), 0.002)
    
    
    range_std = 0.2
    bearing_std = radians(0.5)
    
    R1 = RadarStation ((0,0), range_std, bearing_std)
    R2 = RadarStation ((200,0), range_std, bearing_std)
    
    hx.R1 = R1
    hx.R2 = R2
    
    f.x = array([100, 0.1, 100, 0.02])
    f.R = np.diag([range_std**2, bearing_std**2, range_std**2, bearing_std**2])
    q = Q_discrete_white_noise(2, var=0.002, dt=dt)
    #q = np.array([[0,0],[0,0.0002]])
    f.Q[0:2, 0:2] = q
    f.Q[2:4, 2:4] = q
    f.P = np.diag([.1, 0.01, .1, 0.01])
    
    
    track = []
    zs = []
    
    
    for i in range(int(300/dt)):
    
        pos = aircraft.update()
    
        r1, b1 = R1.noisy_reading(pos)
        r2, b2 = R2.noisy_reading(pos)
    
        z = np.array([r1, b1, r2, b2])
        zs.append(z)
        track.append(pos.copy())
    
    zs = asarray(zs)
    
    
    xs, Ps, Pxz = f.batch_filter(zs)
    ms, _, _ = f.rts_smoother2(xs, Ps, Pxz)
    
    track = asarray(track)
    time = np.arange(0,len(xs)*dt, dt)
    
    plt.figure()
    plt.subplot(411)
    plt.plot(time, track[:,0])
    plt.plot(time, xs[:,0])
    plt.legend(loc=4)
    plt.xlabel('time (sec)')
    plt.ylabel('x position (m)')
    
    
    
    plt.subplot(412)
    plt.plot(time, track[:,1])
    plt.plot(time, xs[:,2])
    plt.legend(loc=4)
    plt.xlabel('time (sec)')
    plt.ylabel('y position (m)')
    
    
    plt.subplot(413)
    plt.plot(time, xs[:,1])
    plt.plot(time, ms[:,1])
    plt.legend(loc=4)
    plt.ylim([0, 0.2])
    plt.xlabel('time (sec)')
    plt.ylabel('x velocity (m/s)')
    
    plt.subplot(414)
    plt.plot(time, xs[:,3])
    plt.plot(time, ms[:,3])
    plt.ylabel('y velocity (m/s)')
    plt.legend(loc=4)
    plt.xlabel('time (sec)')
    
    plt.show()
コード例 #22
0
ファイル: sird_model.py プロジェクト: mjonyh/bd_herd_immunity
class Model:
    """
    SIRD model of Covid-19.
    """

    # __CONFIRMED_URL = 'https://bit.ly/35yJO0d'
    __CONFIRMED_URL = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv'
    __CONFIRMED_URL_DATA = pd.read_csv(__CONFIRMED_URL)
    # __RECOVERED_URL = 'https://bit.ly/2L6jLE9'
    __RECOVERED_URL = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_recovered_global.csv'
    __RECOVERED_URL_DATA = pd.read_csv(__RECOVERED_URL)
    # __DEATHS_URL = 'https://bit.ly/2L0hzxQ'
    __DEATHS_URL = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_deaths_global.csv'
    __DEATHS_URL_DATA = pd.read_csv(__DEATHS_URL)
    __POPULATION_URL = 'https://bit.ly/2WYjZCD'
    __JHU_DATA_SHIFT = 4
    __N_FILTERED = 7  # Number of state variables to filter (I, R, D, β, γ, μ and n, the population).
    __N_MEASURED = 3  # Number of measured variables (I, R and D).
    __NB_OF_STEPS = 100
    __DELTA_T = 1 / __NB_OF_STEPS
    __FIG_SIZE = (11, 13)
    __S_COLOR = '#0072bd'
    __I_COLOR = '#d95319'
    __R_COLOR = '#edb120'
    __D_COLOR = '#7e2f8e'
    __BETA_COLOR = '#77ac30'
    __GAMMA_COLOR = '#4dbeee'
    __MU_COLOR = '#a2142f'
    __DATA_ALPHA = 0.3
    __DATA = None
    __POPULATION = None

    class Use(Enum):
        WIKIPEDIA = auto()
        DATA = auto()

    def __init__(self,
                 use=Use.DATA,
                 country='Bangladesh',
                 max_data=0,
                 people=1e6,
                 tag=1):
        """
        Initialise our Model object.
        """

        # Retrieve the data (if requested and needed).

        if use == Model.Use.DATA and Model.__DATA is None:
            confirmed_data, confirmed_data_start = self.__jhu_data(
                Model.__CONFIRMED_URL_DATA, country)
            recovered_data, recovered_data_start = self.__jhu_data(
                Model.__RECOVERED_URL_DATA, country)
            deaths_data, deaths_data_start = self.__jhu_data(
                Model.__DEATHS_URL_DATA, country)
            data_start = min(confirmed_data_start, recovered_data_start,
                             deaths_data_start) - Model.__JHU_DATA_SHIFT
            start_date = confirmed_data.columns[data_start].split('/')

            for i in range(data_start, confirmed_data.shape[1]):
                c = confirmed_data.iloc[0][i]
                r = recovered_data.iloc[0][i]
                d = deaths_data.iloc[0][i]
                data = [c - r - d, r, d]
                # print(c, r, d)

                if Model.__DATA is None:
                    Model.__DATA = np.array(data)
                else:
                    Model.__DATA = np.vstack((Model.__DATA, data))

        # print(Model.__DATA)
        # Model.__DATA = smooth_data(Model.__DATA)
        if use == Model.Use.DATA:
            self.__data = Model.__DATA
        else:
            self.__data = None

        if self.__data is not None:
            if not isinstance(max_data, int):
                sys.exit('Error: \'max_data\' must be an integer value.')

            if max_data > 0:
                self.__data = self.__data[:max_data]

        # Retrieve the population (if needed).

        if (tag == 0):

            if use == Model.Use.DATA:
                Model.__POPULATION = {}
                # print(tag)
                tag = 1

                # response = requests.get(Model.__POPULATION_URL)
                # soup = BeautifulSoup(response.text, 'html.parser')
                # data = soup.select('div div div div div tbody tr')

                # for i in range(len(data)):
                #     country_soup = BeautifulSoup(data[i].prettify(), 'html.parser')
                #     country_value = country_soup.select('tr td a')[0].get_text().strip()
                #     population_value = country_soup.select('tr td')[2].get_text().strip().replace(',', '')

                #     # Model.__POPULATION[country_value] = int(population_value)
                #     # print(people)
                #     # Model.__POPULATION[country_value] = int(people)
                Model.__POPULATION[country] = int(people)

            if use == Model.Use.DATA:
                if country in Model.__POPULATION:
                    self.__population = Model.__POPULATION[country]
                else:
                    sys.exit('Error: no population data is available for {}.'.
                             format(country))

        # Keep track of whether to use the data.

        self.__use_data = use == Model.Use.DATA

        # Declare some internal variables (that will then be initialised through our call to reset()).

        self.__beta = None
        self.__gamma = None
        self.__mu = None

        self.__ukf = None

        self.__x = None
        self.__n = None

        self.__data_s_values = None
        self.__data_i_values = None
        self.__data_r_values = None
        self.__data_d_values = None

        self.__s_values = None
        self.__i_values = None
        self.__r_values = None
        self.__d_values = None

        self.__beta_values = None
        self.__gamma_values = None
        self.__mu_values = None

        # Initialise (i.e. reset) our SIRD model.

        self.__start_date = pd.to_datetime(start_date[0] + '-' +
                                           start_date[1] + '-' + start_date[2])
        self.__confirmed_data = confirmed_data.iloc[0][data_start:-1]
        self.__recovered_data = recovered_data.iloc[0][data_start:-1]
        self.__deaths_data = deaths_data.iloc[0][data_start:-1]

        self.reset()

    @staticmethod
    def __jhu_data(data, country):
        # data = pd.read_csv(url)
        # data = data[(data['Country/Region'] == country) & data['Province/State'].isnull()]
        data_1 = data[(data['Country/Region'] == country)
                      & data['Province/State'].isnull()]
        if data_1.shape[0] == 0:
            data = data[(data['Country/Region'] == country
                         )].groupby('Country/Region').sum()
            data.to_csv('test_data.csv')
            data = pd.read_csv('test_data.csv')
        else:
            data = data_1

        if data.shape[0] == 0:
            sys.exit(
                'Error: no Covid-19 data is available for {}.'.format(country))

        data = data.drop(data.columns[list(range(Model.__JHU_DATA_SHIFT))],
                         axis=1)  # Skip non-data columns.
        start = None

        for i in range(data.shape[1]):
            if data.iloc[0][i] != 0:
                start = Model.__JHU_DATA_SHIFT + i

                break

        return data, start

    def __data_x(self, day, index):
        """
        Return the I/R/D value for the given day.
        """

        return self.__data[day][index] if self.__use_data else math.nan

    def __data_s(self, day):
        """
        Return the S value for the given day.
        """

        if self.__use_data:
            return self.__population - self.__data_i(day) - self.__data_r(
                day) - self.__data_d(day)
        else:
            return math.nan

    def __data_i(self, day):
        """
        Return the I value for the given day.
        """

        return self.__data_x(day, 0)

    def __data_r(self, day):
        """
        Return the R value for the given day.
        """

        return self.__data_x(day, 1)

    def __data_d(self, day):
        """
        Return the D value for the given day.
        """

        return self.__data_x(day, 2)

    def __data_available(self, day):
        """
        Return whether some data is available for the given day.
        """

        return day <= self.__data.shape[0] - 1 if self.__use_data else False

    def __s_value(self):
        """
        Return the S value based on the values of I, R, D and N.
        """

        return self.__n - self.__x.sum()

    def __i_value(self):
        """
        Return the I value.
        """

        return self.__x[0]

    def __r_value(self):
        """
        Return the R value.
        """

        return self.__x[1]

    def __d_value(self):
        """
        Return the D value.
        """

        return self.__x[2]

    @staticmethod
    def __f(x, dt, **kwargs):
        """
        State function.

        The ODE system to solve is:
          dI/dt = βIS/N - γI - μI
          dR/dt = γI
          dD/dt = μI
        """

        model_self = kwargs.get('model_self')
        with_ukf = kwargs.get('with_ukf', True)

        if with_ukf:
            s = x[6] - x[:3].sum()
            beta = x[3]
            gamma = x[4]
            mu = x[5]
            n = x[6]
        else:
            s = model_self.__n - x.sum()
            beta = model_self.__beta
            gamma = model_self.__gamma
            mu = model_self.__mu
            n = model_self.__n

        a = np.array([[1 + dt * (beta * s / n - gamma - mu), 0, 0, 0, 0, 0, 0],
                      [dt * gamma, 1, 0, 0, 0, 0, 0],
                      [dt * mu, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0],
                      [0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1, 0],
                      [0, 0, 0, 0, 0, 0, 1]])

        if with_ukf:
            return a @ x
        else:
            return a[:3, :3] @ x

    @staticmethod
    def __h(x):
        """
        Measurement function.
        """

        return x[:Model.__N_MEASURED]

    def reset(self):
        """
        Reset our SIRD model.
        """

        # Reset β, γ and μ to the values mentioned on Wikipedia (see https://bit.ly/2VMvb6h).

        Model.__DATA = None

        self.__beta = 0.4
        self.__gamma = 0.035
        self.__mu = 0.005

        # Reset I, R and D to the data at day 0 or the values mentioned on Wikipedia (see https://bit.ly/2VMvb6h).

        if self.__use_data:
            self.__x = np.array(
                [self.__data_i(0),
                 self.__data_r(0),
                 self.__data_d(0)])
            self.__n = self.__population
        else:
            self.__x = np.array([3, 0, 0])
            self.__n = 1000

        # Reset our Unscented Kalman filter (if required). Note tat we use a dt value of 1 (day) and not the value of
        # Model.__DELTA_T.

        if self.__use_data:
            points = MerweScaledSigmaPoints(
                Model.__N_FILTERED,
                1e-3,  # Alpha value (usually a small positive value like 1e-3).
                2,  # Beta value (a value of 2 is optimal for a Gaussian distribution).
                0,  # Kappa value (usually, either 0 or 3-n).
            )

            self.__ukf = UnscentedKalmanFilter(Model.__N_FILTERED,
                                               Model.__N_MEASURED, 1, self.__h,
                                               Model.__f, points)

            self.__ukf.x = np.array([
                self.__data_i(0),
                self.__data_r(0),
                self.__data_d(0), self.__beta, self.__gamma, self.__mu,
                self.__n
            ])

            self.__ukf.P *= 15

        # Reset our data (if requested).

        if self.__use_data:
            self.__data_s_values = np.array([self.__data_s(0)])
            self.__data_i_values = np.array([self.__data_i(0)])
            self.__data_r_values = np.array([self.__data_r(0)])
            self.__data_d_values = np.array([self.__data_d(0)])

        # Reset our predicted/estimated values.

        self.__s_values = np.array([self.__s_value()])
        self.__i_values = np.array([self.__i_value()])
        self.__r_values = np.array([self.__r_value()])
        self.__d_values = np.array([self.__d_value()])

        # Reset our estimated SIRD model parameters.

        self.__beta_values = np.array([self.__beta])
        self.__gamma_values = np.array([self.__gamma])
        self.__mu_values = np.array([self.__mu])

    def run(self, batch_filter=True, nb_of_days=100):
        """
        Run our SIRD model for the given number of days, taking advantage of the data (if requested) to estimate the
        values of β, γ and μ.
        """

        # Make sure that we were given a valid number of days.

        if not isinstance(nb_of_days, int) or nb_of_days <= 0:
            sys.exit(
                'Error: \'nb_of_days\' must be an integer value greater than zero.'
            )

        # Run our SIRD simulation, which involves computing our predicted/estimated state by computing our SIRD model /
        # Unscented Kalman filter in batch filter mode, if required.

        if self.__use_data and batch_filter:
            mu, cov = self.__ukf.batch_filter(self.__data)
            batch_filter_x, _, _ = self.__ukf.rts_smoother(mu, cov)

            # Override the first value of S, I, R and D.

            x = batch_filter_x[0][:3]

            self.__s_values = np.array([self.__n - x.sum()])
            self.__i_values = np.array([x[0]])
            self.__r_values = np.array([x[1]])
            self.__d_values = np.array([x[2]])

        for i in range(1, nb_of_days + 1):
            # Compute our predicted/estimated state by computing our SIRD model / Unscented Kalman filter for one day.

            if self.__use_data and self.__data_available(i):
                if batch_filter:
                    self.__x = batch_filter_x[i][:3]
                    self.__beta = batch_filter_x[i][3]
                    self.__gamma = batch_filter_x[i][4]
                    self.__mu = batch_filter_x[i][5]
                else:
                    self.__ukf.predict(model_self=self)
                    self.__ukf.update(
                        np.array([
                            self.__data_i(i),
                            self.__data_r(i),
                            self.__data_d(i)
                        ]))

                    self.__x = self.__ukf.x[:3]
                    self.__beta = self.__ukf.x[3]
                    self.__gamma = self.__ukf.x[4]
                    self.__mu = self.__ukf.x[5]
            else:
                for j in range(1, Model.__NB_OF_STEPS + 1):
                    self.__x = Model.__f(self.__x,
                                         Model.__DELTA_T,
                                         model_self=self,
                                         with_ukf=False)

            # Keep track of our data (if requested).

            if self.__use_data:
                if self.__data_available(i):
                    self.__data_s_values = np.append(self.__data_s_values,
                                                     self.__data_s(i))
                    self.__data_i_values = np.append(self.__data_i_values,
                                                     self.__data_i(i))
                    self.__data_r_values = np.append(self.__data_r_values,
                                                     self.__data_r(i))
                    self.__data_d_values = np.append(self.__data_d_values,
                                                     self.__data_d(i))
                else:
                    self.__data_s_values = np.append(self.__data_s_values,
                                                     math.nan)
                    self.__data_i_values = np.append(self.__data_i_values,
                                                     math.nan)
                    self.__data_r_values = np.append(self.__data_r_values,
                                                     math.nan)
                    self.__data_d_values = np.append(self.__data_d_values,
                                                     math.nan)

            # Keep track of our predicted/estimated values.

            self.__s_values = np.append(self.__s_values, self.__s_value())
            self.__i_values = np.append(self.__i_values, self.__i_value())
            self.__r_values = np.append(self.__r_values, self.__r_value())
            self.__d_values = np.append(self.__d_values, self.__d_value())

            # Keep track of our estimated SIRD model parameters.

            self.__beta_values = np.append(self.__beta_values, self.__beta)
            self.__gamma_values = np.append(self.__gamma_values, self.__gamma)
            self.__mu_values = np.append(self.__mu_values, self.__mu)

    def plot(self, figure=None, two_axes=False):
        """
        Plot the results using five subplots for 1) S, 2) I and R, 3) D, 4) β, and 5) γ and μ. In each subplot, we plot
        the data (if requested) as bars and the computed value as a line.
        """

        # days = range(self.__s_values.size)
        days = convert_to_date(self.__start_date, self.__s_values)
        nrows = 5 if self.__use_data else 3
        ncols = 1

        if figure is None:
            show_figure = True
            figure, axes = plt.subplots(nrows,
                                        ncols,
                                        figsize=Model.__FIG_SIZE,
                                        sharex=True)
        else:
            figure.clf()

            show_figure = False
            axes = figure.subplots(nrows, ncols, sharex=True)

        figure.canvas.set_window_title('SIRD model fitted to data' if self.
                                       __use_data else 'Wikipedia SIRD model')

        # First subplot: S.

        axis1 = axes[0]
        axis1.plot(days, self.__s_values, Model.__S_COLOR, label='S')
        axis1.legend(loc='best')
        if self.__use_data:
            axis2 = axis1.twinx() if two_axes else axis1
            axis2.bar(days,
                      self.__data_s_values,
                      color=Model.__S_COLOR,
                      alpha=Model.__DATA_ALPHA)
            data_s_range = self.__population - min(self.__data_s_values)
            data_block = 10**(math.floor(math.log10(data_s_range)) - 1)
            s_values_shift = data_block * math.ceil(data_s_range / data_block)
            axis2.set_ylim(
                min(min(self.__s_values), self.__population - s_values_shift),
                self.__population)

        # Second subplot: I and R.

        axis1 = axes[1]
        axis1.plot(days, self.__i_values, Model.__I_COLOR, label='I')
        axis1.plot(days, self.__r_values, Model.__R_COLOR, label='R')
        axis1.legend(loc='best')
        if self.__use_data:
            axis2 = axis1.twinx() if two_axes else axis1
            axis2.bar(days,
                      self.__data_i_values,
                      color=Model.__I_COLOR,
                      alpha=Model.__DATA_ALPHA)
            axis2.bar(days,
                      self.__data_r_values,
                      color=Model.__R_COLOR,
                      alpha=Model.__DATA_ALPHA)

        # Third subplot: D.

        axis1 = axes[2]
        axis1.plot(days, self.__d_values, Model.__D_COLOR, label='D')
        axis1.legend(loc='best')
        if self.__use_data:
            axis2 = axis1.twinx() if two_axes else axis1
            axis2.bar(days,
                      self.__data_d_values,
                      color=Model.__D_COLOR,
                      alpha=Model.__DATA_ALPHA)

        # Fourth subplot: β.

        if self.__use_data:
            axis1 = axes[3]
            axis1.plot(days, self.__beta_values, Model.__BETA_COLOR, label='β')
            axis1.legend(loc='best')

        # Fourth subplot: γ and μ.

        if self.__use_data:
            axis1 = axes[4]
            axis1.plot(days,
                       self.__gamma_values,
                       Model.__GAMMA_COLOR,
                       label='γ')
            axis1.plot(days, self.__mu_values, Model.__MU_COLOR, label='μ')
            axis1.legend(loc='best')

        plt.xlabel('time (day)')

        if show_figure:
            plt.show()

    def movie(self, filename, batch_filter=True, nb_of_days=100):
        """
        Generate, if using the data, a movie showing the evolution of our SIRD model throughout time.
        """

        if self.__use_data:
            data_size = Model.__DATA.shape[0]
            figure = plt.figure(figsize=Model.__FIG_SIZE)
            backend = matplotlib.get_backend()
            writer = manimation.writers['ffmpeg']()

            matplotlib.use("Agg")

            with writer.saving(figure, filename, 96):
                for i in range(1, data_size + 1):
                    print('Processing frame #',
                          i,
                          '/',
                          data_size,
                          '...',
                          sep='')

                    self.__data = Model.__DATA[:i]

                    self.reset()
                    self.run(batch_filter=batch_filter, nb_of_days=nb_of_days)
                    self.plot(figure=figure)

                    writer.grab_frame()

                print('All done!')

            matplotlib.use(backend)

    def s(self, day=-1):
        """
        Return all the S values (if day=-1) or its value for a given day.
        """

        if day == -1:
            return self.__s_values
        else:
            return self.__s_values[day]

    def i(self, day=-1):
        """
        Return all the I values (if day=-1) or its value for a given day.
        """

        if day == -1:
            return self.__i_values
        else:
            return self.__i_values[day]

    def r(self, day=-1):
        """
        Return all the R values (if day=-1) or its value for a given day.
        """

        if day == -1:
            return self.__r_values
        else:
            return self.__r_values[day]

    def d(self, day=-1):
        """
        Return all the D values (if day=-1) or its value for a given day.
        """

        if day == -1:
            return self.__d_values
        else:
            return self.__d_values[day]

    def beta(self, day=-1):
        """
        Return all the D values (if day=-1) or its value for a given day.
        """

        if day == -1:
            return self.__beta_values
        else:
            return self.__beta_values[day]

    def gamma(self, day=-1):
        """
        Return all the D values (if day=-1) or its value for a given day.
        """

        if day == -1:
            return self.__gamma_values
        else:
            return self.__gamma_values[day]

    def mu(self, day=-1):
        """
        Return all the D values (if day=-1) or its value for a given day.
        """

        if day == -1:
            return self.__mu_values
        else:
            return self.__mu_values[day]

    def days_array(self, day=-1):
        """
        Return all the D values (if day=-1) or its value for a given day.
        """
        days = convert_to_date(self.__start_date, self.__s_values)
        return days

    def days_cases(self, day=-1):
        """
        Return all the D values (if day=-1) or its value for a given day.
        """
        days = convert_to_date(self.__start_date, self.__confirmed_data)
        return days

    def confirmed_data(self, day=-1):
        """
        Return all the D values (if day=-1) or its value for a given day.
        """
        return self.__confirmed_data

    def recovered_data(self, day=-1):
        """
        Return all the D values (if day=-1) or its value for a given day.
        """
        return self.__recovered_data

    def deaths_data(self, day=-1):
        """
        Return all the D values (if day=-1) or its value for a given day.
        """
        return self.__deaths_data

    def active_data(self, day=-1):
        """
        Return all the D values (if day=-1) or its value for a given day.
        """
        return self.__confirmed_data - self.__recovered_data - self.__deaths_data
コード例 #23
0
def test_fixed_lag():
    def fx(x, dt):
        A = np.eye(3) + dt * np.array([[0, 1, 0], [0, 0, 0], [0, 0, 0]])
        f = np.dot(A, x)
        return f

    def hx(x):
        return np.sqrt(x[0]**2 + x[2]**2)

    dt = 0.05

    sp = JulierSigmaPoints(n=3, kappa=0)

    kf = UKF(3, 1, dt, fx=fx, hx=hx, points=sp)

    kf.Q *= 0.01
    kf.R = 10
    kf.x = np.array([0., 90., 1100.])
    kf.P *= 1.
    radar = RadarSim(dt)

    t = np.arange(0, 20 + dt, dt)

    n = len(t)

    xs = np.zeros((n, 3))

    random.seed(200)
    rs = []
    #xs = []

    M = []
    P = []
    N = 10
    flxs = []
    for i in range(len(t)):
        r = radar.get_range()
        #r = GetRadar(dt)
        kf.predict()
        kf.update(z=[r])

        xs[i, :] = kf.x
        flxs.append(kf.x)
        rs.append(r)
        M.append(kf.x)
        P.append(kf.P)
        print(i)
        #print(i, np.asarray(flxs)[:,0])
        if i == 20 and len(M) >= N:
            try:
                M2, P2, K = kf.rts_smoother(Xs=np.asarray(M)[-N:],
                                            Ps=np.asarray(P)[-N:])
                flxs[-N:] = M2
                #flxs[-N:] = [20]*N
            except:
                print('except', i)
            #P[-N:] = P2

    kf.x = np.array([0., 90., 1100.])
    kf.P = np.eye(3) * 100
    M, P = kf.batch_filter(rs)

    Qs = [kf.Q] * len(t)
    M2, P2, K = kf.rts_smoother(Xs=M, Ps=P, Qs=Qs)

    flxs = np.asarray(flxs)
    print(xs[:, 0].shape)

    plt.figure()
    plt.subplot(311)
    plt.plot(t, xs[:, 0])
    plt.plot(t, flxs[:, 0], c='r')
    plt.plot(t, M2[:, 0], c='g')
    plt.subplot(312)
    plt.plot(t, xs[:, 1])
    plt.plot(t, flxs[:, 1], c='r')
    plt.plot(t, M2[:, 1], c='g')

    plt.subplot(313)
    plt.plot(t, xs[:, 2])
    plt.plot(t, flxs[:, 2], c='r')
    plt.plot(t, M2[:, 2], c='g')
def ukf(pr, fs, meas_noise=3):
    """Use an unscented Kalman filter to smooth the data and predict
    positions px, py, pz when we have missing data.

    Parameters:
    pr = (ntime x nmark x 3) raw data array in mm
    fs = sampling rate
    meas_noise = measurement noise (from Qualisys), default=3

    Returns:
    out = dict that holds filtered position and calculated velocity
        and accleration. Keys are as follows:
        p, v, a: pos, vel, acc after RTS smoothing. These are the values
            to fill in the missing data gaps
        nans: (ntime x nmark) bool array that stores where we have bad values
    """

    ntime, nmark, ncoords = pr.shape

    g = 9810  # mm/s^2
    dim_x = 9  # tracked variables px, vx, ax, py, vy, ay, pz, vz, az
    dim_z = 3  # measured variables px, py, pz
    dt = 1 / fs

    # state uncertainty matrix (measurement noise)
    R = meas_noise**2 * np.eye(dim_z)

    # process uncertainty matrix (effect of unmodeled behavior)
    sigx, sigy, sigz = .5 * g, .5 * g, .5 * g
    qx = Q_discrete_white_noise(3, dt=dt, var=sigx**2)
    qy = Q_discrete_white_noise(3, dt=dt, var=sigy**2)
    qz = Q_discrete_white_noise(3, dt=dt, var=sigz**2)
    Q = _sub3(qx, qy, qz)

    # uncertainty covariance matrix
    p0 = np.diag([1, 500, 2 * g])**2
    P = _sub3(p0, p0, p0)

    # store the data in 3D arrays
    pf, vf, af = pr.copy(), pr.copy(), pr.copy()  # for after RTS smoothing
    pf0, vf0, af0 = pr.copy(), pr.copy(), pr.copy()  # for original pass
    nans = np.zeros((ntime, nmark)).astype(np.int)

    # indices for ntime x 9 arrays to get pos, vel, and acc
    pos_idx, vel_idx, acc_idx = [0, 3, 6], [1, 4, 7], [2, 5, 8]

    for j in np.arange(nmark):
        zs = pr[:, j].copy()

        # store the nan values
        idx = np.where(np.isnan(zs[:, 0]))[0]
        nans[idx, j] = 1

        # batch_filter needs a list, with a nan values as None
        zs = zs.tolist()
        for ii in idx:
            zs[ii] = None

        # initial conditions for the smoother
        x0 = np.zeros(9)
        x0[pos_idx] = zs[0]
        x0[vel_idx] = 0, 500, 1000  # guess v0, mm/s
        x0[acc_idx] = .1 * g, .1 * g, -.5 * g  # guess a0, mm/s^2

        # how to calculate the sigma points
        msp = MerweScaledSigmaPoints(n=dim_x, alpha=1e-3, beta=2, kappa=0)

        # setup the filter with our values
        kf = UKF(dim_x, dim_z, dt, _hx, _fx, msp)
        kf.x, kf.P, kf.R, kf.Q = x0, P, R, Q

        # filter 'forward'
        xs, covs = kf.batch_filter(zs)

        # apply RTS smoothing
        Ms, Ps, Ks = kf.rts_smoother(xs, covs, dt=dt)

        # get data out of the (ntime x 9) array
        pf0[:, j] = xs[:, pos_idx]
        vf0[:, j] = xs[:, vel_idx]
        af0[:, j] = xs[:, acc_idx]

        pf[:, j] = Ms[:, pos_idx]
        vf[:, j] = Ms[:, vel_idx]
        af[:, j] = Ms[:, acc_idx]

    # finally store everything in a dictionary
    out = {'p': pf, 'v': vf, 'a': af, 'pf0': pf0, 'vf0': vf0, 'af0': af0,
           'nans': nans, 'xs': xs, 'covs': covs, 'zs': zs, 'x0': x0}

    return out