def sce_twitches(ms, before_extension, after_extension): spike_nums_dur = ms.spike_struct.spike_nums_dur spike_nums = ms.spike_struct.spike_nums n_cells_with_spikes = len(np.where(np.sum(spike_nums, axis=1) >= 1)[0]) n_cells = spike_nums_dur.shape[0] # using twitch periods instead of SCE if info is available "shift_twitch" sce_times_bool = ms.shift_data_dict["shift_twitch"] n_frames = len(sce_times_bool) # period_extension extension_frames_after = after_extension extension_frames_before = before_extension shift_bool_tmp = np.copy(sce_times_bool) true_frames = np.where(sce_times_bool)[0] for frame in true_frames: first_frame = max(0, frame - extension_frames_before) last_frame = min(n_frames - 1, frame + extension_frames_after) shift_bool_tmp[first_frame:last_frame] = True sce_times_bool[shift_bool_tmp] = True SCE_times = get_continous_time_periods(sce_times_bool.astype("int8")) sce_times_numbers = np.ones(len(sce_times_bool), dtype="int16") sce_times_numbers *= -1 cells_in_twitches = np.zeros((n_cells, len(SCE_times)), dtype="int16") for index, period in enumerate(SCE_times): sce_times_numbers[period[0]:period[1] + 1] = index cells_in_twitches[:, index] = np.sum( spike_nums_dur[:, period[0]:period[1] + 1], axis=1) cells_in_twitches[cells_in_twitches[:, index] > 0, index] = 1 return cells_in_twitches, n_cells_with_spikes, SCE_times
def update_periods(self): for period_name, period in self.periods.items(): if np.sum(period) == 0: self.periods_as_tuples[period_name] = [] else: self.periods_as_tuples[ period_name] = get_continous_time_periods( period.astype("int8"))
def build_spike_nums_and_peak_nums(spike_nums_dur): n_cells, n_frames = spike_nums_dur.shape spike_nums = np.zeros((n_cells, n_frames), dtype="int8") peak_nums = np.zeros((n_cells, n_frames), dtype="int8") for cell in np.arange(n_cells): transient_periods = get_continous_time_periods(spike_nums_dur[cell]) for transient_period in transient_periods: onset = transient_period[0] peak = transient_period[1] # if onset == peak: # print("onset == peak") spike_nums[cell, onset] = 1 peak_nums[cell, peak] = 1 return spike_nums, peak_nums
def build_spike_nums_and_peak_nums(self): if self.spike_nums_dur is None: return n_cells = len(self.spike_nums_dur) n_frames = self.spike_nums_dur.shape[1] ms = self.mouse_session self.spike_nums = np.zeros((n_cells, n_frames), dtype="int8") self.peak_nums = np.zeros((n_cells, n_frames), dtype="int8") for cell in np.arange(n_cells): transient_periods = get_continous_time_periods( self.spike_nums_dur[cell]) for transient_period in transient_periods: onset = transient_period[0] peak = transient_period[1] # if onset == peak: # print("onset == peak") self.spike_nums[cell, onset] = 1 self.peak_nums[cell, peak] = 1
def compute_stats(spike_nums_dur, predicted_spike_nums_dur, traces, gt_predictions, with_threshold=None): """ Compute the stats based on raster dur :param spike_nums_dur: should not be "uint" dtype :param predicted_spike_nums_dur: should not be "uint" dtype :return: two dicts: first one with stats on frames, the other one with stats on transients Frames dict has the following keys (as String): TP: True Positive FP: False Positive FN: False Negative TN: True Negative sensitivity or TPR: True Positive Rate or Recall specificity or TNR: True Negative Rate or Selectivity FPR: False Positive Rate or Fall-out FNR: False Negative Rate or Miss Rate ACC: accuracy Prevalence: sum positive conditions / total population (for frames only) PPV: Positive Predictive Value or Precision FDR: False Discovery Rate FOR: False Omission Rate NPV: Negative Predictive Value LR+: Positive Likelihood ratio LR-: Negative likelihood ratio transients dict has just the following keys: TP FN sensitivity or TPR: True Positive Rate or Recall FNR: False Negative Rate or Miss Rate """ if spike_nums_dur.shape != predicted_spike_nums_dur.shape: raise Exception("both spike_nums_dur should have the same shape") if len(spike_nums_dur.shape) == 1: # we transform them in a 2 dimensions array spike_nums_dur = spike_nums_dur.reshape(1, spike_nums_dur.shape[0]) predicted_spike_nums_dur = predicted_spike_nums_dur.reshape( 1, predicted_spike_nums_dur.shape[0]) if gt_predictions is not None: gt_predictions = gt_predictions.reshape(1, gt_predictions.shape[0]) if traces is not None: traces = traces.reshape(1, traces.shape[0]) frames_stat = dict() transients_stat = dict() n_frames = spike_nums_dur.shape[1] n_cells = spike_nums_dur.shape[0] # full raster dur represents the raster dur built from all potential onsets and peaks full_raster_dur = None if traces is not None: full_raster_dur = get_raster_dur_from_traces( traces, with_threshold=with_threshold) # positive means active frame, negative means non-active frames # condition is the ground truth # predicted is the one computed (RNN, CaiMan etc...) tp_frames = 0 fp_frames = 0 fn_frames = 0 tn_frames = 0 tp_transients = 0 fn_transients = 0 fp_transients = 0 tn_transients = 0 # will keep values of the predictions of the FN and FP transients and frames # for transients will take median predicted value over the transient fn_transients_predictions = [] fp_transients_predictions = [] tp_transients_predictions = [] tn_transients_predictions = [] fn_frames_predictions = [] fp_frames_predictions = [] tn_frames_predictions = [] tp_frames_predictions = [] proportion_of_frames_detected_in_transients = [] for cell in np.arange(n_cells): raster_dur = spike_nums_dur[cell] predicted_raster_dur = predicted_spike_nums_dur[cell] if gt_predictions is not None: gt_predictions_for_cell = gt_predictions[cell] else: gt_predictions_for_cell = None predicted_positive_frames = np.where(predicted_raster_dur)[0] predicted_negative_frames = np.where(predicted_raster_dur == 0)[0] tp_frames += len( np.where(raster_dur[predicted_positive_frames] == 1)[0]) fp_frames += len( np.where(raster_dur[predicted_positive_frames] == 0)[0]) fn_frames += len( np.where(raster_dur[predicted_negative_frames] == 1)[0]) tn_frames += len( np.where(raster_dur[predicted_negative_frames] == 0)[0]) if gt_predictions is not None: fp_frames_indices = np.where( raster_dur[gt_predictions_for_cell >= 0.5] == 0)[0] fn_frames_indices = np.where( raster_dur[gt_predictions_for_cell < 0.5] == 1)[0] tp_frames_indices = np.where( raster_dur[gt_predictions_for_cell >= 0.5] == 1)[0] tn_frames_indices = np.where( raster_dur[gt_predictions_for_cell < 0.5] == 0)[0] fp_frames_predictions.extend( list(gt_predictions_for_cell[fp_frames_indices])) fn_frames_predictions.extend( list(gt_predictions_for_cell[fn_frames_indices])) tp_frames_predictions.extend( list(gt_predictions_for_cell[tp_frames_indices])) tn_frames_predictions.extend( list(gt_predictions_for_cell[tn_frames_indices])) n_fake_transients = 0 # transients section transient_periods = get_continous_time_periods(raster_dur) if full_raster_dur is not None: full_transient_periods = get_continous_time_periods( full_raster_dur[cell]) fake_transients_periods = [] # keeping only the fake ones for transient_period in full_transient_periods: if np.sum(raster_dur[transient_period[0]:transient_period[1] + 1]) > 0: # it means it's a real transient continue fake_transients_periods.append(transient_period) n_fake_transients = len(fake_transients_periods) # positive condition n_transients = len(transient_periods) tp = 0 for transient_period in transient_periods: frames = np.arange(transient_period[0], transient_period[1] + 1) if np.sum(predicted_raster_dur[frames]) > 0: tp += 1 # keeping only transients with one frame detected proportion_of_frames_detected_in_transients.append( (np.sum(predicted_raster_dur[frames]) / len(frames)) * 100) if gt_predictions is not None: if np.max(gt_predictions_for_cell[frames]) >= 0.5: # adding the median of the predicted frames in the transient tp_transients_predictions.append( np.max(gt_predictions_for_cell[frames])) else: # then if's a FN fn_transients_predictions.append( np.max(gt_predictions_for_cell[frames])) tn = 0 if full_raster_dur is not None: for transient_period in fake_transients_periods: frames = np.arange(transient_period[0], transient_period[1] + 1) if np.sum(predicted_raster_dur[frames]) == 0: tn += 1 if gt_predictions is not None: if np.max(gt_predictions_for_cell[frames]) < 0.5: # adding the max of the predicted frames in the transient # by max we know how far we are from 0.5 tn_transients_predictions.append( np.max(gt_predictions_for_cell[frames])) else: # then if's a FP # taking the max because we want to see how far are we from 0.5 fp_transients_predictions.append( np.max(gt_predictions_for_cell[frames])) tp_transients += tp fn_transients += (n_transients - tp) tn_transients += tn fp_transients += (n_fake_transients - tn) frames_stat["TP"] = tp_frames frames_stat["FP"] = fp_frames frames_stat["FN"] = fn_frames frames_stat["TN"] = tn_frames # frames_stat["TPR"] = tp_frames / (tp_frames + fn_frames) if (tp_frames + fn_frames) > 0: frames_stat["sensitivity"] = tp_frames / (tp_frames + fn_frames) else: frames_stat["sensitivity"] = 1 frames_stat["TPR"] = frames_stat["sensitivity"] if (tn_frames + fp_frames) > 0: frames_stat["specificity"] = tn_frames / (tn_frames + fp_frames) else: frames_stat["specificity"] = 1 frames_stat["TNR"] = frames_stat["specificity"] if (tp_frames + tn_frames + fp_frames + fn_frames) > 0: frames_stat["ACC"] = (tp_frames + tn_frames) / (tp_frames + tn_frames + fp_frames + fn_frames) else: frames_stat["ACC"] = 1 if (tp_frames + fp_frames) > 0: frames_stat["PPV"] = tp_frames / (tp_frames + fp_frames) else: frames_stat["PPV"] = 1 if (tn_frames + fn_frames) > 0: frames_stat["NPV"] = tn_frames / (tn_frames + fn_frames) else: frames_stat["NPV"] = 1 frames_stat["FNR"] = 1 - frames_stat["TPR"] frames_stat["FPR"] = 1 - frames_stat["TNR"] if "PPV" in frames_stat: frames_stat["FDR"] = 1 - frames_stat["PPV"] if "NPV" in frames_stat: frames_stat["FOR"] = 1 - frames_stat["NPV"] if frames_stat["FPR"] > 0: frames_stat["LR+"] = frames_stat["TPR"] / frames_stat["FPR"] else: frames_stat["LR+"] = 1 if frames_stat["TNR"] > 0: frames_stat["LR-"] = frames_stat["FNR"] / frames_stat["TNR"] else: frames_stat["LR-"] = 1 # transients dict transients_stat["TP"] = tp_transients transients_stat["FN"] = fn_transients if traces is not None: # print(f"tn_transients {tn_transients}") transients_stat["TN"] = tn_transients transients_stat["FP"] = fp_transients if (tp_transients + fn_transients) > 0: transients_stat["sensitivity"] = tp_transients / (tp_transients + fn_transients) else: transients_stat["sensitivity"] = 1 # print(f'transients_stat["sensitivity"] {transients_stat["sensitivity"]}') transients_stat["TPR"] = transients_stat["sensitivity"] if traces is not None: if (tn_transients + fp_transients) > 0: transients_stat["specificity"] = tn_transients / (tn_transients + fp_transients) else: transients_stat["specificity"] = 1 transients_stat["TNR"] = transients_stat["specificity"] if (tp_transients + tn_transients + fp_transients + fn_transients) > 0: transients_stat["ACC"] = (tp_transients + tn_transients) / \ (tp_transients + tn_transients + fp_transients + fn_transients) else: transients_stat["ACC"] = 1 if (tp_transients + fp_transients) > 0: transients_stat["PPV"] = tp_transients / (tp_transients + fp_transients) else: transients_stat["PPV"] = 1 if (tn_transients + fn_transients) > 0: transients_stat["NPV"] = tn_transients / (tn_transients + fn_transients) else: transients_stat["NPV"] = 1 transients_stat["FNR"] = 1 - transients_stat["TPR"] predictions_stat = dict() predictions_stat["fn_transients_predictions"] = fn_transients_predictions predictions_stat["fp_transients_predictions"] = fp_transients_predictions predictions_stat["tp_transients_predictions"] = tp_transients_predictions predictions_stat["tn_transients_predictions"] = tn_transients_predictions predictions_stat["fn_frames_predictions"] = fn_frames_predictions predictions_stat["fp_frames_predictions"] = fp_frames_predictions predictions_stat["tn_frames_predictions"] = tn_frames_predictions predictions_stat["tp_frames_predictions"] = tp_frames_predictions return frames_stat, transients_stat, predictions_stat, proportion_of_frames_detected_in_transients
def detect_sce_with_sliding_window(spike_nums, window_duration, perc_threshold=95, with_refractory_period=-1, non_binary=False, activity_threshold=None, debug_mode=False, no_redundancy=False, keep_only_the_peak=False): """ Use a sliding window to detect sce (define as peak of activity > perc_threshold percentile after randomisation during a time corresponding to window_duration) :param spike_nums: 2D array, lines=cells, columns=time :param window_duration: :param perc_threshold: :param no_redundancy: if True, then when using the sliding window, a second spike of a cell is not taking into consideration when looking for a new SCE :param keep_only_the_peak: keep only the frame with the maximum cells co-activating :return: ** one array (mask, boolean) containing True for indices (times) part of an SCE, ** a list of tuple corresponding to the first and last index of each SCE, (last index being included in the SCE) ** sce_nums: a new spike_nums with in x axis the SCE and in y axis the neurons, with 1 if active during a given SCE. ** an array of len n_times, that for each times give the SCE number or -1 if part of no cluster ** activity_threshold """ if non_binary: binary_spikes = np.zeros((len(spike_nums), len(spike_nums[0, :])), dtype="int8") for neuron, spikes in enumerate(spike_nums): binary_spikes[neuron, spikes > 0] = 1 spike_nums = binary_spikes if activity_threshold is None: activity_threshold = get_sce_detection_threshold( spike_nums=spike_nums, n_surrogate=1000, window_duration=window_duration, perc_threshold=perc_threshold, non_binary=False) n_cells = len(spike_nums) n_times = len(spike_nums[0, :]) if window_duration == 1: # using a diff method sum_spike_nums = np.sum(spike_nums, axis=0) binary_sum = np.zeros(n_times, dtype="int8") binary_sum[sum_spike_nums >= activity_threshold] = 1 sce_tuples = get_continous_time_periods(binary_sum) if keep_only_the_peak: new_sce_tuples = [] for sce_index, sce_tuple in enumerate(sce_tuples): index_max = np.argmax( np.sum(spike_nums[:, sce_tuple[0]:sce_tuple[1] + 1], axis=0)) new_sce_tuples.append( (sce_tuple[0] + index_max, sce_tuple[0] + index_max)) sce_tuples = new_sce_tuples sce_bool = np.zeros(n_times, dtype="bool") sce_times_numbers = np.ones(n_times, dtype="int16") sce_times_numbers *= -1 for sce_index, sce_tuple in enumerate(sce_tuples): sce_bool[sce_tuple[0]:sce_tuple[1] + 1] = True sce_times_numbers[sce_tuple[0]:sce_tuple[1] + 1] = sce_index else: start_sce = -1 # keep a trace of which cells have been added to an SCE cells_in_sce_so_far = np.zeros(n_cells, dtype="bool") sce_bool = np.zeros(n_times, dtype="bool") sce_tuples = [] sce_times_numbers = np.ones(n_times, dtype="int16") sce_times_numbers *= -1 if debug_mode: print(f"n_times {n_times}") for t in np.arange(0, (n_times - window_duration)): if debug_mode: if t % 10**6 == 0: print(f"t {t}") cells_has_been_removed_due_to_redundancy = False sum_value_test = np.sum(spike_nums[:, t:(t + window_duration)]) sum_spikes = np.sum(spike_nums[:, t:(t + window_duration)], axis=1) pos_cells = np.where(sum_spikes)[0] # neurons with sum > 1 are active during a SCE sum_value = len(pos_cells) if no_redundancy and (start_sce > -1): # removing from the count the cell that are in the previous SCE nb_cells_already_in_sce = np.sum( cells_in_sce_so_far[pos_cells]) sum_value -= nb_cells_already_in_sce if nb_cells_already_in_sce > 0: cells_has_been_removed_due_to_redundancy = True # print(f"Sum value, test {sum_value_test}, rxeal {sum_value}") if sum_value >= activity_threshold: if start_sce == -1: start_sce = t if no_redundancy: # keeping only cells spiking at time t, as we're gonna shift of one on the next step sum_spikes = np.sum(spike_nums[:, t]) pos_cells = np.where(sum_spikes)[0] cells_in_sce_so_far[pos_cells] = True else: if no_redundancy: # updating which cells are already in the SCE # keeping only cells spiking at time t, as we're gonna shift of one on the next step sum_spikes = np.sum(spike_nums[:, t]) pos_cells = np.where(sum_spikes)[0] cells_in_sce_so_far[pos_cells] = True else: pass else: if start_sce > -1: if keep_only_the_peak: index_max = np.argmax( spike_nums[:, start_sce:(t + window_duration) - 1]) sce_tuples.append((sce_tuple[0] + index_max, sce_tuple[0] + index_max)) sce_bool[sce_tuple[0] + index_max] = True # sce_tuples.append((start_sce, t-1)) sce_times_numbers[sce_tuple[0] + index_max] = len(sce_tuples) - 1 else: # then a new SCE is detected sce_bool[start_sce:(t + window_duration) - 1] = True sce_tuples.append( (start_sce, (t + window_duration) - 2)) # sce_tuples.append((start_sce, t-1)) sce_times_numbers[start_sce:(t + window_duration) - 1] = len(sce_tuples) - 1 start_sce = -1 cells_in_sce_so_far = np.zeros(n_cells, dtype="bool") if no_redundancy and cells_has_been_removed_due_to_redundancy: sum_value += nb_cells_already_in_sce if sum_value >= activity_threshold: # then a new SCE start right after the old one start_sce = t cells_in_sce_so_far = np.zeros(n_cells, dtype="bool") if no_redundancy: # keeping only cells spiking at time t, as we're gonna shift of one on the next step sum_spikes = np.sum(spike_nums[:, t]) pos_cells = np.where(sum_spikes)[0] cells_in_sce_so_far[pos_cells] = True n_sces = len(sce_tuples) sce_nums = np.zeros((n_cells, n_sces), dtype="int16") for sce_index, sce_tuple in enumerate(sce_tuples): sum_spikes = np.sum(spike_nums[:, sce_tuple[0]:(sce_tuple[1] + 1)], axis=1) # neurons with sum > 1 are active during a SCE active_cells = np.where(sum_spikes)[0] sce_nums[active_cells, sce_index] = 1 # print(f"number of sce {len(sce_tuples)}") return sce_bool, sce_tuples, sce_nums, sce_times_numbers, activity_threshold