Ejemplo n.º 1
0
    def _get_data_for_psd(self, inquiry_timing):
        # get data from the DAQ
        raw_data, triggers, target_info = process_data_for_decision(
            inquiry_timing,
            self.daq,
            self.window,
            self.parameters,
            self.rsvp.first_stim_time,
            self.static_offset,
            buf_length=self.trial_length)

        # filter it
        notch_filterted_data = notch.notch_filter(raw_data, self.fs,
                                                  self.notch_filter_frequency)
        filtered_data = bandpass.butter_bandpass_filter(
            notch_filterted_data,
            self.filter_low,
            self.filter_high,
            self.fs,
            order=self.filter_order)
        data = downsample.downsample(filtered_data,
                                     factor=self.downsample_rate)
        letters, times, target_info = self.letter_info(triggers, target_info)

        # reshape with the filtered data with our desired window length
        reshaped_data, _, _, _ = trial_reshaper(target_info,
                                                times,
                                                data,
                                                fs=self.fs,
                                                k=self.downsample_rate,
                                                mode='calibration',
                                                channel_map=self.channel_map,
                                                trial_length=self.trial_length)
        return reshaped_data
Ejemplo n.º 2
0
    def test_trial_reshaper_calibration(self):
        (reshaped_trials, labels,
         num_of_seq, trials_per_sequence) = trial_reshaper(
            trial_target_info=self.target_info,
            timing_info=self.timing_info, eeg_data=self.eeg,
            fs=256, k=2, mode='calibration', channel_map=self.channel_map)

        self.assertTrue(
            len(reshaped_trials) == self.channel_number,
            f'len is {len(reshaped_trials)} not {self.channel_number}')
        self.assertEqual(len(reshaped_trials[0]), 3)
        self.assertEqual(num_of_seq, 1)
        self.assertEqual(trials_per_sequence, 3)
Ejemplo n.º 3
0
    def evaluate_sequence(self, raw_data, triggers, target_info,
                          window_length):
        """Once data is collected, infers meaning from the data.

        Args:
            raw_data(ndarray[float]): C x L eeg data where C is number of
                channels and L is the signal length
            triggers(list[tuple(str,float)]): triggers e.g. ('A', 1)
                as letter and flash time for the letter
            target_info(list[str]): target information about the stimuli
            window_length(int): The length of the time between stimuli presentation
        """
        letters, times, target_info = self.letter_info(triggers, target_info)

        # Remove 60hz noise with a notch filter
        notch_filter_data = notch.notch_filter(
            raw_data,
            self.sampling_rate,
            frequency_to_remove=self.notch_filter_frequency)

        # bandpass filter from 2-45hz
        filtered_data = bandpass.butter_bandpass_filter(
            notch_filter_data,
            self.filter_low,
            self.filter_high,
            self.sampling_rate,
            order=self.filter_order)

        # downsample
        data = downsample.downsample(filtered_data,
                                     factor=self.downsample_rate)
        x, _, _, _ = trial_reshaper(target_info,
                                    times,
                                    data,
                                    fs=self.sampling_rate,
                                    k=self.downsample_rate,
                                    mode=self.mode,
                                    channel_map=self.channel_map,
                                    trial_length=window_length)

        lik_r = inference(x, letters, self.signal_model, self.alp)
        prob = self.conjugator.update_and_fuse({'ERP': lik_r})
        decision, arg = self.decision_maker.decide(prob)

        if 'stimuli' in arg:
            sti = arg['stimuli']
        else:
            sti = None

        return decision, sti
Ejemplo n.º 4
0
def offline_analysis(data_folder: str = None,
                     parameters: dict = {},
                     alert_finished: bool = True):
    """ Gets calibration data and trains the model in an offline fashion.
        pickle dumps the model into a .pkl folder
        Args:
            data_folder(str): folder of the data
                save all information and load all from this folder
            parameter(dict): parameters for running offline analysis
            alert_finished(bool): whether or not to alert the user offline analysis complete

        How it Works:
        - reads data and information from a .csv calibration file
        - reads trigger information from a .txt trigger file
        - filters data
        - reshapes and labels the data for the training procedure
        - fits the model to the data
            - uses cross validation to select parameters
            - based on the parameters, trains system using all the data
        - pickle dumps model into .pkl file
        - generates and saves offline analysis screen
        - [optional] alert the user finished processing
    """

    if not data_folder:
        data_folder = load_experimental_data()

    mode = 'calibration'
    trial_length = parameters.get('collection_window_after_trial_length')

    raw_dat, _, channels, type_amp, fs = read_data_csv(
        data_folder + '/' + parameters.get('raw_data_name', 'raw_data.csv'))

    log.info(f'Channels read from csv: {channels}')
    log.info(f'Device type: {type_amp}')

    downsample_rate = parameters.get('down_sampling_rate', 2)

    # Remove 60hz noise with a notch filter
    notch_filter_data = notch.notch_filter(raw_dat, fs, frequency_to_remove=60)

    # bandpass filter from 2-45hz
    filtered_data = bandpass.butter_bandpass_filter(notch_filter_data,
                                                    2,
                                                    45,
                                                    fs,
                                                    order=2)

    # downsample
    data = downsample.downsample(filtered_data, factor=downsample_rate)

    # Process triggers.txt
    triggers_file = parameters.get('trigger_file_name', 'triggers.txt')
    _, t_t_i, t_i, offset = trigger_decoder(
        mode=mode, trigger_path=f'{data_folder}/{triggers_file}')

    static_offset = parameters.get('static_trigger_offset', 0)

    offset = offset + static_offset

    # Channel map can be checked from raw_data.csv file.
    # read_data_csv already removes the timespamp column.
    channel_map = analysis_channels(channels, type_amp)

    x, y, _, _ = trial_reshaper(t_t_i,
                                t_i,
                                data,
                                mode=mode,
                                fs=fs,
                                k=downsample_rate,
                                offset=offset,
                                channel_map=channel_map,
                                trial_length=trial_length)

    k_folds = parameters.get('k_folds', 10)

    model, auc = train_pca_rda_kde_model(x, y, k_folds=k_folds)

    log.info('Saving offline analysis plots!')

    # After obtaining the model get the transformed data for plotting purposes
    model.transform(x)
    generate_offline_analysis_screen(
        x,
        y,
        model=model,
        folder=data_folder,
        down_sample_rate=downsample_rate,
        fs=fs,
        save_figure=True,
        show_figure=False,
        channel_names=analysis_channel_names_by_pos(channels, channel_map))

    log.info('Saving the model!')
    with open(data_folder + f'/model_{auc}.pkl', 'wb') as output:
        pickle.dump(model, output)

    if alert_finished:
        offline_analysis_tone = parameters.get('offline_analysis_tone')
        play_sound(offline_analysis_tone)

    return model