コード例 #1
0
    def test_getSingleWaves(self):
        eegData = FileUtil().getDto(self.getData32CSV())
        eeg = eegData.getColumn("F3")
        nEeg = len(eeg)
        samplingRate = eegData.getSamplingRate()
        waves = self.util.getWaves(eeg, samplingRate)

        delta = self.util.getDeltaWaves(eeg, samplingRate)
        self.assertEqual(len(delta), nEeg)
        self.assertTrue(all([x in waves["delta"] for x in delta]))

        theta = self.util.getThetaWaves(eeg, samplingRate)
        self.assertEqual(len(theta), nEeg)
        self.assertTrue(all([x in waves["theta"] for x in theta]))

        alpha = self.util.getAlphaWaves(eeg, samplingRate)
        self.assertEqual(len(alpha), nEeg)
        self.assertTrue(all([x in waves["alpha"] for x in alpha]))

        beta = self.util.getBetaWaves(eeg, samplingRate)
        self.assertEqual(len(beta), nEeg)
        self.assertTrue(all([x in waves["alpha"] for x in alpha]))

        gamma = self.util.getGammaWaves(eeg, samplingRate)
        self.assertEqual(len(gamma), nEeg)
        self.assertTrue(all([x in waves["gamma"] for x in gamma]))
コード例 #2
0
def getSplit(proband, fileName):
    fu = FileUtil()
    filePath = "%s%s/%s" % (experimentDir, proband, fileName)
    dto = fu.getDto(filePath)
    s1, s2, s3, s4 = _getStartStopPercent(dto)
    awake = fu.getPartialDto(dto, s1, s2)
    drowsy = fu.getPartialDto(dto, s3, s4)
    return [awake, drowsy]
コード例 #3
0
    def test_getWaves(self):
        eegData = FileUtil().getDto(self.getData32CSV())
        eeg = eegData.getColumn("F3")
        nEeg = len(eeg)
        waves = self.util.getWaves(eeg, eegData.getSamplingRate())

        self.assertEqual(len(waves), 5)
        for _, wave in waves.iteritems():
            self.assertEqual(len(wave), nEeg)
コード例 #4
0
ファイル: feature_plotter.py プロジェクト: ppasler/PoSDBoS
def plot(proband, filename):
    experiments = ConfigProvider().getExperimentConfig()
    experimentDir = experiments["filePath"]
    #filePath = "%s/test/%s" % (experimentDir, "awake_full.csv")
    filePath = "%s/%s/%s" % (experimentDir, proband, filename)

    dto = FileUtil().getDto(filePath)
    fp = FeaturePlotter(dto.getData(), dto.getHeader(), filePath)
    fp.doPlot()
コード例 #5
0
 def __init__(self, filePath=None, infinite=True):
     '''
     Reads data from filePath or ./../../data/dummy_4096.csv and builds the data structure
     '''
     self.filePath = self.setFilePath(filePath)
     self.infinite = infinite
     self.fileUtil = FileUtil()
     self.hasMore = False
     self.data = []
     self.len = 0
     self.index = 0
コード例 #6
0
    def __init__(self, dataUrls, maxFps):
        super(DataWidget, self).__init__()
        self.fileUtil = FileUtil()

        self._initData(dataUrls[0])
        self.maxFps = maxFps
        self.curSecond = 0
        self._initPlot()

        layout = QtGui.QVBoxLayout(self)
        layout.addWidget(self.toolbar)
        layout.addWidget(self.canvas)

        self.setObjectName("datawidget")
コード例 #7
0
    def _get():
        self = Factory

        posdbos = self._initPoSDBoS(True)
        posdbos.collectedQueue = Queue()
        posdbos.extractedQueue = Queue()

        posdbos.fileUtil = FileUtil()
        return posdbos
コード例 #8
0
    def __init__(self, queue, filePath, signals=None, save=True, plot=True, logScale=False, name=""):
        self.queue = queue
        self.filePath = filePath
        self._initStatsDict()
        self.config = ConfigProvider()
        self.eegData = FileUtil().getDto(filePath)
        self._initSignals(signals)
        self.su = SignalUtil()
        self.qu = QualityUtil()
        self.eu = EEGUtil()
        self.fft = FFTUtil()
        self._initFields()
        self.save = save
        self.plot = plot
        self.name = name
        self._initPlotter(logScale)
        self.ssPrint = SignalStatisticPrinter(filePath)
        self.preProcessor = SignalPreProcessor()
        self.processor = SignalProcessor()

        windowSeconds = self.config.getCollectorConfig().get("windowSeconds")
        self.windowSize = EEGDataCollector.calcWindowSize(windowSeconds, self.eegData.samplingRate)
