class TestFileUtil(BaseTest): def setUp(self): self.util = FileUtil() self.dto = self._readData() self.mneObj = self._getMNEObject(self.dto) def _getMNEObject(self, dto): return MNEUtil().createMNEObjectFromEEGDto(dto) def _readData(self): return self.util.getDto(self.getData1024CSV()) def test_isCSVFile(self): self.assertTrue(self.util.isCSVFile("path/to/sth.csv")) self.assertFalse(self.util.isCSVFile("path/.csv/sth.txt")) self.assertFalse(self.util.isCSVFile("path/to/sth.raw.fif")) self.assertFalse(self.util.isCSVFile("path/to/sth.ica.fif")) def test_getDto(self): dto = self.util.getDto(self.getData1024CSV()) dto2 = self.util.getDto(self.getData1024FIF()) self.assertAlmostEqual(dto.samplingRate, dto2.samplingRate, delta=0.1) assert_array_equal(dto.getEEGHeader(), dto2.getEEGHeader()) assert_array_equal(dto.getEEGData(), dto2.getEEGData()) def test_getDto_dtoinput(self): dto = self.util.getDto(self.getData1024CSV()) dto2 = self.util.getDto(dto) self.assertTrue(dto is dto2) self.assertAlmostEqual(dto.samplingRate, dto2.samplingRate, delta=0.1) assert_array_equal(dto.getEEGHeader(), dto2.getEEGHeader()) assert_array_equal(dto.getEEGData(), dto2.getEEGData()) def test_convertMNEToTableDto(self): dto2 = self.util.convertMNEToTableDto(self.mneObj) self.assertListEqual([TIMESTAMP_STRING] + self.dto.getEEGHeader() + self.dto.getGyroHeader(), dto2.getHeader()) assert_array_equal(self.dto.getEEGData(), dto2.getEEGData()) self.assertEqual(self.dto.filePath, dto2.filePath) self.assertTrue(dto2.hasEEGData) def test_getMNEFileName_given(self): filePath = "path/to/sth" mneFilePath = self.util.getMNEFileName(self.mneObj, filePath) self.assertEqual(filePath + ".raw.fif", mneFilePath) def test_getMNEFileName_givenCSV(self): filePath = "path/to/sth" mneFilePath = self.util.getMNEFileName(self.mneObj, filePath + ".csv") self.assertEqual(filePath + ".raw.fif", mneFilePath) def test_getMNEFileName_givenCSVmiddle(self): filePath = "path/.csv/sth" mneFilePath = self.util.getMNEFileName(self.mneObj, filePath) self.assertEqual(filePath + ".raw.fif", mneFilePath) def test_getMNEFileName_extension(self): filePath = "path/to/sth" self.mneObj.info["description"] = filePath mneFilePath = self.util.getMNEFileName(self.mneObj, None) self.assertEqual(filePath + ".raw.fif", mneFilePath) def test_getMNEFileName_CSV(self): filePath = "path/to/sth" self.mneObj.info["description"] = filePath + ".csv" mneFilePath = self.util.getMNEFileName(self.mneObj, None) self.assertEqual(filePath + ".raw.fif", mneFilePath) def test_getPartialDto(self): end = len(self.dto) / 2 copyDto = self.util.getPartialDto(self.dto, 0, end) self.assertFalse(self.dto is copyDto) self.assertFalse(self.dto.header is copyDto.header) assert_array_equal(self.dto.header, copyDto.header) self.assertTrue(self.dto.filePath == copyDto.filePath) self.assertTrue(self.dto.samplingRate == copyDto.samplingRate) partData = copyDto.data fullData = self.dto.data self.assertFalse(fullData is partData) assert_array_equal(fullData[0:end, :], partData) self.assertEqual(len(partData), end) self.assertTrue(fullData.shape[0] > partData.shape[0]) self.assertTrue(fullData.shape[1] == partData.shape[1])
class DataWidget(QtGui.QWidget): def __init__(self, dataUrls, maxFps): super(DataWidget, self).__init__() self.fileUtil = FileUtil() self._initData(dataUrls[0]) self.maxFps = maxFps self.curSecond = 0 self._initPlot() layout = QtGui.QVBoxLayout(self) layout.addWidget(self.toolbar) layout.addWidget(self.canvas) self.setObjectName("datawidget") def _initData(self, filePath): if self.fileUtil.isCSVFile(filePath): dto = self.fileUtil.getDtoFromCsv(filePath) else: dto = self.fileUtil.getDtoFromFif(filePath) self.eegHeader = dto.getEEGHeader() self.eegData = dto.getEEGData() self.numChannels = len(self.eegData) self.samplingRate = dto.getSamplingRate() self.length = len(self.eegData[0]) logging.info("plotter\t#%d\t%.2fHz" % (self.length, self.samplingRate)) def _initPlot(self): self.figure = plt.figure() self.canvas = FigureCanvas(self.figure) self.toolbar = NavigationToolbar(self.canvas, self) self.axes = [] for i, _ in enumerate(self.eegData): self.axes.append( self.figure.add_subplot(self.numChannels, 1, i + 1)) start, end = self._getRange() x_values = [x for x in range(start, end)] self.lines = [] for i, ax in enumerate(self.axes): data = self.eegData[i] line, = ax.plot(x_values, data[start:end], '-') self.lines.append(line) ax.set_xlim([start, end]) ax.set_ylim([min(data), max(data)]) ax.set_ylabel(self.eegHeader[i]) def show(self, curSecond): if self._isReplot(curSecond): self.plot() def plot(self): start, end = self._getRange() if self._isInDataRange(start, end): for i, line in enumerate(self.lines): line.set_ydata(self.eegData[i][start:end]) self.canvas.draw() else: logging.warn("no data found for index range [%d:%d]" % (start, end)) def _isInDataRange(self, start, end): return end < self.length def _getRange(self): start = int(self.curSecond * self.samplingRate) end = int(start + self.samplingRate) return start, end # TODO method does 2 things def _isReplot(self, curSecond): if curSecond != self.curSecond: self.curSecond = curSecond return True return False