Beispiel #1
0
class WidgetsWrapper(object):
    def __init__(self):
        self.widgets = gtk.glade.XML('mpl_with_glade.glade')
        self.widgets.signal_autoconnect(GladeHandlers.__dict__)

        self['windowMain'].connect('destroy', lambda x: gtk.main_quit())
        self['windowMain'].move(10, 10)
        self.figure = Figure(figsize=(8, 6), dpi=72)
        self.axis = self.figure.add_subplot(111)

        t = arange(0.0, 3.0, 0.01)
        s = sin(2 * pi * t)
        self.axis.plot(t, s)
        self.axis.set_xlabel('time (s)')
        self.axis.set_ylabel('voltage')

        self.canvas = FigureCanvas(self.figure)  # a gtk.DrawingArea
        self.canvas.show()
        self.canvas.set_size_request(600, 400)
        self.canvas.set_events(gtk.gdk.BUTTON_PRESS_MASK
                               | gtk.gdk.KEY_PRESS_MASK
                               | gtk.gdk.KEY_RELEASE_MASK)
        self.canvas.set_flags(gtk.HAS_FOCUS | gtk.CAN_FOCUS)
        self.canvas.grab_focus()

        def keypress(widget, event):
            print('key press')

        def buttonpress(widget, event):
            print('button press')

        self.canvas.connect('key_press_event', keypress)
        self.canvas.connect('button_press_event', buttonpress)

        def onselect(xmin, xmax):
            print(xmin, xmax)

        span = SpanSelector(self.axis,
                            onselect,
                            'horizontal',
                            useblit=False,
                            rectprops=dict(alpha=0.5, facecolor='red'))

        self['vboxMain'].pack_start(self.canvas, True, True)
        self['vboxMain'].show()

        # below is optional if you want the navigation toolbar
        self.navToolbar = NavigationToolbar(self.canvas, self['windowMain'])
        self.navToolbar.lastDir = '/var/tmp/'
        self['vboxMain'].pack_start(self.navToolbar)
        self.navToolbar.show()

        sep = gtk.HSeparator()
        sep.show()
        self['vboxMain'].pack_start(sep, True, True)

        self['vboxMain'].reorder_child(self['buttonClickMe'], -1)

    def __getitem__(self, key):
        return self.widgets.get_widget(key)
Beispiel #2
0
class WidgetsWrapper(object):
    def __init__(self):
        self.widgets = gtk.glade.XML('mpl_with_glade.glade')
        self.widgets.signal_autoconnect(GladeHandlers.__dict__)

        self['windowMain'].connect('destroy', lambda x: gtk.main_quit())
        self['windowMain'].move(10, 10)
        self.figure = Figure(figsize=(8, 6), dpi=72)
        self.axis = self.figure.add_subplot(111)

        t = arange(0.0, 3.0, 0.01)
        s = sin(2*pi*t)
        self.axis.plot(t, s)
        self.axis.set_xlabel('time (s)')
        self.axis.set_ylabel('voltage')

        self.canvas = FigureCanvas(self.figure)  # a gtk.DrawingArea
        self.canvas.show()
        self.canvas.set_size_request(600, 400)
        self.canvas.set_events(
            gtk.gdk.BUTTON_PRESS_MASK |
            gtk.gdk.KEY_PRESS_MASK |
            gtk.gdk.KEY_RELEASE_MASK
            )
        self.canvas.set_flags(gtk.HAS_FOCUS | gtk.CAN_FOCUS)
        self.canvas.grab_focus()

        def keypress(widget, event):
            print('key press')

        def buttonpress(widget, event):
            print('button press')

        self.canvas.connect('key_press_event', keypress)
        self.canvas.connect('button_press_event', buttonpress)

        def onselect(xmin, xmax):
            print(xmin, xmax)

        span = SpanSelector(self.axis, onselect, 'horizontal', useblit=False,
                            rectprops=dict(alpha=0.5, facecolor='red'))

        self['vboxMain'].pack_start(self.canvas, True, True)
        self['vboxMain'].show()

        # below is optional if you want the navigation toolbar
        self.navToolbar = NavigationToolbar(self.canvas, self['windowMain'])
        self.navToolbar.lastDir = '/var/tmp/'
        self['vboxMain'].pack_start(self.navToolbar)
        self.navToolbar.show()

        sep = gtk.HSeparator()
        sep.show()
        self['vboxMain'].pack_start(sep, True, True)

        self['vboxMain'].reorder_child(self['buttonClickMe'], -1)

    def __getitem__(self, key):
        return self.widgets.get_widget(key)
Beispiel #3
0
class WidgetsWrapper:
    def __init__(self):
        self.widgets = gtk.glade.XML('test.glade')
        self.widgets.signal_autoconnect(GladeHandlers.__dict__)

        self['windowMain'].connect('destroy', lambda x: gtk.main_quit())
        self['windowMain'].move(10,10)
        self.figure = Figure(figsize=(8,6), dpi=72)
        self.axis = self.figure.add_subplot(111)
        
        t = arange(0.0,3.0,0.01)
        s = sin(2*pi*t)
        self.axis.plot(t,s)
        self.axis.set_xlabel('time (s)')
        self.axis.set_ylabel('voltage')

        self.canvas = FigureCanvas(self.figure) # a gtk.DrawingArea
        self.canvas.show()
        self.canvas.set_size_request(600, 400)

        def onselect(xmin, xmax):
            print xmin, xmax

        span = HorizontalSpanSelector(self.axis, onselect, useblit=False,
                                          rectprops=dict(alpha=0.5, facecolor='red') )


        self['vboxMain'].pack_start(self.canvas, True, True)
        self['vboxMain'].show()
        
        # below is optional if you want the navigation toolbar
        self.navToolbar = NavigationToolbar(self.canvas, self['windowMain'])
        self.navToolbar.lastDir = '/var/tmp/'
        self['vboxMain'].pack_start(self.navToolbar)
        self.navToolbar.show()

        sep = gtk.HSeparator()
        sep.show()
        self['vboxMain'].pack_start(sep, True, True)


        self['vboxMain'].reorder_child(self['buttonClickMe'],-1)





    def __getitem__(self, key):
        return self.widgets.get_widget(key)
Beispiel #4
0
class WidgetsWrapper:
    def __init__(self):
        self.widgets = gtk.glade.XML('test.glade')
        self.widgets.signal_autoconnect(GladeHandlers.__dict__)

        self['windowMain'].connect('destroy', lambda x: gtk.main_quit())
        self['windowMain'].move(10, 10)
        self.figure = Figure(figsize=(8, 6), dpi=72)
        self.axis = self.figure.add_subplot(111)

        t = arange(0.0, 3.0, 0.01)
        s = sin(2 * pi * t)
        self.axis.plot(t, s)
        self.axis.set_xlabel('time (s)')
        self.axis.set_ylabel('voltage')

        self.canvas = FigureCanvas(self.figure)  # a gtk.DrawingArea
        self.canvas.show()
        self.canvas.set_size_request(600, 400)

        def onselect(xmin, xmax):
            print xmin, xmax

        span = HorizontalSpanSelector(self.axis,
                                      onselect,
                                      useblit=False,
                                      rectprops=dict(alpha=0.5,
                                                     facecolor='red'))

        self['vboxMain'].pack_start(self.canvas, True, True)
        self['vboxMain'].show()

        # below is optional if you want the navigation toolbar
        self.navToolbar = NavigationToolbar(self.canvas, self['windowMain'])
        self.navToolbar.lastDir = '/var/tmp/'
        self['vboxMain'].pack_start(self.navToolbar)
        self.navToolbar.show()

        sep = gtk.HSeparator()
        sep.show()
        self['vboxMain'].pack_start(sep, True, True)

        self['vboxMain'].reorder_child(self['buttonClickMe'], -1)

    def __getitem__(self, key):
        return self.widgets.get_widget(key)
Beispiel #5
0
    def add_figure_canvas(self, w=700, h=500):
        """Note: this method assumed self.main_hbox is already
        defined.  A figure canvas with a toolbar at the bottom is
        added to self.main_hbox."""
        self.fig = Figure(figsize=(8,6), dpi=100)
        self.ax = self.fig.add_subplot(111)
        t = arange(0.0,3.0,0.01)
        s = sin(2*pi*t)
        self.ax.plot(t,s)

        self.figcanvas = FigureCanvas(self.fig)  # a gtk.DrawingArea
        self.figcanvas.show()
        self.canvas_vbox = gtk.VBox()
        toolbar = NavigationToolbar(self.figcanvas, self.window)
        #toolbar.set_size_request(-1,50)
        self.figcanvas.set_size_request(w,h)
        toolbar.set_size_request(w,50)
        toolbar.show()
        self.canvas_vbox.pack_start(self.figcanvas)#, expand=True, \
            #fill=True, padding=5)
        self.canvas_vbox.pack_start(toolbar, False)#, False)#, padding=5)
        self.main_hbox.pack_start(self.canvas_vbox)#, expand=True, \
Beispiel #6
0
class HurricaneUI:
    def __init__(self):
        gladefile = "HurricaneUI.glade"
        builder = gtk.Builder()
        builder.add_from_file(gladefile)
        self.window = builder.get_object("mainWindow")
        builder.connect_signals(self)

        self.figure = Figure(figsize=(10, 10), dpi=75)
        self.axis = self.figure.add_subplot(111)

        self.lat = 50
        self.lon = -100
        self.globe = globDisp.GlobeMap(self.axis, self.lat, self.lon)

        self.canvas = FigureCanvasGTK(self.figure)
        self.canvas.show()
        self.canvas.set_size_request(500, 500)

        self.globeview = builder.get_object("map")
        self.globeview.pack_start(self.canvas, True, True)

        self.navToolbar = NavigationToolbar(self.canvas, self.globeview)
        self.navToolbar.lastDir = '/var/tmp'
        self.globeview.pack_start(self.navToolbar)
        self.navToolbar.show()

        self.gridcombo = builder.get_object("gridsize")
        cell = gtk.CellRendererText()
        self.gridcombo.pack_start(cell, True)
        self.gridcombo.add_attribute(cell, 'text', 0)
        #self.gridcombo.set_active(2)

        # read menu configuration
        self.gridopt = builder.get_object("gridopt").get_active()
        self.chkDetected = builder.get_object("detectedopt")
        self.detectedopt = self.chkDetected.get_active()
        self.chkHurricane = builder.get_object("hurricaneopt")
        self.hurricaneopt = self.chkHurricane.get_active()
        model = builder.get_object("liststore1")
        index = self.gridcombo.get_active()
        self.gridsize = model[index][0]
        radio = [
            r for r in builder.get_object("classifieropt1").get_group()
            if r.get_active()
        ][0]
        self.sClassifier = radio.get_label()
        self.start = builder.get_object("startdate")
        self.end = builder.get_object("enddate")

        self.chkUndersample = builder.get_object("undersample")
        self.chkGenKey = builder.get_object("genKey")

        # disable unimplemented classifier selection
        builder.get_object("classifieropt2").set_sensitive(False)
        builder.get_object("classifieropt3").set_sensitive(False)
        builder.get_object("classifieropt4").set_sensitive(False)

        self.btnStore = builder.get_object("store")
        self.datapath = 'GFSdat'
        self.trackpath = 'tracks'
        builder.get_object("btnDatapath").set_current_folder(self.datapath)
        builder.get_object("btnTrackpath").set_current_folder(self.trackpath)
        self.btnDetect = builder.get_object("detect")

        # current operation status
        self.stormlocs = None
        self.detected = None
        self.clssfr = None

        # for test drawing functions
        if os.path.exists('demo.detected'):
            with open('demo.detected', 'r') as f:
                self.detected = pickle.load(f)
                self.stormlocs = pickle.load(f)
                self.chkHurricane.set_label(
                    str(self.stormlocs.shape[0]) + " Hurricanes")
                self.chkDetected.set_label(
                    str(self.detected.shape[0]) + " Detected")

        self.setDisabledBtns()

        # draw Globe
        self.drawGlobe()

    def setDisabledBtns(self):
        self.chkDetected.set_sensitive(self.detected != None)
        self.chkHurricane.set_sensitive(self.stormlocs != None)
        self.btnStore.set_sensitive(self.clssfr != None)
        self.btnDetect.set_sensitive(self.clssfr != None)

    def drawGlobe(self):
        self.globe.drawGlobe(self.gridsize, self.gridopt)
        if self.hurricaneopt: self.globe.drawHurricanes(self.stormlocs)
        if self.detectedopt: self.globe.fillGrids(self.detected)

    def main(self):
        self.window.show_all()
        gtk.main()

    def redraw(self):
        self.axis.cla()
        self.drawGlobe()
        self.canvas.draw_idle()

    def gtk_main_quit(self, widget):
        gtk.main_quit()

    ###############################################################################
    #
    #  utility functions (dialogs)
    #
    def getFilenameToRead(self, stitle, save=False, filter='all'):
        chooser = gtk.FileChooserDialog(
            title=stitle,
            parent=self.window,
            action=gtk.FILE_CHOOSER_ACTION_OPEN
            if not save else gtk.FILE_CHOOSER_ACTION_SAVE,
            buttons=(gtk.STOCK_CANCEL, gtk.RESPONSE_CANCEL,
                     gtk.STOCK_OPEN if not save else gtk.STOCK_SAVE,
                     gtk.RESPONSE_OK))
        chooser.set_default_response(gtk.RESPONSE_OK)

        if filter == 'mat' or filter == 'mdat':
            filter = gtk.FileFilter()
            filter.set_name("Matrix files")
            filter.add_pattern("*.mat")
            chooser.add_filter(filter)
        if filter == 'svm':
            filter = gtk.FileFilter()
            filter.set_name("SVM")
            filter.add_pattern("*.svm")
            chooser.add_filter(filter)
        if filter == 'dat' or filter == 'mdat':
            filter = gtk.FileFilter()
            filter.set_name("Data")
            filter.add_pattern("*.dat")
            chooser.add_filter(filter)

        filter = gtk.FileFilter()
        filter.set_name("All files")
        filter.add_pattern("*")
        chooser.add_filter(filter)

        chooser.set_current_folder(os.getcwd())

        response = chooser.run()
        if response == gtk.RESPONSE_OK:
            filen = chooser.get_filename()
        else:
            filen = None
        chooser.destroy()
        return filen

    def showMessage(self, msg):
        md = gtk.MessageDialog(self.window, gtk.DIALOG_DESTROY_WITH_PARENT,
                               gtk.MESSAGE_INFO, gtk.BUTTONS_CLOSE, msg)
        md.run()
        md.destroy()

    ###############################################################################
    #
    #  read data files and convert into a training matrix input
    #
    def on_btnTrackpath_current_folder_changed(self, widget):
        self.trackpath = widget.get_current_folder()

    def on_btnDatapath_current_folder_changed(self, widget):
        self.datapath = widget.get_current_folder()

    def on_createMat_clicked(self, widget):
        filen = self.getFilenameToRead("Save converted matrix for training",
                                       True,
                                       filter='mat')
        if filen is not None:
            start = self.start.get_text()
            end = self.end.get_text()
            bundersmpl = self.chkUndersample.get_active()
            bgenkey = self.chkGenKey.get_active()

            ### FIX ME: currently gridsize for classification is fixed 1 (no clustering of grids - 2x2)
            if os.path.exists(filen):
                os.unlink(
                    filen
                )  # createMat append existing file, so delete it if exist
            gdtool.createMat(self.datapath,
                             self.trackpath,
                             start,
                             end,
                             store=filen,
                             undersample=bundersmpl,
                             genkeyf=bgenkey)
            self.showMessage("Matrix has been stored to " + filen)

    ###############################################################################
    #
    #   train the selected classifier
    #
    def on_train_clicked(self, widget):
        # FOR NOW, only SVM is supported
        if self.sClassifier == "SVM":
            filen = self.getFilenameToRead("Open training data", filter='mat')
            if filen is not None:
                data = ml.VectorDataSet(filen, labelsColumn=0)
                self.clssfr = ml.SVM()
                self.clssfr.train(data)
                # train finished. need to update button status
                self.setDisabledBtns()
                self.showMessage("Training SVM is done.")
        else:
            self.showMessage("The classifier is not supported yet!")

    def on_classifieropt_toggled(self, widget, data=None):
        self.sClassifier = widget.get_label()

    ###############################################################################
    #
    #   Classify on test data
    #
    def on_detect_clicked(self, widget):
        if self.clssfr is not None:
            filen = self.getFilenameToRead("Open hurricane data",
                                           filter='mdat')
            if filen is not None:
                fname = os.path.basename(filen)
                key, ext = os.path.splitext(fname)
                if ext == '.dat':
                    key = key[1:]  # take 'g' out

                    #testData = gdtool.createMat(self.datapath, self.trackpath, key, key)
                    #result = self.clssfr.test(ml.VectorDataSet(testData,labelsColumn=0))
                    tmpfn = 'f__tmpDetected__'
                    if os.path.exists(tmpfn): os.unlink(tmpfn)
                    # for DEMO, undersampled the normal data -- without undersampling there are too many candidates
                    gdtool.createMat(self.datapath,
                                     self.trackpath,
                                     key,
                                     key,
                                     store=tmpfn,
                                     undersample=True,
                                     genkeyf=True)
                    bneedDel = True
                else:
                    tmpfn = fname
                    bneedDel = False
                result = self.clssfr.test(
                    ml.VectorDataSet(tmpfn, labelsColumn=0))

                gdkeyfilen = ''.join([tmpfn, '.keys'])
                with open(gdkeyfilen, 'r') as f:
                    gridkeys = pickle.load(f)
                    self.stormlocs = pickle.load(f)
                predicted = result.getPredictedLabels()
                predicted = np.array(map(float, predicted))
                self.detected = np.array(gridkeys)[predicted == 1]
                if bneedDel:
                    os.unlink(tmpfn)
                    os.unlink(gdkeyfilen)

                snstroms = str(self.stormlocs.shape[0])
                sndetected = str(self.detected.shape[0])
                self.chkHurricane.set_label(snstroms + " Hurricanes")
                self.chkDetected.set_label(sndetected + " Detected")

                self.showMessage(''.join([
                    sndetected, "/", snstroms,
                    " grids are predicted to have hurricane."
                ]))
                if False:
                    with open('demo.detected', 'w') as f:
                        pickle.dump(self.detected, f)
                        pickle.dump(self.stormlocs, f)

                # test data tested. update buttons
                self.setDisabledBtns()
                self.redraw()
        else:
            self.showMessage("There is no trained classifier!")

    ###############################################################################
    #
    #   load and store trained classifier
    #
    def on_load_clicked(self, widget):
        filen = self.getFilenameToRead("Load Classifier", filter='svm')
        if filen is not None:
            #db = shelve.open(filen)
            #if db.has_key('clssfr'):
            #    self.clssfr = db['clssfr']
            #else:
            #    self.showMessage("Cannot find a classifier!")
            #db.close()
            #with open(filen, 'wb') as f:
            #    self.clssfr = pickle.load(f)

            datfn = self.getFilenameToRead("Open Training Data", filter='mat')
            if datfn is not None:
                data = ml.VectorDataSet(datfn, labelsColumn=0)
                self.clssfr = loadSVM(filen,
                                      data)  ## Why do I need to feed data ???

            #self.clssfr = loadSVM(filen,None) ## edited PyML for this

            # classifier has been loaded. need to update button status
            self.setDisabledBtns()
            self.showMessage("The classifier has been loaded!")

    def on_store_clicked(self, widget):
        if self.clssfr is not None:
            filen = self.getFilenameToRead("Store Classifier",
                                           True,
                                           filter='svm')
            if filen is not None:
                #with open(filen, 'wb') as f:
                #    pickle.dump(self.clssfr,f)
                #db = shelve.open(filen)
                #db['clssfr'] = self.clssfr
                #db.close()

                self.clssfr.save(filen)
                self.showMessage("The classifier has been saved!")
        else:
            self.showMessage("There is no trained classifier!")

    ###############################################################################
    #
    #   Display Globe
    #
    def on_right_clicked(self, widget):
        self.lon += 10
        # rotate Globe

    def on_left_clicked(self, widget):
        self.lon -= 10
        # rotate Globe

    def gridsize_changed_cb(self, widget):
        model = widget.get_model()
        index = widget.get_active()
        if index > -1:
            self.gridsize = model[index][0]
        self.redraw()

    def on_gridopt_toggled(self, widget):
        self.gridopt = not self.gridopt
        self.redraw()

    def on_Hurricane_toggled(self, widget):
        self.hurricaneopt = not self.hurricaneopt
        self.redraw()

    def on_detected_toggled(self, widget):
        self.detectedopt = not self.detectedopt
        self.redraw()