コード例 #9
0
class TestFileUtil(BaseTest):
    def setUp(self):
        self.util = FileUtil()
        self.dto = self._readData()
        self.mneObj = self._getMNEObject(self.dto)

    def _getMNEObject(self, dto):
        return MNEUtil().createMNEObjectFromEEGDto(dto)

    def _readData(self):
        return self.util.getDto(self.getData1024CSV())

    def test_isCSVFile(self):
        self.assertTrue(self.util.isCSVFile("path/to/sth.csv"))

        self.assertFalse(self.util.isCSVFile("path/.csv/sth.txt"))
        self.assertFalse(self.util.isCSVFile("path/to/sth.raw.fif"))
        self.assertFalse(self.util.isCSVFile("path/to/sth.ica.fif"))

    def test_getDto(self):
        dto = self.util.getDto(self.getData1024CSV())
        dto2 = self.util.getDto(self.getData1024FIF())

        self.assertAlmostEqual(dto.samplingRate, dto2.samplingRate, delta=0.1)
        assert_array_equal(dto.getEEGHeader(), dto2.getEEGHeader())
        assert_array_equal(dto.getEEGData(), dto2.getEEGData())

    def test_getDto_dtoinput(self):
        dto = self.util.getDto(self.getData1024CSV())
        dto2 = self.util.getDto(dto)

        self.assertTrue(dto is dto2)
        self.assertAlmostEqual(dto.samplingRate, dto2.samplingRate, delta=0.1)
        assert_array_equal(dto.getEEGHeader(), dto2.getEEGHeader())
        assert_array_equal(dto.getEEGData(), dto2.getEEGData())

    def test_convertMNEToTableDto(self):
        dto2 = self.util.convertMNEToTableDto(self.mneObj)

        self.assertListEqual([TIMESTAMP_STRING] + self.dto.getEEGHeader() +
                             self.dto.getGyroHeader(), dto2.getHeader())
        assert_array_equal(self.dto.getEEGData(), dto2.getEEGData())
        self.assertEqual(self.dto.filePath, dto2.filePath)
        self.assertTrue(dto2.hasEEGData)

    def test_getMNEFileName_given(self):
        filePath = "path/to/sth"
        mneFilePath = self.util.getMNEFileName(self.mneObj, filePath)
        self.assertEqual(filePath + ".raw.fif", mneFilePath)

    def test_getMNEFileName_givenCSV(self):
        filePath = "path/to/sth"
        mneFilePath = self.util.getMNEFileName(self.mneObj, filePath + ".csv")
        self.assertEqual(filePath + ".raw.fif", mneFilePath)

    def test_getMNEFileName_givenCSVmiddle(self):
        filePath = "path/.csv/sth"
        mneFilePath = self.util.getMNEFileName(self.mneObj, filePath)
        self.assertEqual(filePath + ".raw.fif", mneFilePath)

    def test_getMNEFileName_extension(self):
        filePath = "path/to/sth"
        self.mneObj.info["description"] = filePath
        mneFilePath = self.util.getMNEFileName(self.mneObj, None)
        self.assertEqual(filePath + ".raw.fif", mneFilePath)

    def test_getMNEFileName_CSV(self):
        filePath = "path/to/sth"
        self.mneObj.info["description"] = filePath + ".csv"
        mneFilePath = self.util.getMNEFileName(self.mneObj, None)
        self.assertEqual(filePath + ".raw.fif", mneFilePath)

    def test_getPartialDto(self):
        end = len(self.dto) / 2
        copyDto = self.util.getPartialDto(self.dto, 0, end)

        self.assertFalse(self.dto is copyDto)
        self.assertFalse(self.dto.header is copyDto.header)
        assert_array_equal(self.dto.header, copyDto.header)
        self.assertTrue(self.dto.filePath == copyDto.filePath)
        self.assertTrue(self.dto.samplingRate == copyDto.samplingRate)

        partData = copyDto.data
        fullData = self.dto.data

        self.assertFalse(fullData is partData)
        assert_array_equal(fullData[0:end, :], partData)
        self.assertEqual(len(partData), end)
        self.assertTrue(fullData.shape[0] > partData.shape[0])
        self.assertTrue(fullData.shape[1] == partData.shape[1])
