Ejemplo n.º 1
0
    def get_rolling_performance_df(self):
        # Indices to build trial metrics dataframe:
        trials_index = self.trials.index
        not_aborted_index = \
            self.trials[np.logical_not(self.trials.aborted)].index

        # Initialize dataframe:
        performance_metrics_df = pd.DataFrame(index=trials_index)

        # Reward rate:
        performance_metrics_df['reward_rate'] = \
            pd.Series(self.get_reward_rate(), index=self.trials.index)

        # Hit rate raw:
        hit_rate_raw = get_hit_rate(hit=self.trials.hit,
                                    miss=self.trials.miss,
                                    aborted=self.trials.aborted)
        performance_metrics_df['hit_rate_raw'] = \
            pd.Series(hit_rate_raw, index=not_aborted_index)

        # Hit rate with trial count correction:
        hit_rate = get_trial_count_corrected_hit_rate(
            hit=self.trials.hit,
            miss=self.trials.miss,
            aborted=self.trials.aborted)
        performance_metrics_df['hit_rate'] = \
            pd.Series(hit_rate, index=not_aborted_index)

        # False-alarm rate raw:
        false_alarm_rate_raw = \
            get_false_alarm_rate(
                    false_alarm=self.trials.false_alarm,
                    correct_reject=self.trials.correct_reject,
                    aborted=self.trials.aborted)
        performance_metrics_df['false_alarm_rate_raw'] = \
            pd.Series(false_alarm_rate_raw, index=not_aborted_index)

        # False-alarm rate with trial count correction:
        false_alarm_rate = \
            get_trial_count_corrected_false_alarm_rate(
                    false_alarm=self.trials.false_alarm,
                    correct_reject=self.trials.correct_reject,
                    aborted=self.trials.aborted)
        performance_metrics_df['false_alarm_rate'] = \
            pd.Series(false_alarm_rate, index=not_aborted_index)

        # Rolling-dprime:
        rolling_dprime = get_rolling_dprime(hit_rate, false_alarm_rate)
        performance_metrics_df['rolling_dprime'] = \
            pd.Series(rolling_dprime, index=not_aborted_index)

        return performance_metrics_df
Ejemplo n.º 2
0
def test_rolling_dprime_integration(mock_rolling_dprime_fixture):
    sliding_window = 100

    hit = mock_rolling_dprime_fixture.hit
    miss = mock_rolling_dprime_fixture.miss
    false_alarm = mock_rolling_dprime_fixture.false_alarm
    correct_reject = mock_rolling_dprime_fixture.correct_reject
    aborted = mock_rolling_dprime_fixture.aborted

    hr = get_trial_count_corrected_hit_rate(hit=hit,
                                            miss=miss,
                                            aborted=aborted,
                                            sliding_window=sliding_window)
    cr = get_trial_count_corrected_false_alarm_rate(
        false_alarm=false_alarm,
        correct_reject=correct_reject,
        aborted=aborted,
        sliding_window=sliding_window)
    dprime = get_rolling_dprime(hr, cr)

    assert dprime[2] == 0.6744897501960817