示例#1
0
def filter_forward_predictive_distribution(m, V, Q, C, d=None):
    """ The predictive distirbution p(yt | y_{1:t-1}),
    obtained by marginalising the state prediction distribution in the measurement model. """
    Vpy = symmetrize(matmul(C, matmul(V, C.transpose(-1, -2))) + Q)
    mpy = matvec(C, m)
    if d is not None:
        mpy += d
    return mpy, Vpy
示例#2
0
def filter_forward_prediction_step(m, V, R, A, b=None):
    """ Single prediction step in forward filtering cycle
    (prediction --> measurement) """
    mp = matvec(A, m) if A is not None else m
    if b is not None:
        mp = mp + b  # do not 'mp += b', is inplace and wont work with autograd
    Vp = symmetrize((
        matmul(A, matmul(V, A.transpose(-1, -2))) if A is not None else V) + R)
    return mp, Vp
示例#3
0
def smooth_backward_step(m_sm, V_sm, m_fw, V_fw, A, b, R):
    # filter one-step predictive variance
    P = (matmul(A, matmul(V_fw, A.transpose(-1, -2)))
         if A is not None else V_fw) + R
    Pinv = batch_cholesky_inverse(cholesky(P))
    # m and V share J when in the conditioning operation (joint to posterior bw)
    J = matmul(V_fw, matmul(A.transpose(-1, -2), Pinv))

    m_sm_t = m_fw + matvec(
        J, m_sm - (matvec(A, m_fw) if A is not None else m_fw) - b)
    V_sm_t = symmetrize(V_fw +
                        matmul(J, matmul(V_sm - P, J.transpose(-1, -2))))
    Cov_sm_t = matmul(J, V_sm)
    return m_sm_t, V_sm_t, Cov_sm_t
示例#4
0
def filter_forward_measurement_step(y,
                                    m,
                                    V,
                                    Q,
                                    C,
                                    d=None,
                                    return_loss_components=False):
    """ Single measurement/update step in forward filtering cycle
    (prediction --> measurement) """
    mpy, Vpy = filter_forward_predictive_distribution(m=m, V=V, Q=Q, C=C, d=d)
    LVpyinv = torch.inverse(cholesky(Vpy))
    S = matmul(LVpyinv, matmul(
        C, V))  # CV is cov_T. cov VC.T could be output of predictive dist.
    dobs = y - mpy
    dobs_norm = matvec(LVpyinv, dobs)
    mt = m + matvec(S.transpose(-1, -2), dobs_norm)
    Vt = symmetrize(V - matmul(S.transpose(-1, -2), S))
    if return_loss_components:  # I hate to do that! but it is convenient...
        return mt, Vt, dobs_norm, LVpyinv
    else:
        return mt, Vt
示例#5
0
def smooth_global(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    """ compute posterior by direct inversion of unrolled model """
    device, dtype = A.device, A.dtype

    R = cov_from_invcholesky_param(LRinv_tril, LRinv_logdiag)
    Q = cov_from_invcholesky_param(LQinv_tril, LQinv_logdiag)
    V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)

    Q_field = torch.zeros(
        (dims.batch, dims.timesteps * dims.state, dims.timesteps * dims.state),
        device=device,
        dtype=dtype,
    )
    h_field = torch.zeros((dims.batch, dims.timesteps * dims.state),
                          device=device,
                          dtype=dtype)

    # pre-compute biases
    b = matvec(B, u_state) if u_state is not None else 0
    d = matvec(D, u_obs) if u_obs is not None else 0

    Rinv = symmetrize(torch.cholesky_inverse(cholesky(R)))
    Qinv = symmetrize(torch.cholesky_inverse(cholesky(Q)))
    V0inv = symmetrize(torch.cholesky_inverse(cholesky(V0)))

    CtQinvymd = matvec(matmul(C.transpose(-1, -2), Qinv), y - d)
    h_obs = CtQinvymd.transpose(1, 0).reshape((
        dims.batch,
        dims.timesteps * dims.state,
    ))
    Q_obs = kron(
        torch.eye(dims.timesteps, dtype=dtype, device=device),
        matmul(C.transpose(-1, -2), matmul(Qinv, C)),
    )

    AtRinvA = matmul(A.transpose(-1, -2), matmul(Rinv, A))
    RinvA = matmul(Rinv, A)

    h_field[:, :dims.state] = matmul(V0inv, m0).repeat((dims.batch, ) + (1, ) *
                                                       (h_field.ndim - 1))
    Q_field[:, :dims.state, :dims.state] += V0inv.repeat((dims.batch, ) +
                                                         (1, ) *
                                                         (Q_field.ndim - 1))
    for t in range(dims.timesteps - 1):
        idx = t * dims.state
        h_field[:,
                idx:idx + dims.state] += -matvec(RinvA.transpose(-1, -2), b[t])
        h_field[:, idx + dims.state:idx + 2 * dims.state] += matvec(Rinv, b[t])
        Q_field[:, idx:idx + dims.state, idx:idx + dims.state] += AtRinvA
        Q_field[:, idx:idx + dims.state, idx + dims.state:idx +
                2 * dims.state] += -RinvA.transpose(-1, -2)
        Q_field[:, idx + dims.state:idx + 2 * dims.state,
                idx:idx + dims.state] += -RinvA
        Q_field[:, idx + dims.state:idx + 2 * dims.state,
                idx + dims.state:idx + 2 * dims.state, ] += Rinv

    L_all_inv = torch.inverse(cholesky(Q_field + Q_obs))
    V_all = matmul(L_all_inv.transpose(-1, -2), L_all_inv)
    m_all = matvec(V_all, h_obs + h_field)

    # Pytorch has no Fortran style reading of indices.
    m = m_all.reshape((dims.batch, dims.timesteps, dims.state)).transpose(0, 1)
    V, Cov = [], []
    for t in range(0, dims.timesteps):
        idx = t * dims.state
        V.append(V_all[:, idx:idx + dims.state, idx:idx + dims.state])
        if t < (dims.timesteps - 1):
            Cov.append(V_all[:, idx:idx + dims.state,
                             idx + dims.state:idx + 2 * dims.state, ])
        else:
            Cov.append(
                torch.zeros(
                    (dims.batch, dims.state, dims.state),
                    device=device,
                    dtype=dtype,
                ))
    V = torch.stack(V, dim=0)
    Cov = torch.stack(Cov, dim=0)

    return m, V, Cov
