def kf_update(self, obs, x, P):
        """
        The 'update' step of the kalman filter.

        :param obs: The observations. (TODO: does it need to be 3D?)
        :param x: Mean
        :param P: Covariance
        :return: The new (mean, covariance)
        """

        # handle missing values
        obs_nm, x_nm, P_nm, is_nan_slice = self.nan_remove(obs, x, P)

        # expand design-matrices to match batch-size:
        bs = P_nm.data.shape[0]  # batch-size
        H_expanded = expand(self.H, bs)
        R_expanded = expand(self.R, bs)

        # residual:
        residual = obs_nm - torch.bmm(H_expanded, x_nm)

        # kalman-gain:
        K = self.kalman_gain(P_nm, H_expanded, R_expanded)

        # update mean and covariance:
        x_new, P_new = x.clone(), P.clone()
        x_new[is_nan_slice == 0] = x_nm + torch.bmm(K, residual)
        P_new[is_nan_slice == 0] = self.covariance_update(
            P_nm, K, H_expanded, R_expanded)

        return x_new, P_new
    def kf_predict(self, x, P):
        """
        The 'predict' step of the kalman filter. The F matrix dictates how the mean (x) and covariance (P) change.

        :param x: Mean
        :param P: Covariance
        :return: The new (mean, covariance)
        """
        bs = P.data.shape[0]  # batch-size

        # expand design-matrices to match batch-size:
        F_expanded = expand(self.F, bs)
        Ft_expanded = expand(self.F.t(), bs)
        Q_expanded = expand(self.Q, bs)

        x = torch.bmm(F_expanded, x)
        P = torch.bmm(torch.bmm(F_expanded, P), Ft_expanded) + Q_expanded
        return x, P
 def kalman_gain(P, H_expanded, R_expanded):
     bs = P.data.shape[0]  # batch-size
     Ht_expanded = expand(H_expanded[0].t(), bs)
     S = torch.bmm(torch.bmm(H_expanded, P),
                   Ht_expanded) + R_expanded  # total covariance
     Sinv = torch.cat(
         [torch.inverse(S[i, :, :]).unsqueeze(0) for i in range(bs)],
         0)  # invert, batchwise
     K = torch.bmm(torch.bmm(P, Ht_expanded), Sinv)  # kalman gain
     return K
    def covariance_update(P, K, H_expanded, R_expanded):
        """
        "Joseph stabilized" covariance correction.

        :param P: Process covariance.
        :param K: Kalman-gain.
        :param H_expanded: The H design-matrix, expanded for each batch.
        :param R_expanded: The R design-matrix, expanded for each batch.
        :return: The new process covariance.
        """
        rank = H_expanded.data.shape[2]
        I = expand(Variable(torch.eye(rank, rank)), P.data.shape[0])
        p1 = (I - torch.bmm(K, H_expanded))
        p2 = torch.bmm(torch.bmm(p1, P), batch_transpose(p1))
        p3 = torch.bmm(torch.bmm(K, R_expanded), batch_transpose(K))
        return p2 + p3
    def predict_ahead(self, x, n_ahead):
        """
        Given a time-series X -- a 3D tensor with group * variable * time -- generate predictions -- a 4D tensor with
        group * variable * time * n_ahead.

        :param x: A 3D tensor with group * variable * time.
        :param n_ahead: The number of steps ahead for prediction. Minumum is 1.
        :return: Predictions: a 4D tensor with group * variable * time * n_ahead.
        """

        # data shape:
        if len(x.data.shape) != 3:
            raise Exception(
                "`x` should have three-dimensions: group*variable*time. "
                "If there's only one dimension and current structure is group*time, "
                "reshape with x[:,None,:].")
        num_series, num_variables, num_timesteps = x.data.shape

        # initial values:
        k_mean, k_cov = self.initializer(x)

        # preallocate:
        output = Variable(torch.zeros(list(x.data.shape) + [n_ahead]))

        # fill one timestep at a time
        for i in xrange(num_timesteps - 1):
            # update. note `[i]` instead of `i` (keeps dimensionality)
            k_mean, k_cov = self.kf_update(x[:, :, [i]], k_mean, k_cov)

            # predict n-ahead
            for nh in xrange(n_ahead):
                k_mean, k_cov = self.kf_predict(k_mean, k_cov)
                if nh == 0:
                    k_mean_next, k_cov_next = k_mean, k_cov
                output[:, :, i + 1, nh] = torch.bmm(expand(self.H, num_series),
                                                    k_mean)

            # but for next timestep, only use 1-ahead:
            # noinspection PyUnboundLocalVariable
            k_mean, k_cov = k_mean_next, k_cov_next

        # forward-pass is done, so make sure design-mats will be re-instantiated next time:
        if not self.design_mats_as_args:
            self.destroy_design_mats()

        #
        return output