Exemple #1
0
    def predict(self, X):
        """Classify the output for given data

        Parameters
        ----------

        X : list of 2D arrays, element i has shape=[voxels_i, samples_i]
            Each element in the list contains the fMRI data of one subject
            The number of voxels should be according to each subject at
            the moment of training the model.

        Returns
        -------
        p: list of arrays, element i has shape=[samples_i]
            Predictions for each data sample.
        """
        # Check if the model exist
        if hasattr(self, 'w_') is False:
            raise NotFittedError("The model fit has not been run yet.")

        # Check the number of subjects
        if len(X) != len(self.w_):
            raise ValueError("The number of subjects does not match the one"
                             " in the model.")

        X_shared = self.transform(X)
        p = [None] * len(X_shared)
        for subject in range(len(X_shared)):
            sumexp, _, exponents = utils.sumexp_stable(
                self.theta_.T.dot(X_shared[subject]) + self.bias_)
            p[subject] = self.classes_[
                (exponents / sumexp[np.newaxis, :]).argmax(axis=0)]

        return p
Exemple #2
0
    def predict(self, X):
        """Classify the output for given data

        Parameters
        ----------

        X : list of 2D arrays, element i has shape=[voxels_i, samples_i]
            Each element in the list contains the fMRI data of one subject
            The number of voxels should be according to each subject at
            the moment of training the model.

        Returns
        -------
        p: list of arrays, element i has shape=[samples_i]
            Predictions for each data sample.
        """
        # Check if the model exist
        if hasattr(self, 'w_') is False:
            raise NotFittedError("The model fit has not been run yet.")

        # Check the number of subjects
        if len(X) != len(self.w_):
            raise ValueError("The number of subjects does not match the one"
                             " in the model.")

        X_shared = self.transform(X)
        p = [None] * len(X_shared)
        for subject in range(len(X_shared)):
            sumexp, _, exponents = utils.sumexp_stable(
                self.theta_.T.dot(X_shared[subject]) + self.bias_)
            p[subject] = self.classes_[(exponents /
                                        sumexp[np.newaxis, :]).argmax(axis=0)]

        return p
Exemple #3
0
def test_sumexp():
    from brainiak.utils.utils import sumexp_stable
    import numpy as np

    data = np.array([[1, 1],[0, 1]])
    sums, maxs, exps = sumexp_stable(data)
    assert sums.size == data.shape[1], "Invalid sum(exp(v)) computation (wrong # samples in sums)"
    assert exps.shape[0] == data.shape[0], "Invalid exp(v) computation (wrong # features)"
    assert exps.shape[1] == data.shape[1], "Invalid exp(v) computation (wrong # samples)"
    assert maxs.size == data.shape[1], "Invalid max computation (wrong # samples in maxs)"
Exemple #4
0
def test_sumexp():
    from brainiak.utils.utils import sumexp_stable
    import numpy as np

    data = np.array([[1, 1],[0, 1]])
    sums, maxs, exps = sumexp_stable(data)
    assert sums.size == data.shape[1], "Invalid sum(exp(v)) computation (wrong # samples in sums)"
    assert exps.shape[0] == data.shape[0], "Invalid exp(v) computation (wrong # features)"
    assert exps.shape[1] == data.shape[1], "Invalid exp(v) computation (wrong # samples)"
    assert maxs.size == data.shape[1], "Invalid max computation (wrong # samples in maxs)"
Exemple #5
0
    def _loss_lr_subject(self, data, labels, w, theta, bias):
        """Compute the Loss MLR for a single subject (without regularization)

        Parameters
        ----------

        data : array, shape=[voxels, samples]
            The fMRI data of subject i for the classification task.

        labels : array of int, shape=[samples]
            The labels for the data samples in data.

        w : array, shape=[voxels, features]
            The orthogonal transform (mapping) :math:`W_i` for subject i.

        theta : array, shape=[classes, features]
            The MLR class plane parameters.

        bias : array, shape=[classes]
            The MLR class biases.

        Returns
        -------

        loss : float
            The loss MLR for the subject
        """
        if data is None:
            return 0.0

        samples = data.shape[1]

        thetaT_wi_zi_plus_bias = theta.T.dot(w.T.dot(data)) + bias
        sum_exp, max_value, _ = utils.sumexp_stable(thetaT_wi_zi_plus_bias)
        sum_exp_values = np.log(sum_exp) + max_value

        aux = 0.0
        for sample in range(samples):
            label = labels[sample]
            aux += thetaT_wi_zi_plus_bias[label, sample]
        return self.alpha / samples / self.gamma * (sum_exp_values.sum() - aux)