コード例 #10
0
 def getEEGSignal(self):
     return FileUtil().getDto("data/example_1024.csv")
コード例 #11
0
 def __init__(self):
     self.mneUtil = MNEUtil()
     self.fileUtil = FileUtil()
     self.eogChans = [2]
     self.templateICA = self.fileUtil.loadICA(TEMPLATE_ICA_PATH +
                                              "blink_.ica.fif")
コード例 #12
0
class EOGExtractor(object):
    '''
    Class to extract EOG signal from EEG
    '''
    def __init__(self):
        self.mneUtil = MNEUtil()
        self.fileUtil = FileUtil()
        self.eogChans = [2]
        self.templateICA = self.fileUtil.loadICA(TEMPLATE_ICA_PATH +
                                                 "blink_.ica.fif")
        #self._plot()#logging.info("load ICA ", "template", self.templateICA.get_components().shape)

    def _plot(self):
        self.templateRaw = self.fileUtil.load(TEMPLATE_ICA_PATH +
                                              "blink_.raw.fif")
        self.templateICA.plot_components(show=False)
        self.templateICA.plot_sources(self.templateRaw, show=False)
        plt_show()

    def labelEOGChannel(self, icas):
        for eogChan in self.eogChans:
            self.mneUtil.labelArtefact(self.templateICA, eogChan, icas,
                                       BLINK_LABEL)

    def getEOGChannel(self, raw, ica, eogInds=None):
        eogInds = self._getEOGIndex(ica, eogInds)

        eog = raw.copy()
        eog = ica.get_sources(eog)

        eogChan = mean(raw._data[eogInds], axis=0)

        dropNames = self._createDropNames(ica, [0])
        eog.drop_channels(dropNames)

        raw._data[0] = eogChan
        nameDict = {self._getICAName(0): "EOG"}
        eog.rename_channels(nameDict)

        typeDict = {"EOG": "eog"}
        eog.set_channel_types(typeDict)

        return eog

    def _getEOGIndex(self, ica, eogInds):
        if eogInds is None:
            logging.info("has EOG channel %s" % str(ica.labels_))
            eogInds = ica.labels_[BLINK_LABEL]
        return eogInds

    def _createDropNames(self, ica, eogInds):
        ind = range(ica.n_components_)
        return [self._getICAName(i) for i in ind if i not in eogInds]

    def _getICAName(self, number):
        # TODO remove '+ 1' after #3889
        return 'ICA %03d' % (number + 1)

    def removeEOGChannel(self, raw, ica, eogInd=None):
        eogInd = self._getEOGIndex(ica, eogInd)
        return ica.apply(raw, exclude=eogInd)
