def predict(self, X): """Classify the output for given data Parameters ---------- X : list of 2D arrays, element i has shape=[voxels_i, samples_i] Each element in the list contains the fMRI data of one subject The number of voxels should be according to each subject at the moment of training the model. Returns ------- p: list of arrays, element i has shape=[samples_i] Predictions for each data sample. """ # Check if the model exist if hasattr(self, 'w_') is False: raise NotFittedError("The model fit has not been run yet.") # Check the number of subjects if len(X) != len(self.w_): raise ValueError("The number of subjects does not match the one" " in the model.") X_shared = self.transform(X) p = [None] * len(X_shared) for subject in range(len(X_shared)): sumexp, _, exponents = utils.sumexp_stable( self.theta_.T.dot(X_shared[subject]) + self.bias_) p[subject] = self.classes_[ (exponents / sumexp[np.newaxis, :]).argmax(axis=0)] return p
def predict(self, X): """Classify the output for given data Parameters ---------- X : list of 2D arrays, element i has shape=[voxels_i, samples_i] Each element in the list contains the fMRI data of one subject The number of voxels should be according to each subject at the moment of training the model. Returns ------- p: list of arrays, element i has shape=[samples_i] Predictions for each data sample. """ # Check if the model exist if hasattr(self, 'w_') is False: raise NotFittedError("The model fit has not been run yet.") # Check the number of subjects if len(X) != len(self.w_): raise ValueError("The number of subjects does not match the one" " in the model.") X_shared = self.transform(X) p = [None] * len(X_shared) for subject in range(len(X_shared)): sumexp, _, exponents = utils.sumexp_stable( self.theta_.T.dot(X_shared[subject]) + self.bias_) p[subject] = self.classes_[(exponents / sumexp[np.newaxis, :]).argmax(axis=0)] return p
def test_sumexp(): from brainiak.utils.utils import sumexp_stable import numpy as np data = np.array([[1, 1],[0, 1]]) sums, maxs, exps = sumexp_stable(data) assert sums.size == data.shape[1], "Invalid sum(exp(v)) computation (wrong # samples in sums)" assert exps.shape[0] == data.shape[0], "Invalid exp(v) computation (wrong # features)" assert exps.shape[1] == data.shape[1], "Invalid exp(v) computation (wrong # samples)" assert maxs.size == data.shape[1], "Invalid max computation (wrong # samples in maxs)"
def _loss_lr_subject(self, data, labels, w, theta, bias): """Compute the Loss MLR for a single subject (without regularization) Parameters ---------- data : array, shape=[voxels, samples] The fMRI data of subject i for the classification task. labels : array of int, shape=[samples] The labels for the data samples in data. w : array, shape=[voxels, features] The orthogonal transform (mapping) :math:`W_i` for subject i. theta : array, shape=[classes, features] The MLR class plane parameters. bias : array, shape=[classes] The MLR class biases. Returns ------- loss : float The loss MLR for the subject """ if data is None: return 0.0 samples = data.shape[1] thetaT_wi_zi_plus_bias = theta.T.dot(w.T.dot(data)) + bias sum_exp, max_value, _ = utils.sumexp_stable(thetaT_wi_zi_plus_bias) sum_exp_values = np.log(sum_exp) + max_value aux = 0.0 for sample in range(samples): label = labels[sample] aux += thetaT_wi_zi_plus_bias[label, sample] return self.alpha / samples / self.gamma * (sum_exp_values.sum() - aux)
def test_grid_flatten_num_int(): # Check for numeric integration of SNR, and correctly flattening 2-D grids # to 1-D grid. import brainiak.reprsimil.brsa import brainiak.utils.utils as utils import numpy as np import scipy.special n_V = 30 n_T = 50 n_C = 3 design = np.random.randn(n_T, n_C) U_simu = np.asarray([[1.0, 0.1, 0.0], [0.1, 1.0, 0.2], [0.0, 0.2, 1.0]]) L_simu = np.linalg.cholesky(U_simu) SNR = np.random.exponential(size=n_V) beta = np.dot(L_simu, np.random.randn(n_C, n_V)) * SNR noise = np.random.randn(n_T, n_V) Y = np.dot(design, beta) + noise X = design X_base = None scan_onsets = [0] s = brainiak.reprsimil.brsa.GBRSA(n_iter=1, auto_nuisance=False, SNR_prior='exp') s.fit(X=[Y], design=[design]) rank = n_C l_idx, rank = s._chol_idx(n_C, rank) L = np.zeros((n_C, rank)) n_l = np.size(l_idx[0]) current_vec_U_chlsk_l = s.random_state_.randn(n_l) * 10 L[l_idx] = current_vec_U_chlsk_l # Now we change the grids for SNR and rho for testing. s.SNR_bins = 2 s.rho_bins = 2 SNR_grids, SNR_weights = s._set_SNR_grids() # rho_grids, rho_weights = s._set_rho_grids() rho_grids = np.ones(2) * 0.1 rho_weights = np.ones(2) / 2 # We purposefully set all rhos to be equal to test flattening of # grids. n_grid = s.SNR_bins * s.rho_bins D, F, run_TRs, n_run = s._prepare_DF(n_T, scan_onsets=scan_onsets) XTY, XTDY, XTFY, YTY_diag, YTDY_diag, YTFY_diag, XTX, XTDX, XTFX \ = s._prepare_data_XY(X, Y, D, F) X0TX0, X0TDX0, X0TFX0, XTX0, XTDX0, XTFX0, X0TY, X0TDY, X0TFY, X0, \ X_base, n_X0, idx_DC = s._prepare_data_XYX0( X, Y, X_base, None, D, F, run_TRs, no_DC=False) X0TAX0, X0TAX0_i, XTAcorrX, XTAcorrY, YTAcorrY_diag, X0TAY, XTAX0 \ = s._precompute_ar1_quad_forms_marginalized( XTY, XTDY, XTFY, YTY_diag, YTDY_diag, YTFY_diag, XTX, XTDX, XTFX, X0TX0, X0TDX0, X0TFX0, XTX0, XTDX0, XTFX0, X0TY, X0TDY, X0TFY, rho_grids, n_V, n_X0) half_log_det_X0TAX0, X0TAX0, X0TAX0_i, s2XTAcorrX, YTAcorrY_diag, \ sXTAcorrY, X0TAY, XTAX0 = s._matrix_flattened_grid( X0TAX0, X0TAX0_i, SNR_grids, XTAcorrX, YTAcorrY_diag, XTAcorrY, X0TAY, XTAX0, n_C, n_V, n_X0, n_grid) assert (half_log_det_X0TAX0[0] == half_log_det_X0TAX0[1] and half_log_det_X0TAX0[2] == half_log_det_X0TAX0[3] and half_log_det_X0TAX0[0] == half_log_det_X0TAX0[2] ), '_matrix_flattened_grid has mistake with half_log_det_X0TAX0' assert (np.array_equal(X0TAX0[0, :, :], X0TAX0[1, :, :]) and np.array_equal(X0TAX0[2, :, :], X0TAX0[3, :, :]) and np.array_equal( X0TAX0[0, :, :], X0TAX0[2, :, :])), '_matrix_flattened_grid has mistake X0TAX0' assert (np.array_equal(X0TAX0_i[0, :, :], X0TAX0_i[1, :, :]) and np.array_equal(X0TAX0_i[2, :, :], X0TAX0_i[3, :, :]) and np.array_equal(X0TAX0_i[0, :, :], X0TAX0_i[2, :, :]) ), '_matrix_flattened_grid has mistake X0TAX0_i' assert np.allclose(np.dot(X0TAX0[0, :, :], X0TAX0_i[0, :, :]), np.eye(n_X0)), 'X0TAX0_i is not inverse of X0TAX0' assert (np.array_equal(YTAcorrY_diag[0, :], YTAcorrY_diag[1, :]) and np.array_equal(YTAcorrY_diag[2, :], YTAcorrY_diag[3, :]) and np.array_equal(YTAcorrY_diag[0, :], YTAcorrY_diag[2, :]) ), '_matrix_flattened_grid has mistake YTAcorrY_diag' assert (np.array_equal(sXTAcorrY[0, :, :], sXTAcorrY[1, :, :]) and np.array_equal(sXTAcorrY[2, :, :], sXTAcorrY[3, :, :]) and not np.array_equal(sXTAcorrY[0, :, :], sXTAcorrY[2, :, :]) ), '_matrix_flattened_grid has mistake sXTAcorrY' assert (np.array_equal(X0TAY[0, :, :], X0TAY[1, :, :]) and np.array_equal(X0TAY[2, :, :], X0TAY[3, :, :]) and np.array_equal( X0TAY[0, :, :], X0TAY[2, :, :])), '_matrix_flattened_grid has mistake X0TAY' assert (np.array_equal(XTAX0[0, :, :], XTAX0[1, :, :]) and np.array_equal(XTAX0[2, :, :], XTAX0[3, :, :]) and np.array_equal( XTAX0[0, :, :], XTAX0[2, :, :])), '_matrix_flattened_grid has mistake XTAX0' # Now we test the other way rho_grids, rho_weights = s._set_rho_grids() # rho_grids, rho_weights = s._set_rho_grids() SNR_grids = np.ones(2) * 0.1 SNR_weights = np.ones(2) / 2 # We purposefully set all SNR to be equal to test flattening of # grids. X0TAX0, X0TAX0_i, XTAcorrX, XTAcorrY, YTAcorrY_diag, X0TAY, XTAX0 \ = s._precompute_ar1_quad_forms_marginalized( XTY, XTDY, XTFY, YTY_diag, YTDY_diag, YTFY_diag, XTX, XTDX, XTFX, X0TX0, X0TDX0, X0TFX0, XTX0, XTDX0, XTFX0, X0TY, X0TDY, X0TFY, rho_grids, n_V, n_X0) half_log_det_X0TAX0, X0TAX0, X0TAX0_i, s2XTAcorrX, YTAcorrY_diag, \ sXTAcorrY, X0TAY, XTAX0 = s._matrix_flattened_grid( X0TAX0, X0TAX0_i, SNR_grids, XTAcorrX, YTAcorrY_diag, XTAcorrY, X0TAY, XTAX0, n_C, n_V, n_X0, n_grid) assert (half_log_det_X0TAX0[0] == half_log_det_X0TAX0[2] and half_log_det_X0TAX0[1] == half_log_det_X0TAX0[3] and not half_log_det_X0TAX0[0] == half_log_det_X0TAX0[1] ), '_matrix_flattened_grid has mistake with half_log_det_X0TAX0' assert (np.array_equal(X0TAX0[0, :, :], X0TAX0[2, :, :]) and np.array_equal(X0TAX0[1, :, :], X0TAX0[3, :, :]) and not np.array_equal(X0TAX0[0, :, :], X0TAX0[1, :, :]) ), '_matrix_flattened_grid has mistake X0TAX0' assert (np.array_equal(X0TAX0_i[0, :, :], X0TAX0_i[2, :, :]) and np.array_equal(X0TAX0_i[1, :, :], X0TAX0_i[3, :, :]) and not np.array_equal(X0TAX0_i[0, :, :], X0TAX0_i[1, :, :]) ), '_matrix_flattened_grid has mistake X0TAX0_i' assert np.allclose(np.dot(X0TAX0[0, :, :], X0TAX0_i[0, :, :]), np.eye(n_X0)), 'X0TAX0_i is not inverse of X0TAX0' assert (np.array_equal(YTAcorrY_diag[0, :], YTAcorrY_diag[2, :]) and np.array_equal(YTAcorrY_diag[1, :], YTAcorrY_diag[3, :]) and not np.array_equal(YTAcorrY_diag[0, :], YTAcorrY_diag[1, :]) ), '_matrix_flattened_grid has mistake YTAcorrY_diag' assert (np.array_equal(sXTAcorrY[0, :, :], sXTAcorrY[2, :, :]) and np.array_equal(sXTAcorrY[1, :, :], sXTAcorrY[3, :, :]) and not np.array_equal(sXTAcorrY[0, :, :], sXTAcorrY[1, :, :]) ), '_matrix_flattened_grid has mistake sXTAcorrY' assert (np.array_equal(X0TAY[0, :, :], X0TAY[2, :, :]) and np.array_equal(X0TAY[1, :, :], X0TAY[3, :, :]) and not np.array_equal(X0TAY[0, :, :], X0TAY[1, :, :]) ), '_matrix_flattened_grid has mistake X0TAY' assert (np.array_equal(XTAX0[0, :, :], XTAX0[2, :, :]) and np.array_equal(XTAX0[1, :, :], XTAX0[3, :, :]) and not np.array_equal(XTAX0[0, :, :], XTAX0[1, :, :]) ), '_matrix_flattened_grid has mistake XTAX0' # Now test the integration over SNR s.SNR_bins = 50 s.rho_bins = 1 SNR_grids, SNR_weights = s._set_SNR_grids() rho_grids, rho_weights = s._set_rho_grids() n_grid = s.SNR_bins * s.rho_bins def setup_for_test(): # This function will be re-used to set up the variables necessary for # testing. X0TAX0, X0TAX0_i, XTAcorrX, XTAcorrY, YTAcorrY_diag, X0TAY, XTAX0 \ = s._precompute_ar1_quad_forms_marginalized( XTY, XTDY, XTFY, YTY_diag, YTDY_diag, YTFY_diag, XTX, XTDX, XTFX, X0TX0, X0TDX0, X0TFX0, XTX0, XTDX0, XTFX0, X0TY, X0TDY, X0TFY, rho_grids, n_V, n_X0) half_log_det_X0TAX0, X0TAX0, X0TAX0_i, s2XTAcorrX, YTAcorrY_diag, \ sXTAcorrY, X0TAY, XTAX0 = s._matrix_flattened_grid( X0TAX0, X0TAX0_i, SNR_grids, XTAcorrX, YTAcorrY_diag, XTAcorrY, X0TAY, XTAX0, n_C, n_V, n_X0, n_grid) log_weights = np.reshape( np.log(SNR_weights[:, None]) + np.log(rho_weights), n_grid) all_rho_grids = np.reshape( np.repeat(rho_grids[None, :], s.SNR_bins, axis=0), n_grid) log_fixed_terms = - (n_T - n_X0) / 2 * np.log(2 * np.pi) + n_run \ / 2 * np.log(1 - all_rho_grids**2) + scipy.special.gammaln( (n_T - n_X0 - 2) / 2) + (n_T - n_X0 - 2) / 2 * np.log(2) return s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, \ log_weights, log_fixed_terms (s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights, log_fixed_terms) = setup_for_test() LL_total, _ = s._loglike_marginalized(current_vec_U_chlsk_l, s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights, log_fixed_terms, l_idx, n_C, n_T, n_V, n_X0, n_grid, rank=rank) LL_total = -LL_total # Now we re-calculate using scipy.integrate s.SNR_bins = 100 SNR_grids = np.linspace(0, 12, s.SNR_bins) SNR_weights = np.exp(-SNR_grids) SNR_weights = SNR_weights / np.sum(SNR_weights) n_grid = s.SNR_bins * s.rho_bins (s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights, log_fixed_terms) = setup_for_test() LL_raw, _, _, _ = s._raw_loglike_grids(L, s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights, log_fixed_terms, n_C, n_T, n_V, n_X0, n_grid, rank) result_sum, max_value, result_exp = utils.sumexp_stable(LL_raw) scipy_sum = scipy.integrate.simps(y=result_exp, axis=0) LL_total_scipy = np.sum(np.log(scipy_sum) + max_value) tol = 1e-3 assert(np.isclose(LL_total_scipy, LL_total, rtol=tol)), \ 'Error of log likelihood calculation exceeds the tolerance' # Now test the log normal prior s = brainiak.reprsimil.brsa.GBRSA(n_iter=1, auto_nuisance=False, SNR_prior='lognorm') s.SNR_bins = 50 s.rho_bins = 1 SNR_grids, SNR_weights = s._set_SNR_grids() rho_grids, rho_weights = s._set_rho_grids() n_grid = s.SNR_bins * s.rho_bins (s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights, log_fixed_terms) = setup_for_test() LL_total, _ = s._loglike_marginalized(current_vec_U_chlsk_l, s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights, log_fixed_terms, l_idx, n_C, n_T, n_V, n_X0, n_grid, rank=rank) LL_total = -LL_total # Now we re-calculate using scipy.integrate s.SNR_bins = 400 SNR_grids = np.linspace(1e-8, 20, s.SNR_bins) log_SNR_weights = scipy.stats.lognorm.logpdf(SNR_grids, s=s.logS_range) result_sum, max_value, result_exp = utils.sumexp_stable( log_SNR_weights[:, None]) SNR_weights = np.squeeze(result_exp / result_sum) n_grid = s.SNR_bins * s.rho_bins (s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights, log_fixed_terms) = setup_for_test() LL_raw, _, _, _ = s._raw_loglike_grids(L, s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights, log_fixed_terms, n_C, n_T, n_V, n_X0, n_grid, rank) result_sum, max_value, result_exp = utils.sumexp_stable(LL_raw) scipy_sum = scipy.integrate.simps(y=result_exp, axis=0) LL_total_scipy = np.sum(np.log(scipy_sum) + max_value) tol = 1e-3 assert(np.isclose(LL_total_scipy, LL_total, rtol=tol)), \ 'Error of log likelihood calculation exceeds the tolerance'
def test_grid_flatten_num_int(): # Check for numeric integration of SNR, and correctly flattening 2-D grids # to 1-D grid. import brainiak.reprsimil.brsa import brainiak.utils.utils as utils import numpy as np import scipy.special n_V = 30 n_T = 50 n_C = 3 design = np.random.randn(n_T, n_C) U_simu = np.asarray([[1.0, 0.1, 0.0], [0.1, 1.0, 0.2], [0.0, 0.2, 1.0]]) L_simu = np.linalg.cholesky(U_simu) SNR = np.random.exponential(size=n_V) beta = np.dot(L_simu, np.random.randn(n_C, n_V)) * SNR noise = np.random.randn(n_T, n_V) Y = np.dot(design, beta) + noise X = design X_base = None scan_onsets = [0] s = brainiak.reprsimil.brsa.GBRSA(n_iter=1, auto_nuisance=False, SNR_prior='exp') s.fit(X=[Y], design=[design]) rank = n_C l_idx, rank = s._chol_idx(n_C, rank) L = np.zeros((n_C, rank)) n_l = np.size(l_idx[0]) current_vec_U_chlsk_l = s.random_state_.randn(n_l) * 10 L[l_idx] = current_vec_U_chlsk_l # Now we change the grids for SNR and rho for testing. s.SNR_bins = 2 s.rho_bins = 2 SNR_grids, SNR_weights = s._set_SNR_grids() # rho_grids, rho_weights = s._set_rho_grids() rho_grids = np.ones(2) * 0.1 rho_weights = np.ones(2) / 2 # We purposefully set all rhos to be equal to test flattening of # grids. n_grid = s.SNR_bins * s.rho_bins D, F, run_TRs, n_run = s._prepare_DF( n_T, scan_onsets=scan_onsets) XTY, XTDY, XTFY, YTY_diag, YTDY_diag, YTFY_diag, XTX, XTDX, XTFX \ = s._prepare_data_XY(X, Y, D, F) X0TX0, X0TDX0, X0TFX0, XTX0, XTDX0, XTFX0, X0TY, X0TDY, X0TFY, X0, \ X_base, n_X0, idx_DC = s._prepare_data_XYX0( X, Y, X_base, None, D, F, run_TRs, no_DC=False) X0TAX0, X0TAX0_i, XTAcorrX, XTAcorrY, YTAcorrY_diag, X0TAY, XTAX0 \ = s._precompute_ar1_quad_forms_marginalized( XTY, XTDY, XTFY, YTY_diag, YTDY_diag, YTFY_diag, XTX, XTDX, XTFX, X0TX0, X0TDX0, X0TFX0, XTX0, XTDX0, XTFX0, X0TY, X0TDY, X0TFY, rho_grids, n_V, n_X0) half_log_det_X0TAX0, X0TAX0, X0TAX0_i, s2XTAcorrX, YTAcorrY_diag, \ sXTAcorrY, X0TAY, XTAX0 = s._matrix_flattened_grid( X0TAX0, X0TAX0_i, SNR_grids, XTAcorrX, YTAcorrY_diag, XTAcorrY, X0TAY, XTAX0, n_C, n_V, n_X0, n_grid) assert (half_log_det_X0TAX0[0] == half_log_det_X0TAX0[1] and half_log_det_X0TAX0[2] == half_log_det_X0TAX0[3] and half_log_det_X0TAX0[0] == half_log_det_X0TAX0[2] ), '_matrix_flattened_grid has mistake with half_log_det_X0TAX0' assert (np.array_equal(X0TAX0[0, :, :], X0TAX0[1, :, :]) and np.array_equal(X0TAX0[2, :, :], X0TAX0[3, :, :]) and np.array_equal(X0TAX0[0, :, :], X0TAX0[2, :, :]) ), '_matrix_flattened_grid has mistake X0TAX0' assert (np.array_equal(X0TAX0_i[0, :, :], X0TAX0_i[1, :, :]) and np.array_equal(X0TAX0_i[2, :, :], X0TAX0_i[3, :, :]) and np.array_equal(X0TAX0_i[0, :, :], X0TAX0_i[2, :, :]) ), '_matrix_flattened_grid has mistake X0TAX0_i' assert np.allclose( np.dot(X0TAX0[0, :, :], X0TAX0_i[0, :, :]), np.eye(n_X0) ), 'X0TAX0_i is not inverse of X0TAX0' assert (np.array_equal(YTAcorrY_diag[0, :], YTAcorrY_diag[1, :]) and np.array_equal(YTAcorrY_diag[2, :], YTAcorrY_diag[3, :]) and np.array_equal(YTAcorrY_diag[0, :], YTAcorrY_diag[2, :]) ), '_matrix_flattened_grid has mistake YTAcorrY_diag' assert (np.array_equal(sXTAcorrY[0, :, :], sXTAcorrY[1, :, :]) and np.array_equal(sXTAcorrY[2, :, :], sXTAcorrY[3, :, :]) and not np.array_equal(sXTAcorrY[0, :, :], sXTAcorrY[2, :, :]) ), '_matrix_flattened_grid has mistake sXTAcorrY' assert (np.array_equal(X0TAY[0, :, :], X0TAY[1, :, :]) and np.array_equal(X0TAY[2, :, :], X0TAY[3, :, :]) and np.array_equal(X0TAY[0, :, :], X0TAY[2, :, :]) ), '_matrix_flattened_grid has mistake X0TAY' assert (np.array_equal(XTAX0[0, :, :], XTAX0[1, :, :]) and np.array_equal(XTAX0[2, :, :], XTAX0[3, :, :]) and np.array_equal(XTAX0[0, :, :], XTAX0[2, :, :]) ), '_matrix_flattened_grid has mistake XTAX0' # Now we test the other way rho_grids, rho_weights = s._set_rho_grids() # rho_grids, rho_weights = s._set_rho_grids() SNR_grids = np.ones(2) * 0.1 SNR_weights = np.ones(2) / 2 # We purposefully set all SNR to be equal to test flattening of # grids. X0TAX0, X0TAX0_i, XTAcorrX, XTAcorrY, YTAcorrY_diag, X0TAY, XTAX0 \ = s._precompute_ar1_quad_forms_marginalized( XTY, XTDY, XTFY, YTY_diag, YTDY_diag, YTFY_diag, XTX, XTDX, XTFX, X0TX0, X0TDX0, X0TFX0, XTX0, XTDX0, XTFX0, X0TY, X0TDY, X0TFY, rho_grids, n_V, n_X0) half_log_det_X0TAX0, X0TAX0, X0TAX0_i, s2XTAcorrX, YTAcorrY_diag, \ sXTAcorrY, X0TAY, XTAX0 = s._matrix_flattened_grid( X0TAX0, X0TAX0_i, SNR_grids, XTAcorrX, YTAcorrY_diag, XTAcorrY, X0TAY, XTAX0, n_C, n_V, n_X0, n_grid) assert (half_log_det_X0TAX0[0] == half_log_det_X0TAX0[2] and half_log_det_X0TAX0[1] == half_log_det_X0TAX0[3] and not half_log_det_X0TAX0[0] == half_log_det_X0TAX0[1] ), '_matrix_flattened_grid has mistake with half_log_det_X0TAX0' assert (np.array_equal(X0TAX0[0, :, :], X0TAX0[2, :, :]) and np.array_equal(X0TAX0[1, :, :], X0TAX0[3, :, :]) and not np.array_equal(X0TAX0[0, :, :], X0TAX0[1, :, :]) ), '_matrix_flattened_grid has mistake X0TAX0' assert (np.array_equal(X0TAX0_i[0, :, :], X0TAX0_i[2, :, :]) and np.array_equal(X0TAX0_i[1, :, :], X0TAX0_i[3, :, :]) and not np.array_equal(X0TAX0_i[0, :, :], X0TAX0_i[1, :, :]) ), '_matrix_flattened_grid has mistake X0TAX0_i' assert np.allclose( np.dot(X0TAX0[0, :, :], X0TAX0_i[0, :, :]), np.eye(n_X0) ), 'X0TAX0_i is not inverse of X0TAX0' assert (np.array_equal(YTAcorrY_diag[0, :], YTAcorrY_diag[2, :]) and np.array_equal(YTAcorrY_diag[1, :], YTAcorrY_diag[3, :]) and not np.array_equal(YTAcorrY_diag[0, :], YTAcorrY_diag[1, :]) ), '_matrix_flattened_grid has mistake YTAcorrY_diag' assert (np.array_equal(sXTAcorrY[0, :, :], sXTAcorrY[2, :, :]) and np.array_equal(sXTAcorrY[1, :, :], sXTAcorrY[3, :, :]) and not np.array_equal(sXTAcorrY[0, :, :], sXTAcorrY[1, :, :]) ), '_matrix_flattened_grid has mistake sXTAcorrY' assert (np.array_equal(X0TAY[0, :, :], X0TAY[2, :, :]) and np.array_equal(X0TAY[1, :, :], X0TAY[3, :, :]) and not np.array_equal(X0TAY[0, :, :], X0TAY[1, :, :]) ), '_matrix_flattened_grid has mistake X0TAY' assert (np.array_equal(XTAX0[0, :, :], XTAX0[2, :, :]) and np.array_equal(XTAX0[1, :, :], XTAX0[3, :, :]) and not np.array_equal(XTAX0[0, :, :], XTAX0[1, :, :]) ), '_matrix_flattened_grid has mistake XTAX0' # Now test the integration over SNR s.SNR_bins = 50 s.rho_bins = 1 SNR_grids, SNR_weights = s._set_SNR_grids() rho_grids, rho_weights = s._set_rho_grids() n_grid = s.SNR_bins * s.rho_bins def setup_for_test(): # This function will be re-used to set up the variables necessary for # testing. X0TAX0, X0TAX0_i, XTAcorrX, XTAcorrY, YTAcorrY_diag, X0TAY, XTAX0 \ = s._precompute_ar1_quad_forms_marginalized( XTY, XTDY, XTFY, YTY_diag, YTDY_diag, YTFY_diag, XTX, XTDX, XTFX, X0TX0, X0TDX0, X0TFX0, XTX0, XTDX0, XTFX0, X0TY, X0TDY, X0TFY, rho_grids, n_V, n_X0) half_log_det_X0TAX0, X0TAX0, X0TAX0_i, s2XTAcorrX, YTAcorrY_diag, \ sXTAcorrY, X0TAY, XTAX0 = s._matrix_flattened_grid( X0TAX0, X0TAX0_i, SNR_grids, XTAcorrX, YTAcorrY_diag, XTAcorrY, X0TAY, XTAX0, n_C, n_V, n_X0, n_grid) log_weights = np.reshape( np.log(SNR_weights[:, None]) + np.log(rho_weights), n_grid) all_rho_grids = np.reshape(np.repeat( rho_grids[None, :], s.SNR_bins, axis=0), n_grid) log_fixed_terms = - (n_T - n_X0) / 2 * np.log(2 * np.pi) + n_run \ / 2 * np.log(1 - all_rho_grids**2) + scipy.special.gammaln( (n_T - n_X0 - 2) / 2) + (n_T - n_X0 - 2) / 2 * np.log(2) return s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, \ log_weights, log_fixed_terms (s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights, log_fixed_terms) = setup_for_test() LL_total, _ = s._loglike_marginalized(current_vec_U_chlsk_l, s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights, log_fixed_terms, l_idx, n_C, n_T, n_V, n_X0, n_grid, rank=rank) LL_total = - LL_total # Now we re-calculate using scipy.integrate s.SNR_bins = 100 SNR_grids = np.linspace(0, 12, s.SNR_bins) SNR_weights = np.exp(- SNR_grids) SNR_weights = SNR_weights / np.sum(SNR_weights) n_grid = s.SNR_bins * s.rho_bins (s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights, log_fixed_terms) = setup_for_test() LL_raw, _, _, _ = s._raw_loglike_grids(L, s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights, log_fixed_terms, n_C, n_T, n_V, n_X0, n_grid, rank) result_sum, max_value, result_exp = utils.sumexp_stable(LL_raw) scipy_sum = scipy.integrate.simps(y=result_exp, axis=0) LL_total_scipy = np.sum(np.log(scipy_sum) + max_value) tol = 1e-3 assert(np.isclose(LL_total_scipy, LL_total, rtol=tol)), \ 'Error of log likelihood calculation exceeds the tolerance' # Now test the log normal prior s = brainiak.reprsimil.brsa.GBRSA(n_iter=1, auto_nuisance=False, SNR_prior='lognorm') s.SNR_bins = 50 s.rho_bins = 1 SNR_grids, SNR_weights = s._set_SNR_grids() rho_grids, rho_weights = s._set_rho_grids() n_grid = s.SNR_bins * s.rho_bins (s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights, log_fixed_terms) = setup_for_test() LL_total, _ = s._loglike_marginalized(current_vec_U_chlsk_l, s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights, log_fixed_terms, l_idx, n_C, n_T, n_V, n_X0, n_grid, rank=rank) LL_total = - LL_total # Now we re-calculate using scipy.integrate s.SNR_bins = 400 SNR_grids = np.linspace(1e-8, 20, s.SNR_bins) log_SNR_weights = scipy.stats.lognorm.logpdf(SNR_grids, s=s.logS_range) result_sum, max_value, result_exp = utils.sumexp_stable( log_SNR_weights[:, None]) SNR_weights = np.squeeze(result_exp / result_sum) n_grid = s.SNR_bins * s.rho_bins (s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights, log_fixed_terms) = setup_for_test() LL_raw, _, _, _ = s._raw_loglike_grids(L, s2XTAcorrX, YTAcorrY_diag, sXTAcorrY, half_log_det_X0TAX0, log_weights, log_fixed_terms, n_C, n_T, n_V, n_X0, n_grid, rank) result_sum, max_value, result_exp = utils.sumexp_stable(LL_raw) scipy_sum = scipy.integrate.simps(y=result_exp, axis=0) LL_total_scipy = np.sum(np.log(scipy_sum) + max_value) tol = 1e-3 assert(np.isclose(LL_total_scipy, LL_total, rtol=tol)), \ 'Error of log likelihood calculation exceeds the tolerance'