示例#1
0
def sce_twitches(ms, before_extension, after_extension):
    spike_nums_dur = ms.spike_struct.spike_nums_dur
    spike_nums = ms.spike_struct.spike_nums
    n_cells_with_spikes = len(np.where(np.sum(spike_nums, axis=1) >= 1)[0])
    n_cells = spike_nums_dur.shape[0]
    # using twitch periods instead of SCE if info is available "shift_twitch"
    sce_times_bool = ms.shift_data_dict["shift_twitch"]

    n_frames = len(sce_times_bool)

    # period_extension
    extension_frames_after = after_extension
    extension_frames_before = before_extension
    shift_bool_tmp = np.copy(sce_times_bool)
    true_frames = np.where(sce_times_bool)[0]
    for frame in true_frames:
        first_frame = max(0, frame - extension_frames_before)
        last_frame = min(n_frames - 1, frame + extension_frames_after)
        shift_bool_tmp[first_frame:last_frame] = True
    sce_times_bool[shift_bool_tmp] = True
    SCE_times = get_continous_time_periods(sce_times_bool.astype("int8"))
    sce_times_numbers = np.ones(len(sce_times_bool), dtype="int16")
    sce_times_numbers *= -1
    cells_in_twitches = np.zeros((n_cells, len(SCE_times)), dtype="int16")
    for index, period in enumerate(SCE_times):
        sce_times_numbers[period[0]:period[1] + 1] = index
        cells_in_twitches[:, index] = np.sum(
            spike_nums_dur[:, period[0]:period[1] + 1], axis=1)
        cells_in_twitches[cells_in_twitches[:, index] > 0, index] = 1

    return cells_in_twitches, n_cells_with_spikes, SCE_times
示例#2
0
 def update_periods(self):
     for period_name, period in self.periods.items():
         if np.sum(period) == 0:
             self.periods_as_tuples[period_name] = []
         else:
             self.periods_as_tuples[
                 period_name] = get_continous_time_periods(
                     period.astype("int8"))
示例#3
0
def build_spike_nums_and_peak_nums(spike_nums_dur):
    n_cells, n_frames = spike_nums_dur.shape
    spike_nums = np.zeros((n_cells, n_frames), dtype="int8")
    peak_nums = np.zeros((n_cells, n_frames), dtype="int8")
    for cell in np.arange(n_cells):
        transient_periods = get_continous_time_periods(spike_nums_dur[cell])
        for transient_period in transient_periods:
            onset = transient_period[0]
            peak = transient_period[1]
            # if onset == peak:
            #     print("onset == peak")
            spike_nums[cell, onset] = 1
            peak_nums[cell, peak] = 1
    return spike_nums, peak_nums
    def build_spike_nums_and_peak_nums(self):
        if self.spike_nums_dur is None:
            return

        n_cells = len(self.spike_nums_dur)
        n_frames = self.spike_nums_dur.shape[1]
        ms = self.mouse_session
        self.spike_nums = np.zeros((n_cells, n_frames), dtype="int8")
        self.peak_nums = np.zeros((n_cells, n_frames), dtype="int8")
        for cell in np.arange(n_cells):
            transient_periods = get_continous_time_periods(
                self.spike_nums_dur[cell])
            for transient_period in transient_periods:
                onset = transient_period[0]
                peak = transient_period[1]
                # if onset == peak:
                #     print("onset == peak")
                self.spike_nums[cell, onset] = 1
                self.peak_nums[cell, peak] = 1
