def __init__(self, pinballMazeEnv, stateSpace):        
        super(PinballMazeTrajectoryViewer, self).__init__()
        
        self.pinballMazeEnv = pinballMazeEnv
        
        self.dimensions = [stateSpace[dimName] for dimName in sorted(stateSpace.keys())]
        
        # The segments that are obtained while drawing is disabled. These
        # segment are drawn one drawing is reenabled 
        self.rememberedSegments = []
        
        # The eval function that can be used for coloring the trajectory
        self.evalFunction = None
        
        self.colorsCycle = cycle([(1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0),
                                  (1.0, 1.0, 0.0), (0.0, 1.0, 1.0), (1.0, 0.0, 1.0),
                                  (0.5, 0.0, 0.0), (0.0, 0.5, 0.0), (0.0, 0.0, 0.5)])
        self.colors = defaultdict(lambda : self.colorsCycle.next())
        self.valueToColorMapping = dict()
        
        # Get required observables
        self.trajectoryObservable = \
                OBSERVABLES.getAllObservablesOfType(TrajectoryObservable)[0]
        
        # Create matplotlib widgets
        plotWidget = QtGui.QWidget(self)
        plotWidget.setMinimumSize(600, 500)
        plotWidget.setWindowTitle("Pinball Maze")
 
        self.fig = Figure((6.0, 5.0), dpi=100)
        self.axis = self.fig.gca()
        self.pinballMazeEnv.plotStateSpaceStructure(self.axis)
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(plotWidget)    
        self.canvas.draw()
        
        self.ballPatch = None
        self.linePatches = []
        
        # Add other elements to GUI           
        self.drawingEnabledCheckbox = \
                QtGui.QCheckBox("Drawing enabled", self)
        self.drawingEnabledCheckbox.setChecked(True)
        
        self.drawStyle = "Current Position"
        self.drawStyleLabel = QtGui.QLabel("Draw style")
        self.drawStyleComboBox = QtGui.QComboBox(self)
        self.drawStyleComboBox.addItems(["Current Position", "Last Episode", 
                                         "Online (All)"])
        self.connect(self.drawStyleComboBox,
                     QtCore.SIGNAL('activated (const QString&)'), 
                     self._drawStyleChanged)
                
        self.colorCriterion = "Action"
        self.colorCriterionLabel = QtGui.QLabel("Coloring of trajectory")
        self.colorCriterionComboBox = QtGui.QComboBox(self)
        self.colorCriterionComboBox.addItems(["Action", "Reward", "Q-Value"])
        self.connect(self.colorCriterionComboBox,
                     QtCore.SIGNAL('activated (const QString&)'), 
                     self._colorCriterionChanged) 
                
        # Legend of plot
        self.legendLabel = QtGui.QLabel("Legend:")
        self.legendWidget = QtGui.QListWidget(self)
        
        # Create layout
        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(plotWidget)
        self.vlayout = QtGui.QVBoxLayout()
        self.vlayout.addWidget(self.drawingEnabledCheckbox)
        self.drawStyleLayout = QtGui.QHBoxLayout()
        self.drawStyleLayout.addWidget(self.drawStyleLabel)
        self.drawStyleLayout.addWidget(self.drawStyleComboBox)
        self.vlayout.addLayout(self.drawStyleLayout)
        self.coloringLayout = QtGui.QHBoxLayout()
        self.coloringLayout.addWidget(self.colorCriterionLabel)
        self.coloringLayout.addWidget(self.colorCriterionComboBox)
        self.vlayout.addLayout(self.coloringLayout)
        self.vlayout.addWidget(self.legendLabel)
        self.vlayout.addWidget(self.legendWidget)
        self.hlayout.addLayout(self.vlayout)
        self.setLayout(self.hlayout)
        
        # Connect to observer (has to be the last thing!!)
        self.trajectoryObservableCallback = \
             lambda *transition: self._updateSamples(*transition)
        self.trajectoryObservable.addObserver(self.trajectoryObservableCallback)