def smooth_global(
    dims: TensorDims,
    A: torch.Tensor,
    B: Optional[torch.Tensor],
    C: torch.Tensor,
    D: Optional[torch.Tensor],
    LQinv_tril: torch.Tensor,
    LQinv_logdiag: torch.Tensor,
    LRinv_tril: torch.Tensor,
    LRinv_logdiag: torch.Tensor,
    LV0inv_tril: torch.Tensor,
    LV0inv_logdiag: torch.Tensor,
    m0: torch.Tensor,
    y: torch.Tensor,
    u_state: Optional[torch.Tensor] = None,
    u_obs: Optional[torch.Tensor] = None,
):
    """ compute posterior by direct inversion of unrolled model """
    # This implementation works only if all mats have time and batch dimension.
    #  Otherwise does not broadcast correctly.
    assert A.ndim == 4
    assert B.ndim == 4
    assert C.ndim == 4
    assert D.ndim == 4
    assert LQinv_tril.ndim == 4
    assert LQinv_logdiag.ndim == 3
    assert LRinv_tril.ndim == 4
    assert LRinv_logdiag.ndim == 3

    # raise NotImplementedError("Not yet implemented")
    device, dtype = A.device, A.dtype

    R = cov_from_invcholesky_param(LRinv_tril, LRinv_logdiag)
    Q = cov_from_invcholesky_param(LQinv_tril, LQinv_logdiag)
    V0 = cov_from_invcholesky_param(LV0inv_tril, LV0inv_logdiag)

    Q_field = torch.zeros(
        (dims.batch, dims.timesteps * dims.state, dims.timesteps * dims.state),
        device=device,
        dtype=dtype,
    )
    h_field = torch.zeros((dims.batch, dims.timesteps * dims.state),
                          device=device,
                          dtype=dtype)

    b = matvec(B, u_state[:-1]) if u_state is not None else 0
    d = matvec(D, u_obs) if u_obs is not None else 0

    Rinv = symmetrize(batch_cholesky_inverse(cholesky(R)))
    Qinv = symmetrize(batch_cholesky_inverse(cholesky(Q)))
    V0inv = symmetrize(batch_cholesky_inverse(cholesky(V0)))

    CtQinvymd = matvec(matmul(C.transpose(-1, -2), Qinv), y - d)
    h_obs = CtQinvymd.transpose(1, 0).reshape((
        dims.batch,
        dims.timesteps * dims.state,
    ))

    CtQinvC = matmul(C.transpose(-1, -2), matmul(Qinv, C))
    assert len(CtQinvC) == dims.timesteps
    Q_obs = make_block_diagonal(mats=tuple(mat_t for mat_t in CtQinvC))

    AtRinvA = matmul(A.transpose(-1, -2), matmul(Rinv, A))
    RinvA = matmul(Rinv, A)

    h_field[:, :dims.state] = matmul(V0inv, m0).repeat((dims.batch, ) + (1, ) *
                                                       (h_field.ndim - 1))
    Q_field[:, :dims.state, :dims.state] += V0inv.repeat((dims.batch, ) +
                                                         (1, ) *
                                                         (Q_field.ndim - 1))
    for t in range(dims.timesteps - 1):
        id = t * dims.state
        h_field[:, id:id +
                dims.state] += -matvec(RinvA[t].transpose(-1, -2), b[t])
        h_field[:,
                id + dims.state:id + 2 * dims.state] += matvec(Rinv[t], b[t])
        Q_field[:, id:id + dims.state, id:id + dims.state] += AtRinvA[t]
        Q_field[:, id:id + dims.state, id + dims.state:id +
                2 * dims.state] += -RinvA[t].transpose(-1, -2)
        Q_field[:, id + dims.state:id + 2 * dims.state,
                id:id + dims.state] += -RinvA[t]
        Q_field[:, id + dims.state:id + 2 * dims.state,
                id + dims.state:id + 2 * dims.state, ] += Rinv[t]

    L_all_inv = torch.inverse(cholesky(Q_field + Q_obs))
    V_all = matmul(L_all_inv.transpose(-1, -2), L_all_inv)
    m_all = matvec(V_all, h_obs + h_field)

    # Pytorch has no Fortran style reading of indices.
    m = m_all.reshape((dims.batch, dims.timesteps, dims.state)).transpose(0, 1)
    V, Cov = [], []
    for t in range(0, dims.timesteps):
        id = t * dims.state
        V.append(V_all[:, id:id + dims.state, id:id + dims.state])
        if t < (dims.timesteps - 1):
            Cov.append(V_all[:, id:id + dims.state,
                             id + dims.state:id + 2 * dims.state, ])
        else:
            Cov.append(
                torch.zeros(
                    (dims.batch, dims.state, dims.state),
                    device=device,
                    dtype=dtype,
                ))
    V = torch.stack(V, dim=0)
    Cov = torch.stack(Cov, dim=0)

    return m, V, Cov