예제 #1
0
    def _functionObservableChanged(self, selectedFunctionObservable):
        self.lock.acquire()
        if self.functionObservable is not None:
            # Disconnect from old function observable
            self.functionObservable.removeObserver(
                self.functionObservableCallback)

        # Determine new observed function observable
        self.selectedFunctionObservable = str(selectedFunctionObservable)

        # Connect to new function observable
        self.functionObservable = OBSERVABLES.getObservable(
            self.selectedFunctionObservable, FunctionOverStateSpaceObservable)
        if self.functionObservable is None:  # Observing a StateActionValuesObservable
            self.functionObservable = OBSERVABLES.getObservable(
                self.selectedFunctionObservable, StateActionValuesObservable)
            self.actions = None

            def functionObservableCallback(evalFunction, actions):
                # If we get new options to select from
                if actions != self.actions:
                    # Update suboptionComboBox
                    self.actions = actions
                    self.suboptionComboBox.clear()
                    self.suboptionComboBox.addItems(
                        [str(action) for action in actions])
                self._updateFunction(evalFunction)

            self.functionObservableCallback = functionObservableCallback
            self.functionObservable.addObserver(
                self.functionObservableCallback)
        else:  # Observing a FunctionOverStateSpaceObservable
            self.functionObservableCallback = \
                 lambda evalFunction: self._updateFunction(evalFunction)
            self.functionObservable.addObserver(
                self.functionObservableCallback)

        self.lock.release()
예제 #2
0
 def _policyObservableChanged(self, selectedPolicyObservable):
     self.lock.acquire()
     if self.policyObservable:
         # Disconnect from old policy observable
         self.policyObservable.removeObserver(self.policyObservableCallback)
     
     # Determine new observed policy observable
     self.selectedPolicyObservable = str(selectedPolicyObservable)
     
     # Connect to new policy observable
     self.policyObservable = OBSERVABLES.getObservable(self.selectedPolicyObservable,
                                                       FunctionOverStateSpaceObservable)
     self.policyObservableCallback = \
          lambda policyEvalFunction: self.updatePolicy(policyEvalFunction)
     self.policyObservable.addObserver(self.policyObservableCallback)
     
     self.actions = []
     self.lock.release()
예제 #3
0
    def _functionObservableChanged(self, selectedFunctionObservable):
        self.lock.acquire()
        if self.functionObservable is not None:
            # Disconnect from old function observable
            self.functionObservable.removeObserver(
                self.functionObservableCallback)

        # Determine new observed function observable
        self.selectedFunctionObservable = str(selectedFunctionObservable)

        # Connect to new function observable
        self.functionObservable = OBSERVABLES.getObservable(
            self.selectedFunctionObservable, FunctionOverStateSpaceObservable)
        self.functionObservableCallback = \
             lambda evalFunction: self._updateFunction(evalFunction)
        self.functionObservable.addObserver(self.functionObservableCallback)

        self.actions = []
        self.lock.release()