Exemple #6
0
    def _loss_lr_subject(self, data, labels, w, theta, bias):
        """Compute the Loss MLR for a single subject (without regularization)

        Parameters
        ----------

        data : array, shape=[voxels, samples]
            The fMRI data of subject i for the classification task.

        labels : array of int, shape=[samples]
            The labels for the data samples in data.

        w : array, shape=[voxels, features]
            The orthogonal transform (mapping) :math:`W_i` for subject i.

        theta : array, shape=[classes, features]
            The MLR class plane parameters.

        bias : array, shape=[classes]
            The MLR class biases.

        Returns
        -------

        loss : float
            The loss MLR for the subject
        """
        if data is None:
            return 0.0

        samples = data.shape[1]

        thetaT_wi_zi_plus_bias = theta.T.dot(w.T.dot(data)) + bias
        sum_exp, max_value, _ = utils.sumexp_stable(thetaT_wi_zi_plus_bias)
        sum_exp_values = np.log(sum_exp) + max_value

        aux = 0.0
        for sample in range(samples):
            label = labels[sample]
            aux += thetaT_wi_zi_plus_bias[label, sample]
        return self.alpha / samples / self.gamma * (sum_exp_values.sum() - aux)
Exemple #7
0
def test_grid_flatten_num_int():
    # Check for numeric integration of SNR, and correctly flattening 2-D grids
    # to 1-D grid.
    import brainiak.reprsimil.brsa
    import brainiak.utils.utils as utils
    import numpy as np
    import scipy.special
    n_V = 30
    n_T = 50
    n_C = 3
    design = np.random.randn(n_T, n_C)
    U_simu = np.asarray([[1.0, 0.1, 0.0], [0.1, 1.0, 0.2], [0.0, 0.2, 1.0]])
    L_simu = np.linalg.cholesky(U_simu)
    SNR = np.random.exponential(size=n_V)
    beta = np.dot(L_simu, np.random.randn(n_C, n_V)) * SNR
    noise = np.random.randn(n_T, n_V)
    Y = np.dot(design, beta) + noise
    X = design
    X_base = None
    scan_onsets = [0]

    s = brainiak.reprsimil.brsa.GBRSA(n_iter=1,
                                      auto_nuisance=False,
                                      SNR_prior='exp')
    s.fit(X=[Y], design=[design])
    rank = n_C
    l_idx, rank = s._chol_idx(n_C, rank)
    L = np.zeros((n_C, rank))
    n_l = np.size(l_idx[0])
    current_vec_U_chlsk_l = s.random_state_.randn(n_l) * 10
    L[l_idx] = current_vec_U_chlsk_l

    # Now we change the grids for SNR and rho for testing.
    s.SNR_bins = 2
    s.rho_bins = 2
    SNR_grids, SNR_weights = s._set_SNR_grids()
    # rho_grids, rho_weights = s._set_rho_grids()
    rho_grids = np.ones(2) * 0.1
    rho_weights = np.ones(2) / 2
    # We purposefully set all rhos to be equal to test flattening of
    # grids.
    n_grid = s.SNR_bins * s.rho_bins

    D, F, run_TRs, n_run = s._prepare_DF(n_T, scan_onsets=scan_onsets)
    XTY, XTDY, XTFY, YTY_diag, YTDY_diag, YTFY_diag, XTX, XTDX, XTFX \
        = s._prepare_data_XY(X, Y, D, F)
    X0TX0, X0TDX0, X0TFX0, XTX0, XTDX0, XTFX0, X0TY, X0TDY, X0TFY, X0, \
        X_base, n_X0, idx_DC = s._prepare_data_XYX0(
            X, Y, X_base, None, D, F, run_TRs, no_DC=False)

    X0TAX0, X0TAX0_i, XTAcorrX, XTAcorrY, YTAcorrY_diag, X0TAY, XTAX0 \
        = s._precompute_ar1_quad_forms_marginalized(
            XTY, XTDY, XTFY, YTY_diag, YTDY_diag, YTFY_diag, XTX,
            XTDX, XTFX, X0TX0, X0TDX0, X0TFX0, XTX0, XTDX0, XTFX0,
            X0TY, X0TDY, X0TFY, rho_grids, n_V, n_X0)

    half_log_det_X0TAX0, X0TAX0, X0TAX0_i, s2XTAcorrX, YTAcorrY_diag, \
        sXTAcorrY, X0TAY, XTAX0 = s._matrix_flattened_grid(
            X0TAX0, X0TAX0_i, SNR_grids, XTAcorrX, YTAcorrY_diag, XTAcorrY,
            X0TAY, XTAX0, n_C, n_V, n_X0, n_grid)

    assert (half_log_det_X0TAX0[0] == half_log_det_X0TAX0[1]
            and half_log_det_X0TAX0[2] == half_log_det_X0TAX0[3]
            and half_log_det_X0TAX0[0] == half_log_det_X0TAX0[2]
            ), '_matrix_flattened_grid has mistake with half_log_det_X0TAX0'
    assert (np.array_equal(X0TAX0[0, :, :], X0TAX0[1, :, :])
            and np.array_equal(X0TAX0[2, :, :], X0TAX0[3, :, :])
            and np.array_equal(
                X0TAX0[0, :, :],
                X0TAX0[2, :, :])), '_matrix_flattened_grid has mistake X0TAX0'
    assert (np.array_equal(X0TAX0_i[0, :, :], X0TAX0_i[1, :, :])
            and np.array_equal(X0TAX0_i[2, :, :], X0TAX0_i[3, :, :])
            and np.array_equal(X0TAX0_i[0, :, :], X0TAX0_i[2, :, :])
            ), '_matrix_flattened_grid has mistake X0TAX0_i'
    assert np.allclose(np.dot(X0TAX0[0, :, :], X0TAX0_i[0, :, :]),
                       np.eye(n_X0)), 'X0TAX0_i is not inverse of X0TAX0'
    assert (np.array_equal(YTAcorrY_diag[0, :], YTAcorrY_diag[1, :])
            and np.array_equal(YTAcorrY_diag[2, :], YTAcorrY_diag[3, :])
            and np.array_equal(YTAcorrY_diag[0, :], YTAcorrY_diag[2, :])
            ), '_matrix_flattened_grid has mistake YTAcorrY_diag'
    assert (np.array_equal(sXTAcorrY[0, :, :], sXTAcorrY[1, :, :])
            and np.array_equal(sXTAcorrY[2, :, :], sXTAcorrY[3, :, :])
            and not np.array_equal(sXTAcorrY[0, :, :], sXTAcorrY[2, :, :])
            ), '_matrix_flattened_grid has mistake sXTAcorrY'
    assert (np.array_equal(X0TAY[0, :, :], X0TAY[1, :, :])
            and np.array_equal(X0TAY[2, :, :], X0TAY[3, :, :])
            and np.array_equal(
                X0TAY[0, :, :],
                X0TAY[2, :, :])), '_matrix_flattened_grid has mistake X0TAY'
    assert (np.array_equal(XTAX0[0, :, :], XTAX0[1, :, :])
            and np.array_equal(XTAX0[2, :, :], XTAX0[3, :, :])
            and np.array_equal(
                XTAX0[0, :, :],
                XTAX0[2, :, :])), '_matrix_flattened_grid has mistake XTAX0'

    # Now we test the other way
    rho_grids, rho_weights = s._set_rho_grids()
    # rho_grids, rho_weights = s._set_rho_grids()
    SNR_grids = np.ones(2) * 0.1
    SNR_weights = np.ones(2) / 2
    # We purposefully set all SNR to be equal to test flattening of
    # grids.
    X0TAX0, X0TAX0_i, XTAcorrX, XTAcorrY, YTAcorrY_diag, X0TAY, XTAX0 \
        = s._precompute_ar1_quad_forms_marginalized(
            XTY, XTDY, XTFY, YTY_diag, YTDY_diag, YTFY_diag, XTX,
            XTDX, XTFX, X0TX0, X0TDX0, X0TFX0, XTX0, XTDX0, XTFX0,
            X0TY, X0TDY, X0TFY, rho_grids, n_V, n_X0)

    half_log_det_X0TAX0, X0TAX0, X0TAX0_i, s2XTAcorrX, YTAcorrY_diag, \
        sXTAcorrY, X0TAY, XTAX0 = s._matrix_flattened_grid(
            X0TAX0, X0TAX0_i, SNR_grids, XTAcorrX, YTAcorrY_diag, XTAcorrY,
            X0TAY, XTAX0, n_C, n_V, n_X0, n_grid)

    assert (half_log_det_X0TAX0[0] == half_log_det_X0TAX0[2]
            and half_log_det_X0TAX0[1] == half_log_det_X0TAX0[3]
            and not half_log_det_X0TAX0[0] == half_log_det_X0TAX0[1]
            ), '_matrix_flattened_grid has mistake with half_log_det_X0TAX0'
    assert (np.array_equal(X0TAX0[0, :, :], X0TAX0[2, :, :])
            and np.array_equal(X0TAX0[1, :, :], X0TAX0[3, :, :])
            and not np.array_equal(X0TAX0[0, :, :], X0TAX0[1, :, :])
            ), '_matrix_flattened_grid has mistake X0TAX0'
    assert (np.array_equal(X0TAX0_i[0, :, :], X0TAX0_i[2, :, :])
            and np.array_equal(X0TAX0_i[1, :, :], X0TAX0_i[3, :, :])
            and not np.array_equal(X0TAX0_i[0, :, :], X0TAX0_i[1, :, :])
            ), '_matrix_flattened_grid has mistake X0TAX0_i'
    assert np.allclose(np.dot(X0TAX0[0, :, :], X0TAX0_i[0, :, :]),
                       np.eye(n_X0)), 'X0TAX0_i is not inverse of X0TAX0'
    assert (np.array_equal(YTAcorrY_diag[0, :], YTAcorrY_diag[2, :])
            and np.array_equal(YTAcorrY_diag[1, :], YTAcorrY_diag[3, :])
            and not np.array_equal(YTAcorrY_diag[0, :], YTAcorrY_diag[1, :])
            ), '_matrix_flattened_grid has mistake YTAcorrY_diag'
    assert (np.array_equal(sXTAcorrY[0, :, :], sXTAcorrY[2, :, :])
            and np.array_equal(sXTAcorrY[1, :, :], sXTAcorrY[3, :, :])
            and not np.array_equal(sXTAcorrY[0, :, :], sXTAcorrY[1, :, :])
            ), '_matrix_flattened_grid has mistake sXTAcorrY'
    assert (np.array_equal(X0TAY[0, :, :], X0TAY[2, :, :])
            and np.array_equal(X0TAY[1, :, :], X0TAY[3, :, :])
            and not np.array_equal(X0TAY[0, :, :], X0TAY[1, :, :])
            ), '_matrix_flattened_grid has mistake X0TAY'
    assert (np.array_equal(XTAX0[0, :, :], XTAX0[2, :, :])
            and np.array_equal(XTAX0[1, :, :], XTAX0[3, :, :])
            and not np.array_equal(XTAX0[0, :, :], XTAX0[1, :, :])
            ), '_matrix_flattened_grid has mistake XTAX0'

    # Now test the integration over SNR
    s.SNR_bins = 50
    s.rho_bins = 1
    SNR_grids, SNR_weights = s._set_SNR_grids()
    rho_grids, rho_weights = s._set_rho_grids()
    n_grid = s.SNR_bins * s.rho_bins

    def setup_for_test():
        # This function will be re-used to set up the variables necessary for
        # testing.

        X0TAX0, X0TAX0_i, XTAcorrX, XTAcorrY, YTAcorrY_diag, X0TAY, XTAX0 \
            = s._precompute_ar1_quad_forms_marginalized(
                XTY, XTDY, XTFY, YTY_diag, YTDY_diag, YTFY_diag, XTX,
                XTDX, XTFX, X0TX0, X0TDX0, X0TFX0, XTX0, XTDX0, XTFX0,
                X0TY, X0TDY, X0TFY, rho_grids, n_V, n_X0)

        half_log_det_X0TAX0, X0TAX0, X0TAX0_i, s2XTAcorrX, YTAcorrY_diag, \
            sXTAcorrY, X0TAY, XTAX0 = s._matrix_flattened_grid(
                X0TAX0, X0TAX0_i, SNR_grids, XTAcorrX, YTAcorrY_diag, XTAcorrY,
                X0TAY, XTAX0, n_C, n_V, n_X0, n_grid)

        log_weights = np.reshape(
            np.log(SNR_weights[:, None]) + np.log(rho_weights), n_grid)
        all_rho_grids = np.reshape(
            np.repeat(rho_grids[None, :], s.SNR_bins, axis=0), n_grid)
        log_fixed_terms = - (n_T - n_X0) / 2 * np.log(2 * np.pi) + n_run \
            / 2 * np.log(1 - all_rho_grids**2) + scipy.special.gammaln(
                (n_T - n_X0 - 2) / 2) + (n_T - n_X0 - 2) / 2 * np.log(2)
        return s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, \
            log_weights, log_fixed_terms

    (s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights,
     log_fixed_terms) = setup_for_test()
    LL_total, _ = s._loglike_marginalized(current_vec_U_chlsk_l,
                                          s2XTAcorrX,
                                          YTAcorrY_diag,
                                          sXTAcorrY,
                                          half_log_det_X0TAX0,
                                          log_weights,
                                          log_fixed_terms,
                                          l_idx,
                                          n_C,
                                          n_T,
                                          n_V,
                                          n_X0,
                                          n_grid,
                                          rank=rank)
    LL_total = -LL_total
    # Now we re-calculate using scipy.integrate
    s.SNR_bins = 100
    SNR_grids = np.linspace(0, 12, s.SNR_bins)
    SNR_weights = np.exp(-SNR_grids)
    SNR_weights = SNR_weights / np.sum(SNR_weights)
    n_grid = s.SNR_bins * s.rho_bins

    (s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights,
     log_fixed_terms) = setup_for_test()
    LL_raw, _, _, _ = s._raw_loglike_grids(L, s2XTAcorrX, YTAcorrY_diag,
                                           sXTAcorrY, half_log_det_X0TAX0,
                                           log_weights, log_fixed_terms, n_C,
                                           n_T, n_V, n_X0, n_grid, rank)
    result_sum, max_value, result_exp = utils.sumexp_stable(LL_raw)
    scipy_sum = scipy.integrate.simps(y=result_exp, axis=0)
    LL_total_scipy = np.sum(np.log(scipy_sum) + max_value)

    tol = 1e-3
    assert(np.isclose(LL_total_scipy, LL_total, rtol=tol)), \
        'Error of log likelihood calculation exceeds the tolerance'

    # Now test the log normal prior
    s = brainiak.reprsimil.brsa.GBRSA(n_iter=1,
                                      auto_nuisance=False,
                                      SNR_prior='lognorm')
    s.SNR_bins = 50
    s.rho_bins = 1
    SNR_grids, SNR_weights = s._set_SNR_grids()
    rho_grids, rho_weights = s._set_rho_grids()
    n_grid = s.SNR_bins * s.rho_bins

    (s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights,
     log_fixed_terms) = setup_for_test()
    LL_total, _ = s._loglike_marginalized(current_vec_U_chlsk_l,
                                          s2XTAcorrX,
                                          YTAcorrY_diag,
                                          sXTAcorrY,
                                          half_log_det_X0TAX0,
                                          log_weights,
                                          log_fixed_terms,
                                          l_idx,
                                          n_C,
                                          n_T,
                                          n_V,
                                          n_X0,
                                          n_grid,
                                          rank=rank)
    LL_total = -LL_total
    # Now we re-calculate using scipy.integrate
    s.SNR_bins = 400
    SNR_grids = np.linspace(1e-8, 20, s.SNR_bins)
    log_SNR_weights = scipy.stats.lognorm.logpdf(SNR_grids, s=s.logS_range)
    result_sum, max_value, result_exp = utils.sumexp_stable(
        log_SNR_weights[:, None])
    SNR_weights = np.squeeze(result_exp / result_sum)
    n_grid = s.SNR_bins * s.rho_bins

    (s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights,
     log_fixed_terms) = setup_for_test()
    LL_raw, _, _, _ = s._raw_loglike_grids(L, s2XTAcorrX, YTAcorrY_diag,
                                           sXTAcorrY, half_log_det_X0TAX0,
                                           log_weights, log_fixed_terms, n_C,
                                           n_T, n_V, n_X0, n_grid, rank)
    result_sum, max_value, result_exp = utils.sumexp_stable(LL_raw)
    scipy_sum = scipy.integrate.simps(y=result_exp, axis=0)
    LL_total_scipy = np.sum(np.log(scipy_sum) + max_value)

    tol = 1e-3
    assert(np.isclose(LL_total_scipy, LL_total, rtol=tol)), \
        'Error of log likelihood calculation exceeds the tolerance'