Example #2
0
    def __init__(self, maze, stateSpace):
        super(Maze2DFunctionViewer, self).__init__()

        self.maze = maze
        self.stateSpace = stateSpace

        self.updateCounter = 0
        self.updatePlotNow = False
        self.evalFunction = None

        self.lock = threading.Lock()

        # Create matplotlib widgets
        plotWidget = QtGui.QWidget(self)
        plotWidget.setMinimumSize(600, 500)
        plotWidget.setWindowTitle("Maze2D")

        self.fig = Figure((6.0, 5.0), dpi=100)
        self.axis = self.fig.gca()
        self.axis.clear()
        self.maze.drawIntoAxis(self.axis)
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(plotWidget)
        self.canvas.draw()

        self.plottedPatches = []

        # Add a combobox for selecting the function over state space that is observed
        self.selectedFunctionObservable = None
        self.functionObservableLabel = QtGui.QLabel(
            "Function over State Space")
        self.functionObservableComboBox = QtGui.QComboBox(self)
        functionObservables = OBSERVABLES.getAllObservablesOfType(
            FunctionOverStateSpaceObservable)
        stateActionValueObservables = \
                OBSERVABLES.getAllObservablesOfType(StateActionValuesObservable)
        functionObservables.extend(stateActionValueObservables)
        self.functionObservableComboBox.addItems([
            functionObservable.title
            for functionObservable in functionObservables
        ])
        if len(functionObservables) > 0:
            self.selectedFunctionObservable = functionObservables[0].title

        self.connect(self.functionObservableComboBox,
                     QtCore.SIGNAL('activated (const QString&)'),
                     self._functionObservableChanged)

        # Automatically update funtion observable combobox when new observables
        # are created during runtime
        def updateFunctionObservableBox(viewer, action):
            self.functionObservableComboBox.clear()
            functionObservables = \
                OBSERVABLES.getAllObservablesOfType(FunctionOverStateSpaceObservable)
            stateActionValueObservables = \
                OBSERVABLES.getAllObservablesOfType(StateActionValuesObservable)
            functionObservables.extend(stateActionValueObservables)
            self.functionObservableComboBox.addItems([
                functionObservable.title
                for functionObservable in functionObservables
            ])
            if self.selectedFunctionObservable is None \
                    and len(functionObservables) > 0:
                self.selectedFunctionObservable = functionObservables[0].title
            else:
                # Let combobox still show the selected observable
                index = self.functionObservableComboBox.findText(
                    self.selectedFunctionObservable)
                if index != -1:
                    self.functionObservableComboBox.setCurrentIndex(index)

        OBSERVABLES.addObserver(updateFunctionObservableBox)

        # Add a combobox for for selecting the suboption that is used when
        # a StateActionValuesObservable is observed
        self.selectedSuboption = None
        self.suboptionLabel = QtGui.QLabel("Suboption")
        self.suboptionComboBox = QtGui.QComboBox(self)

        # Slider that controls the frequency of update the plot
        self.updateFrequency = 0.0
        self.updateFrequencySlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.updateFrequencySlider.setValue(int(self.updateFrequency * 100))
        self.updateFrequencySlider.setMinimum(0)
        self.updateFrequencySlider.setMaximum(100)
        self.updateFrequencySlider.setTickInterval(0.1)
        self.updateFrequencySlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.updateFrequencySlider,
                     QtCore.SIGNAL('sliderReleased()'),
                     self._changeUpdateFrequency)
        self.updateFrequencyLabel = QtGui.QLabel("UpdateFrequency: %s" %
                                                 self.updateFrequency)

        # Button to enforce update of plot
        self.updatePlotButton = QtGui.QPushButton("Update Plot")
        self.connect(self.updatePlotButton, QtCore.SIGNAL('clicked()'),
                     self._updatePlot)

        # Legend of plot
        self.legendLabel = QtGui.QLabel("Legend:")
        self.legendWidget = QtGui.QListWidget(self)

        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(plotWidget)
        self.vlayout = QtGui.QVBoxLayout()
        self.functionObservableLayout = QtGui.QHBoxLayout()
        self.functionObservableLayout.addWidget(self.functionObservableLabel)
        self.functionObservableLayout.addWidget(
            self.functionObservableComboBox)
        self.vlayout.addLayout(self.functionObservableLayout)
        self.suboptionLayout = QtGui.QHBoxLayout()
        self.suboptionLayout.addWidget(self.suboptionLabel)
        self.suboptionLayout.addWidget(self.suboptionComboBox)
        self.vlayout.addLayout(self.suboptionLayout)
        self.updateFrequencyLayout = QtGui.QHBoxLayout()
        self.updateFrequencyLayout.addWidget(self.updateFrequencyLabel)
        self.updateFrequencyLayout.addWidget(self.updateFrequencySlider)
        self.vlayout.addLayout(self.updateFrequencyLayout)
        self.vlayout.addWidget(self.updatePlotButton)
        self.vlayout.addWidget(self.legendLabel)
        self.vlayout.addWidget(self.legendWidget)
        self.hlayout.addLayout(self.vlayout)
        self.setLayout(self.hlayout)

        # Connect to observer (has to be the last thing!!)
        self.functionObservable = None
        if self.selectedFunctionObservable is not None:
            self._functionObservableChanged(self.selectedFunctionObservable)
    def __init__(self, pinballMazeEnv, stateSpace):
        super(PinballMazeFunctionViewer, self).__init__()

        self.pinballMazeEnv = pinballMazeEnv
        self.stateSpace = stateSpace
        self.actions = []

        self.updateCounter = 0
        self.updatePlotNow = False
        self.evalFunction = None

        self.lock = threading.Lock()

        # Create matplotlib widgets
        plotWidget = QtGui.QWidget(self)
        plotWidget.setMinimumSize(600, 500)
        plotWidget.setWindowTitle("Pinball Maze")

        self.fig = Figure((6.0, 5.0), dpi=100)
        self.axis = self.fig.gca()
        self.pinballMazeEnv.plotStateSpaceStructure(self.axis)
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(plotWidget)
        self.canvas.draw()

        self.plottedPatches = []

        # Add a combobox for selecting the function over state space that is observed
        self.selectedFunctionObservable = None
        self.functionObservableLabel = QtGui.QLabel(
            "Function over State Space")
        self.functionObservableComboBox = QtGui.QComboBox(self)
        functionObservables = OBSERVABLES.getAllObservablesOfType(
            FunctionOverStateSpaceObservable)
        self.functionObservableComboBox.addItems([
            functionObservable.title
            for functionObservable in functionObservables
        ])
        if len(functionObservables) > 0:
            self.selectedFunctionObservable = functionObservables[0].title

        self.connect(self.functionObservableComboBox,
                     QtCore.SIGNAL('activated (const QString&)'),
                     self._functionObservableChanged)

        # Automatically update funtion observable combobox when new observables
        # are created during runtime
        def updateFunctionObservableBox(viewer, action):
            self.functionObservableComboBox.clear()
            functionObservables = \
                OBSERVABLES.getAllObservablesOfType(FunctionOverStateSpaceObservable)
            self.functionObservableComboBox.addItems([
                functionObservable.title
                for functionObservable in functionObservables
            ])
            if len(functionObservables) > 0:
                self.selectedFunctionObservable = functionObservables[0].title

        OBSERVABLES.addObserver(updateFunctionObservableBox)

        # Slider that controls the granularity of the plot-grid
        self.gridNodesPerDim = 50
        self.gridNodesSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.gridNodesSlider.setValue(self.gridNodesPerDim)
        self.gridNodesSlider.setMinimum(0)
        self.gridNodesSlider.setMaximum(100)
        self.gridNodesSlider.setTickInterval(10)
        self.gridNodesSlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.gridNodesSlider, QtCore.SIGNAL('sliderReleased()'),
                     self._changeGridNodes)
        self.gridNodesLabel = QtGui.QLabel("Grid Nodes Per Dimension: %s" %
                                           self.gridNodesPerDim)

        # Slider that controls the frequency of update the plot
        self.updateFrequency = 0.0
        self.updateFrequencySlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.updateFrequencySlider.setValue(int(self.updateFrequency * 100))
        self.updateFrequencySlider.setMinimum(0)
        self.updateFrequencySlider.setMaximum(100)
        self.updateFrequencySlider.setTickInterval(0.1)
        self.updateFrequencySlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.updateFrequencySlider,
                     QtCore.SIGNAL('sliderReleased()'),
                     self._changeUpdateFrequency)
        self.updateFrequencyLabel = QtGui.QLabel("UpdateFrequency: %s" %
                                                 self.updateFrequency)

        # Button to enforce update of plot
        self.updatePlotButton = QtGui.QPushButton("Update Plot")
        self.connect(self.updatePlotButton, QtCore.SIGNAL('clicked()'),
                     self._updatePlot)

        # Chosen xvel and yvel values
        self.xVel = 0.5
        self.xVelSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.xVelSlider.setValue(int(self.xVel * 10))
        self.xVelSlider.setMinimum(0)
        self.xVelSlider.setMaximum(10)
        self.xVelSlider.setTickInterval(1)
        self.xVelSlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.xVelSlider, QtCore.SIGNAL('sliderReleased()'),
                     self._changeXVel)
        self.xVelLabel = QtGui.QLabel("xvel value: %s" % self.xVel)

        self.yVel = 0.5
        self.yVelSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self)
        self.yVelSlider.setValue(int(self.yVel * 10))
        self.yVelSlider.setMinimum(0)
        self.yVelSlider.setMaximum(10)
        self.yVelSlider.setTickInterval(1)
        self.yVelSlider.setTickPosition(QtGui.QSlider.TicksBelow)
        self.connect(self.yVelSlider, QtCore.SIGNAL('sliderReleased()'),
                     self._changeYVel)
        self.yVelLabel = QtGui.QLabel("yvel value: %s" % self.xVel)

        # Legend of plot
        self.legendLabel = QtGui.QLabel("Legend:")
        self.legendWidget = QtGui.QListWidget(self)

        self.hlayout = QtGui.QHBoxLayout()
        self.hlayout.addWidget(plotWidget)
        self.vlayout = QtGui.QVBoxLayout()
        self.functionObservableLayout = QtGui.QHBoxLayout()
        self.functionObservableLayout.addWidget(self.functionObservableLabel)
        self.functionObservableLayout.addWidget(
            self.functionObservableComboBox)
        self.vlayout.addLayout(self.functionObservableLayout)
        self.gridNodesLayout = QtGui.QHBoxLayout()
        self.gridNodesLayout.addWidget(self.gridNodesLabel)
        self.gridNodesLayout.addWidget(self.gridNodesSlider)
        self.vlayout.addLayout(self.gridNodesLayout)
        self.updateFrequencyLayout = QtGui.QHBoxLayout()
        self.updateFrequencyLayout.addWidget(self.updateFrequencyLabel)
        self.updateFrequencyLayout.addWidget(self.updateFrequencySlider)
        self.vlayout.addLayout(self.updateFrequencyLayout)
        self.vlayout.addWidget(self.updatePlotButton)
        self.xVelLayout = QtGui.QHBoxLayout()
        self.xVelLayout.addWidget(self.xVelLabel)
        self.xVelLayout.addWidget(self.xVelSlider)
        self.vlayout.addLayout(self.xVelLayout)
        self.yVelLayout = QtGui.QHBoxLayout()
        self.yVelLayout.addWidget(self.yVelLabel)
        self.yVelLayout.addWidget(self.yVelSlider)
        self.vlayout.addLayout(self.yVelLayout)
        self.vlayout.addWidget(self.legendLabel)
        self.vlayout.addWidget(self.legendWidget)
        self.hlayout.addLayout(self.vlayout)
        self.setLayout(self.hlayout)

        # Connect to observer (has to be the last thing!!)
        self.functionObservable = None
        if self.selectedFunctionObservable:
            self.functionObservable = \
                    OBSERVABLES.getObservable(self.selectedFunctionObservable,
                                              FunctionOverStateSpaceObservable)
            self.functionObservableCallback = \
                 lambda evalFunction: self._updateFunction(evalFunction)
            self.functionObservable.addObserver(
                self.functionObservableCallback)