Beispiel #7
0
class PlotViewer(gtk.VBox):
    def __init__ (self, plotters, fields):
        gtk.VBox.__init__ (self)

        self.figure = mpl.figure.Figure ()
        self.canvas = FigureCanvas (self.figure)
        self.canvas.unset_flags (gtk.CAN_FOCUS)
        self.canvas.set_size_request (600, 400)
        self.pack_start (self.canvas, True, True)
        self.canvas.show ()

        self.navToolbar = NavigationToolbar (self.canvas, self.window)
        #self.navToolbar.lastDir = '/tmp'
        self.pack_start (self.navToolbar, False, False)
        self.navToolbar.show ()

        self.checkboxes = gtk.HBox (len (plotters))
        self.pack_start (self.checkboxes, False, False)
        self.checkboxes.show ()

        self.pol = (1+0j, 0j)
        self.pol2 = None

        self.handlers = []

        i = 0
        self.plots = []
        for plotterClass, default in plotters:
            axes = self.figure.add_subplot (len (plotters), 1, i)
            def makeUpdateInfo (i): return lambda s: self.__updateInfo (i, s)
            def makeUpdatePos (i): return lambda s: self.__updatePos (i, s)
            plotter = plotterClass (axes, fields, makeUpdateInfo (i), makeUpdatePos (i))
            d = PlottedData (axes, plotter, default, self.__updateChildren)
            self.checkboxes.pack_start (d.checkBox, False, False)
            self.plots.append (d)
            i += 1
        self.__infos = [None] * len (self.plots)
        self.__posi = None

        self.legendBox = gtk.CheckButton ("Show legend")
        self.legendBox.set_active (True)
        self.legendBox.connect ("toggled", self.__on_legend_toggled)
        self.checkboxes.pack_start (self.legendBox, False, False)
        self.legendBox.show ()

        self.__updateChildren ()

    def __on_legend_toggled (self, button):
        for pd in self.plots:
            pd.plotter.setLegend (button.get_active ())

    def __updateChildren (self):
        count = 0
        for axes in self.figure.get_axes ():
            visible = axes.get_visible ()
            if axes.get_visible ():
                count += 1
        if count == 0:
            count = 1
        nr = 1
        for axes in self.figure.get_axes ():
            axes.change_geometry (count, 1, nr)
            if axes.get_visible ():
                if nr < count:
                    nr += 1
            else:
                axes.set_position ((0, 0, 1e-10, 1e-10)) # Hack to prevent the invisible axes from getting mouse events
        self.figure.canvas.draw ()
        self.__updateGraph ()

    def __updateGraph (self):
        for pd in self.plots:
            if pd.axes.get_visible ():
                #start = time ()
                pd.plotter.plot (self.pol, self.pol2)
                #print "Plot ", pd.plotter.description, " needed ", time () - start

    def __updateInfo (self, i, arg):
        #print i, arg
        self.__infos[i] = arg
        s = ''
        for info in self.__infos:
            if info is not None:
                if s != '':
                    s += ' '
                s += info
        for handler in self.handlers:
            handler (s)

    def __updatePos (self, i, arg):
        if arg == None and self.__posi != i:
            return
        self.__posi = i
        j = 0
        for pd in self.plots:
            if i != j:
                pd.plotter.updateCPos (arg)
            j += 1

    def onUpdateInfo (self, handler):
        self.handlers.append (handler)

    def setPol (self, value):
        oldValue = self.pol
        self.pol = value
        if value != oldValue:
            self.__updateGraph ()

    def setPol2 (self, value):
        oldValue = self.pol2
        self.pol2 = value
        if value != oldValue:
            self.__updateGraph ()
Beispiel #8
0
	def init(self):
		self.fig = Figure(figsize=(5,4), dpi=100)
		self.ax1=None
		self.show_key=True
		self.hbox=gtk.HBox()
		self.exe_command  =  get_exe_command()
		self.edit_list=[]
		self.line_number=[]
		gui_pos=0

		
		gui_pos=gui_pos+1

		self.draw_graph()
		canvas = FigureCanvas(self.fig)  # a gtk.DrawingArea
		#canvas.set_background('white')
		#canvas.set_facecolor('white')
		canvas.figure.patch.set_facecolor('white')
		canvas.set_size_request(500, 150)
		canvas.show()

		tooltips = gtk.Tooltips()

		toolbar = gtk.Toolbar()
		#toolbar.set_orientation(gtk.ORIENTATION_VERTICAL)
		toolbar.set_style(gtk.TOOLBAR_ICONS)
		toolbar.set_size_request(-1, 50)

		tool_bar_pos=0
		save = gtk.ToolButton(gtk.STOCK_SAVE)
		tooltips.set_tip(save, "Save image")
		save.connect("clicked", self.callback_save)
		toolbar.insert(save, tool_bar_pos)
		tool_bar_pos=tool_bar_pos+1

		hide_key = gtk.ToolButton(gtk.STOCK_INFO)
		tooltips.set_tip(hide_key, "Hide key")
		hide_key.connect("clicked", self.callback_hide_key)
		toolbar.insert(hide_key, tool_bar_pos)
		tool_bar_pos=tool_bar_pos+1

		image = gtk.Image()
   		image.set_from_file(os.path.join(get_image_file_path(),"play.png"))
		save = gtk.ToolButton(image)
		tooltips.set_tip(save, "Run simulation")
		save.connect("clicked", self.callback_refresh)
		toolbar.insert(save, tool_bar_pos)
		tool_bar_pos=tool_bar_pos+1

		plot_toolbar = NavigationToolbar(canvas, self)
		plot_toolbar.show()
		box=gtk.HBox(True, 1)
		box.set_size_request(500,-1)
		box.show()
		box.pack_start(plot_toolbar, True, True, 0)
		tb_comboitem = gtk.ToolItem();
		tb_comboitem.add(box);
		tb_comboitem.show()
		toolbar.insert(tb_comboitem, tool_bar_pos)
		tool_bar_pos=tool_bar_pos+1

		sep = gtk.SeparatorToolItem()
		sep.set_draw(False)
		sep.set_expand(True)
		toolbar.insert(sep, tool_bar_pos)
		sep.show()
		tool_bar_pos=tool_bar_pos+1

		toolbar.show_all()
		window_main_vbox=gtk.VBox()
		window_main_vbox.pack_start(toolbar, False, True, 0)
		#self.attach(toolbar, 0, 1, 0, 1)
		tool_bar_pos=tool_bar_pos+1




		self.hbox.pack_start(canvas, True, True, 0)
		#self.attach(canvas, 1, 3, 0, 1)



		vbox = gtk.VBox(False, 2)
	

		#spacer
		label=gtk.Label(" \n\n    ")
		#self.attach(label, 4, 5, 1, 2,gtk.SHRINK ,gtk.SHRINK)
		vbox.pack_start(label, False, False, 0)

		label.show()


		hbox = gtk.HBox(False, 2)
	    
		hbox.show()
		self.hbox.pack_start(vbox, False, False, 0)
		#self.attach(vbox, 3, 4, 0, 1,gtk.SHRINK ,gtk.SHRINK)
		vbox.show()
		window_main_vbox.add(self.hbox)
		self.add(window_main_vbox)
		self.set_title("Quantum Efficency calculator - (www.opvdm.com)")
		self.set_icon_from_file(os.path.join(get_image_file_path(),"qe.png"))
		self.connect("delete-event", self.callback_close)
		self.set_position(gtk.WIN_POS_CENTER)
Beispiel #9
0
	def init(self):
		print "INIT!!"

		self.emesh_editor=electrical_mesh_editor()
		self.emesh_editor.init()

		self.fig = Figure(figsize=(5,4), dpi=100)
		self.ax1=None
		self.show_key=True
		self.hbox=gtk.HBox()
		self.edit_list=[]
		self.line_number=[]
		gui_pos=0

		gui_pos=gui_pos+1

		self.draw_graph()
		canvas = FigureCanvas(self.fig)  # a gtk.DrawingArea
		#canvas.set_background('white')
		#canvas.set_facecolor('white')
		canvas.figure.patch.set_facecolor('white')
		canvas.set_size_request(500, 150)
		canvas.show()

		tooltips = gtk.Tooltips()

		toolbar = gtk.Toolbar()
		#toolbar.set_orientation(gtk.ORIENTATION_VERTICAL)
		toolbar.set_style(gtk.TOOLBAR_ICONS)
		toolbar.set_size_request(-1, 50)

		tool_bar_pos=0
		save = gtk.ToolButton(gtk.STOCK_SAVE)
		tooltips.set_tip(save, "Save image")
		save.connect("clicked", self.callback_save)
		toolbar.insert(save, tool_bar_pos)
		tool_bar_pos=tool_bar_pos+1

		hide_key = gtk.ToolButton(gtk.STOCK_INFO)
		tooltips.set_tip(hide_key, "Hide key")
		hide_key.connect("clicked", self.callback_hide_key)
		toolbar.insert(hide_key, tool_bar_pos)
		tool_bar_pos=tool_bar_pos+1

		image = gtk.Image()
   		image.set_from_file(os.path.join(get_image_file_path(),"play.png"))
		save = gtk.ToolButton(image)
		tooltips.set_tip(save, "Run simulation")
		save.connect("clicked", self.run_simulation)
		toolbar.insert(save, tool_bar_pos)
		tool_bar_pos=tool_bar_pos+1

		plot_toolbar = NavigationToolbar(canvas, self)
		plot_toolbar.show()
		box=gtk.HBox(True, 1)
		box.set_size_request(500,-1)
		box.show()
		box.pack_start(plot_toolbar, True, True, 0)
		tb_comboitem = gtk.ToolItem();
		tb_comboitem.add(box);
		tb_comboitem.show()
		toolbar.insert(tb_comboitem, tool_bar_pos)
		tool_bar_pos=tool_bar_pos+1

		sep = gtk.SeparatorToolItem()
		sep.set_draw(False)
		sep.set_expand(True)
		toolbar.insert(sep, tool_bar_pos)
		sep.show()
		tool_bar_pos=tool_bar_pos+1

		help = gtk.ToolButton(gtk.STOCK_HELP)
		toolbar.insert(help, tool_bar_pos)
		help.connect("clicked", self.callback_help)
		help.show()
		tool_bar_pos=tool_bar_pos+1

		toolbar.show_all()
		window_main_vbox=gtk.VBox()
		window_main_vbox.pack_start(toolbar, False, True, 0)
		#self.attach(toolbar, 0, 1, 0, 1)
		tool_bar_pos=tool_bar_pos+1


		self.hbox.pack_start(canvas, True, True, 0)
	
		self.emesh_editor.show()
		self.hbox.pack_start(self.emesh_editor, True, True, 0)
		self.emesh_editor.mesh_dump_ctl.connect("update", self.callback_update)
		window_main_vbox.add(self.hbox)

		self.add(window_main_vbox)
		self.set_title("Electrical Mesh Editor - (www.opvdm.com)")
		self.set_icon_from_file(os.path.join(get_image_file_path(),"mesh.png"))
		self.connect("delete-event", self.callback_close)
		self.set_position(gtk.WIN_POS_CENTER)
Beispiel #10
0
	def init(self,index):
		self.index=index
		self.fig = Figure(figsize=(5,4), dpi=100)
		self.ax1=None
		self.show_key=True
		self.hbox=gtk.HBox()
		self.edit_list=[]
		self.line_number=[]
		gui_pos=0

		self.list=[]

		self.load_data()
		self.update_scan_tokens()

		gui_pos=gui_pos+1

		canvas = FigureCanvas(self.fig)  # a gtk.DrawingArea
		#canvas.set_background('white')
		#canvas.set_facecolor('white')
		canvas.figure.patch.set_facecolor('white')
		canvas.set_size_request(500, 150)
		canvas.show()

		tooltips = gtk.Tooltips()

		toolbar = gtk.Toolbar()
		#toolbar.set_orientation(gtk.ORIENTATION_VERTICAL)
		toolbar.set_style(gtk.TOOLBAR_ICONS)
		toolbar.set_size_request(-1, 50)

		self.store = self.create_model()
		treeview = gtk.TreeView(self.store)
		treeview.show()
		tool_bar_pos=0

		save = gtk.ToolButton(gtk.STOCK_SAVE)
		tooltips.set_tip(save, _("Save image"))
		save.connect("clicked", self.callback_save)
		toolbar.insert(save, tool_bar_pos)
		tool_bar_pos=tool_bar_pos+1

		add_section = gtk.ToolButton(gtk.STOCK_ADD)
		tooltips.set_tip(add_section, _("Add section"))
		add_section.connect("clicked", self.callback_add_section,treeview)
		toolbar.insert(add_section, tool_bar_pos)
		tool_bar_pos=tool_bar_pos+1

		add_section = gtk.ToolButton(gtk.STOCK_CLEAR)
		tooltips.set_tip(add_section, _("Delete section"))
		add_section.connect("clicked", self.callback_remove_item,treeview)
		toolbar.insert(add_section, tool_bar_pos)
		tool_bar_pos=tool_bar_pos+1

		move_down = gtk.ToolButton(gtk.STOCK_GO_DOWN)
		tooltips.set_tip(move_down, _("Move down"))
		move_down.connect("clicked", self.callback_move_down,treeview)
		toolbar.insert(move_down, tool_bar_pos)
		tool_bar_pos=tool_bar_pos+1

		image = gtk.Image()
   		image.set_from_file(os.path.join(get_image_file_path(),"start.png"))
		start = gtk.ToolButton(image)
		tooltips.set_tip(start, _("Simulation start frequency"))
		start.connect("clicked", self.callback_start_fx,treeview)
		toolbar.insert(start, tool_bar_pos)
		tool_bar_pos=tool_bar_pos+1

		plot_toolbar = NavigationToolbar(self.fig.canvas, self)
		plot_toolbar.show()
		box=gtk.HBox(True, 1)
		box.set_size_request(300,-1)
		box.show()
		box.pack_start(plot_toolbar, True, True, 0)
		tb_comboitem = gtk.ToolItem();
		tb_comboitem.add(box);
		tb_comboitem.show()
		toolbar.insert(tb_comboitem, tool_bar_pos)
		tool_bar_pos=tool_bar_pos+1

		sep = gtk.SeparatorToolItem()
		sep.set_draw(False)
		sep.set_expand(True)
		toolbar.insert(sep, tool_bar_pos)
		sep.show()
		tool_bar_pos=tool_bar_pos+1

		help = gtk.ToolButton(gtk.STOCK_HELP)
		toolbar.insert(help, tool_bar_pos)
		help.connect("clicked", self.callback_help)
		help.show()
		tool_bar_pos=tool_bar_pos+1

		toolbar.show_all()
		self.pack_start(toolbar, False, True, 0)
		self.pack_start(toolbar, True, True, 0)
		tool_bar_pos=tool_bar_pos+1



		canvas.set_size_request(700,400)
		self.pack_start(canvas, True, True, 0)


		treeview.set_rules_hint(True)

		self.create_columns(treeview)

		self.pack_start(treeview, False, False, 0)

		self.statusbar = gtk.Statusbar()
		self.statusbar.show()
		self.pack_start(self.statusbar, False, False, 0)

		self.build_mesh()
		self.draw_graph()

		self.show()