コード例 #13
0
class SignalStatisticUtil(object):
    '''
    class to show some statistical values for a channel
    '''

    def __init__(self, queue, filePath, signals=None, save=True, plot=True, logScale=False, name=""):
        self.queue = queue
        self.filePath = filePath
        self._initStatsDict()
        self.config = ConfigProvider()
        self.eegData = FileUtil().getDto(filePath)
        self._initSignals(signals)
        self.su = SignalUtil()
        self.qu = QualityUtil()
        self.eu = EEGUtil()
        self.fft = FFTUtil()
        self._initFields()
        self.save = save
        self.plot = plot
        self.name = name
        self._initPlotter(logScale)
        self.ssPrint = SignalStatisticPrinter(filePath)
        self.preProcessor = SignalPreProcessor()
        self.processor = SignalProcessor()

        windowSeconds = self.config.getCollectorConfig().get("windowSeconds")
        self.windowSize = EEGDataCollector.calcWindowSize(windowSeconds, self.eegData.samplingRate)

    def _initStatsDict(self):
        self.stats = OrderedDict()
        self.stats[GENERAL_KEY] = OrderedDict()
        self.stats[SIGNALS_KEY] = OrderedDict()

    def _initFields(self):
        self.statFields = STAT_FIELDS
        addMethods(self)

    def _initSignals(self, signals):
        if not signals:
            emoConfig = self.config.getEmotivConfig()
            signals = emoConfig.get("eegFields") +  emoConfig.get("gyroFields")
        self.signals = signals

    def _initPlotter(self, logScale):
        self.plotter = []
        for clazz in PLOTTER:
            plotter = clazz(self.name, self.eegData, self.signals, self.filePath, self.save, self.plot, logScale)
            thread = Process(target=plotter.doPlot)
            self.plotter.append(thread)

    def main(self):
        self.doPlot()

        self.collect_stats()
        self.printStats()
        #self.plotFFT()
        [plot.join() for plot in self.plotter]

        self.queue.put(self.stats)

    def doPlot(self):
        for thread in self.plotter:
            thread.start()

    def collect_stats(self):
        self.collectGeneralStats()
        self.fftData = {}
        for signal in self.signals:
            self.stats[SIGNALS_KEY][signal] = {}
            self.collectRawStats(signal)

    def plotFFT(self):
        for freq in FREQ_RANGE:
            plotter = FrequencyPlotter(str(freq)+"_"+self.name, self.eegData, self.signals, self.filePath, self.fftData, freq, self.save, self.plot)
            thread = Process(target=plotter.doPlot)
            self.plotter.append(thread)
            thread.start()

    def collectGeneralStats(self):
        self._addGeneralStatValue("file path", self.filePath)
        self._addGeneralStatValue("sampleRate", ("%f.2" % self.eegData.getSamplingRate()))
        self._addGeneralStatValue("dataLength", ("%d" % self.eegData.len))
        self._addGeneralStatValue("bound", ("%d - %d" % (self.qu.lowerBound, self.qu.upperBound)))

        self._addGeneralTimeStat("start time", "getStartTime", TIME_FORMAT_STRING)
        self._addGeneralTimeStat("end time", "getEndTime", TIME_FORMAT_STRING)
        self._addGeneralTimeStat("duration", "getDuration", DURATION_FORMAT_STRING)

    def _addGeneralTimeStat(self, name, method, formatString):
        time = getattr(self.eegData, method)()
        value = self._buildFormattedTime(time, formatString)
        self._addGeneralStatValue(name, value)

    def _buildFormattedTime(self, time, formatString):
        value = datetime.fromtimestamp(time).strftime(formatString)
        return value
    
    def _addGeneralStatValue(self, name, value):
        self.stats[GENERAL_KEY][name] = value

    def collectRawStats(self, signal):
        data = self.eegData.getColumn(signal)
        self._collectSignalStat(signal, RAW_KEY, data)

    def _collectSignalStat(self, signal, category, data):
        self.stats[SIGNALS_KEY][signal][category] = OrderedDict()
        windows = self.getWindows(data)
        for field, attributes in self.statFields.iteritems():
            fieldValues = []
            for window in windows:
                sigStat = self._getSignalStat(attributes["method"], window)
                fieldValues.append(sigStat)
            merged = self._mergeValues(fieldValues, field)
            self._addSignalStatValue(signal, category, field, merged)
        self._addFFT(signal, category, windows)

    def _addFFT(self, signal, category, windows):
        ffts = []
        for window in windows:
            ffts.append(self._getFreqValues(window))
        ffts = array(ffts).transpose()
        fftDict = {}
        for i, fft in zip(FREQ_RANGE, ffts):
            fftDict[str(i)] = fft
            merged = nanmean(fft)
            self._addSignalStatValue(signal, category, str(i), merged)
        self.fftData[signal] = fftDict

    def getWindows(self, raw):
        windows = []
        for start in range(0, len(raw), self.windowSize / 2):
            end = start + self.windowSize
            if end <= len(raw):
                windows.append(raw[start:end])
        return windows

    def _getSignalStat(self, method, raw):
        return method(raw)

    def _getFreqValues(self, raw):
        fft = self.fft.fft(raw)
        return [fft[freq] for freq in FREQ_RANGE]

    def _addSignalStatValue(self, signal, category, name, value):
        self.stats[SIGNALS_KEY][signal][category][name] = value

    def _mergeValues(self, values, field):
        typ = self.statFields[field][TYPE]

        if typ == MAX_TYPE:
            return nanmax(values)
        if typ == MIN_TYPE:
            return nanmin(values)
        if typ == AGGREGATION_TYPE:
            return nansum(values)
        if typ == MEAN_TYPE:
            return nanmean(values)
        if typ == DIFF_TYPE:
            return nanmean(values)

    def printStats(self):
        content = self.ssPrint.getSignalStatsString(self.stats)
        print content
        if self.save:
            filePath = getNewFileName(str(self.filePath), "txt", "_" + self.name)
            self.ssPrint.saveStats(filePath, content)
