コード例 #1
0
def test_virtual_sensor_srukf_consistency(generated_data):
    """Check that our Virtual Sensor SRUKF and standard EKF produce consistent results for
    a linear system.
    """
    # Create filters
    ekf = torchfilter.filters.ExtendedKalmanFilter(
        dynamics_model=LinearDynamicsModel(),
        measurement_model=LinearKalmanFilterMeasurementModel(),
    )
    virtual_sensor_srukf = (
        torchfilter.filters.VirtualSensorSquareRootUnscentedKalmanFilter(
            dynamics_model=LinearDynamicsModel(),
            virtual_sensor_model=LinearVirtualSensorModel(),
        )
    )

    # Run over data
    _run_filter(ekf, generated_data)
    _run_filter(virtual_sensor_srukf, generated_data)

    # Check final beliefs
    torch.testing.assert_allclose(ekf.belief_mean, virtual_sensor_srukf.belief_mean)
    torch.testing.assert_allclose(
        ekf.belief_covariance,
        virtual_sensor_srukf.belief_covariance,
        rtol=1e-4,
        atol=5e-4,
    )
コード例 #2
0
def test_dynamics_jacobian():
    """Checks that our autograd-computed dynamics jacobian is correct."""
    N = 10
    dynamics_model = LinearDynamicsModel()
    A_autograd = dynamics_model.jacobian(
        initial_states=torch.zeros((N, state_dim)),
        controls=torch.zeros((N, control_dim)),
    )

    for i in range(N):
        torch.testing.assert_allclose(A_autograd[i], A)
コード例 #3
0
def test_train_virtual_sensor_ekf_e2e(subsequence_dataloader, buddy):
    """Check that training our virtual sensor EKF end-to-end drops both dynamics and
    virtual sensor errors.
    """
    # Create individual models + filter
    dynamics_model = LinearDynamicsModel(trainable=True)
    virtual_sensor_model = LinearVirtualSensorModel(trainable=True)
    filter_model = torchfilter.filters.VirtualSensorExtendedKalmanFilter(
        dynamics_model=dynamics_model,
        virtual_sensor_model=virtual_sensor_model)

    # Compute initial errors
    initial_dynamics_error = get_trainable_model_error(dynamics_model)
    initial_virtual_sensor_error = get_trainable_model_error(
        virtual_sensor_model)

    # Train for 1 epoch
    buddy.attach_model(filter_model)
    torchfilter.train.train_filter(
        buddy,
        filter_model,
        subsequence_dataloader,
        initial_covariance=torch.eye(state_dim) * 0.01,
    )

    # Check that errors dropped
    assert get_trainable_model_error(dynamics_model) < initial_dynamics_error
    assert (get_trainable_model_error(virtual_sensor_model) <
            initial_virtual_sensor_error)
コード例 #4
0
def test_train_ukf_e2e(subsequence_dataloader, buddy):
    """Check that training our UKF end-to-end drops both dynamics and measurement
    errors.
    """
    # Create individual models + filter
    dynamics_model = LinearDynamicsModel(trainable=True)
    measurement_model = LinearKalmanFilterMeasurementModel(trainable=True)
    filter_model = torchfilter.filters.UnscentedKalmanFilter(
        dynamics_model=dynamics_model, measurement_model=measurement_model)

    # Compute initial errors
    initial_dynamics_error = get_trainable_model_error(dynamics_model)
    initial_measurement_error = get_trainable_model_error(measurement_model)

    # Train for 1 epoch
    buddy.attach_model(filter_model)
    torchfilter.train.train_filter(
        buddy,
        filter_model,
        subsequence_dataloader,
        initial_covariance=torch.eye(state_dim) * 0.01,
    )

    # Check that errors dropped
    assert get_trainable_model_error(dynamics_model) < initial_dynamics_error
    assert get_trainable_model_error(
        measurement_model) < initial_measurement_error
