Example #1
0
    def __init__(self):
        super(SPBTrajectoryViewer, self).__init__()

        self.lenpole = 1.0

        # Get required observables
        self.trajectoryObservable = \
                OBSERVABLES.getAllObservablesOfType(TrajectoryObservable)[0]

        # Create matplotlib widgets
        plotWidget = QtGui.QWidget(self)
        plotWidget.setMinimumSize(600, 500)
        plotWidget.setWindowTitle("SPB Cart Viewer")

        self.fig = Figure((6.0, 5.0), dpi=100)
        self.axis = self.fig.gca()
        self.axis.set_xlim((-3.125, 3.125))
        self.axis.set_ylim((-0.5, 5.5))

        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(plotWidget)
        self.canvas.draw()

        # Create layout
        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(plotWidget)
        self.setLayout(self.hlayout)

        # Connect to observer (has to be the last thing!!)
        self.trajectoryObservableCallback = \
             lambda *transition: self._updateSamples(*transition)
        self.trajectoryObservable.addObserver(
            self.trajectoryObservableCallback)
    def __init__(self, stateSpace):
        super(SeventeenAndFourValuefunctionViewer, self).__init__()

        self.stateSpace = stateSpace
        self.states = stateSpace["count"]["dimensionValues"]

        # Combo Box for selecting the observable
        self.comboBox = QtGui.QComboBox(self)
        self.stateActionValuesObservables = \
                OBSERVABLES.getAllObservablesOfType(StateActionValuesObservable)
        self.comboBox.addItems(
            map(lambda x: "%s" % x.title, self.stateActionValuesObservables))
        self.connect(self.comboBox, QtCore.SIGNAL('currentIndexChanged (int)'),
                     self._observableChanged)

        # Automatically update combobox when new float stream observables
        #  are created during runtime
        def updateComboBox(observable, action):
            self.comboBox.clear()
            self.stateActionValuesObservables = \
                    OBSERVABLES.getAllObservablesOfType(StateActionValuesObservable)
            self.comboBox.addItems(
                map(lambda x: "%s" % x.title,
                    self.stateActionValuesObservables))

        OBSERVABLES.addObserver(updateComboBox)

        # Create matplotlib widgets
        plotWidgetValueFunction = QtGui.QWidget(self)
        plotWidgetValueFunction.setMinimumSize(800, 500)

        self.figValueFunction = Figure((8.0, 5.0), dpi=100)
        #self.figValueFunction.subplots_adjust(left=0.01, bottom=0.04, right=0.99,
        #                               top= 0.95, wspace=0.05, hspace=0.11)
        self.axisValueFunction = self.figValueFunction.gca()
        self.canvasValueFunction = FigureCanvas(self.figValueFunction)
        self.canvasValueFunction.setParent(plotWidgetValueFunction)

        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(plotWidgetValueFunction)
        self.hlayout.addWidget(self.comboBox)
        self.setLayout(self.hlayout)

        # Connect to observer (has to be the last thing!!)
        self.stateActionValuesObservableCallback = \
             lambda valueAccessFunction, actions: self.updateValues(valueAccessFunction, actions)
        if len(self.stateActionValuesObservables) > 0:
            # Show per default the first observable
            self.stateActionValuesObservable = self.stateActionValuesObservables[
                0]
            plotWidgetValueFunction.setWindowTitle(
                self.stateActionValuesObservable.title)

            self.stateActionValuesObservable.addObserver(
                self.stateActionValuesObservableCallback)
        else:
            self.stateActionValuesObservable = None
Example #3
0
    def __init__(self, stateSpace):        
        super(MountainCarPolicyViewer, self).__init__()
        self.stateSpace = stateSpace
        self.actions = []
        self.colors = ['r','g','b', 'c', 'y']
        
        self.lock = threading.Lock()
        
        # Add a combobox for selecting the policy observable
        self.policyObservableLabel = QtGui.QLabel("Policy Observable")
        self.policyObservableComboBox = QtGui.QComboBox(self)
        policyObservables = \
            OBSERVABLES.getAllObservablesOfType(FunctionOverStateSpaceObservable)
        self.policyObservableComboBox.addItems([policyObservable.title 
                                                 for policyObservable in policyObservables])
        self.selectedPolicyObservable = None
        if len(policyObservables) > 0:
            self.selectedPolicyObservable = policyObservables[0].title
        
        self.connect(self.policyObservableComboBox,
                     QtCore.SIGNAL('activated (const QString&)'), 
                     self._policyObservableChanged) 
        
        # Automatically update policy observable combobox when new observables 
        # are created during runtime
        def updatePolicyObservableBox(viewer, action):
            self.policyObservableComboBox.clear()
            policyObservables = OBSERVABLES.getAllObservablesOfType(FunctionOverStateSpaceObservable)
            self.policyObservableComboBox.addItems([policyObservable.title 
                                                for policyObservable in policyObservables])
            if len(policyObservables) > 0:
                self.selectedPolicyObservable = policyObservables[0].title
            else: 
                self.selectedPolicyObservable = None
            
        OBSERVABLES.addObserver(updatePolicyObservableBox)
        
        # Get trajectory observable which is required for informing about end of episode
        self.trajectoryObservable = \
                OBSERVABLES.getAllObservablesOfType(TrajectoryObservable)[0]
        self.episodeTerminated = False

        # Slider that controls the granularity of the plot-grid
        self.gridNodesPerDim = 25        
        self.gridNodesSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.gridNodesSlider.setValue(self.gridNodesPerDim)
        self.gridNodesSlider.setMinimum(0)
        self.gridNodesSlider.setMaximum(100)
        self.gridNodesSlider.setTickInterval(10)
        self.gridNodesSlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.gridNodesSlider, QtCore.SIGNAL('sliderReleased()'), 
                     self._changeGridNodes)
        self.gridNodesLabel = QtGui.QLabel("Grid Nodes Per Dimension: %s" 
                                           % self.gridNodesPerDim )
                
        # Create matplotlib widgets
        plotWidget = QtGui.QWidget(self)
        plotWidget.setMinimumSize(600, 500)
        plotWidget.setWindowTitle("Policy")
 
        self.fig = Figure((6.0, 5.0), dpi=100)
        self.axis = self.fig.gca()
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(plotWidget)    
        
        # Small text in plot legend
        matplotlib.rcParams.update({'legend.fontsize': 6})
        
        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(plotWidget)
        self.vlayout = QtGui.QVBoxLayout()
        self.vlayout.addWidget(self.policyObservableLabel)
        self.vlayout.addWidget(self.policyObservableComboBox)
        self.vlayout.addWidget(self.gridNodesLabel)
        self.vlayout.addWidget(self.gridNodesSlider)
        self.hlayout.addLayout(self.vlayout)
        
        self.setLayout(self.hlayout)
        
        # Connect to observer (has to be the last thing!!)
        self.trajectoryObservableCallback = \
             lambda *transition: self.updateSamples(*transition)
        self.trajectoryObservable.addObserver(self.trajectoryObservableCallback)
        
        self.policyObservable = None
        if self.selectedPolicyObservable:
            self.policyObservable = OBSERVABLES.getObservable(self.selectedPolicyObservable,
                                                              FunctionOverStateSpaceObservable)
            self.policyObservableCallback = \
                 lambda policyEvalFunction: self.updatePolicy(policyEvalFunction)
            self.policyObservable.addObserver(self.policyObservableCallback)
Example #4
0
class MountainCarPolicyViewer(Viewer):
    
    def __init__(self, stateSpace):        
        super(MountainCarPolicyViewer, self).__init__()
        self.stateSpace = stateSpace
        self.actions = []
        self.colors = ['r','g','b', 'c', 'y']
        
        self.lock = threading.Lock()
        
        # Add a combobox for selecting the policy observable
        self.policyObservableLabel = QtGui.QLabel("Policy Observable")
        self.policyObservableComboBox = QtGui.QComboBox(self)
        policyObservables = \
            OBSERVABLES.getAllObservablesOfType(FunctionOverStateSpaceObservable)
        self.policyObservableComboBox.addItems([policyObservable.title 
                                                 for policyObservable in policyObservables])
        self.selectedPolicyObservable = None
        if len(policyObservables) > 0:
            self.selectedPolicyObservable = policyObservables[0].title
        
        self.connect(self.policyObservableComboBox,
                     QtCore.SIGNAL('activated (const QString&)'), 
                     self._policyObservableChanged) 
        
        # Automatically update policy observable combobox when new observables 
        # are created during runtime
        def updatePolicyObservableBox(viewer, action):
            self.policyObservableComboBox.clear()
            policyObservables = OBSERVABLES.getAllObservablesOfType(FunctionOverStateSpaceObservable)
            self.policyObservableComboBox.addItems([policyObservable.title 
                                                for policyObservable in policyObservables])
            if len(policyObservables) > 0:
                self.selectedPolicyObservable = policyObservables[0].title
            else: 
                self.selectedPolicyObservable = None
            
        OBSERVABLES.addObserver(updatePolicyObservableBox)
        
        # Get trajectory observable which is required for informing about end of episode
        self.trajectoryObservable = \
                OBSERVABLES.getAllObservablesOfType(TrajectoryObservable)[0]
        self.episodeTerminated = False

        # Slider that controls the granularity of the plot-grid
        self.gridNodesPerDim = 25        
        self.gridNodesSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.gridNodesSlider.setValue(self.gridNodesPerDim)
        self.gridNodesSlider.setMinimum(0)
        self.gridNodesSlider.setMaximum(100)
        self.gridNodesSlider.setTickInterval(10)
        self.gridNodesSlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.gridNodesSlider, QtCore.SIGNAL('sliderReleased()'), 
                     self._changeGridNodes)
        self.gridNodesLabel = QtGui.QLabel("Grid Nodes Per Dimension: %s" 
                                           % self.gridNodesPerDim )
                
        # Create matplotlib widgets
        plotWidget = QtGui.QWidget(self)
        plotWidget.setMinimumSize(600, 500)
        plotWidget.setWindowTitle("Policy")
 
        self.fig = Figure((6.0, 5.0), dpi=100)
        self.axis = self.fig.gca()
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(plotWidget)    
        
        # Small text in plot legend
        matplotlib.rcParams.update({'legend.fontsize': 6})
        
        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(plotWidget)
        self.vlayout = QtGui.QVBoxLayout()
        self.vlayout.addWidget(self.policyObservableLabel)
        self.vlayout.addWidget(self.policyObservableComboBox)
        self.vlayout.addWidget(self.gridNodesLabel)
        self.vlayout.addWidget(self.gridNodesSlider)
        self.hlayout.addLayout(self.vlayout)
        
        self.setLayout(self.hlayout)
        
        # Connect to observer (has to be the last thing!!)
        self.trajectoryObservableCallback = \
             lambda *transition: self.updateSamples(*transition)
        self.trajectoryObservable.addObserver(self.trajectoryObservableCallback)
        
        self.policyObservable = None
        if self.selectedPolicyObservable:
            self.policyObservable = OBSERVABLES.getObservable(self.selectedPolicyObservable,
                                                              FunctionOverStateSpaceObservable)
            self.policyObservableCallback = \
                 lambda policyEvalFunction: self.updatePolicy(policyEvalFunction)
            self.policyObservable.addObserver(self.policyObservableCallback)
            
    def close(self):
        self.trajectoryObservable.removeObserver(self.trajectoryObservableCallback)
        if self.policyObservable:
            self.policyObservable.removeObserver(self.policyObservableCallback)
        
        super(MountainCarPolicyViewer, self).close()
       
    def updateSamples(self, state, action, reward, succState, episodeTerminated):
        if episodeTerminated:
            self.episodeTerminated = True
                
    def updatePolicy(self, policyEvalFunction):
        # We only update the plot at the end of an episode
        if not self.episodeTerminated: return
        self.episodeTerminated = False
        
        self.lock.acquire()
        
        # Generate 2d state slice
        defaultDimValues = {"position" : 0, "velocity" : 0}    
        stateSlice = generate2dStateSlice(["position", "velocity"], 
                                          self.stateSpace, defaultDimValues,
                                          gridNodesPerDim=self.gridNodesPerDim)
        
        # Compute values that should be plotted
        values, colorMapping = \
                    generate2dPlotArray(policyEvalFunction, stateSlice, 
                                        continuousFunction=False,
                                        shape=(self.gridNodesPerDim, self.gridNodesPerDim)) 
        
        # If all value are None, we cannot draw anything useful
        if values.mask.all():
            self.lock.release()
            return            
        
        # Plot data
        polyCollection = self.axis.pcolor(numpy.linspace(0.0, 1.0, self.gridNodesPerDim), 
                                          numpy.linspace(0.0, 1.0, self.gridNodesPerDim),
                                          values.T)

        # Some dummy code that creates patches that are not shown but allow
        # for a colorbar
        from matplotlib.patches import Rectangle
        linearSegmentedColorbar = polyCollection.get_cmap()
        patches = []
        functionValues = []
        for functionValue, colorValue in colorMapping.items():
            if isinstance(functionValue, tuple):
                functionValue = functionValue[0] # deal with '(action,)'
            normValue = polyCollection.norm(colorValue)
            if isinstance(normValue, numpy.ndarray): 
                normValue = normValue[0] # happens when function is constant
            rgbaColor = linearSegmentedColorbar(normValue)
         
            p = Rectangle((0, 0), 1, 1, fc=rgbaColor)
            functionValues.append(functionValue)
            patches.append(p)
        self.axis.legend(patches, functionValues)
        
        # Labeling etc.
        self.axis.set_xlim(0, 1)
        self.axis.set_ylim(0, 1)
        self.axis.set_xlabel("position")
        self.axis.set_ylabel("velocity")
        self.axis.legend()
        
        self.canvas.draw()
        
        self.lock.release()
        
    def _policyObservableChanged(self, selectedPolicyObservable):
        self.lock.acquire()
        if self.policyObservable:
            # Disconnect from old policy observable
            self.policyObservable.removeObserver(self.policyObservableCallback)
        
        # Determine new observed policy observable
        self.selectedPolicyObservable = str(selectedPolicyObservable)
        
        # Connect to new policy observable
        self.policyObservable = OBSERVABLES.getObservable(self.selectedPolicyObservable,
                                                          FunctionOverStateSpaceObservable)
        self.policyObservableCallback = \
             lambda policyEvalFunction: self.updatePolicy(policyEvalFunction)
        self.policyObservable.addObserver(self.policyObservableCallback)
        
        self.actions = []
        self.lock.release()
          
    
    def _changeGridNodes(self):
        self.gridNodesPerDim = self.gridNodesSlider.value()
        self.gridNodesLabel.setText("Grid Nodes Per Dimension: %s" 
                                    % self.gridNodesPerDim )
