Beispiel #1
0
def filter_forward(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    device, dtype = A.device, A.dtype

    R = cov_from_invcholesky_param(LRinv_tril, LRinv_logdiag)
    Q = cov_from_invcholesky_param(LQinv_tril, LQinv_logdiag)
    V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)
    m0, V0 = add_sample_dims_to_initial_state(m0=m0, V0=V0, dims=dims)

    # pre-compute biases
    b = matvec(B, u_state) if u_state is not None else 0
    d = matvec(D, u_obs) if u_obs is not None else 0

    m_fw = torch.zeros((dims.timesteps, dims.batch, dims.state),
                       device=device,
                       dtype=dtype)
    V_fw = torch.zeros(
        (dims.timesteps, dims.batch, dims.state, dims.state),
        device=device,
        dtype=dtype,
    )
    for t in range(0, dims.timesteps):
        (mp, Vp) = (filter_forward_prediction_step(
            m=m_fw[t - 1],
            V=V_fw[t - 1],
            R=R,
            A=A,
            b=b[t - 1],
        ) if t > 0 else (m0, V0))
        m_fw[t], V_fw[t] = filter_forward_measurement_step(y=y[t],
                                                           m=mp,
                                                           V=Vp,
                                                           Q=Q,
                                                           C=C,
                                                           d=d[t])
    return m_fw, V_fw
Beispiel #2
0
def loss_forward(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    device, dtype = A.device, A.dtype

    R = cov_from_invcholesky_param(LRinv_tril, LRinv_logdiag)
    Q = cov_from_invcholesky_param(LQinv_tril, LQinv_logdiag)
    V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)
    m0, V0 = add_sample_dims_to_initial_state(m0=m0, V0=V0, dims=dims)

    # pre-compute biases
    b = matvec(B, u_state) if u_state is not None else 0
    d = matvec(D, u_obs) if u_obs is not None else 0

    # Note: We can not use (more readable) in-place operations due to backprop problems.
    m_fw = [None] * dims.timesteps
    V_fw = [None] * dims.timesteps
    loss = torch.zeros((dims.batch, ), device=device, dtype=dtype)
    for t in range(0, dims.timesteps):
        (mp, Vp) = (filter_forward_prediction_step(
            m=m_fw[t - 1],
            V=V_fw[t - 1],
            R=R,
            A=A,
            b=b[t - 1],
        ) if t > 0 else (m0, V0))
        m_fw[t], V_fw[t], dobs_norm, LVpyinv = filter_forward_measurement_step(
            y=y[t], m=mp, V=Vp, Q=Q, C=C, d=d[t], return_loss_components=True)
        loss += (
            0.5 * torch.sum(dobs_norm**2, dim=-1) -
            0.5 * 2 * torch.sum(torch.log(batch_diag(LVpyinv)), dim=(-1, )) +
            0.5 * dims.target * LOG_2PI)

    return loss
 def dist_params(self):
     return Box(
         loc=self.m,
         covariance_matrix=cov_from_invcholesky_param(
             Linv_tril=self.LVinv_tril,
             Linv_logdiag=self.LVinv_logdiag,
         ),
     )
def filter_forward(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    # TODO: assumes B != None, D != None. u_state and u_obs can be None.
    #  Better to work with vectors b, d. matmul Bu should be done outside!
    m_fw = create_zeros_state_vec(dims=dims, device=A.device, dtype=A.dtype)
    V_fw = create_zeros_state_mat(dims=dims, device=A.device, dtype=A.dtype)
    for t in range(0, dims.timesteps):
        if t == 0:
            V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)
            mp, Vp = add_sample_dims_to_initial_state(m0=m0, V0=V0, dims=dims)
        else:
            R_tm1 = cov_from_invcholesky_param(LRinv_tril[t - 1],
                                               LRinv_logdiag[t - 1])
            b_tm1 = (matvec(B[t -
                              1], u_state[t -
                                          1]) if u_state is not None else 0)
            mp, Vp = filter_forward_prediction_step(m=m_fw[t - 1],
                                                    V=V_fw[t - 1],
                                                    R=R_tm1,
                                                    A=A[t - 1],
                                                    b=b_tm1)
        Q_t = cov_from_invcholesky_param(LQinv_tril[t], LQinv_logdiag[t])
        d_t = matvec(D[t], u_obs[t]) if u_obs is not None else 0
        m_fw[t], V_fw[t] = filter_forward_measurement_step(y=y[t],
                                                           m=mp,
                                                           V=Vp,
                                                           Q=Q_t,
                                                           C=C[t],
                                                           d=d_t)
    return torch.stack(m_fw, dim=0), torch.stack(V_fw, dim=0)