コード例 #5
0
def test_ekf(generated_data):
    """Smoke test for EKF."""
    _run_filter(
        torchfilter.filters.ExtendedKalmanFilter(
            dynamics_model=LinearDynamicsModel(),
            measurement_model=LinearKalmanFilterMeasurementModel(),
        ),
        generated_data,
    )
コード例 #6
0
def test_srukf(generated_data):
    """Smoke test for SRUKF w/ Julier-style sigma points."""
    _run_filter(
        torchfilter.filters.SquareRootUnscentedKalmanFilter(
            dynamics_model=LinearDynamicsModel(),
            measurement_model=LinearKalmanFilterMeasurementModel(),
        ),
        generated_data,
    )
コード例 #7
0
def test_particle_filter(generated_data):
    """Smoke test for particle filter."""
    _run_filter(
        torchfilter.filters.ParticleFilter(
            dynamics_model=LinearDynamicsModel(),
            measurement_model=LinearParticleFilterMeasurementModel(),
        ),
        generated_data,
    )
コード例 #8
0
def test_virtual_sensor_eif(generated_data):
    """Smoke test for EIF w/ virtual sensor."""
    _run_filter(
        torchfilter.filters.VirtualSensorExtendedInformationFilter(
            dynamics_model=LinearDynamicsModel(),
            virtual_sensor_model=LinearVirtualSensorModel(),
        ),
        generated_data,
    )
コード例 #9
0
def test_particle_filter_resample(generated_data):
    """Smoke test for particle filter with resampling."""
    _run_filter(
        torchfilter.filters.ParticleFilter(
            dynamics_model=LinearDynamicsModel(),
            measurement_model=LinearParticleFilterMeasurementModel(),
            resample=True,
        ),
        generated_data,
    )
コード例 #10
0
def test_virtual_sensor_srukf(generated_data):
    """Smoke test for virtual sensor SRUKF w/ Julier-style sigma points."""
    _run_filter(
        torchfilter.filters.VirtualSensorSquareRootUnscentedKalmanFilter(
            dynamics_model=LinearDynamicsModel(),
            virtual_sensor_model=LinearVirtualSensorModel(),
            sigma_point_strategy=torchfilter.utils.JulierSigmaPointStrategy(),  # optional
        ),
        generated_data,
    )
コード例 #11
0
def test_srukf_merwe(generated_data):
    """Smoke test for SRUKF w/ Merwe-style sigma points."""
    _run_filter(
        torchfilter.filters.SquareRootUnscentedKalmanFilter(
            dynamics_model=LinearDynamicsModel(),
            measurement_model=LinearKalmanFilterMeasurementModel(),
            sigma_point_strategy=torchfilter.utils.MerweSigmaPointStrategy(alpha=1e-1),
        ),
        generated_data,
    )
コード例 #12
0
def test_ukf_srukf_consistency(generated_data):
    """Check that our UKF and SRUKF produce consistent results for a linear system.
    (they should be identical)
    """
    # Create filters
    srukf = torchfilter.filters.SquareRootUnscentedKalmanFilter(
        dynamics_model=LinearDynamicsModel(),
        measurement_model=LinearKalmanFilterMeasurementModel(),
    )
    ukf = torchfilter.filters.UnscentedKalmanFilter(
        dynamics_model=LinearDynamicsModel(),
        measurement_model=LinearKalmanFilterMeasurementModel(),
    )

    # Run over data
    _run_filter(srukf, generated_data)
    _run_filter(ukf, generated_data)

    # Check final beliefs
    torch.testing.assert_allclose(srukf.belief_mean, ukf.belief_mean)
    torch.testing.assert_allclose(
        srukf.belief_covariance, ukf.belief_covariance, rtol=1e-4, atol=5e-4
    )