Example #5
0
class SPBTrajectoryViewer(Viewer):
    def __init__(self):
        super(SPBTrajectoryViewer, self).__init__()

        self.lenpole = 1.0

        # Get required observables
        self.trajectoryObservable = \
                OBSERVABLES.getAllObservablesOfType(TrajectoryObservable)[0]

        # Create matplotlib widgets
        plotWidget = QtGui.QWidget(self)
        plotWidget.setMinimumSize(600, 500)
        plotWidget.setWindowTitle("SPB Cart Viewer")

        self.fig = Figure((6.0, 5.0), dpi=100)
        self.axis = self.fig.gca()
        self.axis.set_xlim((-3.125, 3.125))
        self.axis.set_ylim((-0.5, 5.5))

        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(plotWidget)
        self.canvas.draw()

        # Create layout
        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(plotWidget)
        self.setLayout(self.hlayout)

        # Connect to observer (has to be the last thing!!)
        self.trajectoryObservableCallback = \
             lambda *transition: self._updateSamples(*transition)
        self.trajectoryObservable.addObserver(
            self.trajectoryObservableCallback)

    def close(self):
        self.trajectoryObservable.removeObserver(
            self.trajectoryObservableCallback)

        super(SPBTrajectoryViewer, self).close()

    def _updateSamples(self, state, action, reward, succState,
                       episodeTerminated):
        cartPosition = succState["cartPosition"]
        poleAngularPosition = succState["poleAngularPosition"]

        self.axis.clear()
        # Cart
        cartPath = Path([(cartPosition - 1.0, 0), (cartPosition + 1.0, 0),
                         (cartPosition + 1.0, 0.1), (cartPosition - 1.0, 0.1),
                         (cartPosition - 1.0, 0)])
        cartPatch = PathPatch(cartPath,
                              facecolor=(1.0, 0.0, 0.0),
                              edgecolor=(0.0, 0.0, 0.0))
        self.axis.add_patch(cartPatch)
        # Wheels
        wheelPatch1 = Circle([cartPosition - 0.8, -0.2], 0.2, facecolor='k')
        wheelPatch2 = Circle([cartPosition + 0.8, -0.2], 0.2, facecolor='k')
        self.axis.add_patch(wheelPatch1)
        self.axis.add_patch(wheelPatch2)
        # Pole
        sintheta = math.sin(poleAngularPosition)
        costheta = math.cos(poleAngularPosition)
        polePath = Path([(cartPosition - 0.1, 0.1),
                         (cartPosition - 0.1 + self.lenpole * sintheta,
                          0.1 + self.lenpole * costheta),
                         (cartPosition + 0.1 + self.lenpole * sintheta,
                          0.1 + self.lenpole * costheta),
                         (cartPosition + 0.1, 0.1), (cartPosition - 0.1, 0.1)])
        polePatch = PathPatch(polePath,
                              facecolor=(0.0, 1.0, 0.0),
                              edgecolor=(0.0, 0.0, 0.0))
        self.axis.add_patch(polePatch)

        # Redraw
        self.canvas.draw()
Example #6
0
    def __init__(self):
        super(FloatStreamViewer, self).__init__()

        # Create matplotlib widget
        plotWidget = QtGui.QWidget(self)
        plotWidget.setMinimumSize(800, 500)

        fig = Figure((8.0, 5.0), dpi=100)
        self.canvas = FigureCanvas(fig)
        self.canvas.setParent(plotWidget)
        self.axis = fig.gca()

        # Local container for displayed values
        self.values = deque()
        self.times = deque()

        # Combo Box for selecting the observable
        self.comboBox = QtGui.QComboBox(self)
        self.floatStreamObservables = \
                OBSERVABLES.getAllObservablesOfType(FloatStreamObservable)
        self.comboBox.addItems(
            map(lambda x: "%s" % x.title, self.floatStreamObservables))
        self.connect(self.comboBox, QtCore.SIGNAL('currentIndexChanged (int)'),
                     self._observableChanged)

        # Automatically update combobox when new float stream observables
        #  are created during runtime
        def updateComboBox(observable, action):
            self.comboBox.clear()
            self.floatStreamObservables = \
                    OBSERVABLES.getAllObservablesOfType(FloatStreamObservable)
            self.comboBox.addItems(
                map(lambda x: "%s" % x.title, self.floatStreamObservables))

        OBSERVABLES.addObserver(updateComboBox)

        # The number of values from the observable that are remembered
        self.windowSize = 64

        # Slider for controlling the window size
        self.windowSizeSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.windowSizeSlider.setValue(numpy.log2(self.windowSize))
        self.windowSizeSlider.setMinimum(0)
        self.windowSizeSlider.setMaximum(15)
        self.windowSizeSlider.setTickInterval(1)
        self.windowSizeSlider.setTickPosition(QtGui.QSlider.TicksBelow)

        self.connect(self.windowSizeSlider, QtCore.SIGNAL('sliderReleased()'),
                     self._changeWindowSize)

        self.windowSizeLabel = QtGui.QLabel("WindowSize: %s" % self.windowSize)

        # The length of the moving window average
        self.mwaSize = 10

        # Slider for controlling the moving average window
        self.mwaSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.mwaSlider.setValue(self.mwaSize)
        self.mwaSlider.setMinimum(1)
        self.mwaSlider.setMaximum(50)
        self.mwaSlider.setTickInterval(10)
        self.mwaSlider.setTickPosition(QtGui.QSlider.TicksBelow)

        self.connect(self.mwaSlider, QtCore.SIGNAL('sliderReleased()'),
                     self._changeMWA)

        self.mwaLabel = QtGui.QLabel("Moving Window Average : %s" %
                                     self.mwaSize)

        # Create layout
        self.vlayout = QtGui.QVBoxLayout()
        self.vlayout.addWidget(self.comboBox)
        self.vlayout.addWidget(self.windowSizeSlider)
        self.vlayout.addWidget(self.windowSizeLabel)
        self.vlayout.addWidget(self.mwaSlider)
        self.vlayout.addWidget(self.mwaLabel)

        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(plotWidget)
        self.hlayout.addLayout(self.vlayout)

        self.setLayout(self.hlayout)

        # Handling connecting to observable
        self.observableCallback = lambda time, value, *args: self.update(
            time, value)
        if len(self.floatStreamObservables) > 0:
            # Show per default the first observable
            self.observable = self.floatStreamObservables[0]
            # Connect to observer (has to be the last thing!!)
            self.observable.addObserver(self.observableCallback)
        else:
            self.observable = None
Example #7
0
class Maze2DDetailedViewer(Viewer):
    def __init__(self, maze, stateSpace, actions):
        super(Maze2DDetailedViewer, self).__init__()

        self.maze = maze
        self.stateSpace = stateSpace
        self.actions = actions

        self.samples = defaultdict(lambda: 0)
        self.valueAccessFunction = None

        self.redrawRequested = False

        # Get required observables
        self.trajectoryObservable = \
                OBSERVABLES.getAllObservablesOfType(TrajectoryObservable)[0]
        self.stateActionValuesObservables = \
                OBSERVABLES.getAllObservablesOfType(StateActionValuesObservable)

        # Combo Box for selecting the observable
        self.observableLabel = QtGui.QLabel("Observable")
        self.comboBox = QtGui.QComboBox(self)
        self.comboBox.addItems(
            map(lambda x: "%s" % x.title, self.stateActionValuesObservables))
        self.connect(self.comboBox, QtCore.SIGNAL('currentIndexChanged (int)'),
                     self._observableChanged)

        # Automatically update combobox when new float stream observables
        #  are created during runtime
        def updateComboBox(observable, action):
            self.comboBox.clear()
            self.stateActionValuesObservables = \
                    OBSERVABLES.getAllObservablesOfType(StateActionValuesObservable)
            self.comboBox.addItems(
                map(lambda x: "%s" % x.title,
                    self.stateActionValuesObservables))

        OBSERVABLES.addObserver(updateComboBox)

        # Combo Box for selecting the updateFrequency
        self.updateFreqLabel = QtGui.QLabel("Update")
        self.updateComboBox = QtGui.QComboBox(self)
        self.updateComboBox.addItems(["Every Episode", "Every Step"])

        # Create matplotlib widgets
        plotWidgetPolicy = QtGui.QWidget(self)
        plotWidgetPolicy.setMinimumSize(300, 400)
        plotWidgetPolicy.setWindowTitle("Policy")

        self.figPolicy = Figure((3.0, 4.0), dpi=100)
        self.figPolicy.subplots_adjust(left=0.01,
                                       bottom=0.01,
                                       right=0.99,
                                       top=0.99,
                                       wspace=0.05,
                                       hspace=0.11)

        self.canvasPolicy = FigureCanvas(self.figPolicy)
        self.canvasPolicy.setParent(plotWidgetPolicy)

        ax = self.figPolicy.gca()
        ax.clear()
        self.maze.drawIntoAxis(ax)

        self.plotWidgetValueFunction = dict()
        self.figValueFunction = dict()
        self.canvasValueFunction = dict()
        for index, action in enumerate(self.actions):
            self.plotWidgetValueFunction[action] = QtGui.QWidget(self)
            self.plotWidgetValueFunction[action].setMinimumSize(300, 400)
            self.plotWidgetValueFunction[action].setWindowTitle(str(action))

            self.figValueFunction[action] = Figure((3.0, 4.0), dpi=100)
            self.figValueFunction[action].subplots_adjust(left=0.01,
                                                          bottom=0.01,
                                                          right=0.99,
                                                          top=0.99,
                                                          wspace=0.05,
                                                          hspace=0.11)

            self.canvasValueFunction[action] = FigureCanvas(
                self.figValueFunction[action])
            self.canvasValueFunction[action].setParent(
                self.plotWidgetValueFunction[action])

            ax = self.figValueFunction[action].gca()
            ax.clear()
            self.maze.drawIntoAxis(ax)

        self.textInstances = dict()
        self.arrowInstances = []

        self.canvasPolicy.draw()
        for index, action in enumerate(self.actions):
            self.canvasValueFunction[action].draw()

        self.mdiArea = QtGui.QMdiArea(self)
        self.mdiArea.addSubWindow(plotWidgetPolicy)
        for index, action in enumerate(self.actions):
            self.mdiArea.addSubWindow(self.plotWidgetValueFunction[action])
        self.vlayout = QtGui.QVBoxLayout()
        self.vlayout.addWidget(self.mdiArea)
        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(self.observableLabel)
        self.hlayout.addWidget(self.comboBox)
        self.hlayout.addWidget(self.updateFreqLabel)
        self.hlayout.addWidget(self.updateComboBox)
        self.vlayout.addLayout(self.hlayout)
        self.setLayout(self.vlayout)

        # Connect to observer (has to be the last thing!!)
        self.trajectoryObservableCallback = \
             lambda *transition: self.updateSamples(*transition)
        self.trajectoryObservable.addObserver(
            self.trajectoryObservableCallback)

        self.stateActionValuesObservableCallback = \
             lambda valueAccessFunction, actions: self.updateValues(valueAccessFunction, actions)
        if len(self.stateActionValuesObservables) > 0:
            # Show per default the first observable
            self.stateActionValuesObservable = self.stateActionValuesObservables[
                0]

            self.stateActionValuesObservable.addObserver(
                self.stateActionValuesObservableCallback)
        else:
            self.stateActionValuesObservable = None

    def close(self):
        if self.stateActionValuesObservable is not None:
            # Remove old observable
            self.stateActionValuesObservable.removeObserver(
                self.stateActionValuesObservableCallback)

        super(Maze2DDetailedViewer, self).close()

    def updateValues(self, valueAccessFunction, actions):
        self.valueAccessFunction = valueAccessFunction
        # Check if we have to redraw
        if self.redrawRequested or \
                str(self.updateComboBox.currentText()) == "Every Step":
            self.redraw()
            self.redrawRequested = False

    def updateSamples(self, state, action, reward, succState,
                      episodeTerminated):
        state = self.stateSpace.parseStateDict(state)
        self.samples[(state, action)] = self.samples[(state, action)] + 1
        # Check if we have to redraw
        if str(self.updateComboBox.currentText()) == "Every Episode" \
                 and episodeTerminated:
            self.redrawRequested = True  # Request redrawing once the next observable update happens

    def redraw(self):
        # Update policy visualization
        for arrow in self.arrowInstances:
            arrow.remove()
        self.arrowInstances = []
        # Iterate over all states and compute the value of the observed function
        dimensions = [
            self.stateSpace[dimName] for dimName in ["column", "row"]
        ]
        states = [
            State((column, row), dimensions)  #
            for column in range(self.maze.getColumns())
            for row in range(self.maze.getRows())
        ]
        for state in states:
            # Evaluate function for this state
            actionValues = dict(
                (action, self.valueAccessFunction(state, (action, )) if self.
                 valueAccessFunction is not None else 0.0)
                for action in ["up", "down", "left", "right"])
            maxValue = max(actionValues.values())
            axis = self.figPolicy.gca()
            for action in actionValues.keys():
                if actionValues[action] == maxValue:
                    self._plotArrow(axis, (state[0], state[1]), action)

        # Update Q-function visualization
        for state in states:
            for action in ["up", "down", "left", "right"]:
                value =  self.valueAccessFunction(state, (action,)) \
                                if self.valueAccessFunction is not None else 0.0
                if int(value) == value:
                    valueString = "%s\n%s" % (int(value),
                                              self.samples[(state, action)])
                else:
                    valueString = "%.1f\n%s" % (value, self.samples[(state,
                                                                     action)])
                if (state, action) not in self.textInstances.keys():
                    if isinstance(
                            action, tuple
                    ):  # For TD-Agents that use crossproduct of action space
                        axis = self.figValueFunction[action[0]].gca()
                    else:
                        axis = self.figValueFunction[action].gca()
                    textInstance = \
                        axis.text(state[0] - 0.3, state[1], valueString, fontsize=8)
                    self.textInstances[(state, action)] = textInstance
                else:
                    self.textInstances[(state, action)].set_text(valueString)

        self.canvasPolicy.draw()
        for index, action in enumerate(self.actions):
            self.canvasValueFunction[action].draw()

    def _plotArrow(self, axis, center, direction):
        if isinstance(direction,
                      tuple):  # For TD agent with action crossproduct
            direction = direction[0]
        if direction == 'up':
            (dx, dy) = (0.0, 0.6)
        elif direction == 'down':
            (dx, dy) = (0.0, -0.6)
        elif direction == 'right':
            (dx, dy) = (0.6, 0.0)
        elif direction == 'left':
            (dx, dy) = (-0.6, 0.0)

        arr = axis.arrow(center[0] - dx / 2,
                         center[1] - dy / 2,
                         dx,
                         dy,
                         width=0.05,
                         fc='k')
        self.arrowInstances.append(arr)

    def _observableChanged(self, comboBoxIndex):
        if self.stateActionValuesObservable is not None:
            # Remove old observable
            self.stateActionValuesObservable.removeObserver(
                self.stateActionValuesObservableCallback)
        # Get new observable and add as listener
        self.stateActionValuesObservable = self.stateActionValuesObservables[
            comboBoxIndex]
        self.stateActionValuesObservable.addObserver(
            self.stateActionValuesObservableCallback)