def compute_stats(spike_nums_dur,
                  predicted_spike_nums_dur,
                  traces,
                  gt_predictions,
                  with_threshold=None):
    """
    Compute the stats based on raster dur
    :param spike_nums_dur: should not be "uint" dtype
    :param predicted_spike_nums_dur: should not be "uint" dtype
    :return: two dicts: first one with stats on frames, the other one with stats on transients
    Frames dict has the following keys (as String):
    TP: True Positive
    FP: False Positive
    FN: False Negative
    TN: True Negative
    sensitivity or TPR: True Positive Rate or Recall
    specificity or TNR: True Negative Rate or Selectivity
    FPR: False Positive Rate or Fall-out
    FNR: False Negative Rate or Miss Rate
    ACC: accuracy
    Prevalence: sum positive conditions / total population (for frames only)
    PPV: Positive Predictive Value or Precision
    FDR: False Discovery Rate
    FOR: False Omission Rate
    NPV: Negative Predictive Value
    LR+: Positive Likelihood ratio
    LR-: Negative likelihood ratio

    transients dict has just the following keys:
    TP
    FN
    sensitivity or TPR: True Positive Rate or Recall
    FNR: False Negative Rate or Miss Rate
    """

    if spike_nums_dur.shape != predicted_spike_nums_dur.shape:
        raise Exception("both spike_nums_dur should have the same shape")
    if len(spike_nums_dur.shape) == 1:
        # we transform them in a 2 dimensions array
        spike_nums_dur = spike_nums_dur.reshape(1, spike_nums_dur.shape[0])
        predicted_spike_nums_dur = predicted_spike_nums_dur.reshape(
            1, predicted_spike_nums_dur.shape[0])
        if gt_predictions is not None:
            gt_predictions = gt_predictions.reshape(1, gt_predictions.shape[0])
        if traces is not None:
            traces = traces.reshape(1, traces.shape[0])

    frames_stat = dict()
    transients_stat = dict()

    n_frames = spike_nums_dur.shape[1]
    n_cells = spike_nums_dur.shape[0]

    # full raster dur represents the raster dur built from all potential onsets and peaks
    full_raster_dur = None
    if traces is not None:
        full_raster_dur = get_raster_dur_from_traces(
            traces, with_threshold=with_threshold)

    # positive means active frame, negative means non-active frames
    # condition is the ground truth
    # predicted is the one computed (RNN, CaiMan etc...)

    tp_frames = 0
    fp_frames = 0
    fn_frames = 0
    tn_frames = 0

    tp_transients = 0
    fn_transients = 0
    fp_transients = 0
    tn_transients = 0

    # will keep values of the predictions of the FN and FP transients and frames
    # for transients will take median predicted value over the transient
    fn_transients_predictions = []
    fp_transients_predictions = []
    tp_transients_predictions = []
    tn_transients_predictions = []
    fn_frames_predictions = []
    fp_frames_predictions = []
    tn_frames_predictions = []
    tp_frames_predictions = []

    proportion_of_frames_detected_in_transients = []

    for cell in np.arange(n_cells):
        raster_dur = spike_nums_dur[cell]
        predicted_raster_dur = predicted_spike_nums_dur[cell]
        if gt_predictions is not None:
            gt_predictions_for_cell = gt_predictions[cell]
        else:
            gt_predictions_for_cell = None
        predicted_positive_frames = np.where(predicted_raster_dur)[0]
        predicted_negative_frames = np.where(predicted_raster_dur == 0)[0]

        tp_frames += len(
            np.where(raster_dur[predicted_positive_frames] == 1)[0])
        fp_frames += len(
            np.where(raster_dur[predicted_positive_frames] == 0)[0])
        fn_frames += len(
            np.where(raster_dur[predicted_negative_frames] == 1)[0])
        tn_frames += len(
            np.where(raster_dur[predicted_negative_frames] == 0)[0])

        if gt_predictions is not None:
            fp_frames_indices = np.where(
                raster_dur[gt_predictions_for_cell >= 0.5] == 0)[0]
            fn_frames_indices = np.where(
                raster_dur[gt_predictions_for_cell < 0.5] == 1)[0]
            tp_frames_indices = np.where(
                raster_dur[gt_predictions_for_cell >= 0.5] == 1)[0]
            tn_frames_indices = np.where(
                raster_dur[gt_predictions_for_cell < 0.5] == 0)[0]
            fp_frames_predictions.extend(
                list(gt_predictions_for_cell[fp_frames_indices]))
            fn_frames_predictions.extend(
                list(gt_predictions_for_cell[fn_frames_indices]))
            tp_frames_predictions.extend(
                list(gt_predictions_for_cell[tp_frames_indices]))
            tn_frames_predictions.extend(
                list(gt_predictions_for_cell[tn_frames_indices]))

        n_fake_transients = 0
        # transients section
        transient_periods = get_continous_time_periods(raster_dur)
        if full_raster_dur is not None:
            full_transient_periods = get_continous_time_periods(
                full_raster_dur[cell])
            fake_transients_periods = []
            # keeping only the fake ones
            for transient_period in full_transient_periods:
                if np.sum(raster_dur[transient_period[0]:transient_period[1] +
                                     1]) > 0:
                    # it means it's a real transient
                    continue
                fake_transients_periods.append(transient_period)

            n_fake_transients = len(fake_transients_periods)
        # positive condition
        n_transients = len(transient_periods)
        tp = 0
        for transient_period in transient_periods:
            frames = np.arange(transient_period[0], transient_period[1] + 1)
            if np.sum(predicted_raster_dur[frames]) > 0:
                tp += 1
                # keeping only transients with one frame detected
                proportion_of_frames_detected_in_transients.append(
                    (np.sum(predicted_raster_dur[frames]) / len(frames)) * 100)

            if gt_predictions is not None:
                if np.max(gt_predictions_for_cell[frames]) >= 0.5:
                    # adding the median of the predicted frames in the transient
                    tp_transients_predictions.append(
                        np.max(gt_predictions_for_cell[frames]))
                else:
                    # then if's a FN
                    fn_transients_predictions.append(
                        np.max(gt_predictions_for_cell[frames]))
        tn = 0
        if full_raster_dur is not None:
            for transient_period in fake_transients_periods:
                frames = np.arange(transient_period[0],
                                   transient_period[1] + 1)
                if np.sum(predicted_raster_dur[frames]) == 0:
                    tn += 1
                if gt_predictions is not None:
                    if np.max(gt_predictions_for_cell[frames]) < 0.5:
                        # adding the max of the predicted frames in the transient
                        # by max we know how far we are from 0.5
                        tn_transients_predictions.append(
                            np.max(gt_predictions_for_cell[frames]))
                    else:
                        # then if's a FP
                        # taking the max because we want to see how far are we from 0.5
                        fp_transients_predictions.append(
                            np.max(gt_predictions_for_cell[frames]))

        tp_transients += tp
        fn_transients += (n_transients - tp)
        tn_transients += tn
        fp_transients += (n_fake_transients - tn)

    frames_stat["TP"] = tp_frames
    frames_stat["FP"] = fp_frames
    frames_stat["FN"] = fn_frames
    frames_stat["TN"] = tn_frames

    # frames_stat["TPR"] = tp_frames / (tp_frames + fn_frames)
    if (tp_frames + fn_frames) > 0:
        frames_stat["sensitivity"] = tp_frames / (tp_frames + fn_frames)
    else:
        frames_stat["sensitivity"] = 1
    frames_stat["TPR"] = frames_stat["sensitivity"]

    if (tn_frames + fp_frames) > 0:
        frames_stat["specificity"] = tn_frames / (tn_frames + fp_frames)
    else:
        frames_stat["specificity"] = 1
    frames_stat["TNR"] = frames_stat["specificity"]

    if (tp_frames + tn_frames + fp_frames + fn_frames) > 0:
        frames_stat["ACC"] = (tp_frames + tn_frames) / (tp_frames + tn_frames +
                                                        fp_frames + fn_frames)
    else:
        frames_stat["ACC"] = 1

    if (tp_frames + fp_frames) > 0:
        frames_stat["PPV"] = tp_frames / (tp_frames + fp_frames)
    else:
        frames_stat["PPV"] = 1
    if (tn_frames + fn_frames) > 0:
        frames_stat["NPV"] = tn_frames / (tn_frames + fn_frames)
    else:
        frames_stat["NPV"] = 1

    frames_stat["FNR"] = 1 - frames_stat["TPR"]

    frames_stat["FPR"] = 1 - frames_stat["TNR"]

    if "PPV" in frames_stat:
        frames_stat["FDR"] = 1 - frames_stat["PPV"]

    if "NPV" in frames_stat:
        frames_stat["FOR"] = 1 - frames_stat["NPV"]

    if frames_stat["FPR"] > 0:
        frames_stat["LR+"] = frames_stat["TPR"] / frames_stat["FPR"]
    else:
        frames_stat["LR+"] = 1

    if frames_stat["TNR"] > 0:
        frames_stat["LR-"] = frames_stat["FNR"] / frames_stat["TNR"]
    else:
        frames_stat["LR-"] = 1

    # transients dict
    transients_stat["TP"] = tp_transients
    transients_stat["FN"] = fn_transients
    if traces is not None:
        # print(f"tn_transients {tn_transients}")
        transients_stat["TN"] = tn_transients
        transients_stat["FP"] = fp_transients

    if (tp_transients + fn_transients) > 0:
        transients_stat["sensitivity"] = tp_transients / (tp_transients +
                                                          fn_transients)
    else:
        transients_stat["sensitivity"] = 1

    # print(f'transients_stat["sensitivity"] {transients_stat["sensitivity"]}')
    transients_stat["TPR"] = transients_stat["sensitivity"]

    if traces is not None:
        if (tn_transients + fp_transients) > 0:
            transients_stat["specificity"] = tn_transients / (tn_transients +
                                                              fp_transients)
        else:
            transients_stat["specificity"] = 1
        transients_stat["TNR"] = transients_stat["specificity"]

        if (tp_transients + tn_transients + fp_transients + fn_transients) > 0:
            transients_stat["ACC"] = (tp_transients + tn_transients) / \
                                 (tp_transients + tn_transients + fp_transients + fn_transients)
        else:
            transients_stat["ACC"] = 1

        if (tp_transients + fp_transients) > 0:
            transients_stat["PPV"] = tp_transients / (tp_transients +
                                                      fp_transients)
        else:
            transients_stat["PPV"] = 1
        if (tn_transients + fn_transients) > 0:
            transients_stat["NPV"] = tn_transients / (tn_transients +
                                                      fn_transients)
        else:
            transients_stat["NPV"] = 1

    transients_stat["FNR"] = 1 - transients_stat["TPR"]

    predictions_stat = dict()
    predictions_stat["fn_transients_predictions"] = fn_transients_predictions
    predictions_stat["fp_transients_predictions"] = fp_transients_predictions
    predictions_stat["tp_transients_predictions"] = tp_transients_predictions
    predictions_stat["tn_transients_predictions"] = tn_transients_predictions
    predictions_stat["fn_frames_predictions"] = fn_frames_predictions
    predictions_stat["fp_frames_predictions"] = fp_frames_predictions
    predictions_stat["tn_frames_predictions"] = tn_frames_predictions
    predictions_stat["tp_frames_predictions"] = tp_frames_predictions

    return frames_stat, transients_stat, predictions_stat, proportion_of_frames_detected_in_transients