class MainWindow:
 
    OPTICAL_FLOW_BLOCK_WIDTH = 8
    OPTICAL_FLOW_BLOCK_HEIGHT = 8
    OPTICAL_FLOW_RANGE_WIDTH = 8    # Range to look outside of a block for motion
    OPTICAL_FLOW_RANGE_HEIGHT = 8
    
    PROCESSED_FRAME_DIFF = 1
    
    # Classes of pixel in GrabCut algorithm
    GC_BGD = 0      # background
    GC_FGD = 1      # foreground
    GC_PR_BGD = 2   # most probably background
    GC_PR_FGD = 3   # most probably foreground 

    # GrabCut algorithm flags
    GC_INIT_WITH_RECT = 0
    GC_INIT_WITH_MASK = 1
    GC_EVAL = 2
 
    #---------------------------------------------------------------------------
    def __init__( self, options, bagFilename = None ):
    
        self.options = options
        self.scriptPath = os.path.dirname( __file__ )
        self.image = None
        self.frameIdx = 0
        self.workerThread = None
        self.numFramesProcessed = 0
        self.graphCanvas = None
        self.graphNavToolbar = None
        
        self.PROCESSED_FRAME_DIFF = int( self.options.frameSkip )
            
        # Setup the GUI        
        builder = gtk.Builder()
        builder.add_from_file( self.scriptPath + "/GUI/ObjectDetectorExplorer.glade" )
        
        self.window = builder.get_object( "winMain" )   
        self.comboOutput_1_Mode = builder.get_object( "comboOutput_1_Mode" )
        self.comboOutput_2_Mode = builder.get_object( "comboOutput_2_Mode" )
        self.vboxGraphs = builder.get_object( "vboxGraphs" )

        dwgInput = builder.get_object( "dwgInput" )
        dwgOutput_1 = builder.get_object( "dwgOutput_1" )
        dwgOutput_2 = builder.get_object( "dwgOutput_2" )
        self.dwgInputDisplay = Display( dwgInput )
        self.dwgOutput_1_Display = Display( dwgOutput_1 )
        self.dwgOutput_2_Display = Display( dwgOutput_2 )
        
        self.sequenceControls = builder.get_object( "sequenceControls" )
        self.sequenceControls.setNumFrames( 1 )
        self.sequenceControls.setOnFrameIdxChangedCallback( self.onSequenceControlsFrameIdxChanged )
        
        builder.connect_signals( self )
               
        #updateLoop = self.update()
        #gobject.idle_add( updateLoop.next )
        
        self.window.show()
        
        if bagFilename != None:
            self.tryToLoadBagFile( bagFilename )
        
    #---------------------------------------------------------------------------
    def onWinMainDestroy( self, widget, data = None ):  
        gtk.main_quit()
        
    #---------------------------------------------------------------------------   
    def main( self ):
        # All PyGTK applications must have a gtk.main(). Control ends here
        # and waits for an event to occur (like a key press or mouse event).
        gtk.gdk.threads_init()
        gtk.main()
        
    #---------------------------------------------------------------------------
    def isCurFrameReady( self ):
        return self.frameIdx < self.numFramesProcessed
        
    #---------------------------------------------------------------------------
    def setFrameIdx( self, frameIdx ):
        
        frameReady = frameIdx < self.numFramesProcessed
        
        if frameReady or frameIdx == 0:
            self.frameIdx = frameIdx
            self.updateDisplay()
        else:
            # Try to reset to current frame index
            if self.isCurFrameReady():
                self.sequenceControls.setFrameIdx( self.frameIdx )
            else:
                self.sequenceControls.setFrameIdx( 0 )
        
    #---------------------------------------------------------------------------
    def updateDisplay( self ):
        
        if self.isCurFrameReady():
            self.dwgInputDisplay.setImageFromOpenCVMatrix( self.inputImageList[ self.frameIdx ] )
            
            output_1_Mode = self.comboOutput_1_Mode.get_active_text()
            if output_1_Mode == OutputMode.OPTICAL_FLOW:
                self.dwgOutput_1_Display.setImageFromOpenCVMatrix( self.inputImageList[ self.frameIdx ] )
            elif output_1_Mode == OutputMode.DETECTED_MOTION:
                self.dwgOutput_1_Display.setImageFromNumpyArray( self.motionImageList[ self.frameIdx ] )
            elif output_1_Mode == OutputMode.SEGMENTATION:
                self.dwgOutput_1_Display.setImageFromNumpyArray( self.segmentationList[ self.frameIdx ] )
            elif output_1_Mode == OutputMode.SALIENCY:
                self.dwgOutput_1_Display.setImageFromNumpyArray( self.saliencyMapList[ self.frameIdx ] )
            elif output_1_Mode == OutputMode.SEGMENTATION_MASK:
                self.dwgOutput_1_Display.setImageFromNumpyArray( self.segmentationMaskList[ self.frameIdx ] )
                
            output_2_Mode = self.comboOutput_2_Mode.get_active_text()
            if output_2_Mode == OutputMode.OPTICAL_FLOW:
                self.dwgOutput_2_Display.setImageFromOpenCVMatrix( self.inputImageList[ self.frameIdx ] )
            elif output_2_Mode == OutputMode.DETECTED_MOTION:
                self.dwgOutput_2_Display.setImageFromNumpyArray( self.motionImageList[ self.frameIdx ] )
            elif output_2_Mode == OutputMode.SEGMENTATION:
                
                diffImage = np.array( self.motionImageList[ self.frameIdx ], dtype=np.int32 ) \
                     - np.array( self.imageFlowList[ self.frameIdx ][ 3 ], dtype=np.int32 )
                diffImage = np.array( np.maximum( diffImage, 0 ), dtype=np.uint8 )
                
                #self.dwgOutput_2_Display.setImageFromNumpyArray( diffImage )
                self.dwgOutput_2_Display.setImageFromNumpyArray( self.segmentationList[ self.frameIdx ] )
            elif output_2_Mode == OutputMode.SALIENCY:
                self.dwgOutput_2_Display.setImageFromNumpyArray( self.saliencyMapList[ self.frameIdx ] )
            elif output_2_Mode == OutputMode.SEGMENTATION_MASK:
                self.dwgOutput_2_Display.setImageFromNumpyArray( self.segmentationMaskList[ self.frameIdx ] )
        
    #---------------------------------------------------------------------------
    def chooseBagFile( self ):
        
        result = None
        
        dialog = gtk.FileChooserDialog(
            title="Choose Bag File",
            action=gtk.FILE_CHOOSER_ACTION_SAVE,
            buttons=(gtk.STOCK_CANCEL, gtk.RESPONSE_REJECT,
                      gtk.STOCK_OK, gtk.RESPONSE_ACCEPT) )

        dialog.set_current_folder( self.scriptPath + "/../../test_data/bags" )
            
        filter = gtk.FileFilter()
        filter.add_pattern( "*.bag" )
        filter.set_name( "Bag Files" )
        dialog.add_filter( filter )
        dialog.set_filter( filter )
            
        result = dialog.run()

        if result == gtk.RESPONSE_ACCEPT:
            result = dialog.get_filename()

        dialog.destroy()
        
        return result
    
    #---------------------------------------------------------------------------
    def onMenuItemOpenBagActivate( self, widget ):
                
        bagFilename = self.chooseBagFile()
        
        if bagFilename != None:
            self.tryToLoadBagFile( bagFilename )          
    
    #---------------------------------------------------------------------------
    def onMenuItemQuitActivate( self, widget ):
        self.onWinMainDestroy( widget )
       
    #---------------------------------------------------------------------------
    def onSequenceControlsFrameIdxChanged( self, widget ):
        self.setFrameIdx( widget.frameIdx )
    
    #---------------------------------------------------------------------------
    def onComboOutput_1_ModeChanged( self, widget ):
        self.updateDisplay()
        
    #---------------------------------------------------------------------------
    def onComboOutput_2_ModeChanged( self, widget ):
        self.updateDisplay()
       
    #---------------------------------------------------------------------------
    def onDwgInputExposeEvent( self, widget, data ):
        
        self.dwgInputDisplay.drawPixBufToDrawingArea( data.area )
        
    #---------------------------------------------------------------------------
    def onDwgOutput_1_ExposeEvent( self, widget, data = None ):
        
        imgRect = self.dwgOutput_1_Display.drawPixBufToDrawingArea( data.area ) 
        
        if imgRect != None:
            imgRect = imgRect.intersect( data.area )
            
            outputMode = self.comboOutput_1_Mode.get_active_text()
            self.drawOutputOverlay( widget, imgRect, outputMode )
        
    #---------------------------------------------------------------------------
    def onDwgOutput_2_ExposeEvent( self, widget, data = None ):
        
        imgRect = self.dwgOutput_2_Display.drawPixBufToDrawingArea( data.area )
        
        if imgRect != None:
            imgRect = imgRect.intersect( data.area )
            
            outputMode = self.comboOutput_2_Mode.get_active_text()
            self.drawOutputOverlay( widget, imgRect, outputMode )
        
    #---------------------------------------------------------------------------
    def drawOutputOverlay( self, widget, imgRect, outputMode ):
        
        if outputMode == OutputMode.OPTICAL_FLOW:
                
            # Draw the optical flow if it's available
            opticalFlowX = self.opticalFlowListX[ self.frameIdx ]
            opticalFlowY = self.opticalFlowListY[ self.frameIdx ]
            if opticalFlowX != None and opticalFlowY != None:
            
                graphicsContext = widget.window.new_gc()
                graphicsContext.set_rgb_fg_color( gtk.gdk.Color( 0, 65535, 0 ) )
                
                blockCentreY = imgRect.y + self.OPTICAL_FLOW_BLOCK_HEIGHT / 2
                for y in range( opticalFlowX.shape[ 0 ] ):
                
                    blockCentreX = imgRect.x + self.OPTICAL_FLOW_BLOCK_WIDTH / 2
                    for x in range( opticalFlowX.shape[ 1 ] ):
                            
                        endX = blockCentreX + opticalFlowX[ y, x ]
                        endY = blockCentreY + opticalFlowY[ y, x ]
                        
                        if endY < blockCentreY:
                            # Up is red
                            graphicsContext.set_rgb_fg_color( gtk.gdk.Color( 65535, 0, 0 ) )
                        elif endY > blockCentreY:
                            # Down is blue
                            graphicsContext.set_rgb_fg_color( gtk.gdk.Color( 0, 0, 65535 ) )
                        else:
                            # Static is green
                            graphicsContext.set_rgb_fg_color( gtk.gdk.Color( 0, 65535, 0 ) )
                            
                        widget.window.draw_line( graphicsContext, 
                            int( blockCentreX ), int( blockCentreY ),
                            int( endX ), int( endY ) )
                        
                        blockCentreX += self.OPTICAL_FLOW_BLOCK_WIDTH
                        
                    blockCentreY += self.OPTICAL_FLOW_BLOCK_HEIGHT
                    
        elif outputMode == OutputMode.SALIENCY:
             
            graphicsContext = widget.window.new_gc()
            graphicsContext.set_rgb_fg_color( gtk.gdk.Color( 0, 65535, 0 ) )
                    
            for cluster in self.saliencyClusterList[ self.frameIdx ]:
                
                mean = cluster[ 0 ]
                stdDev = cluster[ 1 ]
                arcX = int( imgRect.x + mean[ 0 ] )
                arcY = int( imgRect.y + mean[ 1 ] )
                
                # Draw a circle to represent the cluster
                arcWidth = arcHeight = int( stdDev * 2 )
            
                drawFilledArc = False
                            
                widget.window.draw_arc( graphicsContext, 
                    drawFilledArc, arcX, arcY, arcWidth, arcHeight, 0, 360 * 64 )

    #---------------------------------------------------------------------------
    def tryToLoadBagFile( self, bagFilename ):
        
        # Locate and open file
        try:
            bag = rosbag.Bag( bagFilename )
        except:
            print "Error: Unable to load", bagFilename
            return
        
        # Count the number of frames in the bag file
        numFrames = 0
        for topic, msg, t in bag.read_messages():
            if msg._type == "sensor_msgs/Image":
                numFrames += 1
        
        if numFrames == 0:
            print "Error: No frames in bag file"
            return
            
        numFrames = int( math.ceil( float( numFrames )/int( self.PROCESSED_FRAME_DIFF ) ) )
        
        # Throw away existing data and prepare to process the bag file
        if self.workerThread != None and self.workerThread.is_alive():
            self.workCancelled = True
            self.workerThread.join()
            
        self.inputImageList = [ None for i in range( numFrames ) ]
        self.grayScaleImageList = [ None for i in range( numFrames ) ]
        self.motionImageList = [ None for i in range( numFrames ) ]
        self.opticalFlowListX = [ None for i in range( numFrames ) ]
        self.opticalFlowListY = [ None for i in range( numFrames ) ]
        self.segmentationList = [ None for i in range( numFrames ) ]
        self.segmentationMaskList = [ None for i in range( numFrames ) ]
        self.imageFlowList = [ ( 0, 0, 0, None ) for i in range( numFrames ) ]
        self.maxMotionCounts = [ 0 for i in range( numFrames ) ]
        self.leftMostMotionList = [ 0 for i in range( numFrames ) ]
        self.saliencyMapList = [ None for i in range( numFrames ) ]
        self.saliencyClusterList = [ [] for i in range( numFrames ) ]
        self.numFramesProcessed = 0
        
        # Kick off worker thread to process the bag file
        self.workCancelled = False
        self.workerThread = threading.Thread( target=self.processBag, args=( bag, ) )
        self.workerThread.daemon = True
        self.workerThread.start()
        
        self.sequenceControls.setNumFrames( numFrames )

    #---------------------------------------------------------------------------
    def produceSegmentation( self, startFrame, impactMotionImage, 
                                preMotionImages, postMotionImages ):
        
        ROI_X = 0
        ROI_Y = 76
        ROI_WIDTH = 230
        ROI_HEIGHT = 100
        
        blankFrame = np.zeros( ( startFrame.height, startFrame.width ), dtype=np.uint8 )
        
        imageFlowFilter = ImageFlowFilter()    
        
        # Create the accumulator image
        accumulatorArray = np.copy( impactMotionImage ).astype( np.int32 )
            
        # Take maximum values from motion images after the impact but
        # don't add them in to de-emphasise the manipulator
        imageNum = 1
        for postMotionImage in postMotionImages:
            
            print "Aligning post impact image {0}...".format( imageNum )
            imageNum += 1
            
            ( transX, transY, rotationAngle, alignedImage ) = \
                imageFlowFilter.calcImageFlow( impactMotionImage, postMotionImage )
            accumulatorArray = np.maximum( accumulatorArray, alignedImage )
                    
        # Dilate and subtract motion images from before the impact
        imageNum = 1
        for preMotionImage in preMotionImages:
            
            print "Aligning pre impact image {0}...".format( imageNum )
            imageNum += 1
            
            ( transX, transY, rotationAngle, alignedImage ) = \
                imageFlowFilter.calcImageFlow( impactMotionImage, preMotionImage )
                
            cv.Dilate( alignedImage, alignedImage )
            cv.Dilate( alignedImage, alignedImage )
            cv.Dilate( alignedImage, alignedImage )
            accumulatorArray = accumulatorArray - alignedImage
            
        accumulatorImage = np.clip( accumulatorArray, 0, 255 ).astype( np.uint8 )
        
        # Create the segmentation mask from the accumulator image
        startMask = np.copy( accumulatorImage )
        cv.Dilate( startMask, startMask )
        cv.Erode( startMask, startMask )
        cv.Dilate( startMask, startMask )
        cv.Erode( startMask, startMask )
        startMask = scipy.ndimage.filters.gaussian_filter( 
            startMask, 5.0, mode='constant' )
        
        startMask[ startMask > 0 ] = 255
            
        # Find the larget blob in the ROI
        # Label blobs
        startMask, numBlobs = PyBlobLib.labelBlobs( startMask )
        
        # Find blobs in the region of interest
        testMap = np.copy( startMask )
        testMap[ :ROI_Y, : ] = 0       # Mask out area above the ROI
        testMap[ :, :ROI_X ] = 0       # Mask out area to the left of the ROI
        testMap[ ROI_Y+ROI_HEIGHT: ] = 0   # Mask out area below the ROI
        testMap[ :, ROI_X+ROI_WIDTH: ] = 0   # Mask out area to the right of the ROI
    
        biggestBlobIdx = None
        biggestBlobSize = 0
    
        for blobIdx in range( 1, numBlobs + 1 ):
            if testMap[ testMap == blobIdx ].size > 0:
                blobSize = startMask[ startMask == blobIdx ].size
                if blobSize > biggestBlobSize:
                    biggestBlobSize = blobSize
                    biggestBlobIdx = blobIdx
    
        # Isolate the largest blob
        if biggestBlobIdx != None:
            biggestBlobPixels = (startMask == biggestBlobIdx)
            startMask[ biggestBlobPixels ] = 255
            startMask[ biggestBlobPixels == False ] = 0
        else:
            print "No central blob"
            return blankFrame
            
        # Now expand it to get exclusion mask
        exclusionMask = np.copy( startMask )
        for i in range( 10 ):
            cv.Dilate( exclusionMask, exclusionMask )
        cv.Erode( exclusionMask, exclusionMask )
        cv.Erode( exclusionMask, exclusionMask )
        
        #----------------------------------------------------
        
        maskArray = np.copy( startMask )
        possiblyForeground = ( maskArray > 0 ) & ( accumulatorImage > 0 )
        maskArray[ possiblyForeground ] = self.GC_PR_FGD
        maskArray[ possiblyForeground == False ] = self.GC_PR_BGD
        maskArray[ exclusionMask == 0 ] = self.GC_BGD
        
        definiteMask = np.copy( accumulatorImage )
        definiteMask[ possiblyForeground ] = 255
        definiteMask[ possiblyForeground == False ] = 0
        cv.Erode( definiteMask, definiteMask )
        cv.Erode( definiteMask, definiteMask )
        maskArray[ definiteMask == 255 ] = self.GC_FGD
        
        # Now create the working mask and segment the image
        
        workingMask = np.copy( maskArray )
            
        fgModel = cv.CreateMat( 1, 5*13, cv.CV_64FC1 )
        cv.Set( fgModel, 0 )
        bgModel = cv.CreateMat( 1, 5*13, cv.CV_64FC1 )
        cv.Set( bgModel, 0 )
        
        workingImage = np.copy( startFrame )
        cv.GrabCut( workingImage, workingMask, 
            (0,0,0,0), fgModel, bgModel, 6, self.GC_INIT_WITH_MASK )
            
        cv.Set( fgModel, 0 )
        cv.Set( bgModel, 0 )
        bgdPixels = (workingMask != self.GC_PR_FGD) & (workingMask != self.GC_FGD)
        workingMask[ bgdPixels ] = 0
        workingMask[ bgdPixels == False ] = 255
        cv.Erode( workingMask, workingMask )
        bgdPixels = workingMask == 0
        workingMask[ bgdPixels ] = self.GC_PR_BGD
        workingMask[ bgdPixels == False ] = self.GC_PR_FGD
        workingMask[ exclusionMask == 0 ] = self.GC_BGD
        
        cv.GrabCut( workingImage, workingMask, 
            (0,0,0,0), fgModel, bgModel, 6, self.GC_INIT_WITH_MASK )
        
        segmentation = np.copy( startFrame )
        segmentation[ (workingMask != self.GC_PR_FGD) & (workingMask != self.GC_FGD) ] = 0
        
        # Remove everything apart from the biggest blob in the ROI
        graySeg = np.zeros( ( startFrame.height, startFrame.width ), dtype=np.uint8 )
        cv.CvtColor( segmentation, graySeg, cv.CV_RGB2GRAY )
        startMask = np.copy( graySeg )
        startMask[ startMask > 0 ] = 255
            
        # Find the larget blob in the ROI
        
        # Label blobs
        startMask, numBlobs = PyBlobLib.labelBlobs( startMask )
        
        # Find blobs in the region of interest
        testMap = np.copy( startMask )
        testMap[ :ROI_Y, : ] = 0       # Mask out area above the ROI
        testMap[ :, :ROI_X ] = 0       # Mask out area to the left of the ROI
        testMap[ ROI_Y+ROI_HEIGHT: ] = 0   # Mask out area below the ROI
        testMap[ :, ROI_X+ROI_WIDTH: ] = 0   # Mask out area to the right of the ROI
    
        biggestBlobIdx = None
        biggestBlobSize = 0
    
        for blobIdx in range( 1, numBlobs + 1 ):
            if testMap[ testMap == blobIdx ].size > 0:
                blobSize = startMask[ startMask == blobIdx ].size
                if blobSize > biggestBlobSize:
                    biggestBlobSize = blobSize
                    biggestBlobIdx = blobIdx
    
        # Isolate the largest blob
        if biggestBlobIdx != None:
            biggestBlobPixels = (startMask == biggestBlobIdx)
            segmentation[ biggestBlobPixels == False, 0 ] = 255
            segmentation[ biggestBlobPixels == False, 1 ] = 0
            segmentation[ biggestBlobPixels == False, 2 ] = 255
        else:
            print "No central blob after main segmentation"
            return blankFrame
        
        return segmentation

    #---------------------------------------------------------------------------
    def processBag( self, bag ):
    
        FLIP_IMAGE = bool( self.options.frameFlip == "True" )
        USING_OPTICAL_FLOW_FOR_MOTION = False
        print "frameFlip = ", FLIP_IMAGE
    
        bagFrameIdx = 0
        frameIdx = 0
        impactFrameIdx = None
        
        # Setup filters
        opticalFlowFilter = OpticalFlowFilter(
            self.OPTICAL_FLOW_BLOCK_WIDTH, self.OPTICAL_FLOW_BLOCK_HEIGHT, 
            self.OPTICAL_FLOW_RANGE_WIDTH, self.OPTICAL_FLOW_RANGE_HEIGHT )
            
        motionDetectionFilter = MotionDetectionFilter()
        imageFlowFilter = ImageFlowFilter()
        residualSaliencyFilter = ResidualSaliencyFilter()
            
        # Process bag file
        for topic, msg, t in bag.read_messages():
            
            if self.workCancelled:
                # We've been given the signal to quit
                break
            
            if msg._type == "sensor_msgs/Image":
                
                bagFrameIdx += 1
                if (bagFrameIdx-1)%self.PROCESSED_FRAME_DIFF != 0:
                    continue
                
                print "Processing image", frameIdx
                
                # Get input image
                image = cv.CreateMatHeader( msg.height, msg.width, cv.CV_8UC3 )
                cv.SetData( image, msg.data, msg.step )
                
                if FLIP_IMAGE:
                    cv.Flip( image, None, 1 )
                
                # Convert to grayscale
                grayImage = cv.CreateMat( msg.height, msg.width, cv.CV_8UC1 )
                cv.CvtColor( image, grayImage, cv.CV_BGR2GRAY )
                grayImageNumpPy = np.array( grayImage )
                
                # Calculate optical flow
                opticalFlowArrayX, opticalFlowArrayY = \
                    opticalFlowFilter.calcOpticalFlow( grayImage )
                    
                # Detect motion
                if USING_OPTICAL_FLOW_FOR_MOTION:
                    if frameIdx == 0:
                        motionImage = PyVarFlowLib.createMotionMask( 
                            grayImageNumpPy, grayImageNumpPy )
                    else:
                        motionImage = PyVarFlowLib.createMotionMask( 
                            np.array( self.grayScaleImageList[ frameIdx - 1 ] ), 
                            grayImageNumpPy )
                else:
                    motionImage = motionDetectionFilter.calcMotion( grayImage )
                
                
                # Work out the left most point in the image where motion appears
                motionTest = np.copy( motionImage )
                
                cv.Erode( motionTest, motionTest )
                if frameIdx == 0:
                    leftMostMotion = motionImage.shape[ 1 ]
                else:
                    leftMostMotion = self.leftMostMotionList[ frameIdx - 1 ]
                
                leftMostMotionDiff = 0
                for i in range( leftMostMotion ):
                    if motionTest[ :, i ].max() > 0:
                        leftMostMotionDiff = abs( leftMostMotion - i )
                        leftMostMotion = i
                        break
                
                segmentationMask = np.zeros( ( msg.height, msg.width ), dtype=np.uint8 )
                
                FRAMES_BACK = 3
                
                if impactFrameIdx == None:        
                    if leftMostMotionDiff > 18 and leftMostMotion < 0.75*msg.width:
                        
                        # Found impact frame
                        impactFrameIdx = frameIdx
                    
                else:
                    PROCESS_IMPACT = False
                    if PROCESS_IMPACT and frameIdx - impactFrameIdx == FRAMES_BACK:
                        
                        # Should now have enough info to segment object
                        impactMotionImage = self.motionImageList[ impactFrameIdx ]
                        
                        print "Aligning"
                        postImpactRealFarFlow = imageFlowFilter.calcImageFlow( impactMotionImage, motionImage )
                        print "Aligning"
                        postImpactFarFlow = imageFlowFilter.calcImageFlow( impactMotionImage, self.motionImageList[ impactFrameIdx + 2 ] )
                        print "Aligning"
                        postImpactNearFlow = imageFlowFilter.calcImageFlow( impactMotionImage, self.motionImageList[ impactFrameIdx + 1 ] )
                        
                        segmentationMask = np.maximum( np.maximum( np.maximum( 
                            impactMotionImage, postImpactNearFlow[ 3 ] ), postImpactFarFlow[ 3 ] ), postImpactRealFarFlow[ 3 ] )
                        cv.Dilate( segmentationMask, segmentationMask )
                        
                        print "Aligning"
                        preImpactRealFarFlow = imageFlowFilter.calcImageFlow( impactMotionImage, self.motionImageList[ impactFrameIdx - 8 ] )
                        print "Aligning"
                        preImpactFarFlow = imageFlowFilter.calcImageFlow( impactMotionImage, self.motionImageList[ impactFrameIdx - 6 ] )
                        print "Aligning"
                        preImpactNearFlow = imageFlowFilter.calcImageFlow( impactMotionImage, self.motionImageList[ impactFrameIdx - 4 ] )
                        
                        subMask = np.maximum( np.maximum( 
                            preImpactRealFarFlow[ 3 ], preImpactFarFlow[ 3 ] ), preImpactNearFlow[ 3 ] )
                        cv.Erode( subMask, subMask )
                        cv.Dilate( subMask, subMask )
                        cv.Dilate( subMask, subMask )
                        cv.Dilate( subMask, subMask )
                        
                        subMask[ subMask > 0 ] = 255
                        diffImage = segmentationMask.astype( np.int32 ) - subMask.astype( np.int32 )
                        diffImage[ diffImage < 0 ] = 0
                        diffImage = diffImage.astype( np.uint8 )
                        cv.Erode( diffImage, diffImage )
                        #diffImage[ diffImage > 0 ] = 255

                        #segmentationMask = subMask
                        segmentationMask = diffImage
                        #segmentationMask = np.where( diffImage > 128, 255, 0 ).astype( np.uint8 )
                
                # Calculate image flow
                #imageFlow = imageFlowFilter.calcImageFlow( motionImage )
                
                ## Calculate saliency map
                #saliencyMap, largeSaliencyMap = residualSaliencyFilter.calcSaliencyMap( grayImageNumpPy )
                
                #blobMap = np.where( largeSaliencyMap > 128, 255, 0 ).astype( np.uint8 )
                
                #blobMap, numBlobs = PyBlobLib.labelBlobs( blobMap )
                #print "found", numBlobs, "blobs"
                
                #largeSaliencyMap = np.where( largeSaliencyMap > 128, 255, 0 ).astype( np.uint8 )
                
                
                
                
                
                
                # Threshold the saliency map
                #largeSaliencyMap = (largeSaliencyMap > 128).astype(np.uint8) * 255
                #cv.AdaptiveThreshold( largeSaliencyMap, largeSaliencyMap, 255 )
                
                # Detect clusters within the saliency map
                #NUM_CLUSTERS = 5
                
                #numSamples = np.sum( saliencyMap )
                #sampleList = np.ndarray( ( numSamples, 2 ), dtype=np.float32 )
                
                #sampleListIdx = 0
                #for y in range( saliencyMap.shape[ 0 ] ):
                    #for x in range( saliencyMap.shape[ 1 ] ):
                        
                        #numNewSamples = saliencyMap[ y, x ]
                        #if numNewSamples > 0:
                            #sampleList[ sampleListIdx:sampleListIdx+numNewSamples, 0 ] = x
                            #sampleList[ sampleListIdx:sampleListIdx+numNewSamples, 1 ] = y
                            #sampleListIdx += numNewSamples
                            
                #sampleList[ 0:numSamples/2 ] = ( 20, 20 )
                #sampleList[ numSamples/2: ] = ( 200, 200 )
                
                #labelList = np.ndarray( ( numSamples, 1 ), dtype=np.int32 )
                #cv.KMeans2( sampleList, NUM_CLUSTERS, labelList, 
                    #(cv.CV_TERMCRIT_ITER | cv.CV_TERMCRIT_EPS, 10, 0.01) )
                    
                #saliencyScaleX = float( largeSaliencyMap.shape[ 1 ] ) / saliencyMap.shape[ 1 ]
                #saliencyScaleY = float( largeSaliencyMap.shape[ 0 ] ) / saliencyMap.shape[ 0 ]
                clusterList = []
                #for clusterIdx in range( NUM_CLUSTERS ):
                    
                    #clusterSamples = sampleList[ 
                        #np.where( labelList == clusterIdx )[ 0 ], : ]

                    #if clusterSamples.size <= 0:
                        #mean = ( 0.0, 0.0 )
                        #stdDev = 0.0
                    #else:
                        #mean = clusterSamples.mean( axis=0 )
                        #mean = ( mean[ 0 ]*saliencyScaleX, mean[ 1 ]*saliencyScaleY )
                        #stdDev = clusterSamples.std()*saliencyScaleX
                    
                    #clusterList.append( ( mean, stdDev ) )
                
                
                
                
                # Work out the maximum amount of motion we've seen in a single frame so far
                #motionCount = motionImage[ motionImage > 0 ].size
                
                #if frameIdx == 0:
                    #lastMotionCount = 0
                #else:
                    #lastMotionCount = self.maxMotionCounts[ frameIdx - 1 ]
                    
                #if motionCount < lastMotionCount:
                    #motionCount = lastMotionCount
                
                ## Work out diffImage    
                #diffImage = np.array( motionImage, dtype=np.int32 ) \
                     #- np.array( imageFlow[ 3 ], dtype=np.int32 )
                #diffImage = np.array( np.maximum( diffImage, 0 ), dtype=np.uint8 )
                
                
                
                
                
                # Segment the image
                #workingMask = np.copy( motionImage )
                #workingMask = np.copy( diffImage )
                workingMask = np.copy( segmentationMask )
                kernel = cv.CreateStructuringElementEx( 
                    cols=3, rows=3, 
                    anchorX=1, anchorY=1, shape=cv.CV_SHAPE_CROSS )
                cv.Erode( workingMask, workingMask, kernel )
                cv.Dilate( workingMask, workingMask )
                
                extraExtraMask = np.copy( workingMask )
                cv.Dilate( extraExtraMask, extraExtraMask )
                cv.Dilate( extraExtraMask, extraExtraMask )
                cv.Dilate( extraExtraMask, extraExtraMask )
                cv.Dilate( extraExtraMask, extraExtraMask )
                cv.Dilate( extraExtraMask, extraExtraMask )
                cv.Dilate( extraExtraMask, extraExtraMask )
                
                allMask = np.copy( extraExtraMask )
                cv.Dilate( allMask, allMask )
                cv.Dilate( allMask, allMask )
                cv.Dilate( allMask, allMask )
                cv.Dilate( allMask, allMask )
                cv.Dilate( allMask, allMask )
                cv.Dilate( allMask, allMask )
                
                possibleForeground = workingMask > 0
            
                if workingMask[ possibleForeground ].size >= 100 \
                    and frameIdx >= 16:
                        
                    print "Msk size", workingMask[ possibleForeground ].size
                    print workingMask[ 0, 0:10 ]
                    
                    fgModel = cv.CreateMat( 1, 5*13, cv.CV_64FC1 )
                    bgModel = cv.CreateMat( 1, 5*13, cv.CV_64FC1 )
                    #workingMask[ possibleForeground ] = self.GC_FGD
                    #workingMask[ possibleForeground == False ] = self.GC_PR_BGD
                    
                    #workingMask[ : ] = self.GC_PR_BGD
                    #workingMask[ possibleForeground ] = self.GC_FGD
                    
                    workingMask[ : ] = self.GC_BGD
                    workingMask[ allMask > 0 ] = self.GC_PR_BGD
                    workingMask[ extraExtraMask > 0 ] = self.GC_PR_FGD
                    workingMask[ possibleForeground ] = self.GC_FGD
                    
                    
                    if frameIdx == 16:
                        # Save mask
                        maskCopy = np.copy( workingMask )
                        maskCopy[ maskCopy == self.GC_BGD ] = 0
                        maskCopy[ maskCopy == self.GC_PR_BGD ] = 64
                        maskCopy[ maskCopy == self.GC_PR_FGD ] = 128
                        maskCopy[ maskCopy == self.GC_FGD ] = 255
                        print "Unused pixels", \
                            maskCopy[ (maskCopy != 255) & (maskCopy != 0) ].size
                          
                        outputImage = cv.CreateMat( msg.height, msg.width, cv.CV_8UC3 )
                        cv.CvtColor( maskCopy, outputImage, cv.CV_GRAY2BGR )
                        
                        cv.SaveImage( "output.png", image );
                        cv.SaveImage( "outputMask.png", outputImage ); 
                        
                        print "Saved images"
                        #return 
                        
                    
                    #print "Set Msk size", workingMask[ workingMask == self.GC_PR_FGD ].size
                
                    imageToSegment = image #self.inputImageList[ frameIdx - FRAMES_BACK ]
                
                    imageCopy = np.copy( imageToSegment )
                    cv.CvtColor( imageCopy, imageCopy, cv.CV_BGR2RGB )
                
                    print "Start seg"
                    cv.GrabCut( imageCopy, workingMask, 
                        (0,0,0,0), fgModel, bgModel, 12, self.GC_INIT_WITH_MASK )
                    print "Finish seg"
                
                    segmentation = np.copy( imageToSegment )
                    segmentation[ (workingMask != self.GC_PR_FGD) & (workingMask != self.GC_FGD) ] = 0
                
                    
                    black = (workingMask != self.GC_PR_FGD) & (workingMask != self.GC_FGD)
                    #motionImage = np.where( black, 0, 255 ).astype( np.uint8 )
                    
                    # Refine the segmentation
                    REFINE_SEG = False
                    if REFINE_SEG:
                        motionImageCopy = np.copy( motionImage )
                        cv.Erode( motionImageCopy, motionImageCopy )
                        #cv.Erode( motionImageCopy, motionImageCopy )
                        #cv.Erode( motionImageCopy, motionImageCopy )
                        
                        workingMask[ motionImageCopy > 0 ] = self.GC_PR_FGD
                        workingMask[ motionImageCopy == 0 ] = self.GC_PR_BGD
                        
                        cv.Dilate( motionImageCopy, motionImageCopy )
                        cv.Dilate( motionImageCopy, motionImageCopy )
                        cv.Dilate( motionImageCopy, motionImageCopy )
                        cv.Dilate( motionImageCopy, motionImageCopy )
                        workingMask[ motionImageCopy == 0 ] = self.GC_BGD
                        
                        print "Other seg"
                        cv.GrabCut( imageCopy, workingMask, 
                            (0,0,0,0), fgModel, bgModel, 12, self.GC_INIT_WITH_MASK )
                        print "Other seg done"
                            
                        segmentation = np.copy( imageToSegment )
                        segmentation[ (workingMask != self.GC_PR_FGD) & (workingMask != self.GC_FGD) ] = 0
                    
                        
                        black = (workingMask != self.GC_PR_FGD) & (workingMask != self.GC_FGD)
                        motionImage = np.where( black, 0, 255 ).astype( np.uint8 )
                    
                
                else:
                    segmentation = np.zeros( ( image.height, image.width ), dtype=np.uint8 )
                
                
                # Save output data
                self.inputImageList[ frameIdx ] = image
                self.grayScaleImageList[ frameIdx ] = grayImage
                self.opticalFlowListX[ frameIdx ] = opticalFlowArrayX
                self.opticalFlowListY[ frameIdx ] = opticalFlowArrayY
                self.motionImageList[ frameIdx ] = motionImage
                self.segmentationList[ frameIdx ] = segmentation
                self.segmentationMaskList[ frameIdx ] = segmentationMask
                #self.maxMotionCounts[ frameIdx ] = motionCount
                #self.imageFlowList[ frameIdx ] = imageFlow
                #self.saliencyMapList[ frameIdx ] = largeSaliencyMap
                #self.saliencyClusterList[ frameIdx ] = clusterList
                self.leftMostMotionList[ frameIdx ] = leftMostMotion
                
                frameIdx += 1
                self.numFramesProcessed += 1
                
        if not self.workCancelled:
            
            
            SAVE_MOTION_IMAGES = True
            BASE_MOTION_IMAGE_NAME = self.scriptPath + "/../../test_data/motion_images/motion_{0:03}.png"
            
            if SAVE_MOTION_IMAGES and len( self.motionImageList ) > 0:
                
                width = self.motionImageList[ 0 ].shape[ 1 ]
                height = self.motionImageList[ 0 ].shape[ 0 ]
                colourImage = np.zeros( ( height, width, 3 ), dtype=np.uint8 )
                
                for frameIdx, motionImage in enumerate( self.motionImageList ):
                    
                    colourImage[ :, :, 0 ] = motionImage
                    colourImage[ :, :, 1 ] = motionImage
                    colourImage[ :, :, 2 ] = motionImage
                    
                    outputName = BASE_MOTION_IMAGE_NAME.format( frameIdx + 1 )
                    cv.SaveImage( outputName, colourImage )
            
            # Recalculate impactFrameIdx
            width = self.motionImageList[ 0 ].shape[ 1 ]
            
            totalMotionDiff = 0
            maxMotionDiff = 0
            impactFrameIdx = None
            for motionIdx in range( 1, len( self.leftMostMotionList ) ):
            
                motionDiff = abs( self.leftMostMotionList[ motionIdx ] \
                    - self.leftMostMotionList[ motionIdx - 1 ] )
                totalMotionDiff += motionDiff
                    
                if motionDiff > maxMotionDiff and totalMotionDiff > 0.5*width:
                    maxMotionDiff = motionDiff
                    impactFrameIdx = motionIdx
            
            if maxMotionDiff <= 18:
                impactFrameIdx = None
                    
            
            if impactFrameIdx != None:
                
                preMotionImages = []
                postMotionImages = []
                impactMotionImage = None
                
                NUM_FRAMES_BEFORE = 3
                
                prefix = self.options.outputPrefix
                if prefix != "":
                    prefix += "_"
                
                BASE_MOTION_IMAGE_NAME = self.scriptPath + "/../../test_data/impact_images/" + prefix + "motion_{0:03}.png"
                START_MOTION_IMAGE_NAME = self.scriptPath + "/../../test_data/impact_images/" + prefix + "start_motion.png"
                START_IMAGE_NAME = self.scriptPath + "/../../test_data/impact_images/" + prefix + "start.png"
                IMPACT_IMAGE_NAME = self.scriptPath + "/../../test_data/impact_images/" + prefix + "impact.png"
                SEGMENTATION_IMAGE_NAME = self.scriptPath + "/../../test_data/impact_images/" + prefix + "segmentation.png"
                NUM_FRAMES_AFTER = 3
                
                width = self.motionImageList[ 0 ].shape[ 1 ]
                height = self.motionImageList[ 0 ].shape[ 0 ]
                colourImage = np.zeros( ( height, width, 3 ), dtype=np.uint8 )
                
                for frameIdx in range( impactFrameIdx - NUM_FRAMES_BEFORE,
                    impactFrameIdx + NUM_FRAMES_AFTER + 1 ):
                    
                    motionImage = self.motionImageList[ frameIdx ]  
                    
                    if frameIdx < impactFrameIdx:
                        preMotionImages.append( motionImage )
                    elif frameIdx == impactFrameIdx:
                        impactMotionImage = motionImage
                    else: # frameIdx > impactFrameIdx
                        postMotionImages.append( motionImage )
                    
                    colourImage[ :, :, 0 ] = motionImage
                    colourImage[ :, :, 1 ] = motionImage
                    colourImage[ :, :, 2 ] = motionImage
                    
                    outputName = BASE_MOTION_IMAGE_NAME.format( frameIdx - impactFrameIdx )
                    cv.SaveImage( outputName, colourImage )
                
                motionDetectionFilter.calcMotion( self.grayScaleImageList[ 0 ] )
                startMotionImage = motionDetectionFilter.calcMotion( 
                    self.grayScaleImageList[ impactFrameIdx ] )
                colourImage[ :, :, 0 ] = startMotionImage
                colourImage[ :, :, 1 ] = startMotionImage
                colourImage[ :, :, 2 ] = startMotionImage  
                cv.SaveImage( START_MOTION_IMAGE_NAME, colourImage )
                
                cv.CvtColor( self.inputImageList[ 0 ], colourImage, cv.CV_RGB2BGR )    
                cv.SaveImage( START_IMAGE_NAME, colourImage )
                cv.CvtColor( self.inputImageList[ impactFrameIdx ], colourImage, cv.CV_RGB2BGR )    
                cv.SaveImage( IMPACT_IMAGE_NAME, colourImage )
                
                print "Segmenting..."
                segmentation = self.produceSegmentation( self.inputImageList[ 0 ], 
                    impactMotionImage, preMotionImages, postMotionImages )
                cv.CvtColor( segmentation, colourImage, cv.CV_RGB2BGR )    
                cv.SaveImage( SEGMENTATION_IMAGE_NAME, colourImage )
                    
            self.refreshGraphDisplay()
            
            
        print "Finished processing bag file"
        if bool( self.options.quitAfterFirstSegmentation == "True" ):
            print "Trying to quit"
            self.onWinMainDestroy( None )
        else:
            print "Not trying to quit so neeah"
        
    #---------------------------------------------------------------------------
    def refreshGraphDisplay( self ):
        
        # Remove existing graph items
        if self.graphCanvas != None:   
            self.vboxGraphs.remove( self.graphCanvas )
            self.graphCanvas.destroy()  
            self.graphCanvas = None   
        if self.graphNavToolbar != None:
            self.vboxGraphs.remove( self.graphNavToolbar )
            self.graphNavToolbar.destroy()  
            self.graphNavToolbar = None   
            
        # Draw the graphs
        self.graphFigure = Figure( figsize=(8,6), dpi=72 )
        self.graphAxis = self.graphFigure.add_subplot( 111 )
        #self.graphAxis.plot( range( 1, len( self.maxMotionCounts )+1 ), self.maxMotionCounts )
        diffs = [ 0 ] + [ self.leftMostMotionList[ i+1 ] - self.leftMostMotionList[ i ] for i in range( len( self.leftMostMotionList ) - 1 ) ]
        #self.graphAxis.plot( range( 1, len( self.leftMostMotionList )+1 ), self.leftMostMotionList )
        self.graphAxis.plot( range( 1, len( self.leftMostMotionList )+1 ), diffs )
        
        # Build the new graph display
        self.graphCanvas = FigureCanvas( self.graphFigure ) # a gtk.DrawingArea
        self.graphCanvas.show()
        self.graphNavToolbar = NavigationToolbar( self.graphCanvas, self.window )
        self.graphNavToolbar.lastDir = '/var/tmp/'
        self.graphNavToolbar.show()
        
        # Show the graph
        self.vboxGraphs.pack_start( self.graphNavToolbar, expand=False, fill=False )
        self.vboxGraphs.pack_start( self.graphCanvas, True, True )
        self.vboxGraphs.show()
        self.vboxGraphs.show()

    #---------------------------------------------------------------------------
    def update( self ):

        lastTime = time.clock()

        while 1:
            
            curTime = time.clock()
            #print "Processing image", framIdx
                
            yield True
            
        yield False
    def __init__(self):
        self.debug = 0
        
        self.test = RTP.Motor_PD_Control_Test(kp=2, kd=0.07)

        # create a new window
        self.window = gtk.Window(gtk.WINDOW_TOPLEVEL)
    
        # When the window is given the "delete_event" signal (this is given
        # by the window manager, usually by the "close" option, or on the
        # titlebar), we ask it to call the delete_event () function
        # as defined above. The data passed to the callback
        # function is NULL and is ignored in the callback function.
        self.window.connect("delete_event", self.delete_event)
    
        # Here we connect the "destroy" event to a signal handler.  
        # This event occurs when we call gtk_widget_destroy() on the window,
        # or if we return FALSE in the "delete_event" callback.
        self.window.connect("destroy", self.destroy)
    
        # Sets the border width of the window.
        self.window.set_border_width(10)
    
        # Creates a new button with the label "Hello World".
        self.swept_sine_button = gtk.Button("Swept Sine")
        self.step_button = gtk.Button("Step Response")
        self.reset_button = gtk.Button("Reset Theta")
        self.return_button = gtk.Button("Return to 0")
        self.sys_check_button = gtk.Button("System Check")
        self.ol_step_button = gtk.Button("OL Step")
        
        #self.vib_check = gtk.CheckButton(label="Use Vibration Suppression", \
        #                                 use_underline=False)
        self.vib_on_radio = gtk.RadioButton(None, "On")
        self.vib_off_radio = gtk.RadioButton(self.vib_on_radio, "Off")
        #button.connect("toggled", self.callback, "radio button 2")
        self.vib_off_radio.set_active(True)
        vib_label = gtk.Label("Vibration Suppression")
        ol_label = gtk.Label("OL Step Response")
        ol_hbox = gtk.HBox(homogeneous=False)
        amp_label = gtk.Label("amp:")
        dur_label = gtk.Label("duration:")
        self.ol_amp_entry = gtk.Entry()
        self.ol_amp_entry.set_max_length(7)
        self.ol_amp_entry.set_text("150")
        Entry_width = 50
        self.ol_amp_entry.set_size_request(Entry_width,-1)
        self.dur_entry = gtk.Entry()
        self.dur_entry.set_max_length(7)
        self.dur_entry.set_text("300")
        self.dur_entry.set_size_request(Entry_width,-1)
        ol_hbox.pack_end(self.ol_amp_entry, False)
        ol_hbox.pack_end(amp_label, False)
        #ol_hbox.pack_start(amp_label, False)
        #ol_hbox.pack_start(self.ol_amp_entry, False)
        ol_dur_box = gtk.HBox(homogeneous=False)
        ol_dur_box.pack_end(self.dur_entry, False)
        ol_dur_box.pack_end(dur_label, False)
        
        #Fixed Sine Controls
        self.fs_amp_entry = gtk.Entry()
        self.fs_amp_entry.set_max_length(7)
        self.fs_amp_entry.set_size_request(Entry_width, -1)
        self.fs_amp_entry.set_text("5")
        self.fs_freq_entry = gtk.Entry()
        self.fs_freq_entry.set_max_length(7)
        self.fs_freq_entry.set_size_request(Entry_width, -1)
        self.fs_freq_entry.set_text("0.5")
        self.fs_dur_entry = gtk.Entry()
        self.fs_dur_entry.set_max_length(7)
        self.fs_dur_entry.set_size_request(Entry_width, -1)
        self.fs_dur_entry.set_text("300")
        fs_label = gtk.Label('Fixed Sine')
        fs_amp_label = gtk.Label("amp (counts):")
        fs_freq_label = gtk.Label("freq (Hz):")
        fs_dur_label = gtk.Label("duration:")
        fsvbox = gtk.VBox(homogeneous=False, spacing=5)
        fsvbox.show()
        fsvbox.pack_start(fs_label)
        fshbox1 = gtk.HBox(homogeneous=False)
        fshbox1.pack_end(self.fs_amp_entry, False)
        fshbox1.pack_end(fs_amp_label, False)
        fsvbox.pack_start(fshbox1, False)
        fshbox2 = gtk.HBox(homogeneous=False)
        fshbox2.pack_end(self.fs_freq_entry, False)
        fshbox2.pack_end(fs_freq_label, False)
        fsvbox.pack_start(fshbox2, False)
        fshbox3 = gtk.HBox(homogeneous=False)
        fshbox3.pack_end(self.fs_dur_entry, False)
        fshbox3.pack_end(fs_dur_label, False)
        fsvbox.pack_start(fshbox3, False)
        self.fixed_sine_button = gtk.Button("Fixed Sine")
        fsvbox.pack_start(self.fixed_sine_button, False)

        #ol_dur_box.pack_start(dur_label, False)
        #ol_dur_box.pack_start(self.dur_entry, False)
        sep0 = gtk.HSeparator()
        sep1 = gtk.HSeparator()
        sep2 = gtk.HSeparator()
        sep3 = gtk.HSeparator()
        sep4 = gtk.HSeparator()
        
        
        #self.button.set_size_request(30, 40)
    
        # When the button receives the "clicked" signal, it will call the
        # function hello() passing it None as its argument.  The hello()
        # function is defined above.
        self.swept_sine_button.connect("clicked", self.swept_sine_test, None)
        self.step_button.connect("clicked", self.step_test, None)
        self.reset_button.connect("clicked", self.reset_theta, None)
        self.return_button.connect("clicked", self.return_to_zero, None)
        self.sys_check_button.connect("clicked", self.system_check, None)
        self.ol_step_button.connect("clicked", self.run_ol_step, None)
        self.fixed_sine_button.connect("clicked", self.fixed_sine_test, None)    
        # This will cause the window to be destroyed by calling
        # gtk_widget_destroy(window) when "clicked".  Again, the destroy
        # signal could come from here, or the window manager.
        #self.button.connect_object("clicked", gtk.Widget.destroy, self.window)

        big_hbox = gtk.HBox()#homogeneous=False, spacing=5)
        button_vbox = gtk.VBox(homogeneous=False, spacing=5)
        #self.vbox1 = gtk.VBox(homogeneous=False, spacing=0)
        # This packs the button into the window (a GTK container).
        #self.window.add(self.button)

        button_vbox.pack_start(self.sys_check_button, False)
        button_vbox.pack_start(sep0, False)
        button_vbox.pack_start(self.swept_sine_button, False, False, 0)
        button_vbox.pack_start(sep1, False)
        button_vbox.pack_start(vib_label, False)
        #button_vbox.pack_start(self.vib_check, False)
        button_vbox.pack_start(self.vib_on_radio, False)
        button_vbox.pack_start(self.vib_off_radio, False)
        button_vbox.pack_start(sep2, False)
        button_vbox.pack_start(self.step_button, False, False, 0)
        button_vbox.pack_start(sep3, False)
        #Fixed Sine Stuff
        button_vbox.pack_start(fsvbox, False)
        button_vbox.pack_start(sep4, False)
        button_vbox.pack_start(ol_label, False)
        button_vbox.pack_start(ol_hbox, False)
        button_vbox.pack_start(ol_dur_box, False)
        button_vbox.pack_start(self.ol_step_button, False)
        button_vbox.pack_start(self.reset_button, False)
        button_vbox.pack_start(self.return_button, False)
        
        

        self.f = Figure(figsize=(5,4), dpi=100)
        self.ax = self.f.add_subplot(111)
        t = arange(0.0,3.0,0.01)
        s = sin(2*pi*t)
        self.ax.plot(t,s)

        self.figcanvas = FigureCanvas(self.f)  # a gtk.DrawingArea
        self.figcanvas.show()
        canvas_vbox = gtk.VBox()
        toolbar = NavigationToolbar(self.figcanvas, self.window)
        #toolbar.set_size_request(-1,50)
        self.figcanvas.set_size_request(600,300)
        toolbar.set_size_request(600,50)
        toolbar.show()
        big_hbox.pack_start(button_vbox, False, False, 0)
        canvas_vbox.pack_start(self.figcanvas)#, expand=True, \
                               #fill=True, padding=5)
        canvas_vbox.pack_start(toolbar, False)#, False)#, padding=5)
        big_hbox.pack_start(canvas_vbox)#, expand=True, \
                            #fill=True, padding=5)
        

        self.window.add(big_hbox)

        # The final step is to display this newly created widget.
        #self.button.show()

        #self.window.set_size_request(1000,800)

        self.window.set_title('RTP Motor GUI v. 1.0')
        # and the window
        #self.window.show()
        self.window.set_position(gtk.WIN_POS_CENTER)
        self.window.show_all()