Exemple #8
0
def test_grid_flatten_num_int():
    # Check for numeric integration of SNR, and correctly flattening 2-D grids
    # to 1-D grid.
    import brainiak.reprsimil.brsa
    import brainiak.utils.utils as utils
    import numpy as np
    import scipy.special
    n_V = 30
    n_T = 50
    n_C = 3
    design = np.random.randn(n_T, n_C)
    U_simu = np.asarray([[1.0, 0.1, 0.0], [0.1, 1.0, 0.2], [0.0, 0.2, 1.0]])
    L_simu = np.linalg.cholesky(U_simu)
    SNR = np.random.exponential(size=n_V)
    beta = np.dot(L_simu, np.random.randn(n_C, n_V)) * SNR
    noise = np.random.randn(n_T, n_V)
    Y = np.dot(design, beta) + noise
    X = design
    X_base = None
    scan_onsets = [0]

    s = brainiak.reprsimil.brsa.GBRSA(n_iter=1, auto_nuisance=False,
                                      SNR_prior='exp')
    s.fit(X=[Y], design=[design])
    rank = n_C
    l_idx, rank = s._chol_idx(n_C, rank)
    L = np.zeros((n_C, rank))
    n_l = np.size(l_idx[0])
    current_vec_U_chlsk_l = s.random_state_.randn(n_l) * 10
    L[l_idx] = current_vec_U_chlsk_l

    # Now we change the grids for SNR and rho for testing.
    s.SNR_bins = 2
    s.rho_bins = 2
    SNR_grids, SNR_weights = s._set_SNR_grids()
    # rho_grids, rho_weights = s._set_rho_grids()
    rho_grids = np.ones(2) * 0.1
    rho_weights = np.ones(2) / 2
    # We purposefully set all rhos to be equal to test flattening of
    # grids.
    n_grid = s.SNR_bins * s.rho_bins

    D, F, run_TRs, n_run = s._prepare_DF(
        n_T, scan_onsets=scan_onsets)
    XTY, XTDY, XTFY, YTY_diag, YTDY_diag, YTFY_diag, XTX, XTDX, XTFX \
        = s._prepare_data_XY(X, Y, D, F)
    X0TX0, X0TDX0, X0TFX0, XTX0, XTDX0, XTFX0, X0TY, X0TDY, X0TFY, X0, \
        X_base, n_X0, idx_DC = s._prepare_data_XYX0(
            X, Y, X_base, None, D, F, run_TRs, no_DC=False)

    X0TAX0, X0TAX0_i, XTAcorrX, XTAcorrY, YTAcorrY_diag, X0TAY, XTAX0 \
        = s._precompute_ar1_quad_forms_marginalized(
            XTY, XTDY, XTFY, YTY_diag, YTDY_diag, YTFY_diag, XTX,
            XTDX, XTFX, X0TX0, X0TDX0, X0TFX0, XTX0, XTDX0, XTFX0,
            X0TY, X0TDY, X0TFY, rho_grids, n_V, n_X0)

    half_log_det_X0TAX0, X0TAX0, X0TAX0_i, s2XTAcorrX, YTAcorrY_diag, \
        sXTAcorrY, X0TAY, XTAX0 = s._matrix_flattened_grid(
            X0TAX0, X0TAX0_i, SNR_grids, XTAcorrX, YTAcorrY_diag, XTAcorrY,
            X0TAY, XTAX0, n_C, n_V, n_X0, n_grid)

    assert (half_log_det_X0TAX0[0] == half_log_det_X0TAX0[1]
            and half_log_det_X0TAX0[2] == half_log_det_X0TAX0[3]
            and half_log_det_X0TAX0[0] == half_log_det_X0TAX0[2]
            ), '_matrix_flattened_grid has mistake with half_log_det_X0TAX0'
    assert (np.array_equal(X0TAX0[0, :, :], X0TAX0[1, :, :])
            and np.array_equal(X0TAX0[2, :, :], X0TAX0[3, :, :])
            and np.array_equal(X0TAX0[0, :, :], X0TAX0[2, :, :])
            ), '_matrix_flattened_grid has mistake X0TAX0'
    assert (np.array_equal(X0TAX0_i[0, :, :], X0TAX0_i[1, :, :])
            and np.array_equal(X0TAX0_i[2, :, :], X0TAX0_i[3, :, :])
            and np.array_equal(X0TAX0_i[0, :, :], X0TAX0_i[2, :, :])
            ), '_matrix_flattened_grid has mistake X0TAX0_i'
    assert np.allclose(
        np.dot(X0TAX0[0, :, :], X0TAX0_i[0, :, :]),
        np.eye(n_X0)
        ), 'X0TAX0_i is not inverse of X0TAX0'
    assert (np.array_equal(YTAcorrY_diag[0, :], YTAcorrY_diag[1, :])
            and np.array_equal(YTAcorrY_diag[2, :], YTAcorrY_diag[3, :])
            and np.array_equal(YTAcorrY_diag[0, :], YTAcorrY_diag[2, :])
            ), '_matrix_flattened_grid has mistake YTAcorrY_diag'
    assert (np.array_equal(sXTAcorrY[0, :, :], sXTAcorrY[1, :, :])
            and np.array_equal(sXTAcorrY[2, :, :], sXTAcorrY[3, :, :])
            and not np.array_equal(sXTAcorrY[0, :, :], sXTAcorrY[2, :, :])
            ), '_matrix_flattened_grid has mistake sXTAcorrY'
    assert (np.array_equal(X0TAY[0, :, :], X0TAY[1, :, :])
            and np.array_equal(X0TAY[2, :, :], X0TAY[3, :, :])
            and np.array_equal(X0TAY[0, :, :], X0TAY[2, :, :])
            ), '_matrix_flattened_grid has mistake X0TAY'
    assert (np.array_equal(XTAX0[0, :, :], XTAX0[1, :, :])
            and np.array_equal(XTAX0[2, :, :], XTAX0[3, :, :])
            and np.array_equal(XTAX0[0, :, :], XTAX0[2, :, :])
            ), '_matrix_flattened_grid has mistake XTAX0'

    # Now we test the other way
    rho_grids, rho_weights = s._set_rho_grids()
    # rho_grids, rho_weights = s._set_rho_grids()
    SNR_grids = np.ones(2) * 0.1
    SNR_weights = np.ones(2) / 2
    # We purposefully set all SNR to be equal to test flattening of
    # grids.
    X0TAX0, X0TAX0_i, XTAcorrX, XTAcorrY, YTAcorrY_diag, X0TAY, XTAX0 \
        = s._precompute_ar1_quad_forms_marginalized(
            XTY, XTDY, XTFY, YTY_diag, YTDY_diag, YTFY_diag, XTX,
            XTDX, XTFX, X0TX0, X0TDX0, X0TFX0, XTX0, XTDX0, XTFX0,
            X0TY, X0TDY, X0TFY, rho_grids, n_V, n_X0)

    half_log_det_X0TAX0, X0TAX0, X0TAX0_i, s2XTAcorrX, YTAcorrY_diag, \
        sXTAcorrY, X0TAY, XTAX0 = s._matrix_flattened_grid(
            X0TAX0, X0TAX0_i, SNR_grids, XTAcorrX, YTAcorrY_diag, XTAcorrY,
            X0TAY, XTAX0, n_C, n_V, n_X0, n_grid)

    assert (half_log_det_X0TAX0[0] == half_log_det_X0TAX0[2]
            and half_log_det_X0TAX0[1] == half_log_det_X0TAX0[3]
            and not half_log_det_X0TAX0[0] == half_log_det_X0TAX0[1]
            ), '_matrix_flattened_grid has mistake with half_log_det_X0TAX0'
    assert (np.array_equal(X0TAX0[0, :, :], X0TAX0[2, :, :])
            and np.array_equal(X0TAX0[1, :, :], X0TAX0[3, :, :])
            and not np.array_equal(X0TAX0[0, :, :], X0TAX0[1, :, :])
            ), '_matrix_flattened_grid has mistake X0TAX0'
    assert (np.array_equal(X0TAX0_i[0, :, :], X0TAX0_i[2, :, :])
            and np.array_equal(X0TAX0_i[1, :, :], X0TAX0_i[3, :, :])
            and not np.array_equal(X0TAX0_i[0, :, :], X0TAX0_i[1, :, :])
            ), '_matrix_flattened_grid has mistake X0TAX0_i'
    assert np.allclose(
        np.dot(X0TAX0[0, :, :], X0TAX0_i[0, :, :]),
        np.eye(n_X0)
        ), 'X0TAX0_i is not inverse of X0TAX0'
    assert (np.array_equal(YTAcorrY_diag[0, :], YTAcorrY_diag[2, :])
            and np.array_equal(YTAcorrY_diag[1, :], YTAcorrY_diag[3, :])
            and not np.array_equal(YTAcorrY_diag[0, :],
                                   YTAcorrY_diag[1, :])
            ), '_matrix_flattened_grid has mistake YTAcorrY_diag'
    assert (np.array_equal(sXTAcorrY[0, :, :], sXTAcorrY[2, :, :])
            and np.array_equal(sXTAcorrY[1, :, :], sXTAcorrY[3, :, :])
            and not np.array_equal(sXTAcorrY[0, :, :], sXTAcorrY[1, :, :])
            ), '_matrix_flattened_grid has mistake sXTAcorrY'
    assert (np.array_equal(X0TAY[0, :, :], X0TAY[2, :, :])
            and np.array_equal(X0TAY[1, :, :], X0TAY[3, :, :])
            and not np.array_equal(X0TAY[0, :, :], X0TAY[1, :, :])
            ), '_matrix_flattened_grid has mistake X0TAY'
    assert (np.array_equal(XTAX0[0, :, :], XTAX0[2, :, :])
            and np.array_equal(XTAX0[1, :, :], XTAX0[3, :, :])
            and not np.array_equal(XTAX0[0, :, :], XTAX0[1, :, :])
            ), '_matrix_flattened_grid has mistake XTAX0'

    # Now test the integration over SNR
    s.SNR_bins = 50
    s.rho_bins = 1
    SNR_grids, SNR_weights = s._set_SNR_grids()
    rho_grids, rho_weights = s._set_rho_grids()
    n_grid = s.SNR_bins * s.rho_bins

    def setup_for_test():
        # This function will be re-used to set up the variables necessary for
        # testing.

        X0TAX0, X0TAX0_i, XTAcorrX, XTAcorrY, YTAcorrY_diag, X0TAY, XTAX0 \
            = s._precompute_ar1_quad_forms_marginalized(
                XTY, XTDY, XTFY, YTY_diag, YTDY_diag, YTFY_diag, XTX,
                XTDX, XTFX, X0TX0, X0TDX0, X0TFX0, XTX0, XTDX0, XTFX0,
                X0TY, X0TDY, X0TFY, rho_grids, n_V, n_X0)

        half_log_det_X0TAX0, X0TAX0, X0TAX0_i, s2XTAcorrX, YTAcorrY_diag, \
            sXTAcorrY, X0TAY, XTAX0 = s._matrix_flattened_grid(
                X0TAX0, X0TAX0_i, SNR_grids, XTAcorrX, YTAcorrY_diag, XTAcorrY,
                X0TAY, XTAX0, n_C, n_V, n_X0, n_grid)

        log_weights = np.reshape(
            np.log(SNR_weights[:, None]) + np.log(rho_weights), n_grid)
        all_rho_grids = np.reshape(np.repeat(
            rho_grids[None, :], s.SNR_bins, axis=0), n_grid)
        log_fixed_terms = - (n_T - n_X0) / 2 * np.log(2 * np.pi) + n_run \
            / 2 * np.log(1 - all_rho_grids**2) + scipy.special.gammaln(
                (n_T - n_X0 - 2) / 2) + (n_T - n_X0 - 2) / 2 * np.log(2)
        return s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, \
            log_weights, log_fixed_terms

    (s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights,
     log_fixed_terms) = setup_for_test()
    LL_total, _ = s._loglike_marginalized(current_vec_U_chlsk_l, s2XTAcorrX,
                                          YTAcorrY_diag, sXTAcorrY,
                                          half_log_det_X0TAX0, log_weights,
                                          log_fixed_terms, l_idx, n_C, n_T,
                                          n_V, n_X0, n_grid, rank=rank)
    LL_total = - LL_total
    # Now we re-calculate using scipy.integrate
    s.SNR_bins = 100
    SNR_grids = np.linspace(0, 12, s.SNR_bins)
    SNR_weights = np.exp(- SNR_grids)
    SNR_weights = SNR_weights / np.sum(SNR_weights)
    n_grid = s.SNR_bins * s.rho_bins

    (s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights,
     log_fixed_terms) = setup_for_test()
    LL_raw, _, _, _ = s._raw_loglike_grids(L, s2XTAcorrX, YTAcorrY_diag,
                                           sXTAcorrY, half_log_det_X0TAX0,
                                           log_weights, log_fixed_terms,
                                           n_C, n_T, n_V, n_X0,
                                           n_grid, rank)
    result_sum, max_value, result_exp = utils.sumexp_stable(LL_raw)
    scipy_sum = scipy.integrate.simps(y=result_exp, axis=0)
    LL_total_scipy = np.sum(np.log(scipy_sum) + max_value)

    tol = 1e-3
    assert(np.isclose(LL_total_scipy, LL_total, rtol=tol)), \
        'Error of log likelihood calculation exceeds the tolerance'

    # Now test the log normal prior
    s = brainiak.reprsimil.brsa.GBRSA(n_iter=1, auto_nuisance=False,
                                      SNR_prior='lognorm')
    s.SNR_bins = 50
    s.rho_bins = 1
    SNR_grids, SNR_weights = s._set_SNR_grids()
    rho_grids, rho_weights = s._set_rho_grids()
    n_grid = s.SNR_bins * s.rho_bins

    (s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights,
     log_fixed_terms) = setup_for_test()
    LL_total, _ = s._loglike_marginalized(current_vec_U_chlsk_l, s2XTAcorrX,
                                          YTAcorrY_diag, sXTAcorrY,
                                          half_log_det_X0TAX0, log_weights,
                                          log_fixed_terms, l_idx, n_C, n_T,
                                          n_V, n_X0, n_grid, rank=rank)
    LL_total = - LL_total
    # Now we re-calculate using scipy.integrate
    s.SNR_bins = 400
    SNR_grids = np.linspace(1e-8, 20, s.SNR_bins)
    log_SNR_weights = scipy.stats.lognorm.logpdf(SNR_grids, s=s.logS_range)
    result_sum, max_value, result_exp = utils.sumexp_stable(
        log_SNR_weights[:, None])
    SNR_weights = np.squeeze(result_exp / result_sum)
    n_grid = s.SNR_bins * s.rho_bins

    (s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights,
     log_fixed_terms) = setup_for_test()
    LL_raw, _, _, _ = s._raw_loglike_grids(L, s2XTAcorrX, YTAcorrY_diag,
                                           sXTAcorrY, half_log_det_X0TAX0,
                                           log_weights, log_fixed_terms,
                                           n_C, n_T, n_V, n_X0,
                                           n_grid, rank)
    result_sum, max_value, result_exp = utils.sumexp_stable(LL_raw)
    scipy_sum = scipy.integrate.simps(y=result_exp, axis=0)
    LL_total_scipy = np.sum(np.log(scipy_sum) + max_value)

    tol = 1e-3
    assert(np.isclose(LL_total_scipy, LL_total, rtol=tol)), \
        'Error of log likelihood calculation exceeds the tolerance'