예제 #4
0
    def __init__(self, stateSpace):        
        super(MountainCarPolicyViewer, self).__init__()
        self.stateSpace = stateSpace
        self.actions = []
        self.colors = ['r','g','b', 'c', 'y']
        
        self.lock = threading.Lock()
        
        # Add a combobox for selecting the policy observable
        self.policyObservableLabel = QtGui.QLabel("Policy Observable")
        self.policyObservableComboBox = QtGui.QComboBox(self)
        policyObservables = \
            OBSERVABLES.getAllObservablesOfType(FunctionOverStateSpaceObservable)
        self.policyObservableComboBox.addItems([policyObservable.title 
                                                 for policyObservable in policyObservables])
        self.selectedPolicyObservable = None
        if len(policyObservables) > 0:
            self.selectedPolicyObservable = policyObservables[0].title
        
        self.connect(self.policyObservableComboBox,
                     QtCore.SIGNAL('activated (const QString&)'), 
                     self._policyObservableChanged) 
        
        # Automatically update policy observable combobox when new observables 
        # are created during runtime
        def updatePolicyObservableBox(viewer, action):
            self.policyObservableComboBox.clear()
            policyObservables = OBSERVABLES.getAllObservablesOfType(FunctionOverStateSpaceObservable)
            self.policyObservableComboBox.addItems([policyObservable.title 
                                                for policyObservable in policyObservables])
            if len(policyObservables) > 0:
                self.selectedPolicyObservable = policyObservables[0].title
            else: 
                self.selectedPolicyObservable = None
            
        OBSERVABLES.addObserver(updatePolicyObservableBox)
        
        # Get trajectory observable which is required for informing about end of episode
        self.trajectoryObservable = \
                OBSERVABLES.getAllObservablesOfType(TrajectoryObservable)[0]
        self.episodeTerminated = False

        # Slider that controls the granularity of the plot-grid
        self.gridNodesPerDim = 25        
        self.gridNodesSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.gridNodesSlider.setValue(self.gridNodesPerDim)
        self.gridNodesSlider.setMinimum(0)
        self.gridNodesSlider.setMaximum(100)
        self.gridNodesSlider.setTickInterval(10)
        self.gridNodesSlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.gridNodesSlider, QtCore.SIGNAL('sliderReleased()'), 
                     self._changeGridNodes)
        self.gridNodesLabel = QtGui.QLabel("Grid Nodes Per Dimension: %s" 
                                           % self.gridNodesPerDim )
                
        # Create matplotlib widgets
        plotWidget = QtGui.QWidget(self)
        plotWidget.setMinimumSize(600, 500)
        plotWidget.setWindowTitle("Policy")
 
        self.fig = Figure((6.0, 5.0), dpi=100)
        self.axis = self.fig.gca()
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(plotWidget)    
        
        # Small text in plot legend
        matplotlib.rcParams.update({'legend.fontsize': 6})
        
        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(plotWidget)
        self.vlayout = QtGui.QVBoxLayout()
        self.vlayout.addWidget(self.policyObservableLabel)
        self.vlayout.addWidget(self.policyObservableComboBox)
        self.vlayout.addWidget(self.gridNodesLabel)
        self.vlayout.addWidget(self.gridNodesSlider)
        self.hlayout.addLayout(self.vlayout)
        
        self.setLayout(self.hlayout)
        
        # Connect to observer (has to be the last thing!!)
        self.trajectoryObservableCallback = \
             lambda *transition: self.updateSamples(*transition)
        self.trajectoryObservable.addObserver(self.trajectoryObservableCallback)
        
        self.policyObservable = None
        if self.selectedPolicyObservable:
            self.policyObservable = OBSERVABLES.getObservable(self.selectedPolicyObservable,
                                                              FunctionOverStateSpaceObservable)
            self.policyObservableCallback = \
                 lambda policyEvalFunction: self.updatePolicy(policyEvalFunction)
            self.policyObservable.addObserver(self.policyObservableCallback)