コード例 #14
0
class DataWidget(QtGui.QWidget):
    def __init__(self, dataUrls, maxFps):
        super(DataWidget, self).__init__()
        self.fileUtil = FileUtil()

        self._initData(dataUrls[0])
        self.maxFps = maxFps
        self.curSecond = 0
        self._initPlot()

        layout = QtGui.QVBoxLayout(self)
        layout.addWidget(self.toolbar)
        layout.addWidget(self.canvas)

        self.setObjectName("datawidget")

    def _initData(self, filePath):
        if self.fileUtil.isCSVFile(filePath):
            dto = self.fileUtil.getDtoFromCsv(filePath)
        else:
            dto = self.fileUtil.getDtoFromFif(filePath)

        self.eegHeader = dto.getEEGHeader()
        self.eegData = dto.getEEGData()
        self.numChannels = len(self.eegData)
        self.samplingRate = dto.getSamplingRate()
        self.length = len(self.eegData[0])
        logging.info("plotter\t#%d\t%.2fHz" % (self.length, self.samplingRate))

    def _initPlot(self):
        self.figure = plt.figure()
        self.canvas = FigureCanvas(self.figure)
        self.toolbar = NavigationToolbar(self.canvas, self)

        self.axes = []
        for i, _ in enumerate(self.eegData):
            self.axes.append(
                self.figure.add_subplot(self.numChannels, 1, i + 1))

        start, end = self._getRange()
        x_values = [x for x in range(start, end)]

        self.lines = []
        for i, ax in enumerate(self.axes):
            data = self.eegData[i]
            line, = ax.plot(x_values, data[start:end], '-')
            self.lines.append(line)

            ax.set_xlim([start, end])
            ax.set_ylim([min(data), max(data)])
            ax.set_ylabel(self.eegHeader[i])

    def show(self, curSecond):
        if self._isReplot(curSecond):
            self.plot()

    def plot(self):
        start, end = self._getRange()
        if self._isInDataRange(start, end):
            for i, line in enumerate(self.lines):
                line.set_ydata(self.eegData[i][start:end])

            self.canvas.draw()
        else:
            logging.warn("no data found for index range [%d:%d]" %
                         (start, end))

    def _isInDataRange(self, start, end):
        return end < self.length

    def _getRange(self):
        start = int(self.curSecond * self.samplingRate)
        end = int(start + self.samplingRate)
        return start, end

    # TODO method does 2 things
    def _isReplot(self, curSecond):
        if curSecond != self.curSecond:
            self.curSecond = curSecond
            return True
        return False
コード例 #15
0
def readEEGFile(fileName):
    eegPath = "E:/thesis/experiment/"

    return FileUtil().getDto(eegPath + fileName)
コード例 #16
0
ファイル: feature_plotter.py プロジェクト: ppasler/PoSDBoS
def plotOld():
    #filePath = scriptPath + "/../../data/awake_full_.csv"
    filePath = scriptPath + "/../../data/classes.csv"
    dto = FileUtil().getDto(filePath)
    fp = FeaturePlotter(dto.getData(), dto.getHeader(), filePath)
    fp.doPlot()
コード例 #17
0
class DummyDataSource(object):
    def __init__(self, filePath=None, infinite=True):
        '''
        Reads data from filePath or ./../../data/dummy_4096.csv and builds the data structure
        '''
        self.filePath = self.setFilePath(filePath)
        self.infinite = infinite
        self.fileUtil = FileUtil()
        self.hasMore = False
        self.data = []
        self.len = 0
        self.index = 0

    def setFilePath(self, filePath):
        if filePath == None:
            return [scriptPath + "/../../../data/dummy_4096.csv"]
        elif type(filePath) != list:
            return [filePath]
        else:
            return filePath

    def convert(self):
        for filePath in self.filePath:
            dto = self.fileUtil.getDto(filePath)
            self._readHeader(dto)
            self._readRawData(dto)
            self.samplingRate = dto.getSamplingRate()
            self._buildDataStructure()
        logging.info("%s: Using %d dummy datasets" %
                     (self.__class__.__name__, self.len))

    def _readHeader(self, dto):
        self.header = dto.getHeader()
        self.fields = dto.getEEGHeader() + dto.getGyroHeader()
        self._hasQuality()

    def _hasQuality(self):
        self.hasQuality = all([(("Q" + field) in self.header)
                               for field in self.fields])

    def _readRawData(self, dto):
        self.rawData = dto.getData()
        self.len += len(self.rawData)
        if self.len > 0:
            self.hasMore = True

    def dequeue(self):
        pass

    def _getNextIndex(self):
        self.index += 1
        if self.index >= len(self.data) and not self.infinite:
            self.hasMore = False
        self.index %= self.len

    def _buildDataStructure(self):
        pass

    def close(self):
        pass

    def stop(self):
        pass