Example #8
0
class TrajectoryViewer(Viewer):
    
    def __init__(self, stateSpace):        
        super(TrajectoryViewer, self).__init__()
        
        self.stateSpace = stateSpace
        
        # Define colors for plotting
        self.colors = itertools.cycle(['g', 'b', 'c', 'm', 'k', 'y'])
                
        self.plotLines = deque()
        
        # Combo Boxes for selecting displaced state space dimensions
        self.comboBox1 = QtGui.QComboBox(self)
        self.comboBox1.addItems(sorted(stateSpace.keys()))
        self.comboBox2 = QtGui.QComboBox(self)
        self.comboBox2.addItems(sorted(stateSpace.keys()))
        self.comboBox2.setCurrentIndex(1)
        self.dimension1 = sorted(self.stateSpace.keys())[0]
        self.dimension2 = sorted(self.stateSpace.keys())[1]
        self.connect(self.comboBox1, QtCore.SIGNAL('currentIndexChanged (int)'), 
                     self._dimension1Changed)
        self.connect(self.comboBox2, QtCore.SIGNAL('currentIndexChanged (int)'), 
                     self._dimension2Changed)
        
        # Slider for controlling the number of Trajectories
        self.maxLines = 5
        self.numberTrajectoriesSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.numberTrajectoriesSlider.setValue(self.maxLines)
        self.numberTrajectoriesSlider.setMinimum(1)
        self.numberTrajectoriesSlider.setMaximum(20)
        self.numberTrajectoriesSlider.setTickInterval(5)
        self.numberTrajectoriesSlider.setTickPosition(QtGui.QSlider.TicksBelow)
        
        self.connect(self.numberTrajectoriesSlider, QtCore.SIGNAL('sliderReleased()'), 
                     self._changeNumTrajectories)
        
        # Checkbox for dis-/enabling trajectory plotting
        self.plottingEnabled = True
        self.enabledCheckBox = QtGui.QCheckBox("Plotting Enabled")
        self.enabledCheckBox.setChecked(self.plottingEnabled)
        self.connect(self.enabledCheckBox, QtCore.SIGNAL('stateChanged(int)'), 
                     self._enablingPlotting)
        
        # Some labels
        self.dimension1Label = QtGui.QLabel("Dimension X Axis")
        self.dimension2Label = QtGui.QLabel("Dimension Y Axis")
        self.numTrajectoriesLabel = QtGui.QLabel("Trajectories shown")
        
        # Create matplotlib widgets
        plotWidgetTrajectory = QtGui.QWidget(self)
        plotWidgetTrajectory.setMinimumSize(800, 500)
 
        self.figTrajectory = Figure((8.0, 5.0), dpi=100)
        self.axisTrajectory = self.figTrajectory .gca()
        self.canvasTrajectory = FigureCanvas(self.figTrajectory)
        self.canvasTrajectory .setParent(plotWidgetTrajectory )
        
        # Initialize plotting
        self._reinitializePlot()
        
        # Create layout
        self.vlayout = QtGui.QVBoxLayout()
        self.hlayout1 = QtGui.QHBoxLayout()
        self.hlayout1.addWidget(self.dimension1Label)
        self.hlayout1.addWidget(self.comboBox1)
        self.hlayout2 = QtGui.QHBoxLayout()
        self.hlayout2.addWidget(self.dimension2Label)
        self.hlayout2.addWidget(self.comboBox2)
        self.hlayout3 = QtGui.QHBoxLayout()
        self.hlayout3.addWidget(self.numTrajectoriesLabel)
        self.hlayout3.addWidget(self.numberTrajectoriesSlider)
        
        self.vlayout.addLayout(self.hlayout1)
        self.vlayout.addLayout(self.hlayout2)
        self.vlayout.addLayout(self.hlayout3)
        self.vlayout.addWidget(self.enabledCheckBox)
        
        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(plotWidgetTrajectory)
        self.hlayout.addLayout(self.vlayout)
        self.setLayout(self.hlayout)
        
        # Connect to trajectory observable
        self.trajectoryObservable = \
                OBSERVABLES.getAllObservablesOfType(TrajectoryObservable)[0]
        self.trajectoryObservableCallback = \
             lambda *transition: self.addTransition(*transition)
        self.trajectoryObservable.addObserver(self.trajectoryObservableCallback)
        
    def close(self):
        self.trajectoryObservable.removeObserver(self.trajectoryObservableCallback)
        
        super(TrajectoryViewer, self).close()
    
    def addTransition(self, state, action, reward, succState, episodeTerminated):
        if not self.plottingEnabled: return
        if len(self.dim1Values) == 0: # for the start state
            self.dim1Values.append(state[self.dimension1])
            self.dim2Values.append(state[self.dimension2])
        self.dim1Values.append(succState[self.dimension1])
        self.dim2Values.append(succState[self.dimension2])
        if episodeTerminated:
            plotLine = self.axisTrajectory.plot(self.dim1Values, self.dim2Values, 
                                                self.colors.next())[0]
            self.plotLines.append(plotLine)
            while len(self.plotLines) > self.maxLines:
                # Remove the oldest line
                oldestLine = self.plotLines.popleft()
                oldestLine.remove()
            
            if self.stateSpace[self.dimension1].isContinuous():
                self.axisTrajectory.set_xlim(self.stateSpace[self.dimension1]["dimensionValues"][0])
            else:
                self.axisTrajectory.set_xlim(min(self.stateSpace[self.dimension1]["dimensionValues"]),
                                             max(self.stateSpace[self.dimension1]["dimensionValues"]))
            if self.stateSpace[self.dimension2].isContinuous():
                self.axisTrajectory.set_ylim(self.stateSpace[self.dimension2]["dimensionValues"][0])
            else:
                self.axisTrajectory.set_ylim(min(self.stateSpace[self.dimension2]["dimensionValues"]),
                                             max(self.stateSpace[self.dimension2]["dimensionValues"]))
            self.canvasTrajectory.draw()
            self.dim1Values = []
            self.dim2Values = [] 

    def _reinitializePlot(self):
        self.dim1Values = []
        self.dim2Values = []
        for line in self.plotLines:
            line.remove()
        self.plotLines = deque()
        if self.stateSpace[self.dimension1].isContinuous():
            self.axisTrajectory.set_xlim(self.stateSpace[self.dimension1]["dimensionValues"][0])
        else:
            self.axisTrajectory.set_xlim(min(self.stateSpace[self.dimension1]["dimensionValues"]),
                                         max(self.stateSpace[self.dimension1]["dimensionValues"]))
        if self.stateSpace[self.dimension2].isContinuous():
            self.axisTrajectory.set_ylim(self.stateSpace[self.dimension2]["dimensionValues"][0])
        else:
            self.axisTrajectory.set_ylim(min(self.stateSpace[self.dimension2]["dimensionValues"]),
                                         max(self.stateSpace[self.dimension2]["dimensionValues"]))
        self.canvasTrajectory.draw()    
    
    def _dimension1Changed(self, comboBoxIndex):
        self.dimension1 = sorted(self.stateSpace.keys())[comboBoxIndex]
        self._reinitializePlot()
        
    def _dimension2Changed(self, comboBoxIndex):
        self.dimension2 = sorted(self.stateSpace.keys())[comboBoxIndex]
        self._reinitializePlot()
        
    def _changeNumTrajectories(self):
        self.maxLines = self.numberTrajectoriesSlider.value()
        
    def _enablingPlotting(self):
        if self.enabledCheckBox.isChecked():
            self.plottingEnabled = True
            self._reinitializePlot()
        else:
            self.plottingEnabled = False
Example #9
0
class PlotExperimentWindow(QtGui.QMainWindow):
    
    def __init__(self, tableModel, parent):  
        super(PlotExperimentWindow, self).__init__(parent)
            
        self.tableModel = tableModel
        # Colors used for different configurations in plots
        self.colors = cycle(["b", "g", "r", "c", "m", "y", "k"])
        self.colorMapping = defaultdict(lambda : self.colors.next())
        # The length of the moving window average 
        self.mwaSize = 2**0
        # Whether we plot each run separately or only their mean
        self.linePlotTypes = ["Each Run", "Average"]
        self.linePlot = self.linePlotTypes[0]
        
        # The central widget
        self.centralWidget = QtGui.QWidget(self)
        
        # Create matplotlib widget
        self.plotWidget = QtGui.QWidget(self.centralWidget)
        self.plotWidget.setMinimumSize(800, 500)
 
        self.fig = Figure((8.0, 5.0), dpi=100)
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(self.plotWidget)
        self.axis = self.fig.gca()
        
        self.mplToolbar = NavigationToolbar(self.canvas, self.centralWidget)
        
        self.mwaLabel = QtGui.QLabel("Moving Window Average: %s" % self.mwaSize)        
        # Slider for controlling the moving average window
        self.mwaSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.mwaSlider.setValue(0)
        self.mwaSlider.setMinimum(0)
        self.mwaSlider.setMaximum(10)
        self.mwaSlider.setTickInterval(10)
        self.mwaSlider.setTickPosition(QtGui.QSlider.TicksBelow)
        
        self.connect(self.mwaSlider, QtCore.SIGNAL('sliderReleased()'), 
                     self._changeMWA)
        
        self.lineLabel = QtGui.QLabel("Plot of agent: ")
        # Combo Box for selecting the observable
        self.lineComboBox = QtGui.QComboBox(self)
        self.lineComboBox.addItems(["Each Run", "Average"])
        self.connect(self.lineComboBox, QtCore.SIGNAL('currentIndexChanged (int)'), 
                     self._linePlotChanged)
    
        # Add a button for replotting
        self.replotButton = QtGui.QPushButton("Update")
        self.connect(self.replotButton, QtCore.SIGNAL('clicked()'), 
                     self._plot)
        
        # Add a button for saving a plot
        self.saveButton = QtGui.QPushButton("Save")
        self.connect(self.saveButton, QtCore.SIGNAL('clicked()'), 
                     self._save)
        
        # Set layout
        self.topLinelayout = QtGui.QHBoxLayout()
        self.topLinelayout.addWidget(self.mwaLabel)
        self.topLinelayout.addWidget(self.mwaSlider)
        self.topLinelayout.addWidget(self.lineLabel)
        self.topLinelayout.addWidget(self.lineComboBox)
        self.topLinelayout.addWidget(self.replotButton)
        self.topLinelayout.addWidget(self.saveButton)
        self.vlayout = QtGui.QVBoxLayout()
        self.vlayout.addLayout(self.topLinelayout)
        self.vlayout.addWidget(self.plotWidget)
        self.vlayout.addWidget(self.mplToolbar)
        self.centralWidget.setLayout(self.vlayout)
        
        self.setCentralWidget(self.centralWidget)
        self.setWindowTitle("Current experiment's results")
        
        # Plot the results once upon creation
        self._plot()
        
    def resizeEvent(self, qResizeEvent):
        self.canvas.setMinimumSize(qResizeEvent.size().width()-50, 
                                   qResizeEvent.size().height()-100)
        self.fig.set_size_inches(qResizeEvent.size().width()/100 -0.5, 
                                 qResizeEvent.size().height()/100 - 1)
        
        self._plot()
        self.plotWidget.repaint()
        self.canvas.draw()
        
    def _plot(self):
        self.axis.clear()
        # Update internal data
        data = self.tableModel.getRunDataForSelectedMetric()
        
        def mwaFilter(inData):
            outData = []
            for i in range(len(inData)):
                start = max(0, i - self.mwaSize/2)
                end = min(len(inData), i + self.mwaSize/2)
                outData.append(float(sum(inData[start:end+1])) / (end+1 - start))
            return outData
        
        if self.linePlot == "Each Run":
            # Do the actual plotting
            plottedWorlds = set()
            for worldName, runNumber in data.keys():
                averageValues = mwaFilter(map(itemgetter(1),
                                              data[(worldName, runNumber)]))
            
                if worldName not in plottedWorlds:  
                    self.axis.plot(map(itemgetter(0), data[(worldName, runNumber)]),
                                   averageValues,
                                   color=self.colorMapping[worldName], 
                                   label=str(worldName))
                    plottedWorlds.add(worldName)
                else:
                    self.axis.plot(map(itemgetter(0), data[(worldName, runNumber)]),
                                   averageValues,
                                   color=self.colorMapping[worldName], 
                                   label="_nolegend_")
        elif self.linePlot == "Average":
            agentAvgData = defaultdict(list)
            agentCounter = defaultdict(list)
            worldsLongestRun = dict()
            for worldName, runNumber in data.keys():
                for i in range(len(data[(worldName, runNumber)])):
                    if i >= len(agentAvgData[worldName]):
                        agentAvgData[worldName].append(data[(worldName, runNumber)][i][1])
                        agentCounter[worldName].append(1.0)
                        worldsLongestRun[worldName] = runNumber
                    else:
                        agentAvgData[worldName][i] += data[(worldName, runNumber)][i][1]
                        agentCounter[worldName][i] += 1.0
            for worldName in agentCounter.keys():
                plotData = []
                for i in range(len(agentAvgData[worldName])):
                    plotData.append(agentAvgData[worldName][i] / agentCounter[worldName][i])
                averagePlotData = mwaFilter(plotData)
                self.axis.plot(map(itemgetter(0), data[(worldName, worldsLongestRun[worldName])]),
                               averagePlotData, color=self.colorMapping[worldName],
                               label=str(worldName))
                    
        self.axis.legend(loc = 'best')
        self.axis.set_xlabel("Episode")
        self.axis.set_ylabel(self.tableModel.selectedMetric)
        #Redraw
        self.canvas.draw()
        
    def _changeMWA(self):
        self.mwaSize = 2**self.mwaSlider.value()
        self.mwaLabel.setText("Moving Window Average: %s" % self.mwaSize)
        # Replot
        self._plot()
        
        
    def _linePlotChanged(self, linePlot):
        self.linePlot = self.linePlotTypes[linePlot]
        # Replot
        self._plot()
        
    def _save(self):
        rootDirectory = \
            self.tableModel.rootDirectory if hasattr(self.tableModel, 
                                                     "rootDirectory") \
                else mmlf.getRWPath()
        graphicFileName = \
            str(QtGui.QFileDialog.getSaveFileName(self,
                                                  "Select a file for the stored graphic",
                                                  rootDirectory,   
                                                  "Plots (*.pdf)"))
        self.fig.savefig(str(graphicFileName), dpi=400)