예제 #5
0
    def __init__(self, pinballMazeEnv, stateSpace):
        super(PinballMazeFunctionViewer, self).__init__()

        self.pinballMazeEnv = pinballMazeEnv
        self.stateSpace = stateSpace
        self.actions = []

        self.updateCounter = 0
        self.updatePlotNow = False
        self.evalFunction = None

        self.lock = threading.Lock()

        # Create matplotlib widgets
        plotWidget = QtGui.QWidget(self)
        plotWidget.setMinimumSize(600, 500)
        plotWidget.setWindowTitle("Pinball Maze")

        self.fig = Figure((6.0, 5.0), dpi=100)
        self.axis = self.fig.gca()
        self.pinballMazeEnv.plotStateSpaceStructure(self.axis)
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(plotWidget)
        self.canvas.draw()

        self.plottedPatches = []

        # Add a combobox for selecting the function over state space that is observed
        self.selectedFunctionObservable = None
        self.functionObservableLabel = QtGui.QLabel(
            "Function over State Space")
        self.functionObservableComboBox = QtGui.QComboBox(self)
        functionObservables = OBSERVABLES.getAllObservablesOfType(
            FunctionOverStateSpaceObservable)
        self.functionObservableComboBox.addItems([
            functionObservable.title
            for functionObservable in functionObservables
        ])
        if len(functionObservables) > 0:
            self.selectedFunctionObservable = functionObservables[0].title

        self.connect(self.functionObservableComboBox,
                     QtCore.SIGNAL('activated (const QString&)'),
                     self._functionObservableChanged)

        # Automatically update funtion observable combobox when new observables
        # are created during runtime
        def updateFunctionObservableBox(viewer, action):
            self.functionObservableComboBox.clear()
            functionObservables = \
                OBSERVABLES.getAllObservablesOfType(FunctionOverStateSpaceObservable)
            self.functionObservableComboBox.addItems([
                functionObservable.title
                for functionObservable in functionObservables
            ])
            if len(functionObservables) > 0:
                self.selectedFunctionObservable = functionObservables[0].title

        OBSERVABLES.addObserver(updateFunctionObservableBox)

        # Slider that controls the granularity of the plot-grid
        self.gridNodesPerDim = 50
        self.gridNodesSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.gridNodesSlider.setValue(self.gridNodesPerDim)
        self.gridNodesSlider.setMinimum(0)
        self.gridNodesSlider.setMaximum(100)
        self.gridNodesSlider.setTickInterval(10)
        self.gridNodesSlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.gridNodesSlider, QtCore.SIGNAL('sliderReleased()'),
                     self._changeGridNodes)
        self.gridNodesLabel = QtGui.QLabel("Grid Nodes Per Dimension: %s" %
                                           self.gridNodesPerDim)

        # Slider that controls the frequency of update the plot
        self.updateFrequency = 0.0
        self.updateFrequencySlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.updateFrequencySlider.setValue(int(self.updateFrequency * 100))
        self.updateFrequencySlider.setMinimum(0)
        self.updateFrequencySlider.setMaximum(100)
        self.updateFrequencySlider.setTickInterval(0.1)
        self.updateFrequencySlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.updateFrequencySlider,
                     QtCore.SIGNAL('sliderReleased()'),
                     self._changeUpdateFrequency)
        self.updateFrequencyLabel = QtGui.QLabel("UpdateFrequency: %s" %
                                                 self.updateFrequency)

        # Button to enforce update of plot
        self.updatePlotButton = QtGui.QPushButton("Update Plot")
        self.connect(self.updatePlotButton, QtCore.SIGNAL('clicked()'),
                     self._updatePlot)

        # Chosen xvel and yvel values
        self.xVel = 0.5
        self.xVelSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.xVelSlider.setValue(int(self.xVel * 10))
        self.xVelSlider.setMinimum(0)
        self.xVelSlider.setMaximum(10)
        self.xVelSlider.setTickInterval(1)
        self.xVelSlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.xVelSlider, QtCore.SIGNAL('sliderReleased()'),
                     self._changeXVel)
        self.xVelLabel = QtGui.QLabel("xvel value: %s" % self.xVel)

        self.yVel = 0.5
        self.yVelSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.yVelSlider.setValue(int(self.yVel * 10))
        self.yVelSlider.setMinimum(0)
        self.yVelSlider.setMaximum(10)
        self.yVelSlider.setTickInterval(1)
        self.yVelSlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.yVelSlider, QtCore.SIGNAL('sliderReleased()'),
                     self._changeYVel)
        self.yVelLabel = QtGui.QLabel("yvel value: %s" % self.xVel)

        # Legend of plot
        self.legendLabel = QtGui.QLabel("Legend:")
        self.legendWidget = QtGui.QListWidget(self)

        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(plotWidget)
        self.vlayout = QtGui.QVBoxLayout()
        self.functionObservableLayout = QtGui.QHBoxLayout()
        self.functionObservableLayout.addWidget(self.functionObservableLabel)
        self.functionObservableLayout.addWidget(
            self.functionObservableComboBox)
        self.vlayout.addLayout(self.functionObservableLayout)
        self.gridNodesLayout = QtGui.QHBoxLayout()
        self.gridNodesLayout.addWidget(self.gridNodesLabel)
        self.gridNodesLayout.addWidget(self.gridNodesSlider)
        self.vlayout.addLayout(self.gridNodesLayout)
        self.updateFrequencyLayout = QtGui.QHBoxLayout()
        self.updateFrequencyLayout.addWidget(self.updateFrequencyLabel)
        self.updateFrequencyLayout.addWidget(self.updateFrequencySlider)
        self.vlayout.addLayout(self.updateFrequencyLayout)
        self.vlayout.addWidget(self.updatePlotButton)
        self.xVelLayout = QtGui.QHBoxLayout()
        self.xVelLayout.addWidget(self.xVelLabel)
        self.xVelLayout.addWidget(self.xVelSlider)
        self.vlayout.addLayout(self.xVelLayout)
        self.yVelLayout = QtGui.QHBoxLayout()
        self.yVelLayout.addWidget(self.yVelLabel)
        self.yVelLayout.addWidget(self.yVelSlider)
        self.vlayout.addLayout(self.yVelLayout)
        self.vlayout.addWidget(self.legendLabel)
        self.vlayout.addWidget(self.legendWidget)
        self.hlayout.addLayout(self.vlayout)
        self.setLayout(self.hlayout)

        # Connect to observer (has to be the last thing!!)
        self.functionObservable = None
        if self.selectedFunctionObservable:
            self.functionObservable = \
                    OBSERVABLES.getObservable(self.selectedFunctionObservable,
                                              FunctionOverStateSpaceObservable)
            self.functionObservableCallback = \
                 lambda evalFunction: self._updateFunction(evalFunction)
            self.functionObservable.addObserver(
                self.functionObservableCallback)