class MainWindow:
    
    OPTICAL_FLOW_BLOCK_WIDTH = 16
    OPTICAL_FLOW_BLOCK_HEIGHT = 16
    OPTICAL_FLOW_RANGE_WIDTH = 16    # Range to look outside of a block for motion
    OPTICAL_FLOW_RANGE_HEIGHT = 16
    OPTICAL_FLOW_METHOD = "BlockMatching"
    #OPTICAL_FLOW_METHOD = "LucasKanade"
    #OPTICAL_FLOW_METHOD = "HornSchunck"
    COMBINATION_METHOD = "NoChange"
    
    #GROUND_TRUTH_FILENAME = "/../../config/TopPosGripper.yaml"
    #GROUND_TRUTH_FILENAME = "/../../config/ExperimentPosGripper.yaml"
    #GROUND_TRUTH_FILENAME = "/../../config/OnTablePosGripper.yaml"
    GROUND_TRUTH_FILENAME = "/../../config/TightBasicWave_Gripper.yaml"
    
    CORRELATION_THRESHOLD = 0.52
    MAX_TEST_POINT_X = (640 - OPTICAL_FLOW_BLOCK_WIDTH)/OPTICAL_FLOW_BLOCK_WIDTH - 1
    MAX_TEST_POINT_Y = (480 - OPTICAL_FLOW_BLOCK_HEIGHT)/OPTICAL_FLOW_BLOCK_HEIGHT - 1
    
    SAMPLES_PER_SECOND = 30.0
    MAX_CORRELATION_LAG = 2.0
    
    GRIPPER_WAVE_FREQUENCY = 1.0    # Waves per second
    GRIPPER_NUM_WAVES = 3.0
    GRIPPER_WAVE_AMPLITUDE = math.radians( 20.0 )
 
    #---------------------------------------------------------------------------
    def __init__( self, bagFilename ):
    
        self.scriptPath = os.path.dirname( __file__ )
        self.cameraImagePixBuf = None
        self.bagFilename = bagFilename
        self.lastImageGray = None
        
        # Read in sequence
        t1 = time.time()
        
        self.inputSequence = InputSequence( bagFilename )
        
        distractors = [
            Distractor( radius=24, startPos=( 25, 35 ), endPos=( 100, 100 ), frequency=2.0 ),
            Distractor( radius=24, startPos=( 200, 200 ), endPos=( 150, 50 ), frequency=0.25 ),
            Distractor( radius=24, startPos=( 188, 130 ), endPos=( 168, 258 ), frequency=0.6 ),
            Distractor( radius=24, startPos=( 63, 94 ), endPos=( 170, 81 ), frequency=1.5 ),
            Distractor( radius=24, startPos=( 40, 287 ), endPos=( 50, 197 ), frequency=3.0 ) ]
        #self.inputSequence.addDistractorObjects( distractors )
        
        self.inputSequence.calculateOpticalFlow(
            self.OPTICAL_FLOW_BLOCK_WIDTH, self.OPTICAL_FLOW_BLOCK_HEIGHT,
            self.OPTICAL_FLOW_RANGE_WIDTH, self.OPTICAL_FLOW_RANGE_HEIGHT,
            self.OPTICAL_FLOW_METHOD )
            
        t2 = time.time()
        print 'Processing sequence took %0.3f ms' % ((t2-t1)*1000.0)
        
        # Resample sequence
        t1 = time.time()
        
        self.regularisedInputSequence = RegularisedInputSequence( 
            self.inputSequence, self.SAMPLES_PER_SECOND )
            
        t2 = time.time()
        print 'Resampling took %0.3f ms' % ((t2-t1)*1000.0)
        
        
                    
        t1 = time.time()
        
        self.crossCorrelatedSequence = CrossCorrelatedSequence( 
            self.regularisedInputSequence, 
            self.MAX_CORRELATION_LAG, self.COMBINATION_METHOD )
        
        t2 = time.time()
        print 'Correlation took %0.3f ms' % ((t2-t1)*1000.0)
        
        # Detect the input signal based on the correlation in the x and y axis
        self.inputSignalDetectedArray = \
            self.crossCorrelatedSequence.detectInputSequence( self.CORRELATION_THRESHOLD )
          
        # Build a histogram for the gripper  
        self.gripperHistogram = cv.CreateHist( 
            [ 256/8, 256/8, 256/8 ], cv.CV_HIST_ARRAY, [ (0,255), (0,255), (0,255) ], 1 )
            
        firstImage = self.inputSequence.cameraImages[ 0 ]
        imageRGB = cv.CreateImageHeader( ( firstImage.shape[ 1 ], firstImage.shape[ 0 ] ), cv.IPL_DEPTH_8U, 3 )
        cv.SetData( imageRGB, firstImage.data, firstImage.shape[ 1 ]*3 )
            
        r_plane = cv.CreateMat( imageRGB.height, imageRGB.width, cv.CV_8UC1 )
        g_plane = cv.CreateMat( imageRGB.height, imageRGB.width, cv.CV_8UC1 )
        b_plane = cv.CreateMat( imageRGB.height, imageRGB.width, cv.CV_8UC1 )
        cv.Split( imageRGB, r_plane, g_plane, b_plane, None )
        planes = [ r_plane, g_plane, b_plane ]

        maskArray = np.zeros(shape=( imageRGB.height, imageRGB.width ), dtype=np.uint8 )
        for rowIdx in range( self.inputSignalDetectedArray.shape[ 0 ] ):
            for colIdx in range( self.inputSignalDetectedArray.shape[ 1 ] ):
                
                if self.inputSignalDetectedArray[ rowIdx, colIdx ]:
                    rowStartIdx = rowIdx*self.OPTICAL_FLOW_BLOCK_HEIGHT
                    rowEndIdx = rowStartIdx + self.OPTICAL_FLOW_BLOCK_HEIGHT
                    colStartIdx = colIdx*self.OPTICAL_FLOW_BLOCK_WIDTH
                    colEndIdx = colStartIdx + self.OPTICAL_FLOW_BLOCK_WIDTH
                    
                    maskArray[ rowStartIdx:rowEndIdx, colStartIdx:colEndIdx ] = 255

        cv.CalcHist( [ cv.GetImage( i ) for i in planes ], 
            self.gripperHistogram, 0, mask=cv.fromarray( maskArray ) )
        
        markerBuffer = MarkerBuffer.loadMarkerBuffer( self.scriptPath + self.GROUND_TRUTH_FILENAME )
        if markerBuffer == None:
            raise Exception( "Unable to load marker buffer" )
        
        self.rocCurve = GripperDetectorROCCurve( self.crossCorrelatedSequence, markerBuffer )
        
        # Create the matplotlib graph
        self.figure = Figure( figsize=(8,6), dpi=72 )
        self.axisX = self.figure.add_subplot( 311 )
        self.axisY = self.figure.add_subplot( 312 )
        self.axisROC = self.figure.add_subplot( 313 )
        
        self.canvas = None  # Wait for GUI to be created before creating canvas
        self.navToolbar = None
 
        # Setup the GUI        
        builder = gtk.Builder()
        builder.add_from_file( self.scriptPath + "/GUI/OpticalFlowExplorer.glade" )
        
        self.dwgCameraImage = builder.get_object( "dwgCameraImage" )
        self.window = builder.get_object( "winMain" )
        self.vboxMain = builder.get_object( "vboxMain" )
        self.hboxWorkArea = builder.get_object( "hboxWorkArea" )
        self.adjTestPointX = builder.get_object( "adjTestPointX" )
        self.adjTestPointY = builder.get_object( "adjTestPointY" )
        self.adjTestPointX.set_upper( self.MAX_TEST_POINT_X )
        self.adjTestPointY.set_upper( self.MAX_TEST_POINT_Y )
        self.sequenceControls = builder.get_object( "sequenceControls" )
        self.sequenceControls.setNumFrames( len( self.inputSequence.cameraImages ) )
        self.sequenceControls.setOnFrameIdxChangedCallback( self.onSequenceControlsFrameIdxChanged )
        self.setFrameIdx( 0 )
        self.processOpticalFlowData()
        
        builder.connect_signals( self )
               
        updateLoop = self.update()
        gobject.idle_add( updateLoop.next )
        
        self.window.show()
        self.window.maximize()
        
    #---------------------------------------------------------------------------
    def onWinMainDestroy( self, widget, data = None ):  
        gtk.main_quit()
        
    #---------------------------------------------------------------------------   
    def main( self ):
        # All PyGTK applications must have a gtk.main(). Control ends here
        # and waits for an event to occur (like a key press or mouse event).
        gtk.main()
        
    #---------------------------------------------------------------------------
    def processOpticalFlowData( self ):

        testX = int( self.adjTestPointX.get_value() )
        testY = int( self.adjTestPointY.get_value() )
        
        regSeq = self.regularisedInputSequence
        corSeq = self.crossCorrelatedSequence
        
        # Normalise the data ready for display
        normalisedServoAngleData = Utils.normaliseSequence( regSeq.regularServoAngleData )
        normalisedOpticalFlowDataX = Utils.normaliseSequence( regSeq.regularOpticalFlowArrayX[ testY ][ testX ] )
        normalisedOpticalFlowDataY = Utils.normaliseSequence( regSeq.regularOpticalFlowArrayY[ testY ][ testX ] )
        
        numCorrelationChannels = len( corSeq.correlationChannels )
        
        # Plot graphs
        self.axisX.clear()
        self.axisX.plot( regSeq.regularSampleTimes, normalisedServoAngleData )
        self.axisX.plot( regSeq.regularSampleTimes, normalisedOpticalFlowDataX )
        
        if numCorrelationChannels >= 1:
            correlationChannel = corSeq.correlationChannels[ 0 ][ testY ][ testX ]
            self.axisX.plot( regSeq.regularSampleTimes[:len(correlationChannel)], correlationChannel )
        
        self.axisY.clear()
        self.axisY.plot( regSeq.regularSampleTimes, normalisedServoAngleData )
        self.axisY.plot( regSeq.regularSampleTimes, normalisedOpticalFlowDataY )
        
        inpSeq = self.inputSequence
        self.axisY.plot( inpSeq.imageTimes, Utils.normaliseSequence( inpSeq.opticalFlowArraysY[ testY ][ testX ] ) )
        
        if numCorrelationChannels >= 2:
            correlationChannel = corSeq.correlationChannels[ 1 ][ testY ][ testX ]
            self.axisY.plot( regSeq.regularSampleTimes[:len(correlationChannel)], correlationChannel )
        
        self.axisROC.clear()
        self.axisROC.plot( self.rocCurve.falsePositiveRates, self.rocCurve.truePositiveRates )
        #self.axisROC.plot( self.rocCurve.thresholds, self.rocCurve.sensitivity )
        #self.axisROC.plot( self.rocCurve.thresholds, self.rocCurve.specificity )
        
        self.refreshGraphDisplay()
        
        outputFile = open( self.scriptPath + "/../../test_results/CrossCorrelation.csv", "w" )
        numSamples = len( regSeq.regularSampleTimes )
        
        correlationChannel = corSeq.correlationChannels[ 0 ][ testY ][ testX ]
        numCorrelationSamples = len( correlationChannel )
        
        print >>outputFile, "Time,Input,Output,CrossCorrelation"
        for i in range( numSamples ):
            
            if i < numCorrelationSamples:
                correlationData = correlationChannel[ i ]
            else:
                correlationData = 0.0
            
            print >>outputFile, "{0},{1},{2},{3}".format( 
                regSeq.regularSampleTimes[ i ], 
                normalisedServoAngleData[ i ],
                normalisedOpticalFlowDataX[ i ],
                correlationData )
        
        outputFile.close()
                
    
    #---------------------------------------------------------------------------
    def refreshGraphDisplay( self ):
        
        if self.canvas != None:   
            self.hboxWorkArea.remove( self.canvas )
            self.canvas.destroy()  
            self.canvas = None   
        if self.navToolbar != None:
            self.vboxMain.remove( self.navToolbar )
            self.navToolbar.destroy()  
            self.navToolbar = None   
        
        self.canvas = FigureCanvas( self.figure ) # a gtk.DrawingArea
        self.canvas.show()
        self.hboxWorkArea.pack_start( self.canvas, True, True )
        self.hboxWorkArea.show()
        
        # Navigation toolbar
        self.navToolbar = NavigationToolbar( self.canvas, self.window )
        self.navToolbar.lastDir = '/var/tmp/'
        self.vboxMain.pack_start( self.navToolbar, expand=False, fill=False )
        self.navToolbar.show()
        self.vboxMain.show()
    
    #---------------------------------------------------------------------------
    def setFrameIdx( self, frameIdx ):
        
        self.frameIdx = frameIdx
        
        # Display the frame
        image = self.inputSequence.cameraImages[ frameIdx ]
        imageWidth = image.shape[ 1 ]
        imageHeight = image.shape[ 0 ]
        imageStep = imageWidth*3
        
        self.cameraImagePixBuf = gtk.gdk.pixbuf_new_from_data( 
            image.tostring(), 
            gtk.gdk.COLORSPACE_RGB,
            False,
            8,
            imageWidth,
            imageHeight,
            imageStep )
            
        # Track gripper
        imageRGB = cv.CreateImageHeader( ( imageWidth, imageHeight ), cv.IPL_DEPTH_8U, 3 )
        cv.SetData( imageRGB, image.data, imageStep )
        imageRGB = cv.CloneImage( imageRGB )
            
        r_plane = cv.CreateMat( imageRGB.height, imageRGB.width, cv.CV_8UC1 )
        g_plane = cv.CreateMat( imageRGB.height, imageRGB.width, cv.CV_8UC1 )
        b_plane = cv.CreateMat( imageRGB.height, imageRGB.width, cv.CV_8UC1 )
        cv.Split( imageRGB, r_plane, g_plane, b_plane, None )
        planes = [ r_plane, g_plane, b_plane ]
        
        backproject = cv.CreateImage(cv.GetSize(imageRGB), 8, 1)

        # Run the cam-shift
        cv.CalcArrBackProject( planes, backproject, self.gripperHistogram )
        #cv.Threshold( backproject, backproject, 1, 255, cv.CV_THRESH_BINARY )
        cv.CvtColor( backproject, imageRGB, cv.CV_GRAY2RGB )
        
        #self.cameraImagePixBuf = gtk.gdk.pixbuf_new_from_data( 
            #imageRGB.tostring(), 
            #gtk.gdk.COLORSPACE_RGB,
            #False,
            #8,
            #imageRGB.width,
            #imageRGB.height,
            #imageRGB.width*3 )
        

        # Resize the drawing area if necessary
        if self.dwgCameraImage.get_size_request() != ( imageWidth, imageHeight ):
            self.dwgCameraImage.set_size_request( imageWidth, imageHeight )

        self.dwgCameraImage.queue_draw()
    
    #---------------------------------------------------------------------------
    def onTestPointAdjustmentValueChanged( self, widget ):
        self.processOpticalFlowData()
        self.dwgCameraImage.queue_draw()
    
    #---------------------------------------------------------------------------
    def onSequenceControlsFrameIdxChanged( self, widget ):
        self.setFrameIdx( widget.frameIdx )
    
    #---------------------------------------------------------------------------
    def onDwgCameraImageButtonPressEvent( self, widget, data ):
        
        if self.cameraImagePixBuf != None:
            
            imgRect = self.getImageRectangleInWidget( widget,
                self.cameraImagePixBuf.get_width(), self.cameraImagePixBuf.get_height() )
        
            self.adjTestPointX.set_value( int( ( data.x - imgRect.x )/self.OPTICAL_FLOW_BLOCK_WIDTH ) )
            self.adjTestPointY.set_value( int( ( data.y - imgRect.y )/self.OPTICAL_FLOW_BLOCK_HEIGHT ) )            
        
    #---------------------------------------------------------------------------
    def onDwgCameraImageExposeEvent( self, widget, data = None ):
        
        if self.cameraImagePixBuf != None:
            
            imgRect = self.getImageRectangleInWidget( widget,
                self.cameraImagePixBuf.get_width(), self.cameraImagePixBuf.get_height() )
                
            imgOffsetX = imgRect.x
            imgOffsetY = imgRect.y
                
            # Get the total area that needs to be redrawn
            imgRect = imgRect.intersect( data.area )
        
            srcX = imgRect.x - imgOffsetX
            srcY = imgRect.y - imgOffsetY
           
            widget.window.draw_pixbuf( widget.get_style().fg_gc[ gtk.STATE_NORMAL ],
                self.cameraImagePixBuf, srcX, srcY, 
                imgRect.x, imgRect.y, imgRect.width, imgRect.height )
               
            #return
               
            # Draw an overlay to show places where the input motion has been detected
            if self.inputSignalDetectedArray != None:
                
                imageData = np.frombuffer( self.cameraImagePixBuf.get_pixels(), dtype=np.uint8 )
                imageData.shape = ( self.cameraImagePixBuf.get_height(), 
                    self.cameraImagePixBuf.get_width(), 3 )
                
                graphicsContext = widget.window.new_gc()
                graphicsContext.set_rgb_fg_color( gtk.gdk.Color( 65535, 65535, 0 ) )
                
                blockY = imgRect.y
                for y in range( self.inputSignalDetectedArray.shape[ 0 ] ):
                
                    blockX = imgRect.x
                    for x in range( self.inputSignalDetectedArray.shape[ 1 ] ):
                        
                        if self.inputSignalDetectedArray[ y, x ]:
                            
                            # Get source block as NumPy array
                            srcX = blockX - imgRect.x
                            srcY = blockY - imgRect.y
                            srcData = imageData[ srcY:srcY+self.OPTICAL_FLOW_BLOCK_HEIGHT,
                                srcX:srcX+self.OPTICAL_FLOW_BLOCK_WIDTH, : ]
                                                            
                            # Create a modified version of the block with a yellow layer
                            # alpha blended over the top
                            yellowLayer = np.ones( ( self.OPTICAL_FLOW_BLOCK_WIDTH, 
                                self.OPTICAL_FLOW_BLOCK_HEIGHT, 3 ) )*[255.0,255.0,0.0]*0.5
                            modifiedData = ( srcData.astype( np.float32 )*0.5 + yellowLayer ).astype( np.uint8 )
                            
                            # Blit the modified version into the widget
                            modifiedPixBuf = gtk.gdk.pixbuf_new_from_array( 
                                modifiedData, gtk.gdk.COLORSPACE_RGB, 8 )
                                
                            widget.window.draw_pixbuf( widget.get_style().fg_gc[ gtk.STATE_NORMAL ],
                                modifiedPixBuf, 0, 0, blockX, blockY, 
                                self.OPTICAL_FLOW_BLOCK_WIDTH, self.OPTICAL_FLOW_BLOCK_HEIGHT )
                            
                        blockX += self.OPTICAL_FLOW_BLOCK_WIDTH
                        
                    blockY += self.OPTICAL_FLOW_BLOCK_HEIGHT
             
            return
               
            # Draw the optical flow if it's available
            opticalFlowX = self.inputSequence.opticalFlowArraysX[ :, :, self.frameIdx ]
            opticalFlowY = self.inputSequence.opticalFlowArraysY[ :, :, self.frameIdx ]
            if opticalFlowX != None and opticalFlowY != None:
                
                testX = int( self.adjTestPointX.get_value() )
                testY = int( self.adjTestPointY.get_value() )
            
                graphicsContext = widget.window.new_gc()
                graphicsContext.set_rgb_fg_color( gtk.gdk.Color( 0, 65535, 0 ) )
                
                blockCentreY = imgRect.y + self.OPTICAL_FLOW_BLOCK_HEIGHT / 2
                for y in range( opticalFlowX.shape[ 0 ] ):
                
                    blockCentreX = imgRect.x + self.OPTICAL_FLOW_BLOCK_WIDTH / 2
                    for x in range( opticalFlowX.shape[ 1 ] ):
                        
                        if testX == x and testY == y:
                            # Highlight the current test point
                            radius = 2
                            arcX = int( blockCentreX - radius )
                            arcY = int( blockCentreY - radius )
                            arcWidth = arcHeight = int( radius * 2 )
            
                            drawFilledArc = False
                            graphicsContext.set_rgb_fg_color( gtk.gdk.Color( 65535, 65535, 65535 ) )

                            widget.window.draw_arc( graphicsContext, 
                                drawFilledArc, arcX, arcY, arcWidth, arcHeight, 0, 360 * 64 )
                
                        
                        endX = blockCentreX + opticalFlowX[ y, x ]
                        endY = blockCentreY + opticalFlowY[ y, x ]
                        
                        if endY < blockCentreY:
                            # Up is red
                            graphicsContext.set_rgb_fg_color( gtk.gdk.Color( 65535, 0, 0 ) )
                        elif endY > blockCentreY:
                            # Down is blue
                            graphicsContext.set_rgb_fg_color( gtk.gdk.Color( 0, 0, 65535 ) )
                        else:
                            # Static is green
                            graphicsContext.set_rgb_fg_color( gtk.gdk.Color( 0, 65535, 0 ) )
                            
                        
                        widget.window.draw_line( graphicsContext, 
                            int( blockCentreX ), int( blockCentreY ),
                            int( endX ), int( endY ) )
                        
                        blockCentreX += self.OPTICAL_FLOW_BLOCK_WIDTH
                        
                    blockCentreY += self.OPTICAL_FLOW_BLOCK_HEIGHT            

    #---------------------------------------------------------------------------
    def getImageRectangleInWidget( self, widget, imageWidth, imageHeight ):
        
        # Centre the image inside the widget
        widgetX, widgetY, widgetWidth, widgetHeight = widget.get_allocation()
        
        imgRect = gtk.gdk.Rectangle( 0, 0, widgetWidth, widgetHeight )
        
        if widgetWidth > imageWidth:
            imgRect.x = (widgetWidth - imageWidth) / 2
            imgRect.width = imageWidth
            
        if widgetHeight > imageHeight:
            imgRect.y = (widgetHeight - imageHeight) / 2
            imgRect.height = imageHeight
        
        return imgRect

    #---------------------------------------------------------------------------
    def update( self ):

        UPDATE_FREQUENCY = 30.0    # Updates in Hz

        lastTime = time.clock()

        while 1:
            
            curTime = time.clock()
            
            if curTime - lastTime >= 1.0 / UPDATE_FREQUENCY:

                # Save the time
                lastTime = curTime
                
            yield True
            
        yield False