def smooth_forward_backward(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    m_sm = create_zeros_state_vec(dims=dims, device=A.device, dtype=A.dtype)
    V_sm = create_zeros_state_mat(dims=dims, device=A.device, dtype=A.dtype)
    Cov_sm = create_zeros_state_mat(dims=dims, device=A.device, dtype=A.dtype)
    m_fw, V_fw = filter_forward(
        dims=dims,
        A=A,
        B=B,
        C=C,
        D=D,
        LQinv_tril=LQinv_tril,
        LQinv_logdiag=LQinv_logdiag,
        LRinv_tril=LRinv_tril,
        LRinv_logdiag=LRinv_logdiag,
        LV0inv_tril=LV0inv_tril,
        LV0inv_logdiag=LV0inv_logdiag,
        m0=m0,
        y=y,
        u_state=u_state,
        u_obs=u_obs,
    )
    m_sm[-1], V_sm[-1] = m_fw[-1], V_fw[-1]
    for t in reversed(range(0, dims.timesteps - 1)):
        Rt = cov_from_invcholesky_param(LRinv_tril[t], LRinv_logdiag[t])
        bt = matvec(B[t], u_state[t]) if u_state is not None else 0
        m_sm[t], V_sm[t], Cov_sm[t] = smooth_backward_step(
            m_sm=m_sm[t + 1],
            V_sm=V_sm[t + 1],
            m_fw=m_fw[t],
            V_fw=V_fw[t],
            A=A[t],
            R=Rt,
            b=bt,
        )
    return (
        torch.stack(m_sm, dim=0),
        torch.stack(V_sm, dim=0),
        torch.stack(Cov_sm, dim=0),
    )
Beispiel #6
0
def smooth_forward_backward(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    device, dtype = A.device, A.dtype
    R = cov_from_invcholesky_param(LRinv_tril, LRinv_logdiag)

    # pre-compute biases
    b = matvec(B, u_state) if u_state is not None else 0
    m_sm = torch.zeros((dims.timesteps, dims.batch, dims.state),
                       device=device,
                       dtype=dtype)
    V_sm = torch.zeros(
        (dims.timesteps, dims.batch, dims.state, dims.state),
        device=device,
        dtype=dtype,
    )
    Cov_sm = torch.zeros(
        (dims.timesteps, dims.batch, dims.state, dims.state),
        device=device,
        dtype=dtype,
    )

    m_fw, V_fw = filter_forward(
        dims=dims,
        A=A,
        B=B,
        C=C,
        D=D,
        LQinv_tril=LQinv_tril,
        LQinv_logdiag=LQinv_logdiag,
        LRinv_tril=LRinv_tril,
        LRinv_logdiag=LRinv_logdiag,
        LV0inv_tril=LV0inv_tril,
        LV0inv_logdiag=LV0inv_logdiag,
        m0=m0,
        y=y,
        u_state=u_state,
        u_obs=u_obs,
    )
    m_sm[-1], V_sm[-1] = m_fw[-1], V_fw[-1]
    for t in reversed(range(0, dims.timesteps - 1)):
        m_sm[t], V_sm[t], Cov_sm[t] = smooth_backward_step(
            m_sm=m_sm[t + 1],
            V_sm=V_sm[t + 1],
            m_fw=m_fw[t],
            V_fw=V_fw[t],
            A=A,
            R=R,
            b=b[t],
        )
    return m_sm, V_sm, Cov_sm
Beispiel #7
0
def smooth_global(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    """ compute posterior by direct inversion of unrolled model """
    device, dtype = A.device, A.dtype

    R = cov_from_invcholesky_param(LRinv_tril, LRinv_logdiag)
    Q = cov_from_invcholesky_param(LQinv_tril, LQinv_logdiag)
    V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)

    Q_field = torch.zeros(
        (dims.batch, dims.timesteps * dims.state, dims.timesteps * dims.state),
        device=device,
        dtype=dtype,
    )
    h_field = torch.zeros((dims.batch, dims.timesteps * dims.state),
                          device=device,
                          dtype=dtype)

    # pre-compute biases
    b = matvec(B, u_state) if u_state is not None else 0
    d = matvec(D, u_obs) if u_obs is not None else 0

    Rinv = symmetrize(torch.cholesky_inverse(cholesky(R)))
    Qinv = symmetrize(torch.cholesky_inverse(cholesky(Q)))
    V0inv = symmetrize(torch.cholesky_inverse(cholesky(V0)))

    CtQinvymd = matvec(matmul(C.transpose(-1, -2), Qinv), y - d)
    h_obs = CtQinvymd.transpose(1, 0).reshape((
        dims.batch,
        dims.timesteps * dims.state,
    ))
    Q_obs = kron(
        torch.eye(dims.timesteps, dtype=dtype, device=device),
        matmul(C.transpose(-1, -2), matmul(Qinv, C)),
    )

    AtRinvA = matmul(A.transpose(-1, -2), matmul(Rinv, A))
    RinvA = matmul(Rinv, A)

    h_field[:, :dims.state] = matmul(V0inv, m0).repeat((dims.batch, ) + (1, ) *
                                                       (h_field.ndim - 1))
    Q_field[:, :dims.state, :dims.state] += V0inv.repeat((dims.batch, ) +
                                                         (1, ) *
                                                         (Q_field.ndim - 1))
    for t in range(dims.timesteps - 1):
        idx = t * dims.state
        h_field[:,
                idx:idx + dims.state] += -matvec(RinvA.transpose(-1, -2), b[t])
        h_field[:, idx + dims.state:idx + 2 * dims.state] += matvec(Rinv, b[t])
        Q_field[:, idx:idx + dims.state, idx:idx + dims.state] += AtRinvA
        Q_field[:, idx:idx + dims.state, idx + dims.state:idx +
                2 * dims.state] += -RinvA.transpose(-1, -2)
        Q_field[:, idx + dims.state:idx + 2 * dims.state,
                idx:idx + dims.state] += -RinvA
        Q_field[:, idx + dims.state:idx + 2 * dims.state,
                idx + dims.state:idx + 2 * dims.state, ] += Rinv

    L_all_inv = torch.inverse(cholesky(Q_field + Q_obs))
    V_all = matmul(L_all_inv.transpose(-1, -2), L_all_inv)
    m_all = matvec(V_all, h_obs + h_field)

    # Pytorch has no Fortran style reading of indices.
    m = m_all.reshape((dims.batch, dims.timesteps, dims.state)).transpose(0, 1)
    V, Cov = [], []
    for t in range(0, dims.timesteps):
        idx = t * dims.state
        V.append(V_all[:, idx:idx + dims.state, idx:idx + dims.state])
        if t < (dims.timesteps - 1):
            Cov.append(V_all[:, idx:idx + dims.state,
                             idx + dims.state:idx + 2 * dims.state, ])
        else:
            Cov.append(
                torch.zeros(
                    (dims.batch, dims.state, dims.state),
                    device=device,
                    dtype=dtype,
                ))
    V = torch.stack(V, dim=0)
    Cov = torch.stack(Cov, dim=0)

    return m, V, Cov
def loss_forward(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    """
    Computes the sample-wise (e.g. (particle, batch)) forward / filter loss.
    If particle and batch dims are used,
    the ordering is TPBF (time, particle, batch, feature).

    Note: it would be better numerically (amount of computation and precision)
    to sum/avg over these dimensions in all loss terms.
    However, we must compute it sample-wise for many algorithms,
    i.e. Rao-Blackwellised Particle Filters.
    """
    device, dtype = A.device, A.dtype

    V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)
    m0, V0 = add_sample_dims_to_initial_state(m0=m0, V0=V0, dims=dims)

    # Note: We can not use (more readable) in-place operations due to backprop problems.
    m_fw = [None] * dims.timesteps
    V_fw = [None] * dims.timesteps
    dim_particle = ((dims.particle, ) if dims.particle is not None
                    and dims.particle != 0 else tuple())
    loss = torch.zeros(dim_particle + (dims.batch, ),
                       device=device,
                       dtype=dtype)
    for t in range(0, dims.timesteps):
        (mp, Vp) = (filter_forward_prediction_step(
            m=m_fw[t - 1],
            V=V_fw[t - 1],
            R=cov_from_invcholesky_param(LRinv_tril[t - 1], LRinv_logdiag[t -
                                                                          1]),
            A=A[t - 1],
            b=matvec(B[t - 1], u_state[t - 1]) if u_state is not None else 0,
        ) if t > 0 else (m0, V0))
        m_fw[t], V_fw[t], dobs_norm, LVpyinv = filter_forward_measurement_step(
            y=y[t],
            m=mp,
            V=Vp,
            Q=cov_from_invcholesky_param(LQinv_tril[t], LQinv_logdiag[t]),
            C=C[t],
            d=matvec(D[t], u_obs[t]) if u_obs is not None else 0,
            return_loss_components=True,
        )
        loss += 0.5 * (torch.sum(dobs_norm**2, dim=-1) - 2 *
                       torch.sum(torch.log(batch_diag(LVpyinv)), dim=(-1, )) +
                       dims.target * LOG_2PI)
    return loss
def smooth_global(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    """ compute posterior by direct inversion of unrolled model """
    # This implementation works only if all mats have time and batch dimension.
    #  Otherwise does not broadcast correctly.
    assert A.ndim == 4
    assert B.ndim == 4
    assert C.ndim == 4
    assert D.ndim == 4
    assert LQinv_tril.ndim == 4
    assert LQinv_logdiag.ndim == 3
    assert LRinv_tril.ndim == 4
    assert LRinv_logdiag.ndim == 3

    # raise NotImplementedError("Not yet implemented")
    device, dtype = A.device, A.dtype

    R = cov_from_invcholesky_param(LRinv_tril, LRinv_logdiag)
    Q = cov_from_invcholesky_param(LQinv_tril, LQinv_logdiag)
    V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)

    Q_field = torch.zeros(
        (dims.batch, dims.timesteps * dims.state, dims.timesteps * dims.state),
        device=device,
        dtype=dtype,
    )
    h_field = torch.zeros((dims.batch, dims.timesteps * dims.state),
                          device=device,
                          dtype=dtype)

    b = matvec(B, u_state[:-1]) if u_state is not None else 0
    d = matvec(D, u_obs) if u_obs is not None else 0

    Rinv = symmetrize(batch_cholesky_inverse(cholesky(R)))
    Qinv = symmetrize(batch_cholesky_inverse(cholesky(Q)))
    V0inv = symmetrize(batch_cholesky_inverse(cholesky(V0)))

    CtQinvymd = matvec(matmul(C.transpose(-1, -2), Qinv), y - d)
    h_obs = CtQinvymd.transpose(1, 0).reshape((
        dims.batch,
        dims.timesteps * dims.state,
    ))

    CtQinvC = matmul(C.transpose(-1, -2), matmul(Qinv, C))
    assert len(CtQinvC) == dims.timesteps
    Q_obs = make_block_diagonal(mats=tuple(mat_t for mat_t in CtQinvC))

    AtRinvA = matmul(A.transpose(-1, -2), matmul(Rinv, A))
    RinvA = matmul(Rinv, A)

    h_field[:, :dims.state] = matmul(V0inv, m0).repeat((dims.batch, ) + (1, ) *
                                                       (h_field.ndim - 1))
    Q_field[:, :dims.state, :dims.state] += V0inv.repeat((dims.batch, ) +
                                                         (1, ) *
                                                         (Q_field.ndim - 1))
    for t in range(dims.timesteps - 1):
        id = t * dims.state
        h_field[:, id:id +
                dims.state] += -matvec(RinvA[t].transpose(-1, -2), b[t])
        h_field[:,
                id + dims.state:id + 2 * dims.state] += matvec(Rinv[t], b[t])
        Q_field[:, id:id + dims.state, id:id + dims.state] += AtRinvA[t]
        Q_field[:, id:id + dims.state, id + dims.state:id +
                2 * dims.state] += -RinvA[t].transpose(-1, -2)
        Q_field[:, id + dims.state:id + 2 * dims.state,
                id:id + dims.state] += -RinvA[t]
        Q_field[:, id + dims.state:id + 2 * dims.state,
                id + dims.state:id + 2 * dims.state, ] += Rinv[t]

    L_all_inv = torch.inverse(cholesky(Q_field + Q_obs))
    V_all = matmul(L_all_inv.transpose(-1, -2), L_all_inv)
    m_all = matvec(V_all, h_obs + h_field)

    # Pytorch has no Fortran style reading of indices.
    m = m_all.reshape((dims.batch, dims.timesteps, dims.state)).transpose(0, 1)
    V, Cov = [], []
    for t in range(0, dims.timesteps):
        id = t * dims.state
        V.append(V_all[:, id:id + dims.state, id:id + dims.state])
        if t < (dims.timesteps - 1):
            Cov.append(V_all[:, id:id + dims.state,
                             id + dims.state:id + 2 * dims.state, ])
        else:
            Cov.append(
                torch.zeros(
                    (dims.batch, dims.state, dims.state),
                    device=device,
                    dtype=dtype,
                ))
    V = torch.stack(V, dim=0)
    Cov = torch.stack(Cov, dim=0)

    return m, V, Cov