Beispiel #1
0
    def __init__(self,
                 min_num_seq,
                 max_num_seq,
                 signal_model=None,
                 fs=300,
                 k=2,
                 alp=None,
                 evidence_names=['LM', 'ERP'],
                 task_list=[('I_LOVE_COOKIES', 'I_LOVE_')],
                 lmodel=None,
                 is_txt_stim=True,
                 device_name='LSL',
                 device_channels=None,
                 stimuli_timing=[1, .2],
                 decision_threshold=0.8,
                 backspace_prob=0.05,
                 backspace_always_shown=False,
                 filter_high=45,
                 filter_low=2,
                 filter_order=2,
                 notch_filter_frequency=60):

        self.conjugator = EvidenceFusion(evidence_names, len_dist=len(alp))

        seq_constants = []
        if backspace_always_shown and BACKSPACE_CHAR in alp:
            seq_constants.append(BACKSPACE_CHAR)
        self.decision_maker = DecisionMaker(
            min_num_seq,
            max_num_seq,
            decision_threshold=decision_threshold,
            state=task_list[0][1],
            alphabet=alp,
            is_txt_stim=is_txt_stim,
            stimuli_timing=stimuli_timing,
            seq_constants=seq_constants)
        self.alp = alp
        # non-letter target labels include the fixation cross and calibration.
        self.nonletters = ['+', 'PLUS', 'calibration_trigger']
        self.valid_targets = set(self.alp)

        self.signal_model = signal_model
        self.sampling_rate = fs
        self.downsample_rate = k
        self.filter_high = filter_high
        self.filter_low = filter_low
        self.filter_order = filter_order
        self.notch_filter_frequency = notch_filter_frequency

        self.mode = 'copy_phrase'
        self.task_list = task_list
        self.lmodel = lmodel
        self.channel_map = analysis_channels(device_channels, device_name)
        self.backspace_prob = backspace_prob
Beispiel #2
0
def offline_analysis(data_folder: str = None,
                     parameters: dict = {},
                     alert_finished: bool = True):
    """ Gets calibration data and trains the model in an offline fashion.
        pickle dumps the model into a .pkl folder
        Args:
            data_folder(str): folder of the data
                save all information and load all from this folder
            parameter(dict): parameters for running offline analysis
            alert_finished(bool): whether or not to alert the user offline analysis complete

        How it Works:
        - reads data and information from a .csv calibration file
        - reads trigger information from a .txt trigger file
        - filters data
        - reshapes and labels the data for the training procedure
        - fits the model to the data
            - uses cross validation to select parameters
            - based on the parameters, trains system using all the data
        - pickle dumps model into .pkl file
        - generates and saves offline analysis screen
        - [optional] alert the user finished processing
    """

    if not data_folder:
        data_folder = load_experimental_data()

    mode = 'calibration'
    trial_length = parameters.get('collection_window_after_trial_length')

    raw_dat, _, channels, type_amp, fs = read_data_csv(
        data_folder + '/' + parameters.get('raw_data_name', 'raw_data.csv'))

    log.info(f'Channels read from csv: {channels}')
    log.info(f'Device type: {type_amp}')

    downsample_rate = parameters.get('down_sampling_rate', 2)

    # Remove 60hz noise with a notch filter
    notch_filter_data = notch.notch_filter(raw_dat, fs, frequency_to_remove=60)

    # bandpass filter from 2-45hz
    filtered_data = bandpass.butter_bandpass_filter(notch_filter_data,
                                                    2,
                                                    45,
                                                    fs,
                                                    order=2)

    # downsample
    data = downsample.downsample(filtered_data, factor=downsample_rate)

    # Process triggers.txt
    triggers_file = parameters.get('trigger_file_name', 'triggers.txt')
    _, t_t_i, t_i, offset = trigger_decoder(
        mode=mode, trigger_path=f'{data_folder}/{triggers_file}')

    static_offset = parameters.get('static_trigger_offset', 0)

    offset = offset + static_offset

    # Channel map can be checked from raw_data.csv file.
    # read_data_csv already removes the timespamp column.
    channel_map = analysis_channels(channels, type_amp)

    x, y, _, _ = trial_reshaper(t_t_i,
                                t_i,
                                data,
                                mode=mode,
                                fs=fs,
                                k=downsample_rate,
                                offset=offset,
                                channel_map=channel_map,
                                trial_length=trial_length)

    k_folds = parameters.get('k_folds', 10)

    model, auc = train_pca_rda_kde_model(x, y, k_folds=k_folds)

    log.info('Saving offline analysis plots!')

    # After obtaining the model get the transformed data for plotting purposes
    model.transform(x)
    generate_offline_analysis_screen(
        x,
        y,
        model=model,
        folder=data_folder,
        down_sample_rate=downsample_rate,
        fs=fs,
        save_figure=True,
        show_figure=False,
        channel_names=analysis_channel_names_by_pos(channels, channel_map))

    log.info('Saving the model!')
    with open(data_folder + f'/model_{auc}.pkl', 'wb') as output:
        pickle.dump(model, output)

    if alert_finished:
        offline_analysis_tone = parameters.get('offline_analysis_tone')
        play_sound(offline_analysis_tone)

    return model
    def __init__(self,
                 min_num_inq,
                 max_num_inq,
                 signal_model=None,
                 fs=300,
                 k=2,
                 alp=None,
                 evidence_names=['LM', 'ERP'],
                 task_list=[('I_LOVE_COOKIES', 'I_LOVE_')],
                 lmodel=None,
                 is_txt_stim=True,
                 device_name='LSL',
                 device_channels=None,
                 stimuli_timing=[1, .2],
                 decision_threshold=0.8,
                 backspace_prob=0.05,
                 backspace_always_shown=False,
                 filter_high=45,
                 filter_low=2,
                 filter_order=2,
                 notch_filter_frequency=60):

        self.conjugator = EvidenceFusion(evidence_names, len_dist=len(alp))

        inq_constants = []
        if backspace_always_shown and BACKSPACE_CHAR in alp:
            inq_constants.append(BACKSPACE_CHAR)

        # Stimuli Selection Module
        stopping_criteria = CriteriaEvaluator(
            continue_criteria=[MinIterationsCriteria(min_num_inq)],
            commit_criteria=[
                MaxIterationsCriteria(max_num_inq),
                ProbThresholdCriteria(decision_threshold)
            ])

        # TODO: Parametrize len_query in the future releases!
        stimuli_agent = NBestStimuliAgent(alphabet=alp, len_query=10)

        self.decision_maker = DecisionMaker(
            stimuli_agent=stimuli_agent,
            stopping_evaluator=stopping_criteria,
            state=task_list[0][1],
            alphabet=alp,
            is_txt_stim=is_txt_stim,
            stimuli_timing=stimuli_timing,
            inq_constants=inq_constants)

        self.alp = alp
        # non-letter target labels include the fixation cross and calibration.
        self.nonletters = ['+', 'PLUS', 'calibration_trigger']
        self.valid_targets = set(self.alp)

        self.signal_model = signal_model
        self.sampling_rate = fs
        self.downsample_rate = k
        self.filter_high = filter_high
        self.filter_low = filter_low
        self.filter_order = filter_order
        self.notch_filter_frequency = notch_filter_frequency

        self.mode = 'copy_phrase'
        self.task_list = task_list
        self.lmodel = lmodel
        self.channel_map = analysis_channels(device_channels, device_name)
        self.backspace_prob = backspace_prob