コード例 #18
0
 def __init__(self):
     self.config = ConfigProvider()
     self.fileUtil = FileUtil()
コード例 #19
0
class MNEUtil(object):
    def __init__(self):
        self.config = ConfigProvider()
        self.fileUtil = FileUtil()

    def createMNEObjectFromCSV(self, filePath):
        eegData = self.fileUtil.getDtoFromCsv(filePath)
        return self.createMNEObjectFromEEGDto(eegData)

    def createMNEObjectFromEEGDto(self, eegDto):
        return self.createMNEObject(eegDto.getEEGData(), eegDto.getEEGHeader(),
                                    eegDto.getGyroData(),
                                    eegDto.getGyroHeader(), eegDto.filePath,
                                    eegDto.getSamplingRate())

    def createMNEObject(self, eegData, eegHeader, gyroData, gyroHeader,
                        filePath, samplingRate):
        info = self._createEEGInfo(eegHeader, gyroHeader, filePath,
                                   samplingRate)
        data = self._mergeData(eegData, gyroData)
        return mne.io.RawArray(data, info)

    def _mergeData(self, eegData, gyroData):
        if gyroData is None:
            return eegData
        return concatenate((eegData, gyroData), axis=0)

    def _createEEGInfo(self, eegChannelNames, gyroChannelNames, filePath,
                       samplingRate):
        channelTypes = ["eeg"] * len(eegChannelNames) + ['misc'] * len(
            gyroChannelNames)
        channelNames = eegChannelNames + gyroChannelNames
        montage = mne.channels.read_montage("standard_1020")
        info = mne.create_info(channelNames, samplingRate, channelTypes,
                               montage)
        info["description"] = filePath
        return info

    def createMNEObjectFromECGDto(self, ecgDto, resampleFac=None):
        info = self._createECGInfo(ecgDto.getECGHeader(), ecgDto.filePath,
                                   ecgDto.getSamplingRate())
        ecgData = ecgDto.getECGData()
        if resampleFac is not None:
            ecgData = signal.resample(ecgData, resampleFac)
        return mne.io.RawArray(ecgData, info)

    def _createECGInfo(self, channelName, filePath, samplingRate):
        channelTypes = ["ecg"]
        info = mne.create_info([channelName], samplingRate, channelTypes)
        info["description"] = filePath
        return info

    def createMNEEpochsObject(self, eegData, clazz):
        raw = self.createMNEObjectFromEEGDto(eegData)
        return self.createMNEEpochsObjectFromRaw(raw, clazz)

    def createMNEEpochsObjectFromRaw(self, raw, clazz, duration=1):
        events = self._createEventsArray(raw, clazz, False)
        return mne.Epochs(raw,
                          events=events,
                          tmin=0.0,
                          tmax=0.99,
                          add_eeg_ref=True)

    def _createEventsArray(self, raw, clazz, overlapping=True, duration=1):
        if overlapping:
            duration = 0.5
        return mne.make_fixed_length_events(raw, clazz, duration=duration)

    def addECGChannel(self, eegRaw, ecgRaw):
        if "ecg" in ecgRaw:
            return self._addChannel(eegRaw, ecgRaw)

    def addEOGChannel(self, eegRaw, eogRaw):
        if "eog" in eogRaw:
            return self._addChannel(eegRaw, eogRaw)

    def _addChannel(self, eegRaw, otherRaw):
        otherRaw = self.adjustSampleRate(eegRaw, otherRaw)
        otherRaw = self.adjustLength(eegRaw, otherRaw)

        return eegRaw.add_channels([otherRaw], force_update_info=True)

    def addICASources(self, raw, ica):
        icaRaw = ica.get_sources(raw)
        raw.add_channels([icaRaw])
        return raw

    def adjustSampleRate(self, eegRaw, otherRaw):
        eegSFreq = eegRaw.info['sfreq']
        otherSFreq = otherRaw.info['sfreq']
        if eegSFreq != otherSFreq:
            otherRaw = otherRaw.resample(eegSFreq, npad='auto')
        return otherRaw

    def adjustLength(self, eegRaw, otherRaw):
        eegNTimes = eegRaw.n_times
        otherNTimes = otherRaw.n_times
        if eegNTimes != otherNTimes:
            eegSFreq = eegRaw.info['sfreq']
            tMax = (eegRaw.n_times - 1) / eegSFreq
            otherRaw = otherRaw.crop(0, tMax)
        return otherRaw

    def markBadChannels(self, raw, channels):
        raw.info['bads'] = channels

    def interpolateBadChannels(self, raw):
        return raw.interpolate_bads()

    def createPicks(self, mneObj):
        return mne.pick_types(mneObj.info,
                              meg=False,
                              eeg=True,
                              eog=False,
                              stim=False,
                              exclude='bads')

    def bandpassFilterData(self, mneObj):
        highFreq = self.config.getProcessingConfig().get("upperFreq")
        lowFreq = self.config.getProcessingConfig().get("lowerFreq")
        return self.filterData(mneObj, lowFreq, highFreq)

    def filterData(self, mneObj, lowFreq, highFreq):
        return mneObj.filter(lowFreq,
                             highFreq,
                             filter_length="auto",
                             l_trans_bandwidth="auto",
                             h_trans_bandwidth="auto",
                             phase='zero',
                             fir_window="hamming")

    def getEEGCannels(self, mneObj):
        return mneObj.copy().pick_types(meg=False, eeg=True)

    def getChannels(self, mneObj, channels):
        return mneObj.copy().pick_channels(channels)

    def cropChannels(self, mneObj, tmin, tmax):
        return mneObj.copy().crop(tmin, tmax - 1)

    def dropChannels(self, mneObj, channels):
        return mneObj.copy().drop_channels(channels)

    def calcPSD(self, raw, fmin, fmax, picks=None):
        return psd_welch(raw, fmin, fmax, picks=picks)

    def ICA(self, mneObj, icCount=None, random_state=None):
        picks = self.createPicks(mneObj)
        reject = dict(eeg=300)

        if icCount is None:
            icCount = len(picks)
        ica = ICA(n_components=icCount,
                  method="fastica",
                  random_state=random_state)
        ica.fit(mneObj, picks=picks, reject=reject)

        return ica

    def labelArtefact(self, templateICA, templateIC, icas, label):
        template = (0, templateIC)
        icas = [templateICA] + icas
        return corrmap(icas,
                       template=template,
                       threshold=0.85,
                       label=label,
                       plot=False,
                       show=False,
                       ch_type='eeg',
                       verbose=True)

    def findCrossCorrelation(self, raw, ica=None):
        import matplotlib.pyplot as plt

        ch_names = raw.info["ch_names"]
        ch_idx = [
            ch_names.index(id) for id in ch_names if id.startswith("ICA")
        ]
        cor_list = []
        data = raw._data
        xChannel = data[ch_names.index("X")]
        for idx in ch_idx:
            chan = data[idx]
            cor = signal.correlate(xChannel, chan)
            plt.plot(cor, label=str(idx))
        plt.legend()
        MNEPlotter().plotRaw(raw)
        plt.show()