Example #10
0
    def __init__(self, maze, stateSpace):
        super(Maze2DFunctionViewer, self).__init__()

        self.maze = maze
        self.stateSpace = stateSpace

        self.updateCounter = 0
        self.updatePlotNow = False
        self.evalFunction = None

        self.lock = threading.Lock()

        # Create matplotlib widgets
        plotWidget = QtGui.QWidget(self)
        plotWidget.setMinimumSize(600, 500)
        plotWidget.setWindowTitle("Maze2D")

        self.fig = Figure((6.0, 5.0), dpi=100)
        self.axis = self.fig.gca()
        self.axis.clear()
        self.maze.drawIntoAxis(self.axis)
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(plotWidget)
        self.canvas.draw()

        self.plottedPatches = []

        # Add a combobox for selecting the function over state space that is observed
        self.selectedFunctionObservable = None
        self.functionObservableLabel = QtGui.QLabel(
            "Function over State Space")
        self.functionObservableComboBox = QtGui.QComboBox(self)
        functionObservables = OBSERVABLES.getAllObservablesOfType(
            FunctionOverStateSpaceObservable)
        stateActionValueObservables = \
                OBSERVABLES.getAllObservablesOfType(StateActionValuesObservable)
        functionObservables.extend(stateActionValueObservables)
        self.functionObservableComboBox.addItems([
            functionObservable.title
            for functionObservable in functionObservables
        ])
        if len(functionObservables) > 0:
            self.selectedFunctionObservable = functionObservables[0].title

        self.connect(self.functionObservableComboBox,
                     QtCore.SIGNAL('activated (const QString&)'),
                     self._functionObservableChanged)

        # Automatically update funtion observable combobox when new observables
        # are created during runtime
        def updateFunctionObservableBox(viewer, action):
            self.functionObservableComboBox.clear()
            functionObservables = \
                OBSERVABLES.getAllObservablesOfType(FunctionOverStateSpaceObservable)
            stateActionValueObservables = \
                OBSERVABLES.getAllObservablesOfType(StateActionValuesObservable)
            functionObservables.extend(stateActionValueObservables)
            self.functionObservableComboBox.addItems([
                functionObservable.title
                for functionObservable in functionObservables
            ])
            if self.selectedFunctionObservable is None \
                    and len(functionObservables) > 0:
                self.selectedFunctionObservable = functionObservables[0].title
            else:
                # Let combobox still show the selected observable
                index = self.functionObservableComboBox.findText(
                    self.selectedFunctionObservable)
                if index != -1:
                    self.functionObservableComboBox.setCurrentIndex(index)

        OBSERVABLES.addObserver(updateFunctionObservableBox)

        # Add a combobox for for selecting the suboption that is used when
        # a StateActionValuesObservable is observed
        self.selectedSuboption = None
        self.suboptionLabel = QtGui.QLabel("Suboption")
        self.suboptionComboBox = QtGui.QComboBox(self)

        # Slider that controls the frequency of update the plot
        self.updateFrequency = 0.0
        self.updateFrequencySlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.updateFrequencySlider.setValue(int(self.updateFrequency * 100))
        self.updateFrequencySlider.setMinimum(0)
        self.updateFrequencySlider.setMaximum(100)
        self.updateFrequencySlider.setTickInterval(0.1)
        self.updateFrequencySlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.updateFrequencySlider,
                     QtCore.SIGNAL('sliderReleased()'),
                     self._changeUpdateFrequency)
        self.updateFrequencyLabel = QtGui.QLabel("UpdateFrequency: %s" %
                                                 self.updateFrequency)

        # Button to enforce update of plot
        self.updatePlotButton = QtGui.QPushButton("Update Plot")
        self.connect(self.updatePlotButton, QtCore.SIGNAL('clicked()'),
                     self._updatePlot)

        # Legend of plot
        self.legendLabel = QtGui.QLabel("Legend:")
        self.legendWidget = QtGui.QListWidget(self)

        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(plotWidget)
        self.vlayout = QtGui.QVBoxLayout()
        self.functionObservableLayout = QtGui.QHBoxLayout()
        self.functionObservableLayout.addWidget(self.functionObservableLabel)
        self.functionObservableLayout.addWidget(
            self.functionObservableComboBox)
        self.vlayout.addLayout(self.functionObservableLayout)
        self.suboptionLayout = QtGui.QHBoxLayout()
        self.suboptionLayout.addWidget(self.suboptionLabel)
        self.suboptionLayout.addWidget(self.suboptionComboBox)
        self.vlayout.addLayout(self.suboptionLayout)
        self.updateFrequencyLayout = QtGui.QHBoxLayout()
        self.updateFrequencyLayout.addWidget(self.updateFrequencyLabel)
        self.updateFrequencyLayout.addWidget(self.updateFrequencySlider)
        self.vlayout.addLayout(self.updateFrequencyLayout)
        self.vlayout.addWidget(self.updatePlotButton)
        self.vlayout.addWidget(self.legendLabel)
        self.vlayout.addWidget(self.legendWidget)
        self.hlayout.addLayout(self.vlayout)
        self.setLayout(self.hlayout)

        # Connect to observer (has to be the last thing!!)
        self.functionObservable = None
        if self.selectedFunctionObservable is not None:
            self._functionObservableChanged(self.selectedFunctionObservable)
Example #11
0
class Maze2DFunctionViewer(Viewer):
    def __init__(self, maze, stateSpace):
        super(Maze2DFunctionViewer, self).__init__()

        self.maze = maze
        self.stateSpace = stateSpace

        self.updateCounter = 0
        self.updatePlotNow = False
        self.evalFunction = None

        self.lock = threading.Lock()

        # Create matplotlib widgets
        plotWidget = QtGui.QWidget(self)
        plotWidget.setMinimumSize(600, 500)
        plotWidget.setWindowTitle("Maze2D")

        self.fig = Figure((6.0, 5.0), dpi=100)
        self.axis = self.fig.gca()
        self.axis.clear()
        self.maze.drawIntoAxis(self.axis)
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(plotWidget)
        self.canvas.draw()

        self.plottedPatches = []

        # Add a combobox for selecting the function over state space that is observed
        self.selectedFunctionObservable = None
        self.functionObservableLabel = QtGui.QLabel(
            "Function over State Space")
        self.functionObservableComboBox = QtGui.QComboBox(self)
        functionObservables = OBSERVABLES.getAllObservablesOfType(
            FunctionOverStateSpaceObservable)
        stateActionValueObservables = \
                OBSERVABLES.getAllObservablesOfType(StateActionValuesObservable)
        functionObservables.extend(stateActionValueObservables)
        self.functionObservableComboBox.addItems([
            functionObservable.title
            for functionObservable in functionObservables
        ])
        if len(functionObservables) > 0:
            self.selectedFunctionObservable = functionObservables[0].title

        self.connect(self.functionObservableComboBox,
                     QtCore.SIGNAL('activated (const QString&)'),
                     self._functionObservableChanged)

        # Automatically update funtion observable combobox when new observables
        # are created during runtime
        def updateFunctionObservableBox(viewer, action):
            self.functionObservableComboBox.clear()
            functionObservables = \
                OBSERVABLES.getAllObservablesOfType(FunctionOverStateSpaceObservable)
            stateActionValueObservables = \
                OBSERVABLES.getAllObservablesOfType(StateActionValuesObservable)
            functionObservables.extend(stateActionValueObservables)
            self.functionObservableComboBox.addItems([
                functionObservable.title
                for functionObservable in functionObservables
            ])
            if self.selectedFunctionObservable is None \
                    and len(functionObservables) > 0:
                self.selectedFunctionObservable = functionObservables[0].title
            else:
                # Let combobox still show the selected observable
                index = self.functionObservableComboBox.findText(
                    self.selectedFunctionObservable)
                if index != -1:
                    self.functionObservableComboBox.setCurrentIndex(index)

        OBSERVABLES.addObserver(updateFunctionObservableBox)

        # Add a combobox for for selecting the suboption that is used when
        # a StateActionValuesObservable is observed
        self.selectedSuboption = None
        self.suboptionLabel = QtGui.QLabel("Suboption")
        self.suboptionComboBox = QtGui.QComboBox(self)

        # Slider that controls the frequency of update the plot
        self.updateFrequency = 0.0
        self.updateFrequencySlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.updateFrequencySlider.setValue(int(self.updateFrequency * 100))
        self.updateFrequencySlider.setMinimum(0)
        self.updateFrequencySlider.setMaximum(100)
        self.updateFrequencySlider.setTickInterval(0.1)
        self.updateFrequencySlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.updateFrequencySlider,
                     QtCore.SIGNAL('sliderReleased()'),
                     self._changeUpdateFrequency)
        self.updateFrequencyLabel = QtGui.QLabel("UpdateFrequency: %s" %
                                                 self.updateFrequency)

        # Button to enforce update of plot
        self.updatePlotButton = QtGui.QPushButton("Update Plot")
        self.connect(self.updatePlotButton, QtCore.SIGNAL('clicked()'),
                     self._updatePlot)

        # Legend of plot
        self.legendLabel = QtGui.QLabel("Legend:")
        self.legendWidget = QtGui.QListWidget(self)

        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(plotWidget)
        self.vlayout = QtGui.QVBoxLayout()
        self.functionObservableLayout = QtGui.QHBoxLayout()
        self.functionObservableLayout.addWidget(self.functionObservableLabel)
        self.functionObservableLayout.addWidget(
            self.functionObservableComboBox)
        self.vlayout.addLayout(self.functionObservableLayout)
        self.suboptionLayout = QtGui.QHBoxLayout()
        self.suboptionLayout.addWidget(self.suboptionLabel)
        self.suboptionLayout.addWidget(self.suboptionComboBox)
        self.vlayout.addLayout(self.suboptionLayout)
        self.updateFrequencyLayout = QtGui.QHBoxLayout()
        self.updateFrequencyLayout.addWidget(self.updateFrequencyLabel)
        self.updateFrequencyLayout.addWidget(self.updateFrequencySlider)
        self.vlayout.addLayout(self.updateFrequencyLayout)
        self.vlayout.addWidget(self.updatePlotButton)
        self.vlayout.addWidget(self.legendLabel)
        self.vlayout.addWidget(self.legendWidget)
        self.hlayout.addLayout(self.vlayout)
        self.setLayout(self.hlayout)

        # Connect to observer (has to be the last thing!!)
        self.functionObservable = None
        if self.selectedFunctionObservable is not None:
            self._functionObservableChanged(self.selectedFunctionObservable)

    def _updateFunction(self, evalFunction):
        if (self.updateFrequency > 0.0
                and self.updateCounter > numpy.float(1.0)/self.updateFrequency) \
             or self.updatePlotNow:
            # Reset counter and update plot
            self.updateCounter = 0
            self.updatePlotNow = False
            self.evalFunction = evalFunction
            self._plotFunction()
        else:
            # Do not update the plot
            self.updateCounter += 1

    def _plotFunction(self):
        if self.evalFunction is None:
            return

        self.lock.acquire()

        # Clean up old plot
        for patch in self.plottedPatches:
            patch.remove()
        self.plottedPatches = []

        self.colorMapping = dict()
        self.colors = cycle(["b", "g", "r", "c", "m", "y"])

        cmap = pylab.get_cmap("jet")

        # Check if the observed function returns discrete or continuous value
        discreteFunction = isinstance(self.functionObservable,
                                      FunctionOverStateSpaceObservable) \
                                and self.functionObservable.discreteValues
        if not discreteFunction:
            # The values of the observed function over the 2d state space
            values = numpy.ma.array(numpy.zeros(
                (self.maze.getColumns(), self.maze.getRows())),
                                    mask=numpy.zeros((self.maze.getColumns(),
                                                      self.maze.getRows())))

        # Iterate over all states and compute the value of the observed function
        dimensions = [
            self.stateSpace[dimName] for dimName in ["column", "row"]
        ]
        for column in range(self.maze.getColumns()):
            for row in range(self.maze.getRows()):
                # Create state object
                state = State((column, row), dimensions)
                # Evaluate function for this state
                if isinstance(self.functionObservable,
                              FunctionOverStateSpaceObservable):
                    functionValue = self.evalFunction(state)
                else:  # StateActionValuesObservable
                    # Determine chosen option first
                    selectedOption = None
                    for option in self.actions:
                        selectedOptionName = str(
                            self.suboptionComboBox.currentText())
                        if str(option) == selectedOptionName:
                            selectedOption = option
                            break
                    assert selectedOption is not None
                    functionValue = self.evalFunction(state, option)

                # Map function value onto color value
                if discreteFunction:
                    # Deal with situations where the function is only defined over
                    # part of the state space
                    if functionValue == None or functionValue in [
                            numpy.nan, numpy.inf, -numpy.inf
                    ]:
                        continue
                    # Determine color value for function value
                    if not functionValue in self.colorMapping:
                        # Choose value for function value that occurrs for the
                        # first time
                        self.colorMapping[functionValue] = self.colors.next()
                    patch = self.maze.plotSquare(
                        self.axis, (column, row),
                        self.colorMapping[functionValue])
                    self.plottedPatches.append(patch[0])
                else:
                    # Remember values since we have to know the min and max value
                    # before we can plot
                    values[column, row] = functionValue
                    if functionValue == None or functionValue in [
                            numpy.nan, numpy.inf, -numpy.inf
                    ]:
                        values.mask[column, row] = True

        # Do the actual plotting for functions with continuous values
        if not discreteFunction:
            minValue = values.min()
            maxValue = values.max()
            for column in range(self.maze.getColumns()):
                for row in range(self.maze.getRows()):
                    if values.mask[column, row]: continue
                    value = (values[column, row] - minValue) / (maxValue -
                                                                minValue)
                    patch = self.maze.plotSquare(self.axis, (column, row),
                                                 cmap(value),
                                                 zorder=0)
                    self.plottedPatches.append(patch[0])

        # Set limits
        self.axis.set_xlim(0, len(self.maze.structure[0]) - 1)
        self.axis.set_ylim(0, len(self.maze.structure) - 1)

        # Update legend
        self.legendWidget.clear()
        if discreteFunction:
            for functionValue, colorValue in self.colorMapping.items():
                if isinstance(functionValue, tuple):
                    functionValue = functionValue[0]  # deal with '(action,)'
                rgbaColor = matplotlib.colors.ColorConverter().to_rgba(
                    colorValue)
                item = QtGui.QListWidgetItem(str(functionValue),
                                             self.legendWidget)
                color = QtGui.QColor(int(rgbaColor[0] * 255),
                                     int(rgbaColor[1] * 255),
                                     int(rgbaColor[2] * 255))
                item.setTextColor(color)
                self.legendWidget.addItem(item)
        else:
            for value in numpy.linspace(values.min(), values.max(), 10):
                rgbaColor = cmap(
                    (value - values.min()) / (values.max() - values.min()))
                item = QtGui.QListWidgetItem(str(value), self.legendWidget)
                color = QtGui.QColor(int(rgbaColor[0] * 255),
                                     int(rgbaColor[1] * 255),
                                     int(rgbaColor[2] * 255))
                item.setTextColor(color)
                self.legendWidget.addItem(item)

        self.canvas.draw()

        self.lock.release()

    def _functionObservableChanged(self, selectedFunctionObservable):
        self.lock.acquire()
        if self.functionObservable is not None:
            # Disconnect from old function observable
            self.functionObservable.removeObserver(
                self.functionObservableCallback)

        # Determine new observed function observable
        self.selectedFunctionObservable = str(selectedFunctionObservable)

        # Connect to new function observable
        self.functionObservable = OBSERVABLES.getObservable(
            self.selectedFunctionObservable, FunctionOverStateSpaceObservable)
        if self.functionObservable is None:  # Observing a StateActionValuesObservable
            self.functionObservable = OBSERVABLES.getObservable(
                self.selectedFunctionObservable, StateActionValuesObservable)
            self.actions = None

            def functionObservableCallback(evalFunction, actions):
                # If we get new options to select from
                if actions != self.actions:
                    # Update suboptionComboBox
                    self.actions = actions
                    self.suboptionComboBox.clear()
                    self.suboptionComboBox.addItems(
                        [str(action) for action in actions])
                self._updateFunction(evalFunction)

            self.functionObservableCallback = functionObservableCallback
            self.functionObservable.addObserver(
                self.functionObservableCallback)
        else:  # Observing a FunctionOverStateSpaceObservable
            self.functionObservableCallback = \
                 lambda evalFunction: self._updateFunction(evalFunction)
            self.functionObservable.addObserver(
                self.functionObservableCallback)

        self.lock.release()

    def _changeUpdateFrequency(self):
        self.updateFrequency = self.updateFrequencySlider.value() / 100.0
        self.updateFrequencyLabel.setText("UpdateFrequency: %s" %
                                          self.updateFrequency)

    def _updatePlot(self):
        self.updatePlotNow = True
    def __init__(self, pinballMazeEnv, stateSpace):        
        super(PinballMazeTrajectoryViewer, self).__init__()
        
        self.pinballMazeEnv = pinballMazeEnv
        
        self.dimensions = [stateSpace[dimName] for dimName in sorted(stateSpace.keys())]
        
        # The segments that are obtained while drawing is disabled. These
        # segment are drawn one drawing is reenabled 
        self.rememberedSegments = []
        
        # The eval function that can be used for coloring the trajectory
        self.evalFunction = None
        
        self.colorsCycle = cycle([(1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0),
                                  (1.0, 1.0, 0.0), (0.0, 1.0, 1.0), (1.0, 0.0, 1.0),
                                  (0.5, 0.0, 0.0), (0.0, 0.5, 0.0), (0.0, 0.0, 0.5)])
        self.colors = defaultdict(lambda : self.colorsCycle.next())
        self.valueToColorMapping = dict()
        
        # Get required observables
        self.trajectoryObservable = \
                OBSERVABLES.getAllObservablesOfType(TrajectoryObservable)[0]
        
        # Create matplotlib widgets
        plotWidget = QtGui.QWidget(self)
        plotWidget.setMinimumSize(600, 500)
        plotWidget.setWindowTitle("Pinball Maze")
 
        self.fig = Figure((6.0, 5.0), dpi=100)
        self.axis = self.fig.gca()
        self.pinballMazeEnv.plotStateSpaceStructure(self.axis)
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(plotWidget)    
        self.canvas.draw()
        
        self.ballPatch = None
        self.linePatches = []
        
        # Add other elements to GUI           
        self.drawingEnabledCheckbox = \
                QtGui.QCheckBox("Drawing enabled", self)
        self.drawingEnabledCheckbox.setChecked(True)
        
        self.drawStyle = "Current Position"
        self.drawStyleLabel = QtGui.QLabel("Draw style")
        self.drawStyleComboBox = QtGui.QComboBox(self)
        self.drawStyleComboBox.addItems(["Current Position", "Last Episode", 
                                         "Online (All)"])
        self.connect(self.drawStyleComboBox,
                     QtCore.SIGNAL('activated (const QString&)'), 
                     self._drawStyleChanged)
                
        self.colorCriterion = "Action"
        self.colorCriterionLabel = QtGui.QLabel("Coloring of trajectory")
        self.colorCriterionComboBox = QtGui.QComboBox(self)
        self.colorCriterionComboBox.addItems(["Action", "Reward", "Q-Value"])
        self.connect(self.colorCriterionComboBox,
                     QtCore.SIGNAL('activated (const QString&)'), 
                     self._colorCriterionChanged) 
                
        # Legend of plot
        self.legendLabel = QtGui.QLabel("Legend:")
        self.legendWidget = QtGui.QListWidget(self)
        
        # Create layout
        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(plotWidget)
        self.vlayout = QtGui.QVBoxLayout()
        self.vlayout.addWidget(self.drawingEnabledCheckbox)
        self.drawStyleLayout = QtGui.QHBoxLayout()
        self.drawStyleLayout.addWidget(self.drawStyleLabel)
        self.drawStyleLayout.addWidget(self.drawStyleComboBox)
        self.vlayout.addLayout(self.drawStyleLayout)
        self.coloringLayout = QtGui.QHBoxLayout()
        self.coloringLayout.addWidget(self.colorCriterionLabel)
        self.coloringLayout.addWidget(self.colorCriterionComboBox)
        self.vlayout.addLayout(self.coloringLayout)
        self.vlayout.addWidget(self.legendLabel)
        self.vlayout.addWidget(self.legendWidget)
        self.hlayout.addLayout(self.vlayout)
        self.setLayout(self.hlayout)
        
        # Connect to observer (has to be the last thing!!)
        self.trajectoryObservableCallback = \
             lambda *transition: self._updateSamples(*transition)
        self.trajectoryObservable.addObserver(self.trajectoryObservableCallback)