Beispiel #14
0
class MainWindow:

    OPTICAL_FLOW_BLOCK_WIDTH = 16
    OPTICAL_FLOW_BLOCK_HEIGHT = 16
    OPTICAL_FLOW_RANGE_WIDTH = 16  # Range to look outside of a block for motion
    OPTICAL_FLOW_RANGE_HEIGHT = 16
    OPTICAL_FLOW_METHOD = "BlockMatching"
    #OPTICAL_FLOW_METHOD = "LucasKanade"
    #OPTICAL_FLOW_METHOD = "HornSchunck"
    COMBINATION_METHOD = "NoChange"

    #GROUND_TRUTH_FILENAME = "/../../config/TopPosGripper.yaml"
    #GROUND_TRUTH_FILENAME = "/../../config/ExperimentPosGripper.yaml"
    #GROUND_TRUTH_FILENAME = "/../../config/OnTablePosGripper.yaml"
    GROUND_TRUTH_FILENAME = "/../../config/TightBasicWave_Gripper.yaml"

    CORRELATION_THRESHOLD = 0.52
    MAX_TEST_POINT_X = (
        640 - OPTICAL_FLOW_BLOCK_WIDTH) / OPTICAL_FLOW_BLOCK_WIDTH - 1
    MAX_TEST_POINT_Y = (
        480 - OPTICAL_FLOW_BLOCK_HEIGHT) / OPTICAL_FLOW_BLOCK_HEIGHT - 1

    SAMPLES_PER_SECOND = 30.0
    MAX_CORRELATION_LAG = 2.0

    GRIPPER_WAVE_FREQUENCY = 1.0  # Waves per second
    GRIPPER_NUM_WAVES = 3.0
    GRIPPER_WAVE_AMPLITUDE = math.radians(20.0)

    #---------------------------------------------------------------------------
    def __init__(self, bagFilename):

        self.scriptPath = os.path.dirname(__file__)
        self.cameraImagePixBuf = None
        self.bagFilename = bagFilename
        self.lastImageGray = None

        # Read in sequence
        t1 = time.time()

        self.inputSequence = InputSequence(bagFilename)

        distractors = [
            Distractor(radius=24,
                       startPos=(25, 35),
                       endPos=(100, 100),
                       frequency=2.0),
            Distractor(radius=24,
                       startPos=(200, 200),
                       endPos=(150, 50),
                       frequency=0.25),
            Distractor(radius=24,
                       startPos=(188, 130),
                       endPos=(168, 258),
                       frequency=0.6),
            Distractor(radius=24,
                       startPos=(63, 94),
                       endPos=(170, 81),
                       frequency=1.5),
            Distractor(radius=24,
                       startPos=(40, 287),
                       endPos=(50, 197),
                       frequency=3.0)
        ]
        #self.inputSequence.addDistractorObjects( distractors )

        self.inputSequence.calculateOpticalFlow(self.OPTICAL_FLOW_BLOCK_WIDTH,
                                                self.OPTICAL_FLOW_BLOCK_HEIGHT,
                                                self.OPTICAL_FLOW_RANGE_WIDTH,
                                                self.OPTICAL_FLOW_RANGE_HEIGHT,
                                                self.OPTICAL_FLOW_METHOD)

        t2 = time.time()
        print 'Processing sequence took %0.3f ms' % ((t2 - t1) * 1000.0)

        # Resample sequence
        t1 = time.time()

        self.regularisedInputSequence = RegularisedInputSequence(
            self.inputSequence, self.SAMPLES_PER_SECOND)

        t2 = time.time()
        print 'Resampling took %0.3f ms' % ((t2 - t1) * 1000.0)

        t1 = time.time()

        self.crossCorrelatedSequence = CrossCorrelatedSequence(
            self.regularisedInputSequence, self.MAX_CORRELATION_LAG,
            self.COMBINATION_METHOD)

        t2 = time.time()
        print 'Correlation took %0.3f ms' % ((t2 - t1) * 1000.0)

        # Detect the input signal based on the correlation in the x and y axis
        self.inputSignalDetectedArray = \
            self.crossCorrelatedSequence.detectInputSequence( self.CORRELATION_THRESHOLD )

        # Build a histogram for the gripper
        self.gripperHistogram = cv.CreateHist([256 / 8, 256 / 8, 256 / 8],
                                              cv.CV_HIST_ARRAY, [(0, 255),
                                                                 (0, 255),
                                                                 (0, 255)], 1)

        firstImage = self.inputSequence.cameraImages[0]
        imageRGB = cv.CreateImageHeader(
            (firstImage.shape[1], firstImage.shape[0]), cv.IPL_DEPTH_8U, 3)
        cv.SetData(imageRGB, firstImage.data, firstImage.shape[1] * 3)

        r_plane = cv.CreateMat(imageRGB.height, imageRGB.width, cv.CV_8UC1)
        g_plane = cv.CreateMat(imageRGB.height, imageRGB.width, cv.CV_8UC1)
        b_plane = cv.CreateMat(imageRGB.height, imageRGB.width, cv.CV_8UC1)
        cv.Split(imageRGB, r_plane, g_plane, b_plane, None)
        planes = [r_plane, g_plane, b_plane]

        maskArray = np.zeros(shape=(imageRGB.height, imageRGB.width),
                             dtype=np.uint8)
        for rowIdx in range(self.inputSignalDetectedArray.shape[0]):
            for colIdx in range(self.inputSignalDetectedArray.shape[1]):

                if self.inputSignalDetectedArray[rowIdx, colIdx]:
                    rowStartIdx = rowIdx * self.OPTICAL_FLOW_BLOCK_HEIGHT
                    rowEndIdx = rowStartIdx + self.OPTICAL_FLOW_BLOCK_HEIGHT
                    colStartIdx = colIdx * self.OPTICAL_FLOW_BLOCK_WIDTH
                    colEndIdx = colStartIdx + self.OPTICAL_FLOW_BLOCK_WIDTH

                    maskArray[rowStartIdx:rowEndIdx,
                              colStartIdx:colEndIdx] = 255

        cv.CalcHist([cv.GetImage(i) for i in planes],
                    self.gripperHistogram,
                    0,
                    mask=cv.fromarray(maskArray))

        markerBuffer = MarkerBuffer.loadMarkerBuffer(
            self.scriptPath + self.GROUND_TRUTH_FILENAME)
        if markerBuffer == None:
            raise Exception("Unable to load marker buffer")

        self.rocCurve = GripperDetectorROCCurve(self.crossCorrelatedSequence,
                                                markerBuffer)

        # Create the matplotlib graph
        self.figure = Figure(figsize=(8, 6), dpi=72)
        self.axisX = self.figure.add_subplot(311)
        self.axisY = self.figure.add_subplot(312)
        self.axisROC = self.figure.add_subplot(313)

        self.canvas = None  # Wait for GUI to be created before creating canvas
        self.navToolbar = None

        # Setup the GUI
        builder = gtk.Builder()
        builder.add_from_file(self.scriptPath +
                              "/GUI/OpticalFlowExplorer.glade")

        self.dwgCameraImage = builder.get_object("dwgCameraImage")
        self.window = builder.get_object("winMain")
        self.vboxMain = builder.get_object("vboxMain")
        self.hboxWorkArea = builder.get_object("hboxWorkArea")
        self.adjTestPointX = builder.get_object("adjTestPointX")
        self.adjTestPointY = builder.get_object("adjTestPointY")
        self.adjTestPointX.set_upper(self.MAX_TEST_POINT_X)
        self.adjTestPointY.set_upper(self.MAX_TEST_POINT_Y)
        self.sequenceControls = builder.get_object("sequenceControls")
        self.sequenceControls.setNumFrames(len(
            self.inputSequence.cameraImages))
        self.sequenceControls.setOnFrameIdxChangedCallback(
            self.onSequenceControlsFrameIdxChanged)
        self.setFrameIdx(0)
        self.processOpticalFlowData()

        builder.connect_signals(self)

        updateLoop = self.update()
        gobject.idle_add(updateLoop.next)

        self.window.show()
        self.window.maximize()

    #---------------------------------------------------------------------------
    def onWinMainDestroy(self, widget, data=None):
        gtk.main_quit()

    #---------------------------------------------------------------------------
    def main(self):
        # All PyGTK applications must have a gtk.main(). Control ends here
        # and waits for an event to occur (like a key press or mouse event).
        gtk.main()

    #---------------------------------------------------------------------------
    def processOpticalFlowData(self):

        testX = int(self.adjTestPointX.get_value())
        testY = int(self.adjTestPointY.get_value())

        regSeq = self.regularisedInputSequence
        corSeq = self.crossCorrelatedSequence

        # Normalise the data ready for display
        normalisedServoAngleData = Utils.normaliseSequence(
            regSeq.regularServoAngleData)
        normalisedOpticalFlowDataX = Utils.normaliseSequence(
            regSeq.regularOpticalFlowArrayX[testY][testX])
        normalisedOpticalFlowDataY = Utils.normaliseSequence(
            regSeq.regularOpticalFlowArrayY[testY][testX])

        numCorrelationChannels = len(corSeq.correlationChannels)

        # Plot graphs
        self.axisX.clear()
        self.axisX.plot(regSeq.regularSampleTimes, normalisedServoAngleData)
        self.axisX.plot(regSeq.regularSampleTimes, normalisedOpticalFlowDataX)

        if numCorrelationChannels >= 1:
            correlationChannel = corSeq.correlationChannels[0][testY][testX]
            self.axisX.plot(
                regSeq.regularSampleTimes[:len(correlationChannel)],
                correlationChannel)

        self.axisY.clear()
        self.axisY.plot(regSeq.regularSampleTimes, normalisedServoAngleData)
        self.axisY.plot(regSeq.regularSampleTimes, normalisedOpticalFlowDataY)

        inpSeq = self.inputSequence
        self.axisY.plot(
            inpSeq.imageTimes,
            Utils.normaliseSequence(inpSeq.opticalFlowArraysY[testY][testX]))

        if numCorrelationChannels >= 2:
            correlationChannel = corSeq.correlationChannels[1][testY][testX]
            self.axisY.plot(
                regSeq.regularSampleTimes[:len(correlationChannel)],
                correlationChannel)

        self.axisROC.clear()
        self.axisROC.plot(self.rocCurve.falsePositiveRates,
                          self.rocCurve.truePositiveRates)
        #self.axisROC.plot( self.rocCurve.thresholds, self.rocCurve.sensitivity )
        #self.axisROC.plot( self.rocCurve.thresholds, self.rocCurve.specificity )

        self.refreshGraphDisplay()

        outputFile = open(
            self.scriptPath + "/../../test_results/CrossCorrelation.csv", "w")
        numSamples = len(regSeq.regularSampleTimes)

        correlationChannel = corSeq.correlationChannels[0][testY][testX]
        numCorrelationSamples = len(correlationChannel)

        print >> outputFile, "Time,Input,Output,CrossCorrelation"
        for i in range(numSamples):

            if i < numCorrelationSamples:
                correlationData = correlationChannel[i]
            else:
                correlationData = 0.0

            print >> outputFile, "{0},{1},{2},{3}".format(
                regSeq.regularSampleTimes[i], normalisedServoAngleData[i],
                normalisedOpticalFlowDataX[i], correlationData)

        outputFile.close()

    #---------------------------------------------------------------------------
    def refreshGraphDisplay(self):

        if self.canvas != None:
            self.hboxWorkArea.remove(self.canvas)
            self.canvas.destroy()
            self.canvas = None
        if self.navToolbar != None:
            self.vboxMain.remove(self.navToolbar)
            self.navToolbar.destroy()
            self.navToolbar = None

        self.canvas = FigureCanvas(self.figure)  # a gtk.DrawingArea
        self.canvas.show()
        self.hboxWorkArea.pack_start(self.canvas, True, True)
        self.hboxWorkArea.show()

        # Navigation toolbar
        self.navToolbar = NavigationToolbar(self.canvas, self.window)
        self.navToolbar.lastDir = '/var/tmp/'
        self.vboxMain.pack_start(self.navToolbar, expand=False, fill=False)
        self.navToolbar.show()
        self.vboxMain.show()

    #---------------------------------------------------------------------------
    def setFrameIdx(self, frameIdx):

        self.frameIdx = frameIdx

        # Display the frame
        image = self.inputSequence.cameraImages[frameIdx]
        imageWidth = image.shape[1]
        imageHeight = image.shape[0]
        imageStep = imageWidth * 3

        self.cameraImagePixBuf = gtk.gdk.pixbuf_new_from_data(
            image.tostring(), gtk.gdk.COLORSPACE_RGB, False, 8, imageWidth,
            imageHeight, imageStep)

        # Track gripper
        imageRGB = cv.CreateImageHeader((imageWidth, imageHeight),
                                        cv.IPL_DEPTH_8U, 3)
        cv.SetData(imageRGB, image.data, imageStep)
        imageRGB = cv.CloneImage(imageRGB)

        r_plane = cv.CreateMat(imageRGB.height, imageRGB.width, cv.CV_8UC1)
        g_plane = cv.CreateMat(imageRGB.height, imageRGB.width, cv.CV_8UC1)
        b_plane = cv.CreateMat(imageRGB.height, imageRGB.width, cv.CV_8UC1)
        cv.Split(imageRGB, r_plane, g_plane, b_plane, None)
        planes = [r_plane, g_plane, b_plane]

        backproject = cv.CreateImage(cv.GetSize(imageRGB), 8, 1)

        # Run the cam-shift
        cv.CalcArrBackProject(planes, backproject, self.gripperHistogram)
        #cv.Threshold( backproject, backproject, 1, 255, cv.CV_THRESH_BINARY )
        cv.CvtColor(backproject, imageRGB, cv.CV_GRAY2RGB)

        #self.cameraImagePixBuf = gtk.gdk.pixbuf_new_from_data(
        #imageRGB.tostring(),
        #gtk.gdk.COLORSPACE_RGB,
        #False,
        #8,
        #imageRGB.width,
        #imageRGB.height,
        #imageRGB.width*3 )

        # Resize the drawing area if necessary
        if self.dwgCameraImage.get_size_request() != (imageWidth, imageHeight):
            self.dwgCameraImage.set_size_request(imageWidth, imageHeight)

        self.dwgCameraImage.queue_draw()

    #---------------------------------------------------------------------------
    def onTestPointAdjustmentValueChanged(self, widget):
        self.processOpticalFlowData()
        self.dwgCameraImage.queue_draw()

    #---------------------------------------------------------------------------
    def onSequenceControlsFrameIdxChanged(self, widget):
        self.setFrameIdx(widget.frameIdx)

    #---------------------------------------------------------------------------
    def onDwgCameraImageButtonPressEvent(self, widget, data):

        if self.cameraImagePixBuf != None:

            imgRect = self.getImageRectangleInWidget(
                widget, self.cameraImagePixBuf.get_width(),
                self.cameraImagePixBuf.get_height())

            self.adjTestPointX.set_value(
                int((data.x - imgRect.x) / self.OPTICAL_FLOW_BLOCK_WIDTH))
            self.adjTestPointY.set_value(
                int((data.y - imgRect.y) / self.OPTICAL_FLOW_BLOCK_HEIGHT))

    #---------------------------------------------------------------------------
    def onDwgCameraImageExposeEvent(self, widget, data=None):

        if self.cameraImagePixBuf != None:

            imgRect = self.getImageRectangleInWidget(
                widget, self.cameraImagePixBuf.get_width(),
                self.cameraImagePixBuf.get_height())

            imgOffsetX = imgRect.x
            imgOffsetY = imgRect.y

            # Get the total area that needs to be redrawn
            imgRect = imgRect.intersect(data.area)

            srcX = imgRect.x - imgOffsetX
            srcY = imgRect.y - imgOffsetY

            widget.window.draw_pixbuf(
                widget.get_style().fg_gc[gtk.STATE_NORMAL],
                self.cameraImagePixBuf, srcX, srcY, imgRect.x, imgRect.y,
                imgRect.width, imgRect.height)

            #return

            # Draw an overlay to show places where the input motion has been detected
            if self.inputSignalDetectedArray != None:

                imageData = np.frombuffer(self.cameraImagePixBuf.get_pixels(),
                                          dtype=np.uint8)
                imageData.shape = (self.cameraImagePixBuf.get_height(),
                                   self.cameraImagePixBuf.get_width(), 3)

                graphicsContext = widget.window.new_gc()
                graphicsContext.set_rgb_fg_color(gtk.gdk.Color(
                    65535, 65535, 0))

                blockY = imgRect.y
                for y in range(self.inputSignalDetectedArray.shape[0]):

                    blockX = imgRect.x
                    for x in range(self.inputSignalDetectedArray.shape[1]):

                        if self.inputSignalDetectedArray[y, x]:

                            # Get source block as NumPy array
                            srcX = blockX - imgRect.x
                            srcY = blockY - imgRect.y
                            srcData = imageData[
                                srcY:srcY + self.OPTICAL_FLOW_BLOCK_HEIGHT,
                                srcX:srcX + self.OPTICAL_FLOW_BLOCK_WIDTH, :]

                            # Create a modified version of the block with a yellow layer
                            # alpha blended over the top
                            yellowLayer = np.ones(
                                (self.OPTICAL_FLOW_BLOCK_WIDTH,
                                 self.OPTICAL_FLOW_BLOCK_HEIGHT,
                                 3)) * [255.0, 255.0, 0.0] * 0.5
                            modifiedData = (srcData.astype(np.float32) * 0.5 +
                                            yellowLayer).astype(np.uint8)

                            # Blit the modified version into the widget
                            modifiedPixBuf = gtk.gdk.pixbuf_new_from_array(
                                modifiedData, gtk.gdk.COLORSPACE_RGB, 8)

                            widget.window.draw_pixbuf(
                                widget.get_style().fg_gc[gtk.STATE_NORMAL],
                                modifiedPixBuf, 0, 0, blockX, blockY,
                                self.OPTICAL_FLOW_BLOCK_WIDTH,
                                self.OPTICAL_FLOW_BLOCK_HEIGHT)

                        blockX += self.OPTICAL_FLOW_BLOCK_WIDTH

                    blockY += self.OPTICAL_FLOW_BLOCK_HEIGHT

            return

            # Draw the optical flow if it's available
            opticalFlowX = self.inputSequence.opticalFlowArraysX[:, :,
                                                                 self.frameIdx]
            opticalFlowY = self.inputSequence.opticalFlowArraysY[:, :,
                                                                 self.frameIdx]
            if opticalFlowX != None and opticalFlowY != None:

                testX = int(self.adjTestPointX.get_value())
                testY = int(self.adjTestPointY.get_value())

                graphicsContext = widget.window.new_gc()
                graphicsContext.set_rgb_fg_color(gtk.gdk.Color(0, 65535, 0))

                blockCentreY = imgRect.y + self.OPTICAL_FLOW_BLOCK_HEIGHT / 2
                for y in range(opticalFlowX.shape[0]):

                    blockCentreX = imgRect.x + self.OPTICAL_FLOW_BLOCK_WIDTH / 2
                    for x in range(opticalFlowX.shape[1]):

                        if testX == x and testY == y:
                            # Highlight the current test point
                            radius = 2
                            arcX = int(blockCentreX - radius)
                            arcY = int(blockCentreY - radius)
                            arcWidth = arcHeight = int(radius * 2)

                            drawFilledArc = False
                            graphicsContext.set_rgb_fg_color(
                                gtk.gdk.Color(65535, 65535, 65535))

                            widget.window.draw_arc(graphicsContext,
                                                   drawFilledArc, arcX, arcY,
                                                   arcWidth, arcHeight, 0,
                                                   360 * 64)

                        endX = blockCentreX + opticalFlowX[y, x]
                        endY = blockCentreY + opticalFlowY[y, x]

                        if endY < blockCentreY:
                            # Up is red
                            graphicsContext.set_rgb_fg_color(
                                gtk.gdk.Color(65535, 0, 0))
                        elif endY > blockCentreY:
                            # Down is blue
                            graphicsContext.set_rgb_fg_color(
                                gtk.gdk.Color(0, 0, 65535))
                        else:
                            # Static is green
                            graphicsContext.set_rgb_fg_color(
                                gtk.gdk.Color(0, 65535, 0))

                        widget.window.draw_line(graphicsContext,
                                                int(blockCentreX),
                                                int(blockCentreY), int(endX),
                                                int(endY))

                        blockCentreX += self.OPTICAL_FLOW_BLOCK_WIDTH

                    blockCentreY += self.OPTICAL_FLOW_BLOCK_HEIGHT

    #---------------------------------------------------------------------------
    def getImageRectangleInWidget(self, widget, imageWidth, imageHeight):

        # Centre the image inside the widget
        widgetX, widgetY, widgetWidth, widgetHeight = widget.get_allocation()

        imgRect = gtk.gdk.Rectangle(0, 0, widgetWidth, widgetHeight)

        if widgetWidth > imageWidth:
            imgRect.x = (widgetWidth - imageWidth) / 2
            imgRect.width = imageWidth

        if widgetHeight > imageHeight:
            imgRect.y = (widgetHeight - imageHeight) / 2
            imgRect.height = imageHeight

        return imgRect

    #---------------------------------------------------------------------------
    def update(self):

        UPDATE_FREQUENCY = 30.0  # Updates in Hz

        lastTime = time.clock()

        while 1:

            curTime = time.clock()

            if curTime - lastTime >= 1.0 / UPDATE_FREQUENCY:

                # Save the time
                lastTime = curTime

            yield True

        yield False
