예제 #1
0
def test_EKFState_with_NcvContinuous():
    d = 6
    ncv = NcvContinuous(dimension=d, sa2=2.0)
    x = torch.rand(d)
    P = torch.eye(d)
    t = 0.0
    dt = 2.0
    ekf_state = EKFState(
        dynamic_model=ncv, mean=x, cov=P, time=t)

    assert ekf_state.dynamic_model.__class__ == NcvContinuous
    assert ekf_state.dimension == d
    assert ekf_state.dimension_pv == d

    assert_equal(x, ekf_state.mean, prec=1e-5)
    assert_equal(P, ekf_state.cov, prec=1e-5)
    assert_equal(x, ekf_state.mean_pv, prec=1e-5)
    assert_equal(P, ekf_state.cov_pv, prec=1e-5)
    assert_equal(t, ekf_state.time, prec=1e-5)

    ekf_state1 = EKFState(ncv, 2*x, 2*P, t)
    ekf_state2 = ekf_state1.predict(dt)
    assert ekf_state2.dynamic_model.__class__ == NcvContinuous

    measurement = PositionMeasurement(
        mean=torch.rand(d),
        cov=torch.eye(d),
        time=t + dt)
    log_likelihood = ekf_state2.log_likelihood_of_update(measurement)
    assert (log_likelihood < 0.).all()
    ekf_state3, (dz, S) = ekf_state2.update(measurement)
    assert dz.shape == (measurement.dimension,)
    assert S.shape == (measurement.dimension, measurement.dimension)
    assert_not_equal(ekf_state3.mean, ekf_state2.mean, prec=1e-5)
예제 #2
0
    def filter_states(self, value):
        """
        Returns the ekf states given measurements

        :param value: measurement means of shape `(time_steps, event_shape)`
        :type value: torch.Tensor
        """
        states = []
        state = EKFState(self.dynamic_model, self.x0, self.P0, time=0.)
        assert value.shape[-1] == self.event_shape[-1]
        for i, measurement_mean in enumerate(value):
            if i:
                state = state.predict(self.dt)
            measurement = PositionMeasurement(measurement_mean, self.measurement_cov,
                                              time=state.time)
            state, (dz, S) = state.update(measurement)
            states.append(state)
        return states
예제 #3
0
    def log_prob(self, value):
        """
        Returns the joint log probability of the innovations of a tensor of measurements

        :param value: measurement means of shape `(time_steps, event_shape)`
        :type value: torch.Tensor
        """
        state = EKFState(self.dynamic_model, self.x0, self.P0, time=0.)
        result = 0.
        assert value.shape == self.event_shape
        zero = torch.zeros(self.event_shape[-1], dtype=value.dtype, device=value.device)
        for i, measurement_mean in enumerate(value):
            if i:
                state = state.predict(self.dt)
            measurement = PositionMeasurement(measurement_mean, self.measurement_cov,
                                              time=state.time)
            state, (dz, S) = state.update(measurement)
            result = result + dist.MultivariateNormal(dz, S).log_prob(zero)
        return result