class PinballMazeTrajectoryViewer(Viewer):
    
    def __init__(self, pinballMazeEnv, stateSpace):        
        super(PinballMazeTrajectoryViewer, self).__init__()
        
        self.pinballMazeEnv = pinballMazeEnv
        
        self.dimensions = [stateSpace[dimName] for dimName in sorted(stateSpace.keys())]
        
        # The segments that are obtained while drawing is disabled. These
        # segment are drawn one drawing is reenabled 
        self.rememberedSegments = []
        
        # The eval function that can be used for coloring the trajectory
        self.evalFunction = None
        
        self.colorsCycle = cycle([(1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0),
                                  (1.0, 1.0, 0.0), (0.0, 1.0, 1.0), (1.0, 0.0, 1.0),
                                  (0.5, 0.0, 0.0), (0.0, 0.5, 0.0), (0.0, 0.0, 0.5)])
        self.colors = defaultdict(lambda : self.colorsCycle.next())
        self.valueToColorMapping = dict()
        
        # Get required observables
        self.trajectoryObservable = \
                OBSERVABLES.getAllObservablesOfType(TrajectoryObservable)[0]
        
        # Create matplotlib widgets
        plotWidget = QtGui.QWidget(self)
        plotWidget.setMinimumSize(600, 500)
        plotWidget.setWindowTitle("Pinball Maze")
 
        self.fig = Figure((6.0, 5.0), dpi=100)
        self.axis = self.fig.gca()
        self.pinballMazeEnv.plotStateSpaceStructure(self.axis)
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(plotWidget)    
        self.canvas.draw()
        
        self.ballPatch = None
        self.linePatches = []
        
        # Add other elements to GUI           
        self.drawingEnabledCheckbox = \
                QtGui.QCheckBox("Drawing enabled", self)
        self.drawingEnabledCheckbox.setChecked(True)
        
        self.drawStyle = "Current Position"
        self.drawStyleLabel = QtGui.QLabel("Draw style")
        self.drawStyleComboBox = QtGui.QComboBox(self)
        self.drawStyleComboBox.addItems(["Current Position", "Last Episode", 
                                         "Online (All)"])
        self.connect(self.drawStyleComboBox,
                     QtCore.SIGNAL('activated (const QString&)'), 
                     self._drawStyleChanged)
                
        self.colorCriterion = "Action"
        self.colorCriterionLabel = QtGui.QLabel("Coloring of trajectory")
        self.colorCriterionComboBox = QtGui.QComboBox(self)
        self.colorCriterionComboBox.addItems(["Action", "Reward", "Q-Value"])
        self.connect(self.colorCriterionComboBox,
                     QtCore.SIGNAL('activated (const QString&)'), 
                     self._colorCriterionChanged) 
                
        # Legend of plot
        self.legendLabel = QtGui.QLabel("Legend:")
        self.legendWidget = QtGui.QListWidget(self)
        
        # Create layout
        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(plotWidget)
        self.vlayout = QtGui.QVBoxLayout()
        self.vlayout.addWidget(self.drawingEnabledCheckbox)
        self.drawStyleLayout = QtGui.QHBoxLayout()
        self.drawStyleLayout.addWidget(self.drawStyleLabel)
        self.drawStyleLayout.addWidget(self.drawStyleComboBox)
        self.vlayout.addLayout(self.drawStyleLayout)
        self.coloringLayout = QtGui.QHBoxLayout()
        self.coloringLayout.addWidget(self.colorCriterionLabel)
        self.coloringLayout.addWidget(self.colorCriterionComboBox)
        self.vlayout.addLayout(self.coloringLayout)
        self.vlayout.addWidget(self.legendLabel)
        self.vlayout.addWidget(self.legendWidget)
        self.hlayout.addLayout(self.vlayout)
        self.setLayout(self.hlayout)
        
        # Connect to observer (has to be the last thing!!)
        self.trajectoryObservableCallback = \
             lambda *transition: self._updateSamples(*transition)
        self.trajectoryObservable.addObserver(self.trajectoryObservableCallback)
    
    def close(self):
        self.trajectoryObservable.removeObserver(self.trajectoryObservableCallback)
        
        super(PinballMazeTrajectoryViewer, self).close()
                
    def _updateSamples(self, state, action, reward, succState, episodeTerminated):
        # Determine color
        if self.colorCriterion == "Action":
            value = action 
        elif self.colorCriterion == "Reward": 
            value = reward
        elif self.colorCriterion == "Q-Value":
            if self.evalFunction is None: return
            queryState = State((succState['x'], succState['xdot'], 
                                succState['y'], succState['ydot']), 
                               self.dimensions)
            value = self.evalFunction(queryState)
            
            self.minValue = min(value, self.minValue)
            self.maxValue = max(value, self.maxValue)

        if self.drawingEnabledCheckbox.checkState(): # Immediate drawing           
            # Remove ball patch if it is drawn currently
            if self.ballPatch != None:
                self.ballPatch.remove()
                self.ballPatch = None
                
            if self.drawStyle == "Current Position":
                # Remove old trajectory
                self._removeTrajectory()
                self.rememberedSegments = []
                # Plot ball     
                self.ballPatch = Circle([state["x"], state["y"]], 
                                        self.pinballMazeEnv.maze.ballRadius, facecolor='k') 
                self.axis.add_patch(self.ballPatch)
                self.canvas.draw()
            elif self.drawStyle == "Online (All)":   
                # If drawing was just reactivated
                self._drawRememberedSegments()
                # Draw current transition             
                lines = self.axis.plot([state["x"], succState["x"]], 
                                       [state["y"], succState["y"]], '-',
                                       color=self._determineColor(value))
                self.linePatches.extend(lines)
                self.canvas.draw()
            else: # "Last Episode"
                # Remember state trajectory, it will be drawn at the end 
                # of the episode
                self.rememberedSegments.append((state["x"], succState["x"],
                                                state["y"], succState["y"], 
                                                value))
                if episodeTerminated:
                    # Remove last trajectory, draw this episode's trajectory
                    self._removeTrajectory()
                    self._drawRememberedSegments()
                    self.canvas.draw()
                    # When coloring trajectory based on real valued criteria,
                    # we have to update the legend now 
                    if self.colorCriterion == "Q-Value":
                        self.legendWidget.clear()
                        for value in numpy.logspace(0, numpy.log10(self.maxValue - self.minValue + 1), 10):
                            value = value - 1 + self.minValue
                            
                            color = self._determineColor(value)
                            item = QtGui.QListWidgetItem(str(value), self.legendWidget)
                            qColor = QtGui.QColor(int(color[0]*255),
                                                  int(color[1]*255), 
                                                  int(color[2]*255))
                            item.setTextColor(qColor)
                            self.legendWidget.addItem(item) 
        else:
            if self.drawStyle != "Current Position":
                # Remember state trajectory, it will be drawn once drawing is
                # reenabled
                self.rememberedSegments.append((state["x"], succState["x"],
                                                state["y"], succState["y"], 
                                                value))
                
    def _determineColor(self, value):
        # Choose the color for the value
        if self.colorCriterion in ["Action", "Reward"]:
            # Finite number of values 
            if value not in self.valueToColorMapping:
                color = self.colorsCycle.next()
                self.valueToColorMapping[value] = color
                
                # Add to legend
                item = QtGui.QListWidgetItem(str(value), self.legendWidget)
                qColor = QtGui.QColor(int(color[0]*255), int(color[1]*255), 
                                     int(color[2]*255))
                item.setTextColor(qColor)
                self.legendWidget.addItem(item)
                   
            return self.valueToColorMapping[value]
        else:
            if self.maxValue != self.minValue:
                alpha = numpy.log10(value - self.minValue + 1) \
                                / numpy.log10(self.maxValue - self.minValue + 1) 
            else:
                alpha = 0.5
            return (alpha, 0, 1-alpha)
            
    
    def _removeTrajectory(self):
        if len(self.linePatches) > 0:
            for line in self.linePatches:
                line.remove()
            self.linePatches = []
            
    def _drawRememberedSegments(self):
        if len(self.rememberedSegments) > 0:
            for x1, x2, y1, y2, value  in self.rememberedSegments:
                lines = self.axis.plot([x1, x2], [y1, y2], '-', 
                                       color=self._determineColor(value))
                self.linePatches.extend(lines)
            self.rememberedSegments = []
        
    def _drawStyleChanged(self, drawStyle):
        self.drawStyle = drawStyle
        
        if self.drawStyle != "Last Episode" and self.colorCriterion == "Q-Value":
            # This combination is not possible, change coloring criterion
            self._colorCriterionChanged("Reward")
            
            
    def _colorCriterionChanged(self, colorCriterion):
        # If we changed color criterion 
        if colorCriterion != self.colorCriterion:
            # Remove old trajectory
            self._removeTrajectory()
            self.rememberedSegments = []
            self.legendWidget.clear()
            
        self.colorCriterion = colorCriterion
        self.valueToColorMapping = {}

        if self.colorCriterion == "Q-Value":
            # Register to FunctionOverStateSpaceObservable for global Q-Function
            from mmlf.framework.observables import OBSERVABLES, \
                                             FunctionOverStateSpaceObservable
            for functionObservable in OBSERVABLES.getAllObservablesOfType(
                                            FunctionOverStateSpaceObservable):
                # TODO: Name of function is hard-coded
                if functionObservable.title == "Option TopLevel (optimal value function)":
                    def updateEvalFunction(evalFunction):
                        self.evalFunction = evalFunction
                    functionObservable.addObserver(updateEvalFunction)
                    break
                
            # Displaying Q-Value makes only sense when plotting at the end of 
            # an episode
            self._drawStyleChanged("Last Episode")
            
            # We have to remember minimal and maximal value
            self.minValue = numpy.inf
            self.maxValue = -numpy.inf
            
            
                    
    def __init__(self, pinballMazeEnv, stateSpace):
        super(PinballMazeFunctionViewer, self).__init__()

        self.pinballMazeEnv = pinballMazeEnv
        self.stateSpace = stateSpace
        self.actions = []

        self.updateCounter = 0
        self.updatePlotNow = False
        self.evalFunction = None

        self.lock = threading.Lock()

        # Create matplotlib widgets
        plotWidget = QtGui.QWidget(self)
        plotWidget.setMinimumSize(600, 500)
        plotWidget.setWindowTitle("Pinball Maze")

        self.fig = Figure((6.0, 5.0), dpi=100)
        self.axis = self.fig.gca()
        self.pinballMazeEnv.plotStateSpaceStructure(self.axis)
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(plotWidget)
        self.canvas.draw()

        self.plottedPatches = []

        # Add a combobox for selecting the function over state space that is observed
        self.selectedFunctionObservable = None
        self.functionObservableLabel = QtGui.QLabel(
            "Function over State Space")
        self.functionObservableComboBox = QtGui.QComboBox(self)
        functionObservables = OBSERVABLES.getAllObservablesOfType(
            FunctionOverStateSpaceObservable)
        self.functionObservableComboBox.addItems([
            functionObservable.title
            for functionObservable in functionObservables
        ])
        if len(functionObservables) > 0:
            self.selectedFunctionObservable = functionObservables[0].title

        self.connect(self.functionObservableComboBox,
                     QtCore.SIGNAL('activated (const QString&)'),
                     self._functionObservableChanged)

        # Automatically update funtion observable combobox when new observables
        # are created during runtime
        def updateFunctionObservableBox(viewer, action):
            self.functionObservableComboBox.clear()
            functionObservables = \
                OBSERVABLES.getAllObservablesOfType(FunctionOverStateSpaceObservable)
            self.functionObservableComboBox.addItems([
                functionObservable.title
                for functionObservable in functionObservables
            ])
            if len(functionObservables) > 0:
                self.selectedFunctionObservable = functionObservables[0].title

        OBSERVABLES.addObserver(updateFunctionObservableBox)

        # Slider that controls the granularity of the plot-grid
        self.gridNodesPerDim = 50
        self.gridNodesSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.gridNodesSlider.setValue(self.gridNodesPerDim)
        self.gridNodesSlider.setMinimum(0)
        self.gridNodesSlider.setMaximum(100)
        self.gridNodesSlider.setTickInterval(10)
        self.gridNodesSlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.gridNodesSlider, QtCore.SIGNAL('sliderReleased()'),
                     self._changeGridNodes)
        self.gridNodesLabel = QtGui.QLabel("Grid Nodes Per Dimension: %s" %
                                           self.gridNodesPerDim)

        # Slider that controls the frequency of update the plot
        self.updateFrequency = 0.0
        self.updateFrequencySlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.updateFrequencySlider.setValue(int(self.updateFrequency * 100))
        self.updateFrequencySlider.setMinimum(0)
        self.updateFrequencySlider.setMaximum(100)
        self.updateFrequencySlider.setTickInterval(0.1)
        self.updateFrequencySlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.updateFrequencySlider,
                     QtCore.SIGNAL('sliderReleased()'),
                     self._changeUpdateFrequency)
        self.updateFrequencyLabel = QtGui.QLabel("UpdateFrequency: %s" %
                                                 self.updateFrequency)

        # Button to enforce update of plot
        self.updatePlotButton = QtGui.QPushButton("Update Plot")
        self.connect(self.updatePlotButton, QtCore.SIGNAL('clicked()'),
                     self._updatePlot)

        # Chosen xvel and yvel values
        self.xVel = 0.5
        self.xVelSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.xVelSlider.setValue(int(self.xVel * 10))
        self.xVelSlider.setMinimum(0)
        self.xVelSlider.setMaximum(10)
        self.xVelSlider.setTickInterval(1)
        self.xVelSlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.xVelSlider, QtCore.SIGNAL('sliderReleased()'),
                     self._changeXVel)
        self.xVelLabel = QtGui.QLabel("xvel value: %s" % self.xVel)

        self.yVel = 0.5
        self.yVelSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.yVelSlider.setValue(int(self.yVel * 10))
        self.yVelSlider.setMinimum(0)
        self.yVelSlider.setMaximum(10)
        self.yVelSlider.setTickInterval(1)
        self.yVelSlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.yVelSlider, QtCore.SIGNAL('sliderReleased()'),
                     self._changeYVel)
        self.yVelLabel = QtGui.QLabel("yvel value: %s" % self.xVel)

        # Legend of plot
        self.legendLabel = QtGui.QLabel("Legend:")
        self.legendWidget = QtGui.QListWidget(self)

        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(plotWidget)
        self.vlayout = QtGui.QVBoxLayout()
        self.functionObservableLayout = QtGui.QHBoxLayout()
        self.functionObservableLayout.addWidget(self.functionObservableLabel)
        self.functionObservableLayout.addWidget(
            self.functionObservableComboBox)
        self.vlayout.addLayout(self.functionObservableLayout)
        self.gridNodesLayout = QtGui.QHBoxLayout()
        self.gridNodesLayout.addWidget(self.gridNodesLabel)
        self.gridNodesLayout.addWidget(self.gridNodesSlider)
        self.vlayout.addLayout(self.gridNodesLayout)
        self.updateFrequencyLayout = QtGui.QHBoxLayout()
        self.updateFrequencyLayout.addWidget(self.updateFrequencyLabel)
        self.updateFrequencyLayout.addWidget(self.updateFrequencySlider)
        self.vlayout.addLayout(self.updateFrequencyLayout)
        self.vlayout.addWidget(self.updatePlotButton)
        self.xVelLayout = QtGui.QHBoxLayout()
        self.xVelLayout.addWidget(self.xVelLabel)
        self.xVelLayout.addWidget(self.xVelSlider)
        self.vlayout.addLayout(self.xVelLayout)
        self.yVelLayout = QtGui.QHBoxLayout()
        self.yVelLayout.addWidget(self.yVelLabel)
        self.yVelLayout.addWidget(self.yVelSlider)
        self.vlayout.addLayout(self.yVelLayout)
        self.vlayout.addWidget(self.legendLabel)
        self.vlayout.addWidget(self.legendWidget)
        self.hlayout.addLayout(self.vlayout)
        self.setLayout(self.hlayout)

        # Connect to observer (has to be the last thing!!)
        self.functionObservable = None
        if self.selectedFunctionObservable:
            self.functionObservable = \
                    OBSERVABLES.getObservable(self.selectedFunctionObservable,
                                              FunctionOverStateSpaceObservable)
            self.functionObservableCallback = \
                 lambda evalFunction: self._updateFunction(evalFunction)
            self.functionObservable.addObserver(
                self.functionObservableCallback)