コード例 #20
0
 def readEEGFile(self, filePath):
     return FileUtil().getDto(filePath)
コード例 #21
0
 def setUp(self):
     self.util = FileUtil()
     self.dto = self._readData()
     self.mneObj = self._getMNEObject(self.dto)
コード例 #22
0
logging.basicConfig(
    level=logging.INFO,
    format=
    '%(asctime)s.%(msecs)03d %(levelname)-8s %(module)s.%(funcName)s:%(lineno)d %(message)s',
    datefmt='%H:%M:%S')
from posdbos.util.file_util import FileUtil

import threading
from config.config import ConfigProvider
from posdbos.factory import Factory

exConfig = ConfigProvider().getExperimentConfig()
probands = exConfig.get("probands")
experimentDir = exConfig.get("filePath")

fileUtil = FileUtil()


def getFilePaths(fileName):
    filePaths = []
    for proband in probands:
        filePath = "%s%s/" % (experimentDir, proband)
        filePaths.append(filePath + fileName)
    return filePaths


def splitDtos(filePaths):
    awakes, drowsies = [], []
    for filePath in filePaths:
        dto = fileUtil.getDto(filePath)
        s1, e1, s2, e2 = _getStartStopPercent(dto)
コード例 #23
0
 def _readData(self):
     return FileUtil().getDto(self.getData1024CSV())