new_temporal_factors, axis=0) new_temporal_factors_normed[np.isnan( new_temporal_factors_normed)] = 0 # fix 0/0 = nan assert not np.isnan(new_init_data).any() assert new_temporal_factors.shape[ 1] == new_rand_trial_factor.shape[1] assert new_temporal_factors.shape[1] == new_cell_factor.shape[ 1] # run TCA on this small init_factors = KTensor( (new_cell_factor, new_temporal_factors_normed, new_rand_trial_factor)) temp_fit_options = deepcopy(fit_options) temp_fit_options['init'] = init_factors temp_ensemble = tt.Ensemble(fit_method=method, fit_options=temp_fit_options) temp_ensemble.fit(new_init_data, ranks=rr, replicates=replicates, verbose=False) # save trial factors into a single vector now concatenated together output_trial_factors[ this_run_bool, :] = temp_ensemble.results[rr][0].factors[2] # hold onto model results new_KT = KTensor( [mouse_cell_factor, scaled_traces, output_trial_factors]) mouse_kt_list.append(new_KT) mouse_tfac_list.append(output_trial_factors)
I, J, K, R = 100, 100, 100, 4 X = tt.rand_ktensor((I, J, K), rank=R) # Add noise. Xn = np.maximum(0, X.full() + .1 * np.random.randn(I, J, K)) # Fit ensemble of unconstrained tensor decompositions. methods = ( 'cp_als', 'ncp_bcd', 'ncp_hals', ) ensembles = {} for m in methods: ensembles[m] = tt.Ensemble(fit_method=m, fit_options=dict(tol=1e-4)) ensembles[m].fit(Xn, ranks=range(1, 9), replicates=3) # Plotting options for the unconstrained and nonnegative models. plot_options = { 'cp_als': { 'line_kw': { 'color': 'black', 'label': 'cp_als', }, 'scatter_kw': { 'color': 'black', }, }, 'ncp_hals': { 'line_kw': {
def run_unwrapped_tca(thresh=4, force=False, verbose=False): """ Run script to build and save TCA inputs as well as TCA model outputs :param thresh: int -log10(p-value) threshold for calling a cell driven for any stage of learning. :param force: boolean Force a rerun even if output file exists. :param verbose: boolean View TCA progress in the terminal. """ # Do not overwrite existing files if os.path.isfile( cas.paths.analysis_file( f'tca_ensemble_v{thresh}_nneg_nT0_20210209.npy', 'tca_dfs')) and not force: return # parameters # -------------------------------------------------------------------------------------------------- # input data params mice = cas.lookups.mice['allFOV'] words = ['respondent' if s in 'OA27' else 'computation' for s in mice] group_by = 'all3' with_model = False nan_thresh = 0.95 # TCA params method = 'ncp_hals' replicates = 3 fit_options = {'tol': 0.0001, 'max_iter': 500, 'verbose': False} ranks = list(np.arange(1, 21, dtype=int)) ranks.extend([40]) ranks.extend([80]) tca_ranks = [int(s) for s in ranks] # plot params hue_order = [ 'becomes_unrewarded', 'remains_unrewarded', 'becomes_rewarded' ] heatmap_rank = 12 # load in a full size data # -------------------------------------------------------------------------------------------------- tensor_list = [] id_list = [] bhv_list = [] meta_list = [] for mouse, word in zip(mice, words): # return ids, tensor, meta, bhv out = cas.load.load_all_groupday(mouse, word=word, with_model=with_model, group_by=group_by, nan_thresh=nan_thresh) tensor_list.append(out[2]) id_list.append(out[1]) bhv_list.append(out[4]) meta_list.append(cas.utils.add_stages_to_meta(out[3], 'parsed_11stage')) # load Oren's better offset classification # -------------------------------------------------------------------------------------------------- off_df_all_mice = pd.read_pickle( '/twophoton_analysis/Data/analysis/core_dfs/offsets_dfs.pkl') df_list = [] for k, v in off_df_all_mice.items(): v['mouse'] = k df_list.append(v) updated_off_df = pd.concat(df_list, axis=0) updated_off_df = updated_off_df.set_index(['mouse', 'cell_id']) updated_off_df.head() # build input for TCA, selecting cells driven any stage with a given negative log10 p-value 'thresh' # -------------------------------------------------------------------------------------------------- driven_on_tensor_flat_run = [] driven_off_tensor_flat_run = [] driven_on_tensor_flat = [] driven_off_tensor_flat = [] off_mouse_vec_flat, on_mouse_vec_flat, off_cell_vec, on_cell_vec = [], [], [], [] for meta, ids, tensor in zip(meta_list, id_list, tensor_list): # skip LM and seizure mouse if cas.utils.meta_mouse(meta) in ['AS20', 'AS23']: continue if cas.utils.meta_mouse(meta) in ['AS41', 'AS47', 'OA38']: continue if cas.utils.meta_mouse(meta) in cas.lookups.mice['lml5']: continue # calculate drivenness across stages, using Oren's offset boolean off_df = updated_off_df.loc[updated_off_df.reset_index( ).mouse.isin([cas.utils.meta_mouse(meta)]).values, ['offset_test']] off_df = off_df.reindex(ids, level=1) assert np.array_equal(off_df.reset_index().cell_id.values, ids) offset_bool = off_df.offset_test.values with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) drive_df = cas.drive.multi_stat_drive(meta, ids, tensor, alternative='less', offset_bool=offset_bool, neg_log10_pv_thresh=thresh) # flatten tensor and unwrap it to look across stages flat_tensors = {} off_flat_tensors = {} with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) for cue in ['plus', 'minus', 'neutral']: meta_bool = meta.initial_condition.isin([cue]).values # get mean per stage per cue (also split on running speed) stage_mean_tensor_slow = cas.utils.balanced_mean_per_stage( meta, tensor, meta_bool=meta_bool, filter_running='low_pre_speed_only') stage_mean_tensor_fast = cas.utils.balanced_mean_per_stage( meta, tensor, meta_bool=meta_bool, filter_running='high_pre_speed_only') stage_mean_tensor = cas.utils.balanced_mean_per_stage( meta, tensor, meta_bool=meta_bool) # limit to 2 seconds across mice flat_tensors[cue] = cas.utils.unwrap_tensor( stage_mean_tensor[:, :int(np.ceil(15.5 * 3)), :]) flat_tensors[cue + '_slow'] = cas.utils.unwrap_tensor( stage_mean_tensor_slow[:, :int(np.ceil(15.5 * 3)), :]) flat_tensors[cue + '_fast'] = cas.utils.unwrap_tensor( stage_mean_tensor_fast[:, :int(np.ceil(15.5 * 3)), :]) # for offset, use 1 second pre offset and 2 seconds post, assumes 15.5 Hz off_int = int( np.ceil(cas.lookups.stim_length[cas.utils.meta_mouse(meta)] + 1) * 15.5) off_flat_tensors[cue] = cas.utils.unwrap_tensor( stage_mean_tensor[:, (off_int - 15):(off_int + 42), :]) off_flat_tensors[cue + '_slow'] = cas.utils.unwrap_tensor( stage_mean_tensor_slow[:, (off_int - 15):(off_int + 42), :]) off_flat_tensors[cue + '_fast'] = cas.utils.unwrap_tensor( stage_mean_tensor_fast[:, (off_int - 15):(off_int + 42), :]) # get driven ids for different behaviors driven_onset = [] driven_offset = [] for cc, cue in enumerate(['plus', 'minus', 'neutral']): for c, stages in enumerate(cas.lookups.staging['parsed_11stage']): # skip naive when considering which cells are driven if stages in ['T0 naive', 'L0 naive']: continue # Onset cells on_cells = drive_df.loc[~drive_df.offset_cell & drive_df.driven] on_cells = on_cells.loc[on_cells.reset_index().parsed_11stage. isin([stages]).values, :] on_cells = on_cells.loc[ on_cells.reset_index().initial_cue.isin([cue]).values, :] # make sure you can't double count cells assert on_cells.groupby( ['mouse', 'cell_id']).nunique().gt(1).sum(axis=0).eq(0).all() id_vec = on_cells.cell_id.unique() driven_onset.extend(id_vec) # Offset cells off_cells = drive_df.loc[drive_df.offset_cell & drive_df.driven] off_cells = off_cells.loc[off_cells.reset_index( ).parsed_11stage.isin([stages]).values, :] off_cells = off_cells.loc[ off_cells.reset_index().initial_cue.isin([cue]).values, :] # make sure you can't double count cells assert off_cells.groupby( ['mouse', 'cell_id']).nunique().gt(1).sum(axis=0).eq(0).all() id_vec = off_cells.cell_id.unique() driven_offset.extend(id_vec) driven_onset = np.unique(driven_onset) driven_offset = np.unique(driven_offset) driven_on_bool = np.isin(np.array(ids), driven_onset) driven_off_bool = np.isin(np.array(ids), driven_offset) on_cells_in_order = np.array(ids)[driven_on_bool] off_cells_in_order = np.array(ids)[driven_off_bool] # keep track of mice and cell ids on_mouse_vec_flat.append([cas.utils.meta_mouse(meta)] * len(on_cells_in_order)) off_mouse_vec_flat.append([cas.utils.meta_mouse(meta)] * len(off_cells_in_order)) on_cell_vec.append(on_cells_in_order) off_cell_vec.append(off_cells_in_order) # ensure that cells are not counted in both onset and offset groups assert all(~np.isin(driven_onset, driven_offset)) # Onset cells ten_list = [] for cc, cue in enumerate( ['becomes_unrewarded', 'remains_unrewarded', 'becomes_rewarded']): invert_lookup = { v: k for k, v in cas.lookups.lookup_mm[cas.utils.meta_mouse( meta)].items() } ten_list.append(flat_tensors[invert_lookup[cue]][:, :, None]) new_on_tensor_unwrap = np.dstack(ten_list)[driven_on_bool, :, :] driven_on_tensor_flat.append(new_on_tensor_unwrap) # Offset cells ten_list = [] for cc, cue in enumerate( ['becomes_unrewarded', 'remains_unrewarded', 'becomes_rewarded']): invert_lookup = { v: k for k, v in cas.lookups.lookup_mm[cas.utils.meta_mouse( meta)].items() } ten_list.append(off_flat_tensors[invert_lookup[cue]][:, :, None]) new_off_tensor_unwrap = np.dstack(ten_list)[driven_off_bool, :, :] driven_off_tensor_flat.append(new_off_tensor_unwrap) # Onset cells with speed ten_list = [] for speedi in ['_fast', '_slow']: for cue in [ 'becomes_unrewarded', 'remains_unrewarded', 'becomes_rewarded' ]: invert_lookup = { v: k for k, v in cas.lookups.lookup_mm[cas.utils.meta_mouse( meta)].items() } ten_list.append(flat_tensors[invert_lookup[cue] + speedi][:, :, None]) new_onset_tensor_unwrap = np.dstack(ten_list)[driven_on_bool, :, :] driven_on_tensor_flat_run.append(new_onset_tensor_unwrap) # Offset cells with speed ten_list = [] for speedi in ['_fast', '_slow']: for cue in [ 'becomes_unrewarded', 'remains_unrewarded', 'becomes_rewarded' ]: invert_lookup = { v: k for k, v in cas.lookups.lookup_mm[cas.utils.meta_mouse( meta)].items() } ten_list.append(off_flat_tensors[invert_lookup[cue] + speedi][:, :, None]) new_offset_tensor_unwrap = np.dstack(ten_list)[driven_off_bool, :, :] driven_off_tensor_flat_run.append(new_offset_tensor_unwrap) # concatenate tensors on_mega_tensor_flat_run = np.vstack(driven_on_tensor_flat_run) off_mega_tensor_flat_run = np.vstack(driven_off_tensor_flat_run) on_mega_tensor_flat = np.vstack(driven_on_tensor_flat) off_mega_tensor_flat = np.vstack(driven_off_tensor_flat) # concatenate lists of mice and cells on_mega_mouse_flat = np.hstack(on_mouse_vec_flat) off_mega_mouse_flat = np.hstack(off_mouse_vec_flat) on_mega_cell_flat = np.hstack(on_cell_vec) off_mega_cell_flat = np.hstack(off_cell_vec) # make sure that you get the shapes you expected assert on_mega_tensor_flat_run.shape[0] == on_mega_tensor_flat.shape[0] assert off_mega_tensor_flat_run.shape[0] == off_mega_tensor_flat.shape[0] assert len(on_mega_mouse_flat) == len(on_mega_cell_flat) assert len(off_mega_mouse_flat) == len(off_mega_cell_flat) assert len(on_mega_cell_flat) == on_mega_tensor_flat.shape[0] assert len(off_mega_cell_flat) == off_mega_tensor_flat.shape[0] # Normalize per cell to max cell_max = np.nanmax(np.nanmax(off_mega_tensor_flat, axis=1), axis=1) off_mega_tensor_flat_norm = off_mega_tensor_flat / cell_max[:, None, None] cell_max = np.nanmax(np.nanmax(on_mega_tensor_flat, axis=1), axis=1) on_mega_tensor_flat_norm = on_mega_tensor_flat / cell_max[:, None, None] cell_max = np.nanmax(np.nanmax(off_mega_tensor_flat_run, axis=1), axis=1) off_mega_tensor_flat_run_norm = off_mega_tensor_flat_run / cell_max[:, None, None] cell_max = np.nanmax(np.nanmax(on_mega_tensor_flat_run, axis=1), axis=1) on_mega_tensor_flat_run_norm = on_mega_tensor_flat_run / cell_max[:, None, None] # initial restructure of data and save of input data # -------------------------------------------------------------------------------------------------- data_dict = { f'v{thresh}_on_nneg_nT0': on_mega_tensor_flat, f'v{thresh}_off_nneg_nT0': off_mega_tensor_flat, f'v{thresh}_norm_on_nneg_nT0': on_mega_tensor_flat_norm, f'v{thresh}_norm_off_nneg_nT0': off_mega_tensor_flat_norm, f'v{thresh}_speed_on_nneg_nT0': on_mega_tensor_flat_run, f'v{thresh}_speed_off_nneg_nT0': off_mega_tensor_flat_run, f'v{thresh}_speed_norm_on_nneg_nT0': on_mega_tensor_flat_run_norm, f'v{thresh}_speed_norm_off_nneg_nT0': off_mega_tensor_flat_run_norm } # add mouse and cell data to dict data_dict[f'v{thresh}_off_mouse_nneg_nT0'] = off_mega_mouse_flat data_dict[f'v{thresh}_on_mouse_nneg_nT0'] = on_mega_mouse_flat data_dict[f'v{thresh}_off_cell_nneg_nT0'] = off_mega_cell_flat data_dict[f'v{thresh}_on_cell_nneg_nT0'] = on_mega_cell_flat # remove cells with more negative than positive values (using max across cues) # this will basically only remove cells that are suppressed to all three cues # onset best_cue_resp = np.nanmax(data_dict[f'v{thresh}_norm_on_nneg_nT0'], axis=2) nneg = np.sum(best_cue_resp >= 0, axis=1) neg = np.sum(best_cue_resp < 0, axis=1) nneg_bool_on = np.greater(nneg, neg) if verbose: print(f'ONSET cells dropped for nneg: {np.sum(~nneg_bool_on)}') # offset best_cue_resp = np.nanmax(data_dict[f'v{thresh}_norm_off_nneg_nT0'], axis=2) nneg = np.sum(best_cue_resp >= 0, axis=1) neg = np.sum(best_cue_resp < 0, axis=1) nneg_bool_off = np.greater(nneg, neg) if verbose: print(f'OFFSET cells dropped for nneg: {np.sum(~nneg_bool_off)}') for k, v in data_dict.items(): if '_on_' in k: data_dict[k] = v[ nneg_bool_on] # index 0 dim and first dimension equivlaently elif '_off_' in k: data_dict[k] = v[nneg_bool_off] else: raise ValueError data_dict_path = cas.paths.analysis_file( f'input_data_v{thresh}_nneg_nT0_20210209.npy', 'tca_dfs') np.save(data_dict_path, data_dict, allow_pickle=True) if verbose: print(f'Input data saved to:\n\t{data_dict_path}') # run TCA and save # -------------------------------------------------------------------------------------------------- ensemble = {} for k, v in data_dict.items(): if '_mouse_' in k or '_cell_' in k: continue mask = ~np.isnan(v) fit_options['mask'] = mask if verbose: print(f'TCA starting: {k} --> n={v.shape[0]} cells') ensemble[k] = tt.Ensemble(fit_method=method, fit_options=deepcopy(fit_options)) ensemble[k].fit(v, ranks=tca_ranks, replicates=3, verbose=True) np.save(cas.paths.analysis_file( f'tca_ensemble_v{thresh}_nneg_nT0_20210209.npy', 'tca_dfs'), ensemble, allow_pickle=True) # plot and save relevant results # -------------------------------------------------------------------------------------------------- # sort cell factors sort_ensembles = {} sort_orders = {} for k, v in ensemble.items(): sort_ensembles[k], sort_orders[k] = cas.utils.sortfactors(v) # plot model performance for k, v in ensemble.items(): fig, ax = plt.subplots(2, 2, figsize=(10, 10), sharex='row', sharey='col') ax = ax.reshape([2, 2]) # full plot tt.visualization.plot_objective(v, ax=ax[0, 0], line_kw={'color': 'red'}, scatter_kw={ 'facecolor': 'black', 'alpha': 0.5 }) tt.visualization.plot_similarity(v, ax=ax[0, 1], line_kw={'color': 'blue'}, scatter_kw={ 'facecolor': 'black', 'alpha': 0.5 }) ax[0, 0].set_title( f'{k}: Objective function\ndata: cells x times-stages x cues') ax[0, 1].set_title( f'{k}: Model similarity\ndata: cells x times-stages x cues') ax[0, 0].axvline(-1, linestyle=':', color='grey') ax[0, 0].axvline(21, linestyle=':', color='grey') ax[0, 1].axvline(-1, linestyle=':', color='grey') ax[0, 1].axvline(21, linestyle=':', color='grey') # zoom in on 1-20 tt.visualization.plot_objective(v, ax=ax[1, 0], line_kw={'color': 'red'}, scatter_kw={ 'facecolor': 'black', 'alpha': 0.5 }) tt.visualization.plot_similarity(v, ax=ax[1, 1], line_kw={'color': 'blue'}, scatter_kw={ 'facecolor': 'black', 'alpha': 0.5 }) ax[1, 1].set_xlim([-1, 21]) ax[1, 0].set_title(f'Zoom: Objective function') ax[1, 1].set_title(f'Zoom: Model similarity') plt.savefig(cas.paths.analysis_file(f'{k}_nneg_nT0.png', 'tca_dfs/TCA_qc'), bbox_inches='tight') # plot factors for k, v in ensemble.items(): for rr in range(5, 20): fig, ax, _ = tt.visualization.plot_factors( v.results[rr][0].factors.rebalance(), plots=['scatter', 'line', 'line'], scatter_kw=cas.lookups.tt_plot_options['ncp_hals'] ['scatter_kw'], line_kw=cas.lookups.tt_plot_options['ncp_hals']['line_kw'], bar_kw=cas.lookups.tt_plot_options['ncp_hals']['bar_kw']) cell_count = v.results[rr][0].factors[0].shape[0] for i in range(ax.shape[0]): ax[i, 0].set_ylabel(f' Component {i+1}', size=16, ha='right', rotation=0) ax[0, 1].set_title(f'{k}, rank {rr} (n = {cell_count})\n\n', size=20) plt.savefig(cas.paths.analysis_file( f'{k}_rank{rr}_facs_nneg_nT0.png', f'tca_dfs/TCA_factors/{k}_nneg_nT0'), bbox_inches='tight') # plot heatmap for mod, mmod in zip([f'v{thresh}_norm_on', f'v{thresh}_norm_off'], [f'v{thresh}_on_mouse', f'v{thresh}_off_mouse']): mat2d_norm = data_dict[mod] mouse_dict = data_dict[mmod] mouse_mapper = {k: c for c, k in enumerate(np.unique(mouse_dict))} number_mouse_mat = np.array([mouse_mapper[s] for s in mouse_dict]) # ensemble sort ensort = sort_orders[mod][heatmap_rank - 1] clabel = 'normalized \u0394F/F' # clabel = '\u0394F/F (z-score)' #sort mat2d_norm = mat2d_norm[ensort, :] number_mouse_mat = number_mouse_mat[ensort] ax = [] fig = plt.figure(figsize=(30, 15)) gs = fig.add_gridspec(100, 110) ax.append(fig.add_subplot(gs[:, 3:5])) ax.append(fig.add_subplot(gs[:, 10:38])) ax.append(fig.add_subplot(gs[:, 40:68])) ax.append(fig.add_subplot(gs[:, 70:98])) ax.append(fig.add_subplot(gs[:30, 105:108])) # plot "categorical" heatmap using defined color mappings sns.heatmap(number_mouse_mat[:, None], cmap='Set2', ax=ax[0], cbar=False) ax[0].set_xticklabels(['mouse'], rotation=45, ha='right', size=18) ax[0].set_yticklabels([]) ax[0].set_ylabel('cell number', size=14) for i in range(1, 4): if i == 3: g = sns.heatmap(mat2d_norm[:, :, i - 1], ax=ax[i], center=0, vmax=1, vmin=-0.5, cmap='vlag', cbar_ax=ax[4], cbar_kws={'label': clabel}) cbar = g.collections[0].colorbar cbar.set_label(clabel, size=16) else: g = sns.heatmap(mat2d_norm[:, :, i - 1], ax=ax[i], center=0, vmax=1, vmin=-0.5, cmap='vlag', cbar=False) g.set_facecolor('#c5c5c5') ax[i].set_title(f'initial cue: {hue_order[i-1]}\n', size=20) stim_starts = [ 15.5 + 47 * s for s in np.arange(len(cas.lookups.staging['parsed_11stage'])) ] stim_labels = [ f'0\n\n{s}' if c % 2 == 0 else f'0\n{s}' for c, s in enumerate(cas.lookups.staging['parsed_11stage_T']) ] ax[i].set_xticks(stim_starts) ax[i].set_xticklabels(stim_labels, rotation=0) if i == 1: ax[i].set_ylabel('cell number', size=18) ax[i].set_xlabel('\ntime from stimulus onset (sec)', size=18) if i > 1: ax[i].set_yticks([]) plt.savefig(cas.paths.analysis_file( f'{mod}_rank{heatmap_rank}_heatmap_nneg_nT0.png', f'tca_dfs/TCA_heatmaps/v{thresh}_nneg_nT0'), bbox_inches='tight')
def run_unwrapped_tca(thresh=4, iteration=10, force=False, verbose=False, debug=False, skip_tca=False): """ Run script to build and save TCA inputs as well as TCA model outputs :param thresh: int -log10(p-value) threshold for calling a cell driven for any stage of learning. :param force: boolean Force a rerun even if output file exists. :param verbose: boolean View TCA progress in the terminal. """ # parameters # -------------------------------------------------------------------------------------------------- # model naming # mod_suffix = '_noT0_shuffle' mod_suffix = '_noT0_shuffle2' # mod_date = 20210215 mod_date = 20210307 # db in balanced mean # input data params mice = cas.lookups.mice['allFOV'] words = ['respondent' if s in 'OA27' else 'computation' for s in mice] group_by = 'all3' with_model = False nan_thresh = 0.95 # TCA params method = 'ncp_hals' replicates = iteration fit_options = {'tol': 0.0001, 'max_iter': 500, 'verbose': False} ranks = list(np.arange(1, 21, dtype=int)) ranks.extend([40]) ranks.extend([80]) tca_ranks = [int(s) for s in ranks] # plot params hue_order = ['becomes_unrewarded', 'remains_unrewarded', 'becomes_rewarded'] heatmap_rank = 12 # Do not overwrite existing files if os.path.isfile( cas.paths.analysis_file( f'tca_ensemble_v{thresh}i{iteration}{mod_suffix}_{mod_date}.npy', 'tca_dfs')) and not force: return # load in a full size data # -------------------------------------------------------------------------------------------------- tensor_list = [] id_list = [] bhv_list = [] meta_list = [] for mouse, word in zip(mice, words): # return ids, tensor, meta, bhv out = cas.load.load_all_groupday(mouse, word=word, with_model=with_model, group_by=group_by, nan_thresh=nan_thresh) tensor_list.append(out[2]) id_list.append(out[1]) bhv_list.append(out[4]) meta_list.append(cas.utils.add_stages_to_meta(out[3], 'parsed_11stage')) # load Oren's better offset classification # -------------------------------------------------------------------------------------------------- off_df_all_mice = pd.read_pickle('/twophoton_analysis/Data/analysis/core_dfs/offsets_dfs.pkl') df_list = [] for k, v in off_df_all_mice.items(): v['mouse'] = k df_list.append(v) updated_off_df = pd.concat(df_list, axis=0) updated_off_df = updated_off_df.set_index(['mouse', 'cell_id']) updated_off_df.head() # build input for TCA, selecting cells driven any stage with a given negative log10 p-value 'thresh' # -------------------------------------------------------------------------------------------------- # get data that has T0 naive removed data_dict = build_data_dict( meta_list, id_list, tensor_list, updated_off_df, thresh=thresh, iteration=iteration, debug=debug, mod_suffix=mod_suffix, add_suffix='' ) # get data that has all stages data_dict_allstages = build_data_dict( meta_list, id_list, tensor_list, updated_off_df, thresh=thresh, iteration=iteration, debug=debug, mod_suffix=mod_suffix, add_suffix='_allstages' ) data_dict.update(data_dict_allstages) # combine dicts # loop over keys and apply your T0 normalization to 'allstages' data pseudonorm_dict = renormalize_allstages_to_t0( data_dict, thresh=thresh, iteration=iteration, mod_suffix=mod_suffix, add_suffix='_allstages' ) data_dict.update(pseudonorm_dict) # filter out suppressed cells and cells with only a single stage of data (normed data used as benchmark) data_dict = filter_data_dict( data_dict, thresh=thresh, iteration=iteration, verbose=verbose, mod_suffix=mod_suffix ) # add a scaled-by-mouse zscore tensor data_dict = add_scaled_zscore_to_data_dict( data_dict, thresh=thresh, iteration=iteration, verbose=verbose, mod_suffix=mod_suffix ) # save data data_dict_path = cas.paths.analysis_file(f'input_data_v{thresh}i{iteration}{mod_suffix}_{mod_date}.npy', 'tca_dfs') np.save(data_dict_path, data_dict, allow_pickle=True) if verbose: print(f'Input data saved to:\n\t{data_dict_path}') # run TCA and save # -------------------------------------------------------------------------------------------------- if skip_tca: ensemble = np.load( cas.paths.analysis_file(f'tca_ensemble_v{thresh}i{iteration}{mod_suffix}_{mod_date}.npy', 'tca_dfs'), allow_pickle=True ).item() else: ensemble = {} # shuffle your timepoints unshuffle_dict = {} for k, v in data_dict.items(): if '_mouse_' in k or '_cell_' in k: continue if 'allstages' in k: continue shuffle_vec = np.arange(v.shape[1], dtype=int) np.random.shuffle(shuffle_vec) unshuffle_vec = np.argsort(shuffle_vec) unshuffle_dict[k] = unshuffle_vec v = v[:, shuffle_vec, :] mask = ~np.isnan(v) fit_options['mask'] = mask if verbose: print(f'TCA starting: {k} --> n={v.shape[0]} cells') ensemble[k] = tt.Ensemble(fit_method=method, fit_options=deepcopy(fit_options)) ensemble[k].fit(v, ranks=tca_ranks, replicates=replicates, verbose=True) np.save( cas.paths.analysis_file(f'tca_ensemble_v{thresh}i{iteration}{mod_suffix}_{mod_date}.npy', 'tca_dfs'), ensemble, allow_pickle=True ) np.save( cas.paths.analysis_file(f'unshuffler_v{thresh}i{iteration}{mod_suffix}_{mod_date}.npy', 'tca_dfs'), ensemble, allow_pickle=True ) # plot and save relevant results # -------------------------------------------------------------------------------------------------- # plot model performance for k, v in ensemble.items(): fig, ax = plt.subplots(2,2, figsize=(10,10), sharex='row', sharey='col') ax = ax.reshape([2,2]) # full plot tt.visualization.plot_objective(v, ax=ax[0,0], line_kw={'color': 'red'}, scatter_kw={'facecolor': 'black', 'alpha': 0.5}) tt.visualization.plot_similarity(v, ax=ax[0,1], line_kw={'color': 'blue'}, scatter_kw={'facecolor': 'black', 'alpha': 0.5}) ax[0, 0].set_title(f'{k}: Objective function\ndata: cells x times-stages x cues') ax[0, 1].set_title(f'{k}: Model similarity\ndata: cells x times-stages x cues') ax[0, 0].axvline(-1, linestyle=':', color='grey') ax[0, 0].axvline(21, linestyle=':', color='grey') ax[0, 1].axvline(-1, linestyle=':', color='grey') ax[0, 1].axvline(21, linestyle=':', color='grey') # zoom in on 1-20 tt.visualization.plot_objective(v, ax=ax[1,0], line_kw={'color': 'red'}, scatter_kw={'facecolor': 'black', 'alpha': 0.5}) tt.visualization.plot_similarity(v, ax=ax[1,1], line_kw={'color': 'blue'}, scatter_kw={'facecolor': 'black', 'alpha': 0.5}) ax[1, 1].set_xlim([-1, 21]) ax[1, 0].set_title(f'Zoom: Objective function') ax[1, 1].set_title(f'Zoom: Model similarity') plt.savefig(cas.paths.analysis_file(f'{k}_obj_sim.png', 'tca_dfs/TCA_qc'), bbox_inches='tight') # plot factors after sorting sort_ensembles, sort_orders = {}, {} for k, v in ensemble.items(): sort_ensembles[k], sort_orders[k] = cas.utils.sortfactors(v) for k, v in sort_ensembles.items(): unshuffle_vec = unshuffle_dict[k] for rr in range(5, 20): factors = v.results[rr][0].factors.rebalance() fig, ax, _ = tt.visualization.plot_factors(factors, plots=['scatter', 'line', 'line'], scatter_kw=cas.lookups.tt_plot_options['ncp_hals']['scatter_kw'], line_kw=cas.lookups.tt_plot_options['ncp_hals']['line_kw'], bar_kw=cas.lookups.tt_plot_options['ncp_hals']['bar_kw']); unshuffle_facs = deepcopy(factors) unshuffle_facs[1] = unshuffle_facs[1][unshuffle_vec, :] for ri in range(rr): ax[ri, 1].plot(unshuffle_facs[1][:, ri], color='blue', alpha=0.7) cell_count = v.results[rr][0].factors[0].shape[0] for i in range(ax.shape[0]): ax[i, 0].set_ylabel(f' Component {i+1}', size=16, ha='right', rotation=0) ax[0, 1].set_title(f'{k}, rank {rr} (n = {cell_count})\n\n', size=20) plt.savefig(cas.paths.analysis_file(f'{k}_rank{rr}_facs.png', f'tca_dfs/TCA_factors/{k}'), bbox_inches='tight') # plot heatmap mod_heat = [f'v{thresh}i{iteration}_norm_on{mod_suffix}', f'v{thresh}i{iteration}_norm_off{mod_suffix}'] mmod_heat = [f'v{thresh}i{iteration}_on_mouse{mod_suffix}', f'v{thresh}i{iteration}_off_mouse{mod_suffix}'] for mod, mmod in zip(mod_heat, mmod_heat): mat2d_norm = data_dict[mod] mouse_dict = data_dict[mmod] mouse_mapper = {k: c for c, k in enumerate(np.unique(mouse_dict))} number_mouse_mat = np.array([mouse_mapper[s] for s in mouse_dict]) # ensemble sort ensort = sort_orders[mod][heatmap_rank - 1] clabel = 'normalized \u0394F/F' # clabel = '\u0394F/F (z-score)' #sort mat2d_norm = mat2d_norm[ensort, :] number_mouse_mat = number_mouse_mat[ensort] ax = [] fig = plt.figure(figsize=(30, 15)) gs = fig.add_gridspec(100, 110) ax.append(fig.add_subplot(gs[:, 3:5])) ax.append(fig.add_subplot(gs[:, 10:38])) ax.append(fig.add_subplot(gs[:, 40:68])) ax.append(fig.add_subplot(gs[:, 70:98])) ax.append(fig.add_subplot(gs[:30, 105:108])) # plot "categorical" heatmap using defined color mappings sns.heatmap(number_mouse_mat[:, None], cmap='Set2', ax=ax[0], cbar=False) ax[0].set_xticklabels(['mouse'], rotation=45, ha='right', size=18) ax[0].set_yticklabels([]) ax[0].set_ylabel('cell number', size=14) for i in range(1,4): if i == 3: g = sns.heatmap(mat2d_norm[:,:,i-1], ax=ax[i], center=0, vmax=1, vmin=-0.5, cmap='vlag', cbar_ax=ax[4], cbar_kws={'label': clabel}) cbar = g.collections[0].colorbar cbar.set_label(clabel, size=16) else: g = sns.heatmap(mat2d_norm[:,:,i-1], ax=ax[i], center=0, vmax=1, vmin=-0.5, cmap='vlag', cbar=False) g.set_facecolor('#c5c5c5') ax[i].set_title(f'initial cue: {hue_order[i-1]}\n', size=20) stage_labels = cas.lookups.staging['parsed_11stage_T'][1:] stim_starts = [15.5 + 47*s for s in np.arange(len(stage_labels))] stim_labels = [f'0\n\n{s}' if c%2 == 0 else f'0\n{s}' for c, s in enumerate(stage_labels)] ax[i].set_xticks(stim_starts) ax[i].set_xticklabels(stim_labels, rotation=0) if i == 1: ax[i].set_ylabel('cell number', size=18) ax[i].set_xlabel('\ntime from stimulus onset (sec)', size=18) if i > 1: ax[i].set_yticks([]) plt.savefig( cas.paths.analysis_file(f'{mod}_rank{heatmap_rank}_heatmap.png', f'tca_dfs/TCA_heatmaps/v{thresh}i{iteration}{mod_suffix}'), bbox_inches='tight')