class PinballMazeFunctionViewer(Viewer):
    def __init__(self, pinballMazeEnv, stateSpace):
        super(PinballMazeFunctionViewer, self).__init__()

        self.pinballMazeEnv = pinballMazeEnv
        self.stateSpace = stateSpace
        self.actions = []

        self.updateCounter = 0
        self.updatePlotNow = False
        self.evalFunction = None

        self.lock = threading.Lock()

        # Create matplotlib widgets
        plotWidget = QtGui.QWidget(self)
        plotWidget.setMinimumSize(600, 500)
        plotWidget.setWindowTitle("Pinball Maze")

        self.fig = Figure((6.0, 5.0), dpi=100)
        self.axis = self.fig.gca()
        self.pinballMazeEnv.plotStateSpaceStructure(self.axis)
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(plotWidget)
        self.canvas.draw()

        self.plottedPatches = []

        # Add a combobox for selecting the function over state space that is observed
        self.selectedFunctionObservable = None
        self.functionObservableLabel = QtGui.QLabel(
            "Function over State Space")
        self.functionObservableComboBox = QtGui.QComboBox(self)
        functionObservables = OBSERVABLES.getAllObservablesOfType(
            FunctionOverStateSpaceObservable)
        self.functionObservableComboBox.addItems([
            functionObservable.title
            for functionObservable in functionObservables
        ])
        if len(functionObservables) > 0:
            self.selectedFunctionObservable = functionObservables[0].title

        self.connect(self.functionObservableComboBox,
                     QtCore.SIGNAL('activated (const QString&)'),
                     self._functionObservableChanged)

        # Automatically update funtion observable combobox when new observables
        # are created during runtime
        def updateFunctionObservableBox(viewer, action):
            self.functionObservableComboBox.clear()
            functionObservables = \
                OBSERVABLES.getAllObservablesOfType(FunctionOverStateSpaceObservable)
            self.functionObservableComboBox.addItems([
                functionObservable.title
                for functionObservable in functionObservables
            ])
            if len(functionObservables) > 0:
                self.selectedFunctionObservable = functionObservables[0].title

        OBSERVABLES.addObserver(updateFunctionObservableBox)

        # Slider that controls the granularity of the plot-grid
        self.gridNodesPerDim = 50
        self.gridNodesSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.gridNodesSlider.setValue(self.gridNodesPerDim)
        self.gridNodesSlider.setMinimum(0)
        self.gridNodesSlider.setMaximum(100)
        self.gridNodesSlider.setTickInterval(10)
        self.gridNodesSlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.gridNodesSlider, QtCore.SIGNAL('sliderReleased()'),
                     self._changeGridNodes)
        self.gridNodesLabel = QtGui.QLabel("Grid Nodes Per Dimension: %s" %
                                           self.gridNodesPerDim)

        # Slider that controls the frequency of update the plot
        self.updateFrequency = 0.0
        self.updateFrequencySlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.updateFrequencySlider.setValue(int(self.updateFrequency * 100))
        self.updateFrequencySlider.setMinimum(0)
        self.updateFrequencySlider.setMaximum(100)
        self.updateFrequencySlider.setTickInterval(0.1)
        self.updateFrequencySlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.updateFrequencySlider,
                     QtCore.SIGNAL('sliderReleased()'),
                     self._changeUpdateFrequency)
        self.updateFrequencyLabel = QtGui.QLabel("UpdateFrequency: %s" %
                                                 self.updateFrequency)

        # Button to enforce update of plot
        self.updatePlotButton = QtGui.QPushButton("Update Plot")
        self.connect(self.updatePlotButton, QtCore.SIGNAL('clicked()'),
                     self._updatePlot)

        # Chosen xvel and yvel values
        self.xVel = 0.5
        self.xVelSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.xVelSlider.setValue(int(self.xVel * 10))
        self.xVelSlider.setMinimum(0)
        self.xVelSlider.setMaximum(10)
        self.xVelSlider.setTickInterval(1)
        self.xVelSlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.xVelSlider, QtCore.SIGNAL('sliderReleased()'),
                     self._changeXVel)
        self.xVelLabel = QtGui.QLabel("xvel value: %s" % self.xVel)

        self.yVel = 0.5
        self.yVelSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.yVelSlider.setValue(int(self.yVel * 10))
        self.yVelSlider.setMinimum(0)
        self.yVelSlider.setMaximum(10)
        self.yVelSlider.setTickInterval(1)
        self.yVelSlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.yVelSlider, QtCore.SIGNAL('sliderReleased()'),
                     self._changeYVel)
        self.yVelLabel = QtGui.QLabel("yvel value: %s" % self.xVel)

        # Legend of plot
        self.legendLabel = QtGui.QLabel("Legend:")
        self.legendWidget = QtGui.QListWidget(self)

        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(plotWidget)
        self.vlayout = QtGui.QVBoxLayout()
        self.functionObservableLayout = QtGui.QHBoxLayout()
        self.functionObservableLayout.addWidget(self.functionObservableLabel)
        self.functionObservableLayout.addWidget(
            self.functionObservableComboBox)
        self.vlayout.addLayout(self.functionObservableLayout)
        self.gridNodesLayout = QtGui.QHBoxLayout()
        self.gridNodesLayout.addWidget(self.gridNodesLabel)
        self.gridNodesLayout.addWidget(self.gridNodesSlider)
        self.vlayout.addLayout(self.gridNodesLayout)
        self.updateFrequencyLayout = QtGui.QHBoxLayout()
        self.updateFrequencyLayout.addWidget(self.updateFrequencyLabel)
        self.updateFrequencyLayout.addWidget(self.updateFrequencySlider)
        self.vlayout.addLayout(self.updateFrequencyLayout)
        self.vlayout.addWidget(self.updatePlotButton)
        self.xVelLayout = QtGui.QHBoxLayout()
        self.xVelLayout.addWidget(self.xVelLabel)
        self.xVelLayout.addWidget(self.xVelSlider)
        self.vlayout.addLayout(self.xVelLayout)
        self.yVelLayout = QtGui.QHBoxLayout()
        self.yVelLayout.addWidget(self.yVelLabel)
        self.yVelLayout.addWidget(self.yVelSlider)
        self.vlayout.addLayout(self.yVelLayout)
        self.vlayout.addWidget(self.legendLabel)
        self.vlayout.addWidget(self.legendWidget)
        self.hlayout.addLayout(self.vlayout)
        self.setLayout(self.hlayout)

        # Connect to observer (has to be the last thing!!)
        self.functionObservable = None
        if self.selectedFunctionObservable:
            self.functionObservable = \
                    OBSERVABLES.getObservable(self.selectedFunctionObservable,
                                              FunctionOverStateSpaceObservable)
            self.functionObservableCallback = \
                 lambda evalFunction: self._updateFunction(evalFunction)
            self.functionObservable.addObserver(
                self.functionObservableCallback)

    def close(self):
        if self.functionObservable is not None:
            # Disconnect from old function observable
            self.functionObservable.removeObserver(
                self.functionObservableCallback)

        super(PinballMazeFunctionViewer, self).close()

    def _updateFunction(self, evalFunction):
        if (self.updateFrequency > 0.0
                and self.updateCounter > numpy.float(1.0)/self.updateFrequency) \
             or self.updatePlotNow:
            # Reset counter and update plot
            self.updateCounter = 0
            self.updatePlotNow = False
            self.evalFunction = evalFunction
            self._plotFunction()
        else:
            # Do not update the plot
            self.updateCounter += 1

    def _plotFunction(self):
        if self.evalFunction is None:
            return

        self.lock.acquire()

        # Clean up old plot
        for patch in self.plottedPatches:
            patch.remove()
        self.plottedPatches = []

        # Check if the observed function returns discrete or continuous value
        discreteFunction = self.functionObservable.discreteValues

        # Generate 2d state slice
        defaultDimValues = {
            "x": 0,
            "xdot": self.xVel,
            "y": 0,
            "ydot": self.yVel
        }
        stateSlice = generate2dStateSlice(["x", "y"],
                                          self.stateSpace,
                                          defaultDimValues,
                                          gridNodesPerDim=self.gridNodesPerDim)

        # Compute values that should be plotted
        values, colorMapping = \
                generate2dPlotArray(self.evalFunction, stateSlice,
                                    not discreteFunction,
                                    shape=(self.gridNodesPerDim,self.gridNodesPerDim))

        # If all value are None, we cannot draw anything useful
        if values.mask.all():
            self.lock.release()
            return

        polyCollection = self.axis.pcolor(
            numpy.linspace(0.0, 1.0, self.gridNodesPerDim),
            numpy.linspace(0.0, 1.0, self.gridNodesPerDim), values.T)
        self.plottedPatches.append(polyCollection)

        # Set axis limits
        self.axis.set_xlim(0, 1)
        self.axis.set_ylim(0, 1)
        # Update legend
        self.legendWidget.clear()
        linearSegmentedColorbar = polyCollection.get_cmap()
        if discreteFunction:
            for functionValue, colorValue in colorMapping.items():
                if isinstance(functionValue, tuple):
                    functionValue = functionValue[0]  # deal with '(action,)'
                normValue = polyCollection.norm(colorValue)
                if isinstance(normValue, numpy.ndarray):
                    normValue = normValue[
                        0]  # happens when function is constant
                rgbaColor = linearSegmentedColorbar(normValue)
                item = QtGui.QListWidgetItem(str(functionValue),
                                             self.legendWidget)
                color = QtGui.QColor(int(rgbaColor[0] * 255),
                                     int(rgbaColor[1] * 255),
                                     int(rgbaColor[2] * 255))
                item.setTextColor(color)
                self.legendWidget.addItem(item)
        else:
            for value in numpy.linspace(polyCollection.norm.vmin,
                                        polyCollection.norm.vmax, 10):
                normValue = polyCollection.norm(value)
                if isinstance(normValue, numpy.ndarray):
                    normValue = normValue[
                        0]  # happens when function is constant
                rgbaColor = linearSegmentedColorbar(normValue)
                item = QtGui.QListWidgetItem(str(value), self.legendWidget)
                color = QtGui.QColor(int(rgbaColor[0] * 255),
                                     int(rgbaColor[1] * 255),
                                     int(rgbaColor[2] * 255))
                item.setTextColor(color)
                self.legendWidget.addItem(item)

        self.canvas.draw()

        self.lock.release()

    def _functionObservableChanged(self, selectedFunctionObservable):
        self.lock.acquire()
        if self.functionObservable is not None:
            # Disconnect from old function observable
            self.functionObservable.removeObserver(
                self.functionObservableCallback)

        # Determine new observed function observable
        self.selectedFunctionObservable = str(selectedFunctionObservable)

        # Connect to new function observable
        self.functionObservable = OBSERVABLES.getObservable(
            self.selectedFunctionObservable, FunctionOverStateSpaceObservable)
        self.functionObservableCallback = \
             lambda evalFunction: self._updateFunction(evalFunction)
        self.functionObservable.addObserver(self.functionObservableCallback)

        self.actions = []
        self.lock.release()

    def _changeGridNodes(self):
        self.gridNodesPerDim = self.gridNodesSlider.value()
        self.gridNodesLabel.setText("Grid Nodes Per Dimension: %s" %
                                    self.gridNodesPerDim)
        # update plot
        self._plotFunction()

    def _changeUpdateFrequency(self):
        self.updateFrequency = self.updateFrequencySlider.value() / 100.0
        self.updateFrequencyLabel.setText("UpdateFrequency: %s" %
                                          self.updateFrequency)

    def _updatePlot(self):
        self.updatePlotNow = True

    def _changeXVel(self):
        self.xVel = self.xVelSlider.value() / 10.0
        self.xVelLabel.setText("xvel value: %s" % self.xVel)

    def _changeYVel(self):
        self.yVel = self.yVelSlider.value() / 10.0
        self.yVelLabel.setText("yvel value: %s" % self.yVel)
