class SignalStatisticUtil(object):
    '''
    class to show some statistical values for a channel
    '''

    def __init__(self, person, filePath, signals=None, save=True, plot=True, logScale=False):
        self.person = person
        self.filePath = filePath
        self._initStatsDict()
        self._readData()
        self._initSignals(signals)
        self.su = SignalUtil()
        self.qu = QualityUtil()
        self.eu = EEGUtil()
        self._initFields()
        self.save = save
        self._initPlotter(person, plot, logScale)
        self.ssPrint = SignalStatisticPrinter(person)
        self.preProcessor = SignalPreProcessor()
        self.processor = SignalProcessor()
        self.windowSize = ConfigProvider().getCollectorConfig().get("windowSize")

    def _initStatsDict(self):
        self.stats = OrderedDict()
        self.stats[GENERAL_KEY] = OrderedDict()
        self.stats[SIGNALS_KEY] = OrderedDict()

    def _readData(self):
        self.reader = EEGTableFileUtil()
        self.eegData = self.reader.readFile(self.filePath)

    def _initSignals(self, signals):
        if not signals:
            signals = ConfigProvider().getEmotivConfig().get("eegFields")
        self.signals = signals

    def _initFields(self):
        self.statFields = STAT_FIELDS
        self.statFields["max"][METHOD] = self.su.maximum 
        self.statFields["min"][METHOD] = self.su.minimum
        self.statFields["mean"][METHOD] = self.su.mean
        self.statFields["std"][METHOD] = self.su.std
        self.statFields["var"][METHOD] = self.su.var
        self.statFields["zeros"][METHOD] = self.qu.countZeros
        self.statFields["seq"][METHOD] = self.qu.countSequences
        self.statFields["out"][METHOD] = self.qu.countOutliners
        self.statFields["nrgy"][METHOD] = self.su.energy
        self.statFields["zcr"][METHOD] = self.su.zcr

    def _initPlotter(self, person, plot, logScale):
        self.plotter = []
        for clazz in PLOTTER:
            plotter = clazz(person, self.eegData, self.signals, self.filePath, self.save, plot, logScale)
            thread = multiprocessing.Process(target=plotter.doPlot)
            self.plotter.append(thread)


    def main(self):
        self.doPlot()

        self.collect_stats()
        self.printStats()

    def doPlot(self):
        for thread in self.plotter:
            thread.start()

    def collect_stats(self):
        self.collectGeneralStats()
        for signal in self.signals:
            self.stats[SIGNALS_KEY][signal] = {}
            self.collectRawStats(signal)

    def collectGeneralStats(self):
        self._addGeneralStatValue("file path", self.filePath)
        self._addGeneralStatValue("sampleRate", ("%f.2" % self.eegData.getSamplingRate()))
        self._addGeneralStatValue("dataLength", ("%d" % self.eegData.len))
        self._addGeneralStatValue("bound", ("%d - %d" % (self.qu.lowerBound, self.qu.upperBound)))

        self._addGeneralTimeStat("start time", "getStartTime", TIME_FORMAT_STRING)
        self._addGeneralTimeStat("end time", "getEndTime", TIME_FORMAT_STRING)
        self._addGeneralTimeStat("duration", "getDuration", DURATION_FORMAT_STRING)

    def _addGeneralTimeStat(self, name, method, formatString):
        time = getattr(self.eegData, method)()
        value = self._buildFormattedTime(time, formatString)
        self._addGeneralStatValue(name, value)

    def _buildFormattedTime(self, time, formatString):
        value = datetime.fromtimestamp(time).strftime(formatString)
        return value
    
    def _addGeneralStatValue(self, name, value):
        self.stats[GENERAL_KEY][name] = value

    def collectRawStats(self, signal):
        data = self.eegData.getColumn(signal)
        self._collectSignalStat(signal, RAW_KEY, data)

    def _collectSignalStat(self, signal, category, data):
        self.stats[SIGNALS_KEY][signal][category] = OrderedDict()
        for field, attributes in self.statFields.iteritems():
            fieldValues = []
            for window in self.getWindows(data):
                fieldValues.append(self._getSignalStat(signal, category, field, attributes["method"], window, 0))
            merged = self._mergeValues(fieldValues, field)
            self._addSignalStatValue(signal, category, field, merged)

    def getWindows(self, raw):
        windows = []
        for start in range(0, len(raw), self.windowSize / 2):
            end = start + self.windowSize
            if end <= len(raw):
                windows.append(raw[start:end])
        return windows

    def _getSignalStat(self, signal, category, name, method, raw, decPlace=2):
        return method(raw)

    def _addSignalStatValue(self, signal, category, name, value):
        self.stats[SIGNALS_KEY][signal][category][name] = value

    def _mergeValues(self, values, field):
        typ = self.statFields[field][TYPE]

        if typ == MAX_TYPE:
            return nanmax(values)
        if typ == MIN_TYPE:
            return nanmin(values)
        if typ == AGGREGATION_TYPE:
            return nansum(values)
        if typ == MEAN_TYPE:
            return nanmean(values)
        if typ == DIFF_TYPE:
            return nanmean(values)

    def setStats(self, stats):
        self.stats = stats

    def printStats(self):
        content = self.ssPrint.getSignalStatsString(self.stats)
        print content
        if self.save:
            filePath = getNewFileName(self.filePath, "txt", "_stats")
            self.ssPrint.saveStats(filePath, content)