Beispiel #15
0
class HurricaneUI:
    
    def __init__(self):
        gladefile = "HurricaneUI.glade"
        builder = gtk.Builder()
        builder.add_from_file(gladefile)
        self.window = builder.get_object("mainWindow")
        builder.connect_signals(self)

        self.figure = Figure(figsize=(10,10), dpi=75)
        self.axis = self.figure.add_subplot(111)

        self.lat = 50
        self.lon = -100
        self.globe= globDisp.GlobeMap(self.axis, self.lat, self.lon)

        self.canvas = FigureCanvasGTK(self.figure)  
        self.canvas.show()
        self.canvas.set_size_request(500,500)

        self.globeview = builder.get_object("map")
        self.globeview.pack_start(self.canvas, True, True)

        self.navToolbar = NavigationToolbar(self.canvas, self.globeview)
        self.navToolbar.lastDir = '/var/tmp'
        self.globeview.pack_start(self.navToolbar)
        self.navToolbar.show()

        self.gridcombo = builder.get_object("gridsize")
        cell=gtk.CellRendererText()
        self.gridcombo.pack_start(cell,True)
        self.gridcombo.add_attribute(cell, 'text', 0)
        #self.gridcombo.set_active(2)

        # read menu configuration  
        self.gridopt = builder.get_object("gridopt").get_active()
        self.chkDetected = builder.get_object("detectedopt")
        self.detectedopt = self.chkDetected.get_active()
        self.chkHurricane = builder.get_object("hurricaneopt")
        self.hurricaneopt = self.chkHurricane.get_active()
        model = builder.get_object("liststore1")
        index = self.gridcombo.get_active()
        self.gridsize = model[index][0]
        radio = [ r for r in builder.get_object("classifieropt1").get_group() if r.get_active() ][0]
        self.sClassifier = radio.get_label()
        self.start = builder.get_object("startdate")
        self.end = builder.get_object("enddate")

        self.chkUndersample = builder.get_object("undersample")
        self.chkGenKey = builder.get_object("genKey")

        # disable unimplemented classifier selection
        builder.get_object("classifieropt2").set_sensitive(False)
        builder.get_object("classifieropt3").set_sensitive(False)
        builder.get_object("classifieropt4").set_sensitive(False)

        self.btnStore =  builder.get_object("store")
        self.datapath = 'GFSdat'
        self.trackpath = 'tracks'
        builder.get_object("btnDatapath").set_current_folder(self.datapath)
        builder.get_object("btnTrackpath").set_current_folder(self.trackpath)
        self.btnDetect =  builder.get_object("detect")

        # current operation status
        self.stormlocs = None
        self.detected = None
        self.clssfr = None

        # for test drawing functions
        if os.path.exists('demo.detected'):
            with open('demo.detected','r') as f:
                self.detected = pickle.load(f)
                self.stormlocs = pickle.load(f)
                self.chkHurricane.set_label(str(self.stormlocs.shape[0])+" Hurricanes")
                self.chkDetected.set_label(str(self.detected.shape[0])+" Detected")

        self.setDisabledBtns()

        # draw Globe
        self.drawGlobe()


    def setDisabledBtns(self):
        self.chkDetected.set_sensitive(self.detected!=None)
        self.chkHurricane.set_sensitive(self.stormlocs!=None)
        self.btnStore.set_sensitive(self.clssfr!=None)
        self.btnDetect.set_sensitive(self.clssfr!=None)
    
    def drawGlobe(self):
        self.globe.drawGlobe(self.gridsize, self.gridopt)
        if self.hurricaneopt : self.globe.drawHurricanes(self.stormlocs)
        if self.detectedopt : self.globe.fillGrids(self.detected)
        
    def main(self):
        self.window.show_all()
        gtk.main()

    def redraw(self):
        self.axis.cla()
        self.drawGlobe()
        self.canvas.draw_idle()

    def gtk_main_quit(self,widget):
        gtk.main_quit()
 

    ###############################################################################
    #
    #  utility functions (dialogs)
    #
    def getFilenameToRead(self, stitle, save=False, filter='all'):
        chooser = gtk.FileChooserDialog(title=stitle, parent=self.window,
                                        action=gtk.FILE_CHOOSER_ACTION_OPEN if not save else gtk.FILE_CHOOSER_ACTION_SAVE,
                                        buttons=(gtk.STOCK_CANCEL,gtk.RESPONSE_CANCEL,gtk.STOCK_OPEN if not save else gtk.STOCK_SAVE,gtk.RESPONSE_OK))
        chooser.set_default_response(gtk.RESPONSE_OK)

        if filter=='mat' or filter=='mdat':
            filter = gtk.FileFilter()
            filter.set_name("Matrix files")
            filter.add_pattern("*.mat")
            chooser.add_filter(filter)
        if filter=='svm':
            filter = gtk.FileFilter()
            filter.set_name("SVM")
            filter.add_pattern("*.svm")
            chooser.add_filter(filter)
        if filter=='dat' or filter=='mdat':
            filter = gtk.FileFilter()
            filter.set_name("Data")
            filter.add_pattern("*.dat")
            chooser.add_filter(filter)

        filter = gtk.FileFilter()
        filter.set_name("All files")
        filter.add_pattern("*")
        chooser.add_filter(filter)

        chooser.set_current_folder(os.getcwd())

        response = chooser.run()
        if response == gtk.RESPONSE_OK:
            filen = chooser.get_filename()
        else: filen = None
        chooser.destroy()
        return filen


    def showMessage(self,msg):
        md = gtk.MessageDialog(self.window, gtk.DIALOG_DESTROY_WITH_PARENT, gtk.MESSAGE_INFO, 
                                gtk.BUTTONS_CLOSE, msg)
        md.run()
        md.destroy()


    ###############################################################################
    #
    #  read data files and convert into a training matrix input 
    #
    def on_btnTrackpath_current_folder_changed(self,widget):
        self.trackpath = widget.get_current_folder()

    def on_btnDatapath_current_folder_changed(self,widget):
        self.datapath = widget.get_current_folder()

    def on_createMat_clicked(self,widget):
        filen = self.getFilenameToRead("Save converted matrix for training", True, filter='mat')
        if filen is not None:
            start = self.start.get_text()        
            end = self.end.get_text()        
            bundersmpl = self.chkUndersample.get_active()
            bgenkey = self.chkGenKey.get_active()

            ### FIX ME: currently gridsize for classification is fixed 1 (no clustering of grids - 2x2)
            if os.path.exists(filen): os.unlink(filen) # createMat append existing file, so delete it if exist
            gdtool.createMat(self.datapath, self.trackpath, start, end, store=filen, 
                             undersample=bundersmpl, genkeyf=bgenkey)
            self.showMessage("Matrix has been stored to "+filen)


    ###############################################################################
    #
    #   train the selected classifier
    #
    def on_train_clicked(self, widget):
        # FOR NOW, only SVM is supported
        if self.sClassifier == "SVM":
            filen = self.getFilenameToRead("Open training data",filter='mat')
            if filen is not None:
                data = ml.VectorDataSet(filen,labelsColumn=0)
                self.clssfr = ml.SVM()
                self.clssfr.train(data)
                # train finished. need to update button status
                self.setDisabledBtns()
                self.showMessage("Training SVM is done.")
        else :
            self.showMessage("The classifier is not supported yet!")

    def on_classifieropt_toggled(self,widget, data=None):
        self.sClassifier = widget.get_label()


    ###############################################################################
    #
    #   Classify on test data
    #
    def on_detect_clicked(self, widget):
        if self.clssfr is not None:
            filen = self.getFilenameToRead("Open hurricane data", filter='mdat')
            if filen is not None:
                fname = os.path.basename(filen)
                key, ext = os.path.splitext(fname)
                if ext == '.dat':
                    key = key[1:] # take 'g' out

                    #testData = gdtool.createMat(self.datapath, self.trackpath, key, key)
                    #result = self.clssfr.test(ml.VectorDataSet(testData,labelsColumn=0))
                    tmpfn = 'f__tmpDetected__'
                    if os.path.exists(tmpfn): os.unlink(tmpfn)
                    # for DEMO, undersampled the normal data -- without undersampling there are too many candidates
                    gdtool.createMat(self.datapath, self.trackpath, key, key, store=tmpfn, undersample=True, genkeyf=True)
                    bneedDel = True
                else:
                    tmpfn = fname
                    bneedDel = False
                result = self.clssfr.test(ml.VectorDataSet(tmpfn,labelsColumn=0))

                gdkeyfilen = ''.join([tmpfn,'.keys'])
                with open(gdkeyfilen, 'r') as f:
                    gridkeys = pickle.load(f)
                    self.stormlocs = pickle.load(f)
                predicted = result.getPredictedLabels()
                predicted = np.array(map(float,predicted))
                self.detected = np.array(gridkeys)[predicted==1]
                if bneedDel: 
                    os.unlink(tmpfn)
                    os.unlink(gdkeyfilen)

                snstroms = str(self.stormlocs.shape[0])
                sndetected = str(self.detected.shape[0])
                self.chkHurricane.set_label(snstroms+" Hurricanes")
                self.chkDetected.set_label(sndetected+" Detected")

                self.showMessage(''.join([sndetected,"/",snstroms," grids are predicted to have hurricane."]))
                if False:
                    with open('demo.detected','w') as f:
                        pickle.dump(self.detected,f)
                        pickle.dump(self.stormlocs,f)

                # test data tested. update buttons
                self.setDisabledBtns()
                self.redraw()
        else:
            self.showMessage("There is no trained classifier!")


    ###############################################################################
    #
    #   load and store trained classifier
    #
    def on_load_clicked(self, widget):
        filen = self.getFilenameToRead("Load Classifier",filter='svm')
        if filen is not None:
            #db = shelve.open(filen)
            #if db.has_key('clssfr'):
            #    self.clssfr = db['clssfr'] 
            #else:
            #    self.showMessage("Cannot find a classifier!")
            #db.close()
            #with open(filen, 'wb') as f:
            #    self.clssfr = pickle.load(f)

            datfn = self.getFilenameToRead("Open Training Data",filter='mat')
            if datfn is not None:
                data = ml.VectorDataSet(datfn,labelsColumn=0)
                self.clssfr = loadSVM(filen,data) ## Why do I need to feed data ???
            
            #self.clssfr = loadSVM(filen,None) ## edited PyML for this

            # classifier has been loaded. need to update button status
            self.setDisabledBtns()
            self.showMessage("The classifier has been loaded!")

    def on_store_clicked(self, widget):
        if self.clssfr is not None:
            filen = self.getFilenameToRead("Store Classifier", True, filter='svm')
            if filen is not None:
                #with open(filen, 'wb') as f:
                #    pickle.dump(self.clssfr,f)
                #db = shelve.open(filen)
                #db['clssfr'] = self.clssfr
                #db.close()

                self.clssfr.save(filen)
                self.showMessage("The classifier has been saved!")
        else:
            self.showMessage("There is no trained classifier!")


    ###############################################################################
    #
    #   Display Globe 
    #
    def on_right_clicked(self,widget):
        self.lon += 10
        # rotate Globe

    def on_left_clicked(self,widget):
        self.lon -= 10
        # rotate Globe

    def gridsize_changed_cb(self, widget):
        model = widget.get_model()
        index = widget.get_active()
        if index > -1:
            self.gridsize = model[index][0]
        self.redraw()

    def on_gridopt_toggled(self, widget):
        self.gridopt = not self.gridopt
        self.redraw()
        
    def on_Hurricane_toggled(self, widget):
        self.hurricaneopt = not self.hurricaneopt
        self.redraw()

    def on_detected_toggled(self, widget):
        self.detectedopt = not self.detectedopt 
        self.redraw()