class SeventeenAndFourValuefunctionViewer(Viewer):
    def __init__(self, stateSpace):
        super(SeventeenAndFourValuefunctionViewer, self).__init__()

        self.stateSpace = stateSpace
        self.states = stateSpace["count"]["dimensionValues"]

        # Combo Box for selecting the observable
        self.comboBox = QtGui.QComboBox(self)
        self.stateActionValuesObservables = \
                OBSERVABLES.getAllObservablesOfType(StateActionValuesObservable)
        self.comboBox.addItems(
            map(lambda x: "%s" % x.title, self.stateActionValuesObservables))
        self.connect(self.comboBox, QtCore.SIGNAL('currentIndexChanged (int)'),
                     self._observableChanged)

        # Automatically update combobox when new float stream observables
        #  are created during runtime
        def updateComboBox(observable, action):
            self.comboBox.clear()
            self.stateActionValuesObservables = \
                    OBSERVABLES.getAllObservablesOfType(StateActionValuesObservable)
            self.comboBox.addItems(
                map(lambda x: "%s" % x.title,
                    self.stateActionValuesObservables))

        OBSERVABLES.addObserver(updateComboBox)

        # Create matplotlib widgets
        plotWidgetValueFunction = QtGui.QWidget(self)
        plotWidgetValueFunction.setMinimumSize(800, 500)

        self.figValueFunction = Figure((8.0, 5.0), dpi=100)
        #self.figValueFunction.subplots_adjust(left=0.01, bottom=0.04, right=0.99,
        #                               top= 0.95, wspace=0.05, hspace=0.11)
        self.axisValueFunction = self.figValueFunction.gca()
        self.canvasValueFunction = FigureCanvas(self.figValueFunction)
        self.canvasValueFunction.setParent(plotWidgetValueFunction)

        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(plotWidgetValueFunction)
        self.hlayout.addWidget(self.comboBox)
        self.setLayout(self.hlayout)

        # Connect to observer (has to be the last thing!!)
        self.stateActionValuesObservableCallback = \
             lambda valueAccessFunction, actions: self.updateValues(valueAccessFunction, actions)
        if len(self.stateActionValuesObservables) > 0:
            # Show per default the first observable
            self.stateActionValuesObservable = self.stateActionValuesObservables[
                0]
            plotWidgetValueFunction.setWindowTitle(
                self.stateActionValuesObservable.title)

            self.stateActionValuesObservable.addObserver(
                self.stateActionValuesObservableCallback)
        else:
            self.stateActionValuesObservable = None

    def close(self):
        if self.stateActionValuesObservable is not None:
            # Remove old observable
            self.stateActionValuesObservable.removeObserver(
                self.stateActionValuesObservableCallback)

        super(SeventeenAndFourValuefunctionViewer, self).close()

    def updateValues(self, valueAccessFunction, actions):
        self.axisValueFunction.clear()
        for action in actions:
            actionValues = []
            for state in sorted(self.states):
                actionValues.append(
                    valueAccessFunction(
                        State([state], self.stateSpace.values()), action))

            self.axisValueFunction.plot(sorted(self.states),
                                        actionValues,
                                        label=str(action))

        self.axisValueFunction.set_xlabel('Sum of cards')
        self.axisValueFunction.set_ylabel('Value')
        self.axisValueFunction.legend()
        self.canvasValueFunction.draw()

    def _observableChanged(self, comboBoxIndex):
        if self.stateActionValuesObservable is not None:
            # Remove old observable
            self.stateActionValuesObservable.removeObserver(
                self.stateActionValuesObservableCallback)
        # Get new observable and add as listener
        self.stateActionValuesObservable = self.stateActionValuesObservables[
            comboBoxIndex]
        self.stateActionValuesObservable.addObserver(
            self.stateActionValuesObservableCallback)
Example #17
0
    def __init__(self, tableModel, parent):  
        super(PlotExperimentWindow, self).__init__(parent)
            
        self.tableModel = tableModel
        # Colors used for different configurations in plots
        self.colors = cycle(["b", "g", "r", "c", "m", "y", "k"])
        self.colorMapping = defaultdict(lambda : self.colors.next())
        # The length of the moving window average 
        self.mwaSize = 2**0
        # Whether we plot each run separately or only their mean
        self.linePlotTypes = ["Each Run", "Average"]
        self.linePlot = self.linePlotTypes[0]
        
        # The central widget
        self.centralWidget = QtGui.QWidget(self)
        
        # Create matplotlib widget
        self.plotWidget = QtGui.QWidget(self.centralWidget)
        self.plotWidget.setMinimumSize(800, 500)
 
        self.fig = Figure((8.0, 5.0), dpi=100)
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(self.plotWidget)
        self.axis = self.fig.gca()
        
        self.mplToolbar = NavigationToolbar(self.canvas, self.centralWidget)
        
        self.mwaLabel = QtGui.QLabel("Moving Window Average: %s" % self.mwaSize)        
        # Slider for controlling the moving average window
        self.mwaSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.mwaSlider.setValue(0)
        self.mwaSlider.setMinimum(0)
        self.mwaSlider.setMaximum(10)
        self.mwaSlider.setTickInterval(10)
        self.mwaSlider.setTickPosition(QtGui.QSlider.TicksBelow)
        
        self.connect(self.mwaSlider, QtCore.SIGNAL('sliderReleased()'), 
                     self._changeMWA)
        
        self.lineLabel = QtGui.QLabel("Plot of agent: ")
        # Combo Box for selecting the observable
        self.lineComboBox = QtGui.QComboBox(self)
        self.lineComboBox.addItems(["Each Run", "Average"])
        self.connect(self.lineComboBox, QtCore.SIGNAL('currentIndexChanged (int)'), 
                     self._linePlotChanged)
    
        # Add a button for replotting
        self.replotButton = QtGui.QPushButton("Update")
        self.connect(self.replotButton, QtCore.SIGNAL('clicked()'), 
                     self._plot)
        
        # Add a button for saving a plot
        self.saveButton = QtGui.QPushButton("Save")
        self.connect(self.saveButton, QtCore.SIGNAL('clicked()'), 
                     self._save)
        
        # Set layout
        self.topLinelayout = QtGui.QHBoxLayout()
        self.topLinelayout.addWidget(self.mwaLabel)
        self.topLinelayout.addWidget(self.mwaSlider)
        self.topLinelayout.addWidget(self.lineLabel)
        self.topLinelayout.addWidget(self.lineComboBox)
        self.topLinelayout.addWidget(self.replotButton)
        self.topLinelayout.addWidget(self.saveButton)
        self.vlayout = QtGui.QVBoxLayout()
        self.vlayout.addLayout(self.topLinelayout)
        self.vlayout.addWidget(self.plotWidget)
        self.vlayout.addWidget(self.mplToolbar)
        self.centralWidget.setLayout(self.vlayout)
        
        self.setCentralWidget(self.centralWidget)
        self.setWindowTitle("Current experiment's results")
        
        # Plot the results once upon creation
        self._plot()
Example #18
0
class StatisticalAnalysisWidget(QtGui.QWidget):
    
    def __init__(self, experimentResults, parent=None):
        super(StatisticalAnalysisWidget, self).__init__(parent)
        
        self.experimentResults = experimentResults
        
        # Statistical test
        self.TESTS = {'MannWhitney U-Test': lambda x, y: scipy.stats.mannwhitneyu(x,y)[1],
                      'Student t-test': lambda x, y: scipy.stats.ttest_ind(x,y)[1]/2}
                
        # Create combobox for selecting the metric
        metricsLabel = QtGui.QLabel("Metric")
        self.metricsComboBox = QtGui.QComboBox(self)
        self.metricsComboBox.addItems(self.experimentResults.metrics)
        
        # Text field for the aggregation function
        aggregationLabel = QtGui.QLabel("Aggregation") 
        self.aggregationFctEdit = QtGui.QLineEdit("lambda x: mean(x[:])")
        self.aggregationFctEdit.minimumSizeHint = lambda : QtCore.QSize(100,30)
        self.aggregationFctEdit.setToolTip("Function which maps a time series "
                                           "onto a single scalar value, which "
                                           "is then used as a sample in "
                                           "the statistical hypothesis testing."
                                           "The functions min, max, mean, and "
                                           "median may be used.")
        
        # Create combobox for selecting the test
        testLabel = QtGui.QLabel("Hypothesis test")
        self.testComboBox = QtGui.QComboBox(self)
        self.testComboBox.addItems(self.TESTS.keys()) 
        
        # Text field for the p-Value
        pValueLabel = QtGui.QLabel("p <") 
        self.pValueEdit = QtGui.QLineEdit("0.05")
        self.pValueEdit.minimumSizeHint = lambda : QtCore.QSize(100,30)
        self.pValueEdit.setToolTip("Significance level: The minimal p-Value "
                                   "which is required for something to be "
                                   "considered as significant.")
        
        # button for redoing the statistics for the current setting
        self.updateButton = QtGui.QPushButton("Update")
        self.connect(self.updateButton, QtCore.SIGNAL('clicked()'), 
                     self._analyze)
                
        # Create matplotlib widget
        plotWidget = QtGui.QWidget(self)
        plotWidget.setMinimumSize(500, 500)
 
        fig = Figure((5.0, 5.0), dpi=100)
        fig.subplots_adjust(0.2)
        self.canvas = FigureCanvas(fig)
        self.canvas.setParent(plotWidget)
        self.axis = fig.gca()
        
        # The table for statistics results
        self.significanceTable = QtGui.QTableWidget(self)
        
        # Do the analyzing once for the default values
        self._analyze()
        
        # Create layout
        layout = QtGui.QVBoxLayout()
        hlayout1 = QtGui.QHBoxLayout()
        hlayout1.addWidget(metricsLabel)
        hlayout1.addWidget(self.metricsComboBox)
        hlayout1.addWidget(aggregationLabel)
        hlayout1.addWidget(self.aggregationFctEdit)
        hlayout1.addWidget(testLabel)
        hlayout1.addWidget(self.testComboBox)
        hlayout1.addWidget(pValueLabel)
        hlayout1.addWidget(self.pValueEdit)
        hlayout1.addWidget(self.updateButton)
        hlayout2 = QtGui.QHBoxLayout()
        hlayout2.addWidget(plotWidget)
        hlayout2.addWidget(self.significanceTable)
        layout.addLayout(hlayout1)
        layout.addLayout(hlayout2)
        self.setLayout(layout)
        
    def _analyze(self):
        # get the raw metric's values
        data = self.experimentResults.runData[str(self.metricsComboBox.currentText())]
        # Compute averages over relevant area and sort them according to configuration
        performances = defaultdict(list)
        for (config, run), values in data.iteritems():
            aggregationFct = eval(str(self.aggregationFctEdit.text()))
            performances[config].append(aggregationFct(map(itemgetter(1), values)))
        
        # Do the plotting
        self.axis.clear()
        self.axis.boxplot([performances[key] for key in sorted(performances.keys())])
        self.axis.set_xticklabels(sorted(performances.keys()))
        self.axis.set_ylabel(str(self.metricsComboBox.currentText()))
        
        # Prepare significanceTable
        self.significanceTable.clear()
        self.significanceTable.setRowCount(len(performances.keys()))
        self.significanceTable.setColumnCount(len(performances.keys()))

        # Setting tables headers    
        self.significanceTable.setHorizontalHeaderLabels(
                                    ["y=%s" % key for key in sorted(performances.keys())])
        self.significanceTable.setVerticalHeaderLabels(
                                    ["x=%s" % key for key in sorted(performances.keys())])
        
        # Add actual p-Value into table
        for index1, config1 in enumerate(sorted(performances.keys())):
            for index2, config2 in enumerate(sorted(performances.keys())):
                # Compute p-Value with selected test
                testFct = self.TESTS[str(self.testComboBox.currentText())]
                try:
                    pValue = testFct(performances[config1], performances[config2])
                except ValueError:
                    pValue = nan
                                        
                # Distinguish a>b and a<b
                if mean(performances[config1]) <  mean(performances[config2]):
                    pValue = 1 - pValue
                
                tableWidgetItem = QtGui.QTableWidgetItem("x>y: p = %.4f" % pValue)
                if pValue < float(str(self.pValueEdit.text())):
                    tableWidgetItem.setFont(QtGui.QFont("Times", 10, QtGui.QFont.Bold))
                else:
                    tableWidgetItem.setFont(QtGui.QFont("Times", 10))
                    
                self.significanceTable.setItem(index1, index2, tableWidgetItem)
                
        self.significanceTable.update()
        self.canvas.draw()
                
        
        
Example #19
0
    def __init__(self, stateSpace):        
        super(TrajectoryViewer, self).__init__()
        
        self.stateSpace = stateSpace
        
        # Define colors for plotting
        self.colors = itertools.cycle(['g', 'b', 'c', 'm', 'k', 'y'])
                
        self.plotLines = deque()
        
        # Combo Boxes for selecting displaced state space dimensions
        self.comboBox1 = QtGui.QComboBox(self)
        self.comboBox1.addItems(sorted(stateSpace.keys()))
        self.comboBox2 = QtGui.QComboBox(self)
        self.comboBox2.addItems(sorted(stateSpace.keys()))
        self.comboBox2.setCurrentIndex(1)
        self.dimension1 = sorted(self.stateSpace.keys())[0]
        self.dimension2 = sorted(self.stateSpace.keys())[1]
        self.connect(self.comboBox1, QtCore.SIGNAL('currentIndexChanged (int)'), 
                     self._dimension1Changed)
        self.connect(self.comboBox2, QtCore.SIGNAL('currentIndexChanged (int)'), 
                     self._dimension2Changed)
        
        # Slider for controlling the number of Trajectories
        self.maxLines = 5
        self.numberTrajectoriesSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.numberTrajectoriesSlider.setValue(self.maxLines)
        self.numberTrajectoriesSlider.setMinimum(1)
        self.numberTrajectoriesSlider.setMaximum(20)
        self.numberTrajectoriesSlider.setTickInterval(5)
        self.numberTrajectoriesSlider.setTickPosition(QtGui.QSlider.TicksBelow)
        
        self.connect(self.numberTrajectoriesSlider, QtCore.SIGNAL('sliderReleased()'), 
                     self._changeNumTrajectories)
        
        # Checkbox for dis-/enabling trajectory plotting
        self.plottingEnabled = True
        self.enabledCheckBox = QtGui.QCheckBox("Plotting Enabled")
        self.enabledCheckBox.setChecked(self.plottingEnabled)
        self.connect(self.enabledCheckBox, QtCore.SIGNAL('stateChanged(int)'), 
                     self._enablingPlotting)
        
        # Some labels
        self.dimension1Label = QtGui.QLabel("Dimension X Axis")
        self.dimension2Label = QtGui.QLabel("Dimension Y Axis")
        self.numTrajectoriesLabel = QtGui.QLabel("Trajectories shown")
        
        # Create matplotlib widgets
        plotWidgetTrajectory = QtGui.QWidget(self)
        plotWidgetTrajectory.setMinimumSize(800, 500)
 
        self.figTrajectory = Figure((8.0, 5.0), dpi=100)
        self.axisTrajectory = self.figTrajectory .gca()
        self.canvasTrajectory = FigureCanvas(self.figTrajectory)
        self.canvasTrajectory .setParent(plotWidgetTrajectory )
        
        # Initialize plotting
        self._reinitializePlot()
        
        # Create layout
        self.vlayout = QtGui.QVBoxLayout()
        self.hlayout1 = QtGui.QHBoxLayout()
        self.hlayout1.addWidget(self.dimension1Label)
        self.hlayout1.addWidget(self.comboBox1)
        self.hlayout2 = QtGui.QHBoxLayout()
        self.hlayout2.addWidget(self.dimension2Label)
        self.hlayout2.addWidget(self.comboBox2)
        self.hlayout3 = QtGui.QHBoxLayout()
        self.hlayout3.addWidget(self.numTrajectoriesLabel)
        self.hlayout3.addWidget(self.numberTrajectoriesSlider)
        
        self.vlayout.addLayout(self.hlayout1)
        self.vlayout.addLayout(self.hlayout2)
        self.vlayout.addLayout(self.hlayout3)
        self.vlayout.addWidget(self.enabledCheckBox)
        
        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(plotWidgetTrajectory)
        self.hlayout.addLayout(self.vlayout)
        self.setLayout(self.hlayout)
        
        # Connect to trajectory observable
        self.trajectoryObservable = \
                OBSERVABLES.getAllObservablesOfType(TrajectoryObservable)[0]
        self.trajectoryObservableCallback = \
             lambda *transition: self.addTransition(*transition)
        self.trajectoryObservable.addObserver(self.trajectoryObservableCallback)
