def get_gradient_R2_over_each(states, meta_curr, bin_selections=None): """ This method uses different regressions for different bins, and measures R2 on each bin. the average across bins is reported. upside is that is not sensitive to variance across bins, and won't be killed by a small number of badly predicted bins downside is that if bins are unevenly sampled, the output is not very representative. not just in terms of number of samples, but also in total variance per bin. """ X2, Z2 = states, meta_curr r2_over_bins, n_over_bins = [], [] r2_over_bins_niter = [] if bin_selections is not None: for t_select in bin_selections: nobs = np.nansum(t_select) n_over_bins.append(nobs) if leave_one_out: k_fold_val = np.max([nobs, min_nobs]) k_fold_val = np.min([k_fold_val, max_k]) else: k_fold_val = k_folds if nobs < k_fold_val: r_2 = np.nan r2_over_splits = [np.nan] * niter else: # should not resample over R2, since this will be biased up. # resample (niter) in regression to calculate r r_2, r2_over_splits = utils.get_R2_highD( Z2[t_select, :], X2[t_select, :], niter=niter, k_folds=k_fold_val, preproc_method=preproc_method, return_splits=True) r2_over_bins.append(r_2) r2_over_bins_niter.append(r2_over_splits) r = np.nanmean(r2_over_bins) else: r = utils.get_R2_highD(Z2, X2, k_folds=k_folds) res_full = { 'r2_over_bins': np.array(r2_over_bins), 'r2_over_bins_niter': np.array(r2_over_bins_niter), 'bin_selections': bin_selections, 'r': r, } if return_full_results: return res_full else: return r
def get_model_decoding_of_gradient_masked(model_data, return_results=False): res, full_meta, masks = data_utils.load_one_base(model_data, mask_early_late=True) masks_of_interest = [ 'output_sim_roll5_nobounce', 'output_vis_nobounce', 'output_sim_roll5_nobounce_early', 'output_vis_nobounce_early', 'output_sim_roll5_nobounce_late', 'output_vis_nobounce_late', ] results_r2 = {} for label in ['xy', 'dxdy']: for mask_fn in masks_of_interest: summary1 = data_utils.get_matrices_from_res(res, full_meta, masks, mask_fn=mask_fn, label=label) results_r2['grad_r2_%s_%s' % (label, mask_fn)] = utils.get_R2_highD(summary1['y'], summary1['X'], k_folds=20) if return_results: results_r2['grad_res_%s_%s' % (label, mask_fn)] = \ utils.linear_regress(summary1['X'], summary1['y'], k_folds=20) return results_r2
def get_cross_time_state_prediction(model_data): def patch_model_data(md): md_out = deepcopy(md) X = np.array(md['output_f']) tmp = np.tile(X, (1, 1, 2)) tmp[:, 0, :] = np.nan md_out['output_f'] = tmp return md_out model_data_patched = patch_model_data(model_data) data_start = data_utils.get_data_aligned_to_epoch_start( model_data_patched, use_pca_proj=False, epoch_align='output_vis') state_start, meta_start = data_start[0]['state'], data_start[0]['meta'] data_occ = data_utils.get_data_aligned_to_epoch_start( model_data_patched, use_pca_proj=False, epoch_align='output_sim') state_occ, meta_occ = data_occ[0]['state'], data_occ[0]['meta'] data_end = data_utils.get_data_aligned_to_epoch_start( model_data_patched, use_pca_proj=False, epoch_align='output_f') state_end, meta_end = data_end[0]['state'], data_end[0]['meta'] states_to_predict = { 'start': np.nanmean(state_start, axis=1), 'occ': np.nanmean(state_occ, axis=1), 'end': np.nanmean(state_end, axis=1), } R2_cross_epoch = {} for i in states_to_predict.keys(): for j in states_to_predict.keys(): if j == i: continue R2_cross_epoch['%s_to_%s' % (j, i)] = utils.get_R2_highD( states_to_predict[i], states_to_predict[j]) return R2_cross_epoch
def get_variance_explained_metrics(state_matrix, meta_matrices): res_ = {} keys_oi = ['xy', 'xydxdy', 'xydspeeddtheta'] for yfn in keys_oi: meta_matrix = np.array(meta_matrices[yfn]) state_matrix_curr, meta_matrix_curr = utils.mask2d( state_matrix, meta_matrix) res_['prop_var_exp_%s' % yfn] = utils.get_R2_highD( state_matrix_curr, meta_matrix_curr) return res_
def regress_state_from_meta_vars(state_, meta_): keys_oi = ['y', 'xydxdy', 'xydspeeddtheta'] reg_r = {} X = np.nanmean(state_, axis=1) for fk in keys_oi: Y = np.squeeze(np.nanmean(meta_[fk], axis=1)) if Y.ndim < 2: Y = np.expand_dims(Y, axis=1) reg_r['encode_R2_%s' % fk] = utils.get_R2_highD(X, Y) return reg_r
def get_gradient_R2_over_all(states, meta_curr, bin_selections=None): """ this method uses different regressions for different bins, but then aggregates all into a single prediction that is used to compute R2. downside is that poor prediction in a small number of bins can tank the whole thing. upside is that if bins are unevenly sampled, this is immune to that. """ X2, Z2 = states, meta_curr Z_pred = np.ones(Z2.shape) * np.nan if bin_selections is not None: for t_select in bin_selections: if np.nansum(t_select) > 2: reg_res = utils.linear_regress(X2[t_select, :], Z2[t_select, :]) Z_pred[t_select, :] = reg_res['y_pred'] else: reg_res = utils.linear_regress(X2, Z2) Z_pred = reg_res['y_pred'] Z2, Z_pred = remove_nans(Z2, Z_pred) return utils.get_R2_highD(Z2, Z_pred, cross_validate_regression=False)
def get_piecewise_encoding(model_data): def get_prediction_variables(epoch_mask='output_vis'): xy = np.array(model_data[epoch_mask]) full_state = np.array(model_data['state']) full_meta = data_utils.get_full_meta_from_xy(xy[:, :, 0], xy[:, :, 1]) mask = np.array(full_meta['x']) state_matrix = utils.flatten_to_mat(full_state, mask)['X'] meta_matrices = {} for fk in full_meta.keys(): full_y = np.array(full_meta[fk]) meta_matrices[fk] = utils.flatten_to_mat(full_y, mask)['X'] return state_matrix, meta_matrices def get_piecewise_prediction(M1, M2, mfn='xy'): # format as visual_meta, occ_meta, visual_vs_occluded_bias M1_a = np.concatenate( (M1[mfn], np.zeros(M1[mfn].shape), np.ones((M1[mfn].shape[0], 1))), axis=1) M2_a = np.concatenate( (np.zeros(M2[mfn].shape), M2[mfn], np.zeros( (M2[mfn].shape[0], 1))), axis=1) return np.concatenate((M1_a, M2_a), axis=0) res_ = {} S1_, M1_ = get_prediction_variables(epoch_mask='output_vis') S2_, M2_ = get_prediction_variables(epoch_mask='output_sim') S = np.concatenate((S1_, S2_), axis=0) for mfn_ in ['xy', 'xydspeeddtheta']: M = get_piecewise_prediction(M1_, M2_, mfn=mfn_) t = np.isfinite(np.mean(M, axis=1)) res_['R2_%s' % mfn_] = utils.get_R2_highD(S[t, :], M[t, :]) return res_