def detect_sce_with_sliding_window(spike_nums,
                                   window_duration,
                                   perc_threshold=95,
                                   with_refractory_period=-1,
                                   non_binary=False,
                                   activity_threshold=None,
                                   debug_mode=False,
                                   no_redundancy=False,
                                   keep_only_the_peak=False):
    """
    Use a sliding window to detect sce (define as peak of activity > perc_threshold percentile after
    randomisation during a time corresponding to window_duration)
    :param spike_nums: 2D array, lines=cells, columns=time
    :param window_duration:
    :param perc_threshold:
    :param no_redundancy: if True, then when using the sliding window, a second spike of a cell is not taking into
    consideration when looking for a new SCE
    :param keep_only_the_peak: keep only the frame with the maximum cells co-activating
    :return: ** one array (mask, boolean) containing True for indices (times) part of an SCE,
    ** a list of tuple corresponding to the first and last index of each SCE, (last index being included in the SCE)
    ** sce_nums: a new spike_nums with in x axis the SCE and in y axis the neurons, with 1 if
    active during a given SCE.
    ** an array of len n_times, that for each times give the SCE number or -1 if part of no cluster
    ** activity_threshold

    """

    if non_binary:
        binary_spikes = np.zeros((len(spike_nums), len(spike_nums[0, :])),
                                 dtype="int8")
        for neuron, spikes in enumerate(spike_nums):
            binary_spikes[neuron, spikes > 0] = 1
        spike_nums = binary_spikes

    if activity_threshold is None:
        activity_threshold = get_sce_detection_threshold(
            spike_nums=spike_nums,
            n_surrogate=1000,
            window_duration=window_duration,
            perc_threshold=perc_threshold,
            non_binary=False)

    n_cells = len(spike_nums)
    n_times = len(spike_nums[0, :])

    if window_duration == 1:
        # using a diff method
        sum_spike_nums = np.sum(spike_nums, axis=0)
        binary_sum = np.zeros(n_times, dtype="int8")
        binary_sum[sum_spike_nums >= activity_threshold] = 1
        sce_tuples = get_continous_time_periods(binary_sum)
        if keep_only_the_peak:
            new_sce_tuples = []
            for sce_index, sce_tuple in enumerate(sce_tuples):
                index_max = np.argmax(
                    np.sum(spike_nums[:, sce_tuple[0]:sce_tuple[1] + 1],
                           axis=0))
                new_sce_tuples.append(
                    (sce_tuple[0] + index_max, sce_tuple[0] + index_max))
            sce_tuples = new_sce_tuples
        sce_bool = np.zeros(n_times, dtype="bool")
        sce_times_numbers = np.ones(n_times, dtype="int16")
        sce_times_numbers *= -1
        for sce_index, sce_tuple in enumerate(sce_tuples):
            sce_bool[sce_tuple[0]:sce_tuple[1] + 1] = True
            sce_times_numbers[sce_tuple[0]:sce_tuple[1] + 1] = sce_index

    else:
        start_sce = -1
        # keep a trace of which cells have been added to an SCE
        cells_in_sce_so_far = np.zeros(n_cells, dtype="bool")
        sce_bool = np.zeros(n_times, dtype="bool")
        sce_tuples = []
        sce_times_numbers = np.ones(n_times, dtype="int16")
        sce_times_numbers *= -1
        if debug_mode:
            print(f"n_times {n_times}")
        for t in np.arange(0, (n_times - window_duration)):
            if debug_mode:
                if t % 10**6 == 0:
                    print(f"t {t}")
            cells_has_been_removed_due_to_redundancy = False
            sum_value_test = np.sum(spike_nums[:, t:(t + window_duration)])
            sum_spikes = np.sum(spike_nums[:, t:(t + window_duration)], axis=1)
            pos_cells = np.where(sum_spikes)[0]
            # neurons with sum > 1 are active during a SCE
            sum_value = len(pos_cells)
            if no_redundancy and (start_sce > -1):
                # removing from the count the cell that are in the previous SCE
                nb_cells_already_in_sce = np.sum(
                    cells_in_sce_so_far[pos_cells])
                sum_value -= nb_cells_already_in_sce
                if nb_cells_already_in_sce > 0:
                    cells_has_been_removed_due_to_redundancy = True
            # print(f"Sum value, test {sum_value_test}, rxeal {sum_value}")
            if sum_value >= activity_threshold:
                if start_sce == -1:
                    start_sce = t
                    if no_redundancy:
                        # keeping only cells spiking at time t, as we're gonna shift of one on the next step
                        sum_spikes = np.sum(spike_nums[:, t])
                        pos_cells = np.where(sum_spikes)[0]
                        cells_in_sce_so_far[pos_cells] = True
                else:
                    if no_redundancy:
                        # updating which cells are already in the SCE
                        # keeping only cells spiking at time t, as we're gonna shift of one on the next step
                        sum_spikes = np.sum(spike_nums[:, t])
                        pos_cells = np.where(sum_spikes)[0]
                        cells_in_sce_so_far[pos_cells] = True
                    else:
                        pass
            else:
                if start_sce > -1:
                    if keep_only_the_peak:
                        index_max = np.argmax(
                            spike_nums[:, start_sce:(t + window_duration) - 1])
                        sce_tuples.append((sce_tuple[0] + index_max,
                                           sce_tuple[0] + index_max))
                        sce_bool[sce_tuple[0] + index_max] = True
                        # sce_tuples.append((start_sce, t-1))
                        sce_times_numbers[sce_tuple[0] +
                                          index_max] = len(sce_tuples) - 1
                    else:
                        # then a new SCE is detected
                        sce_bool[start_sce:(t + window_duration) - 1] = True
                        sce_tuples.append(
                            (start_sce, (t + window_duration) - 2))
                        # sce_tuples.append((start_sce, t-1))
                        sce_times_numbers[start_sce:(t + window_duration) -
                                          1] = len(sce_tuples) - 1

                    start_sce = -1
                    cells_in_sce_so_far = np.zeros(n_cells, dtype="bool")
                if no_redundancy and cells_has_been_removed_due_to_redundancy:
                    sum_value += nb_cells_already_in_sce
                    if sum_value >= activity_threshold:
                        # then a new SCE start right after the old one
                        start_sce = t
                        cells_in_sce_so_far = np.zeros(n_cells, dtype="bool")
                        if no_redundancy:
                            # keeping only cells spiking at time t, as we're gonna shift of one on the next step
                            sum_spikes = np.sum(spike_nums[:, t])
                            pos_cells = np.where(sum_spikes)[0]
                            cells_in_sce_so_far[pos_cells] = True

    n_sces = len(sce_tuples)
    sce_nums = np.zeros((n_cells, n_sces), dtype="int16")
    for sce_index, sce_tuple in enumerate(sce_tuples):
        sum_spikes = np.sum(spike_nums[:, sce_tuple[0]:(sce_tuple[1] + 1)],
                            axis=1)
        # neurons with sum > 1 are active during a SCE
        active_cells = np.where(sum_spikes)[0]
        sce_nums[active_cells, sce_index] = 1

    # print(f"number of sce {len(sce_tuples)}")

    return sce_bool, sce_tuples, sce_nums, sce_times_numbers, activity_threshold