Example #20
0
    def __init__(self, experimentResults, parent=None):
        super(StatisticalAnalysisWidget, self).__init__(parent)
        
        self.experimentResults = experimentResults
        
        # Statistical test
        self.TESTS = {'MannWhitney U-Test': lambda x, y: scipy.stats.mannwhitneyu(x,y)[1],
                      'Student t-test': lambda x, y: scipy.stats.ttest_ind(x,y)[1]/2}
                
        # Create combobox for selecting the metric
        metricsLabel = QtGui.QLabel("Metric")
        self.metricsComboBox = QtGui.QComboBox(self)
        self.metricsComboBox.addItems(self.experimentResults.metrics)
        
        # Text field for the aggregation function
        aggregationLabel = QtGui.QLabel("Aggregation") 
        self.aggregationFctEdit = QtGui.QLineEdit("lambda x: mean(x[:])")
        self.aggregationFctEdit.minimumSizeHint = lambda : QtCore.QSize(100,30)
        self.aggregationFctEdit.setToolTip("Function which maps a time series "
                                           "onto a single scalar value, which "
                                           "is then used as a sample in "
                                           "the statistical hypothesis testing."
                                           "The functions min, max, mean, and "
                                           "median may be used.")
        
        # Create combobox for selecting the test
        testLabel = QtGui.QLabel("Hypothesis test")
        self.testComboBox = QtGui.QComboBox(self)
        self.testComboBox.addItems(self.TESTS.keys()) 
        
        # Text field for the p-Value
        pValueLabel = QtGui.QLabel("p <") 
        self.pValueEdit = QtGui.QLineEdit("0.05")
        self.pValueEdit.minimumSizeHint = lambda : QtCore.QSize(100,30)
        self.pValueEdit.setToolTip("Significance level: The minimal p-Value "
                                   "which is required for something to be "
                                   "considered as significant.")
        
        # button for redoing the statistics for the current setting
        self.updateButton = QtGui.QPushButton("Update")
        self.connect(self.updateButton, QtCore.SIGNAL('clicked()'), 
                     self._analyze)
                
        # Create matplotlib widget
        plotWidget = QtGui.QWidget(self)
        plotWidget.setMinimumSize(500, 500)
 
        fig = Figure((5.0, 5.0), dpi=100)
        fig.subplots_adjust(0.2)
        self.canvas = FigureCanvas(fig)
        self.canvas.setParent(plotWidget)
        self.axis = fig.gca()
        
        # The table for statistics results
        self.significanceTable = QtGui.QTableWidget(self)
        
        # Do the analyzing once for the default values
        self._analyze()
        
        # Create layout
        layout = QtGui.QVBoxLayout()
        hlayout1 = QtGui.QHBoxLayout()
        hlayout1.addWidget(metricsLabel)
        hlayout1.addWidget(self.metricsComboBox)
        hlayout1.addWidget(aggregationLabel)
        hlayout1.addWidget(self.aggregationFctEdit)
        hlayout1.addWidget(testLabel)
        hlayout1.addWidget(self.testComboBox)
        hlayout1.addWidget(pValueLabel)
        hlayout1.addWidget(self.pValueEdit)
        hlayout1.addWidget(self.updateButton)
        hlayout2 = QtGui.QHBoxLayout()
        hlayout2.addWidget(plotWidget)
        hlayout2.addWidget(self.significanceTable)
        layout.addLayout(hlayout1)
        layout.addLayout(hlayout2)
        self.setLayout(layout)
Example #21
0
    def __init__(self, maze, stateSpace, actions):
        super(Maze2DDetailedViewer, self).__init__()

        self.maze = maze
        self.stateSpace = stateSpace
        self.actions = actions

        self.samples = defaultdict(lambda: 0)
        self.valueAccessFunction = None

        self.redrawRequested = False

        # Get required observables
        self.trajectoryObservable = \
                OBSERVABLES.getAllObservablesOfType(TrajectoryObservable)[0]
        self.stateActionValuesObservables = \
                OBSERVABLES.getAllObservablesOfType(StateActionValuesObservable)

        # Combo Box for selecting the observable
        self.observableLabel = QtGui.QLabel("Observable")
        self.comboBox = QtGui.QComboBox(self)
        self.comboBox.addItems(
            map(lambda x: "%s" % x.title, self.stateActionValuesObservables))
        self.connect(self.comboBox, QtCore.SIGNAL('currentIndexChanged (int)'),
                     self._observableChanged)

        # Automatically update combobox when new float stream observables
        #  are created during runtime
        def updateComboBox(observable, action):
            self.comboBox.clear()
            self.stateActionValuesObservables = \
                    OBSERVABLES.getAllObservablesOfType(StateActionValuesObservable)
            self.comboBox.addItems(
                map(lambda x: "%s" % x.title,
                    self.stateActionValuesObservables))

        OBSERVABLES.addObserver(updateComboBox)

        # Combo Box for selecting the updateFrequency
        self.updateFreqLabel = QtGui.QLabel("Update")
        self.updateComboBox = QtGui.QComboBox(self)
        self.updateComboBox.addItems(["Every Episode", "Every Step"])

        # Create matplotlib widgets
        plotWidgetPolicy = QtGui.QWidget(self)
        plotWidgetPolicy.setMinimumSize(300, 400)
        plotWidgetPolicy.setWindowTitle("Policy")

        self.figPolicy = Figure((3.0, 4.0), dpi=100)
        self.figPolicy.subplots_adjust(left=0.01,
                                       bottom=0.01,
                                       right=0.99,
                                       top=0.99,
                                       wspace=0.05,
                                       hspace=0.11)

        self.canvasPolicy = FigureCanvas(self.figPolicy)
        self.canvasPolicy.setParent(plotWidgetPolicy)

        ax = self.figPolicy.gca()
        ax.clear()
        self.maze.drawIntoAxis(ax)

        self.plotWidgetValueFunction = dict()
        self.figValueFunction = dict()
        self.canvasValueFunction = dict()
        for index, action in enumerate(self.actions):
            self.plotWidgetValueFunction[action] = QtGui.QWidget(self)
            self.plotWidgetValueFunction[action].setMinimumSize(300, 400)
            self.plotWidgetValueFunction[action].setWindowTitle(str(action))

            self.figValueFunction[action] = Figure((3.0, 4.0), dpi=100)
            self.figValueFunction[action].subplots_adjust(left=0.01,
                                                          bottom=0.01,
                                                          right=0.99,
                                                          top=0.99,
                                                          wspace=0.05,
                                                          hspace=0.11)

            self.canvasValueFunction[action] = FigureCanvas(
                self.figValueFunction[action])
            self.canvasValueFunction[action].setParent(
                self.plotWidgetValueFunction[action])

            ax = self.figValueFunction[action].gca()
            ax.clear()
            self.maze.drawIntoAxis(ax)

        self.textInstances = dict()
        self.arrowInstances = []

        self.canvasPolicy.draw()
        for index, action in enumerate(self.actions):
            self.canvasValueFunction[action].draw()

        self.mdiArea = QtGui.QMdiArea(self)
        self.mdiArea.addSubWindow(plotWidgetPolicy)
        for index, action in enumerate(self.actions):
            self.mdiArea.addSubWindow(self.plotWidgetValueFunction[action])
        self.vlayout = QtGui.QVBoxLayout()
        self.vlayout.addWidget(self.mdiArea)
        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(self.observableLabel)
        self.hlayout.addWidget(self.comboBox)
        self.hlayout.addWidget(self.updateFreqLabel)
        self.hlayout.addWidget(self.updateComboBox)
        self.vlayout.addLayout(self.hlayout)
        self.setLayout(self.vlayout)

        # Connect to observer (has to be the last thing!!)
        self.trajectoryObservableCallback = \
             lambda *transition: self.updateSamples(*transition)
        self.trajectoryObservable.addObserver(
            self.trajectoryObservableCallback)

        self.stateActionValuesObservableCallback = \
             lambda valueAccessFunction, actions: self.updateValues(valueAccessFunction, actions)
        if len(self.stateActionValuesObservables) > 0:
            # Show per default the first observable
            self.stateActionValuesObservable = self.stateActionValuesObservables[
                0]

            self.stateActionValuesObservable.addObserver(
                self.stateActionValuesObservableCallback)
        else:
            self.stateActionValuesObservable = None
Example #22
0
class FloatStreamViewer(Viewer):
    def __init__(self):
        super(FloatStreamViewer, self).__init__()

        # Create matplotlib widget
        plotWidget = QtGui.QWidget(self)
        plotWidget.setMinimumSize(800, 500)

        fig = Figure((8.0, 5.0), dpi=100)
        self.canvas = FigureCanvas(fig)
        self.canvas.setParent(plotWidget)
        self.axis = fig.gca()

        # Local container for displayed values
        self.values = deque()
        self.times = deque()

        # Combo Box for selecting the observable
        self.comboBox = QtGui.QComboBox(self)
        self.floatStreamObservables = \
                OBSERVABLES.getAllObservablesOfType(FloatStreamObservable)
        self.comboBox.addItems(
            map(lambda x: "%s" % x.title, self.floatStreamObservables))
        self.connect(self.comboBox, QtCore.SIGNAL('currentIndexChanged (int)'),
                     self._observableChanged)

        # Automatically update combobox when new float stream observables
        #  are created during runtime
        def updateComboBox(observable, action):
            self.comboBox.clear()
            self.floatStreamObservables = \
                    OBSERVABLES.getAllObservablesOfType(FloatStreamObservable)
            self.comboBox.addItems(
                map(lambda x: "%s" % x.title, self.floatStreamObservables))

        OBSERVABLES.addObserver(updateComboBox)

        # The number of values from the observable that are remembered
        self.windowSize = 64

        # Slider for controlling the window size
        self.windowSizeSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.windowSizeSlider.setValue(numpy.log2(self.windowSize))
        self.windowSizeSlider.setMinimum(0)
        self.windowSizeSlider.setMaximum(15)
        self.windowSizeSlider.setTickInterval(1)
        self.windowSizeSlider.setTickPosition(QtGui.QSlider.TicksBelow)

        self.connect(self.windowSizeSlider, QtCore.SIGNAL('sliderReleased()'),
                     self._changeWindowSize)

        self.windowSizeLabel = QtGui.QLabel("WindowSize: %s" % self.windowSize)

        # The length of the moving window average
        self.mwaSize = 10

        # Slider for controlling the moving average window
        self.mwaSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.mwaSlider.setValue(self.mwaSize)
        self.mwaSlider.setMinimum(1)
        self.mwaSlider.setMaximum(50)
        self.mwaSlider.setTickInterval(10)
        self.mwaSlider.setTickPosition(QtGui.QSlider.TicksBelow)

        self.connect(self.mwaSlider, QtCore.SIGNAL('sliderReleased()'),
                     self._changeMWA)

        self.mwaLabel = QtGui.QLabel("Moving Window Average : %s" %
                                     self.mwaSize)

        # Create layout
        self.vlayout = QtGui.QVBoxLayout()
        self.vlayout.addWidget(self.comboBox)
        self.vlayout.addWidget(self.windowSizeSlider)
        self.vlayout.addWidget(self.windowSizeLabel)
        self.vlayout.addWidget(self.mwaSlider)
        self.vlayout.addWidget(self.mwaLabel)

        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(plotWidget)
        self.hlayout.addLayout(self.vlayout)

        self.setLayout(self.hlayout)

        # Handling connecting to observable
        self.observableCallback = lambda time, value, *args: self.update(
            time, value)
        if len(self.floatStreamObservables) > 0:
            # Show per default the first observable
            self.observable = self.floatStreamObservables[0]
            # Connect to observer (has to be the last thing!!)
            self.observable.addObserver(self.observableCallback)
        else:
            self.observable = None

    def close(self):
        if self.observable is not None:
            # Remove old observable
            self.observable.removeObserver(self.observableCallback)

        super(FloatStreamViewer, self).close()

    def update(self, time, value):
        self.values.append(value)
        self.times.append(time)

        if len(self.values) > self.windowSize:
            self.values.popleft()
            self.times.popleft()

        self._redraw()

    def _redraw(self):
        self.axis.clear()
        if len(self.times) == 0:  # No data available
            return
        self.axis.plot(self.times, self.values, 'k')
        averageValues = []
        for i in range(len(self.values)):
            effectiveMWASize = min(2 * i, 2 * (len(self.values) - 1 - i),
                                   self.mwaSize)
            start = i - effectiveMWASize / 2
            end = i + effectiveMWASize / 2 + 1
            averageValues.append(
                float(sum(list(self.values)[start:end])) / (end - start))

        self.axis.plot(self.times, averageValues, 'r')

        self.axis.set_xlim((min(self.times), max(self.times)))

        self.axis.set_xlabel(self.observable.time_dimension_name)
        self.axis.set_ylabel(self.observable.value_name)

        self.canvas.draw()

    def _observableChanged(self, comboBoxIndex):
        if self.observable is not None:
            # Remove old observable
            self.observable.removeObserver(self.observableCallback)
        # Get new observable and add as listener
        self.observable = self.floatStreamObservables[comboBoxIndex]
        self.observable.addObserver(self.observableCallback)
        # Remove old values
        self.values = deque()
        self.times = deque()

    def _changeWindowSize(self):
        self.windowSize = 2**self.windowSizeSlider.value()

        while len(self.values) > self.windowSize:
            self.values.popleft()
            self.times.popleft()

        self._redraw()

        self.windowSizeLabel.setText("WindowSize: %s" % self.windowSize)

    def _changeMWA(self):
        self.mwaSize = self.mwaSlider.value()

        self._redraw()

        self.mwaLabel.setText("Moving Window Average : %s" % self.mwaSize)