Beispiel #4
0
    def __init__(self, win, daq, parameters, file_save):
        super(RSVPInterInquiryFeedbackCalibration, self).__init__()
        self._task = RSVPCalibrationTask(win, daq, parameters, file_save)

        self.daq = daq
        self.fs = self.daq.device_info.fs
        self.alp = self._task.alp
        self.rsvp = self._task.rsvp
        self.parameters = parameters
        self.file_save = file_save
        self.enable_breaks = self._task.enable_breaks
        self.window = self._task.window
        self.stim_number = self._task.stim_number
        self.stim_length = self._task.stim_length
        self.is_txt_stim = self.rsvp.is_txt_stim
        self.stimuli_height = self._task.stimuli_height
        self.color = self._task.color
        self.timing = self._task.timing
        self.wait_screen_message = self._task.wait_screen_message
        self.wait_screen_message_color = self._task.wait_screen_message_color

        self.visual_feedback = LevelFeedback(display=self.window,
                                             parameters=self.parameters,
                                             clock=self._task.experiment_clock)

        self.static_offset = self.parameters['static_trigger_offset']
        self.nonletters = ['+', 'PLUS', 'calibration_trigger']
        self.valid_targets = set(self.alp)

        self.time_flash = self.parameters['time_flash']

        self.downsample_rate = self.parameters['down_sampling_rate']
        self.filtered_sampling_rate = self.fs / self.downsample_rate

        self.device_name = self.daq.device_info.name
        self.channel_map = analysis_channels(self.daq.device_info.channels,
                                             self.device_name)

        # EDIT ME FOR FEEDBACK CONFIGURATION

        self.feedback_buffer_time = self.parameters['feedback_buffer_time']
        self.feedback_line_color = self.parameters['feedback_line_color']

        self.psd_method = PSD_TYPE.WELCH

        # The channel used to calculate the PSD from RSVP inquiry.
        self.psd_channel_index = self.PSD_CHANNEL_INDEX

        # filter parameters
        self.filter_low = self.parameters['filter_low']
        self.filter_high = self.parameters['filter_high']
        self.filter_order = self.parameters['filter_order']
        self.notch_filter_frequency = self.parameters['notch_filter_frequency']

        # get the feedback band of interest
        self.psd_lower_limit = self.parameters['feedback_band_lower_limit']
        self.psd_upper_limit = self.parameters['feedback_band_upper_limit']

        # psd band of interest to use for feeback (low, high)
        self.psd_export_band = (self.psd_lower_limit, self.psd_upper_limit)

        # length of time to use for PSD calculation
        self.trial_length = self.time_flash * self.stim_length

        self.lvl_5_threshold = self.parameters['feedback_level_5_threshold']
        self.lvl_4_threshold = self.parameters['feedback_level_4_threshold']
        self.lvl_3_threshold = self.parameters['feedback_level_3_threshold']
        self.lvl_2_threshold = self.parameters['feedback_level_2_threshold']

        # true/false order is desceding from 5 -> 1 for level
        self.feedback_descending = self.parameters['feedback_level_descending']