コード例 #13
0
def test_eif_ekf_consistency(generated_data):
    """Check that our EIF and EKF produce consistent results for a linear system. (they
    should be identical)
    """
    # Create filters
    ekf = torchfilter.filters.ExtendedKalmanFilter(
        dynamics_model=LinearDynamicsModel(),
        measurement_model=LinearKalmanFilterMeasurementModel(),
    )
    eif = torchfilter.filters.ExtendedInformationFilter(
        dynamics_model=LinearDynamicsModel(),
        measurement_model=LinearKalmanFilterMeasurementModel(),
    )

    # Run over data
    _run_filter(ekf, generated_data)
    _run_filter(eif, generated_data)

    # Check final beliefs
    torch.testing.assert_allclose(ekf.belief_mean, eif.belief_mean)
    torch.testing.assert_allclose(
        ekf.belief_covariance, eif.belief_covariance, rtol=1e-4, atol=5e-4
    )
コード例 #14
0
def generated_data() -> Tuple[
    types.StatesTorch, types.ObservationsNoDictTorch, types.ControlsNoDictTorch
]:
    """Generate `N` (noisy) trajectories using our dynamics and measurement models.

    Returns:
        tuple: (states, observations, controls). First dimension of all tensors should
            be `N`.
    """
    torch.random.manual_seed(0)

    # Batch size
    N = 5

    # Timesteps
    T = 100

    dynamics_model = LinearDynamicsModel()
    measurement_model = LinearKalmanFilterMeasurementModel()

    # Initialize empty states, observations
    states = torch.zeros((T, N, state_dim))
    observations = torch.zeros((T, N, observation_dim))

    # Generate random control inputs
    controls = torch.randn(size=(T, N, control_dim))

    for t in range(T):
        if t == 0:
            # Initialize random initial state
            states[0, :, :] = torch.randn(size=(N, state_dim))
        else:
            # Update state and add noise
            pred_states, Q_tril = dynamics_model(
                initial_states=states[t - 1, :, :], controls=controls[t, :, :]
            )
            assert pred_states.shape == (N, state_dim)
            assert Q_tril.shape == (N, state_dim, state_dim)

            states[t, :, :] = pred_states + (
                Q_tril @ torch.randn(size=(N, state_dim, 1))
            ).squeeze(-1)

        # Compute observations and add noise
        pred_observations, R_tril = measurement_model(states=states[t, :, :])
        observations[t, :, :] = pred_observations + (
            R_tril @ torch.randn(size=(N, observation_dim, 1))
        ).squeeze(-1)

    return states, observations, controls
コード例 #15
0
def test_train_dynamics_single_step(single_step_dataloader, buddy):
    """Check that our single-step dynamics training drops model error."""
    # Create individual model
    dynamics_model = LinearDynamicsModel(trainable=True)

    # Compute initial error
    initial_error = get_trainable_model_error(dynamics_model)

    # Train for 1 epoch
    buddy.attach_model(dynamics_model)
    torchfilter.train.train_dynamics_single_step(buddy, dynamics_model,
                                                 single_step_dataloader)

    # Check that error dropped
    assert get_trainable_model_error(dynamics_model) < initial_error
コード例 #16
0
def test_particle_filter_dynamic_particle_count_resample(generated_data):
    """Smoke test for particle filter with a dynamically changing particle count w/ resampling."""
    filter_model = torchfilter.filters.ParticleFilter(
        dynamics_model=LinearDynamicsModel(),
        measurement_model=LinearParticleFilterMeasurementModel(),
        resample=True,
        num_particles=30,
    )
    _run_filter(filter_model, generated_data)
    assert filter_model.particle_states.shape[1] == 30

    # Expand
    filter_model.num_particles = 100
    _run_filter(filter_model, generated_data, initialize_beliefs=False)
    assert filter_model.particle_states.shape[1] == 100

    # Contract
    filter_model.num_particles = 30
    _run_filter(filter_model, generated_data, initialize_beliefs=False)
    assert filter_model.particle_states.shape[1] == 30