Beispiel #16
0
class PlotViewer(gtk.VBox):
    def __init__(self, plotters, fields):
        gtk.VBox.__init__(self)

        self.figure = mpl.figure.Figure()
        self.canvas = FigureCanvas(self.figure)
        self.canvas.unset_flags(gtk.CAN_FOCUS)
        self.canvas.set_size_request(600, 400)
        self.pack_start(self.canvas, True, True)
        self.canvas.show()

        self.navToolbar = NavigationToolbar(self.canvas, self.window)
        #self.navToolbar.lastDir = '/tmp'
        self.pack_start(self.navToolbar, False, False)
        self.navToolbar.show()

        self.checkboxes = gtk.HBox(len(plotters))
        self.pack_start(self.checkboxes, False, False)
        self.checkboxes.show()

        self.pol = (1 + 0j, 0j)
        self.pol2 = None

        self.handlers = []

        i = 0
        self.plots = []
        for plotterClass, default in plotters:
            axes = self.figure.add_subplot(len(plotters), 1, i)

            def makeUpdateInfo(i):
                return lambda s: self.__updateInfo(i, s)

            def makeUpdatePos(i):
                return lambda s: self.__updatePos(i, s)

            plotter = plotterClass(axes, fields, makeUpdateInfo(i),
                                   makeUpdatePos(i))
            d = PlottedData(axes, plotter, default, self.__updateChildren)
            self.checkboxes.pack_start(d.checkBox, False, False)
            self.plots.append(d)
            i += 1
        self.__infos = [None] * len(self.plots)
        self.__posi = None

        self.legendBox = gtk.CheckButton("Show legend")
        self.legendBox.set_active(True)
        self.legendBox.connect("toggled", self.__on_legend_toggled)
        self.checkboxes.pack_start(self.legendBox, False, False)
        self.legendBox.show()

        self.__updateChildren()

    def __on_legend_toggled(self, button):
        for pd in self.plots:
            pd.plotter.setLegend(button.get_active())

    def __updateChildren(self):
        count = 0
        for axes in self.figure.get_axes():
            visible = axes.get_visible()
            if axes.get_visible():
                count += 1
        if count == 0:
            count = 1
        nr = 1
        for axes in self.figure.get_axes():
            axes.change_geometry(count, 1, nr)
            if axes.get_visible():
                if nr < count:
                    nr += 1
            else:
                axes.set_position(
                    (0, 0, 1e-10, 1e-10)
                )  # Hack to prevent the invisible axes from getting mouse events
        self.figure.canvas.draw()
        self.__updateGraph()

    def __updateGraph(self):
        for pd in self.plots:
            if pd.axes.get_visible():
                #start = time ()
                pd.plotter.plot(self.pol, self.pol2)
                #print "Plot ", pd.plotter.description, " needed ", time () - start

    def __updateInfo(self, i, arg):
        #print i, arg
        self.__infos[i] = arg
        s = ''
        for info in self.__infos:
            if info is not None:
                if s != '':
                    s += ' '
                s += info
        for handler in self.handlers:
            handler(s)

    def __updatePos(self, i, arg):
        if arg == None and self.__posi != i:
            return
        self.__posi = i
        j = 0
        for pd in self.plots:
            if i != j:
                pd.plotter.updateCPos(arg)
            j += 1

    def onUpdateInfo(self, handler):
        self.handlers.append(handler)

    def setPol(self, value):
        oldValue = self.pol
        self.pol = value
        if value != oldValue:
            self.__updateGraph()

    def setPol2(self, value):
        oldValue = self.pol2
        self.pol2 = value
        if value != oldValue:
            self.__updateGraph()