def test_EKFState_with_NcvContinuous(): d = 6 ncv = NcvContinuous(dimension=d, sa2=2.0) x = torch.rand(d) P = torch.eye(d) t = 0.0 dt = 2.0 ekf_state = EKFState( dynamic_model=ncv, mean=x, cov=P, time=t) assert ekf_state.dynamic_model.__class__ == NcvContinuous assert ekf_state.dimension == d assert ekf_state.dimension_pv == d assert_equal(x, ekf_state.mean, prec=1e-5) assert_equal(P, ekf_state.cov, prec=1e-5) assert_equal(x, ekf_state.mean_pv, prec=1e-5) assert_equal(P, ekf_state.cov_pv, prec=1e-5) assert_equal(t, ekf_state.time, prec=1e-5) ekf_state1 = EKFState(ncv, 2*x, 2*P, t) ekf_state2 = ekf_state1.predict(dt) assert ekf_state2.dynamic_model.__class__ == NcvContinuous measurement = PositionMeasurement( mean=torch.rand(d), cov=torch.eye(d), time=t + dt) log_likelihood = ekf_state2.log_likelihood_of_update(measurement) assert (log_likelihood < 0.).all() ekf_state3, (dz, S) = ekf_state2.update(measurement) assert dz.shape == (measurement.dimension,) assert S.shape == (measurement.dimension, measurement.dimension) assert_not_equal(ekf_state3.mean, ekf_state2.mean, prec=1e-5)
def filter_states(self, value): """ Returns the ekf states given measurements :param value: measurement means of shape `(time_steps, event_shape)` :type value: torch.Tensor """ states = [] state = EKFState(self.dynamic_model, self.x0, self.P0, time=0.) assert value.shape[-1] == self.event_shape[-1] for i, measurement_mean in enumerate(value): if i: state = state.predict(self.dt) measurement = PositionMeasurement(measurement_mean, self.measurement_cov, time=state.time) state, (dz, S) = state.update(measurement) states.append(state) return states
def log_prob(self, value): """ Returns the joint log probability of the innovations of a tensor of measurements :param value: measurement means of shape `(time_steps, event_shape)` :type value: torch.Tensor """ state = EKFState(self.dynamic_model, self.x0, self.P0, time=0.) result = 0. assert value.shape == self.event_shape zero = torch.zeros(self.event_shape[-1], dtype=value.dtype, device=value.device) for i, measurement_mean in enumerate(value): if i: state = state.predict(self.dt) measurement = PositionMeasurement(measurement_mean, self.measurement_cov, time=state.time) state, (dz, S) = state.update(measurement) result = result + dist.MultivariateNormal(dz, S).log_prob(zero) return result