示例#2
0
class TestEEGTableFileUtil(unittest.TestCase):

    def setUp(self):
        self.reader = EEGTableFileUtil()

    def test_readData(self):
        file_path = PATH + "example_32.csv"
        if os.path.isfile(file_path):
            self.reader.readData(file_path)
        else:
            print "'%s' not found" % file_path

    def test_readHeader(self):
        file_path = PATH + "example_32.csv"
        if os.path.isfile(file_path):
            self.reader.readHeader(file_path)
        else:
            print "'%s' not found" % file_path

    def test_readFile(self):
        file_path = PATH + "example_32.csv"
        if os.path.isfile(file_path):
            self.reader.readFile(file_path)
        else:
            print "'%s' not found" % file_path

    def test_writeFile(self):
        filePath = PATH + "test.csv"
        header= ["A", "B", "C"]
        data = np.array([[1.123456789, 2, 3], [-4.123456789, 5, 6], [7.123456789, 8, 99.123]])
        self.reader.writeFile(filePath, data, header)
        
        if os.path.isfile(filePath):
            read = self.reader.readFile(filePath)

            for i in range(len(data)):
                for j in range(len(data[i])):
                    self.assertAlmostEqual(data[i, j], read.data[i, j], delta= 0.001)

        removeFile(filePath)

    def test_writeStructredFile(self):
        filePath = PATH + "test_structured.csv"
        data = {
            "A": {
                "value": [1, 2, 3],
                "quality": [-1, -1, -1]
            },
            "B": {
                "value": [4, 5, 6],
                "quality": [-2, -2, -2]
            },
            "C": {
                "value": [7, 8, 9],
                "quality": [-3, -3, -3]
            }
        }
        self.reader.writeStructredFile(filePath, data)
        
        if os.path.isfile(filePath):
            read = self.reader.readFile(filePath)
            for key, values in data.iteritems():
                self.assertTrue(sameEntries(values["value"], read.getColumn(key)))
        removeFile(filePath)

    def test_readFile_NaNValues(self):
        eegData = self.reader.readFile(PATH + "example_32_empty.csv")
        emptyCol = eegData.getColumn("Y")
        self.assertTrue(np.isnan(emptyCol).any())
        
        nonEmptyCol = eegData.getColumn("F3")
        self.assertFalse(np.isnan(nonEmptyCol).any())

    def test_readFile_SeparatorFallback(self):
        eegData = self.reader.readFile(PATH + "example_32_empty.csv")
        semicolonData = eegData.getColumn("F3")

        eegData = self.reader.readFile(PATH + "example_32_comma.csv")
        commaData = eegData.getColumn("F3")

        self.assertTrue((semicolonData == commaData).all())