def test_train_virtual_sensor_ekf_e2e(subsequence_dataloader, buddy):
    """Check that training our virtual sensor EKF end-to-end drops both dynamics and
    virtual sensor errors.
    """
    # Create individual models + filter
    dynamics_model = LinearDynamicsModel(trainable=True)
    virtual_sensor_model = LinearVirtualSensorModel(trainable=True)
    filter_model = torchfilter.filters.VirtualSensorExtendedKalmanFilter(
        dynamics_model=dynamics_model,
        virtual_sensor_model=virtual_sensor_model)

    # Compute initial errors
    initial_dynamics_error = get_trainable_model_error(dynamics_model)
    initial_virtual_sensor_error = get_trainable_model_error(
        virtual_sensor_model)

    # Train for 1 epoch
    buddy.attach_model(filter_model)
    torchfilter.train.train_filter(
        buddy,
        filter_model,
        subsequence_dataloader,
        initial_covariance=torch.eye(state_dim) * 0.01,
    )

    # Check that errors dropped
    assert get_trainable_model_error(dynamics_model) < initial_dynamics_error
    assert (get_trainable_model_error(virtual_sensor_model) <
            initial_virtual_sensor_error)
def test_virtual_sensor_srukf_consistency(generated_data):
    """Check that our Virtual Sensor SRUKF and standard EKF produce consistent results for
    a linear system.
    """
    # Create filters
    ekf = torchfilter.filters.ExtendedKalmanFilter(
        dynamics_model=LinearDynamicsModel(),
        measurement_model=LinearKalmanFilterMeasurementModel(),
    )
    virtual_sensor_srukf = (
        torchfilter.filters.VirtualSensorSquareRootUnscentedKalmanFilter(
            dynamics_model=LinearDynamicsModel(),
            virtual_sensor_model=LinearVirtualSensorModel(),
        )
    )

    # Run over data
    _run_filter(ekf, generated_data)
    _run_filter(virtual_sensor_srukf, generated_data)

    # Check final beliefs
    torch.testing.assert_allclose(ekf.belief_mean, virtual_sensor_srukf.belief_mean)
    torch.testing.assert_allclose(
        ekf.belief_covariance,
        virtual_sensor_srukf.belief_covariance,
        rtol=1e-4,
        atol=5e-4,
    )
def test_virtual_sensor_eif(generated_data):
    """Smoke test for EIF w/ virtual sensor."""
    _run_filter(
        torchfilter.filters.VirtualSensorExtendedInformationFilter(
            dynamics_model=LinearDynamicsModel(),
            virtual_sensor_model=LinearVirtualSensorModel(),
        ),
        generated_data,
    )
def test_virtual_sensor_srukf(generated_data):
    """Smoke test for virtual sensor SRUKF w/ Julier-style sigma points."""
    _run_filter(
        torchfilter.filters.VirtualSensorSquareRootUnscentedKalmanFilter(
            dynamics_model=LinearDynamicsModel(),
            virtual_sensor_model=LinearVirtualSensorModel(),
            sigma_point_strategy=torchfilter.utils.JulierSigmaPointStrategy(),  # optional
        ),
        generated_data,
    )
def test_train_virtual_sensor(single_step_dataloader, buddy):
    """Check that our virtual sensor training drops model error."""
    # Create individual model
    virtual_sensor_model = LinearVirtualSensorModel(trainable=True)

    # Compute initial error
    initial_error = get_trainable_model_error(virtual_sensor_model)

    # Train for 1 epoch
    buddy.attach_model(virtual_sensor_model)
    torchfilter.train.train_virtual_sensor(buddy, virtual_sensor_model,
                                           single_step_dataloader)

    # Check that error dropped
    assert get_trainable_model_error(virtual_sensor_model) < initial_error