def __init__(self, pinballMazeEnv, stateSpace): super(PinballMazeTrajectoryViewer, self).__init__() self.pinballMazeEnv = pinballMazeEnv self.dimensions = [stateSpace[dimName] for dimName in sorted(stateSpace.keys())] # The segments that are obtained while drawing is disabled. These # segment are drawn one drawing is reenabled self.rememberedSegments = [] # The eval function that can be used for coloring the trajectory self.evalFunction = None self.colorsCycle = cycle([(1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0), (1.0, 1.0, 0.0), (0.0, 1.0, 1.0), (1.0, 0.0, 1.0), (0.5, 0.0, 0.0), (0.0, 0.5, 0.0), (0.0, 0.0, 0.5)]) self.colors = defaultdict(lambda : self.colorsCycle.next()) self.valueToColorMapping = dict() # Get required observables self.trajectoryObservable = \ OBSERVABLES.getAllObservablesOfType(TrajectoryObservable)[0] # Create matplotlib widgets plotWidget = QtGui.QWidget(self) plotWidget.setMinimumSize(600, 500) plotWidget.setWindowTitle("Pinball Maze") self.fig = Figure((6.0, 5.0), dpi=100) self.axis = self.fig.gca() self.pinballMazeEnv.plotStateSpaceStructure(self.axis) self.canvas = FigureCanvas(self.fig) self.canvas.setParent(plotWidget) self.canvas.draw() self.ballPatch = None self.linePatches = [] # Add other elements to GUI self.drawingEnabledCheckbox = \ QtGui.QCheckBox("Drawing enabled", self) self.drawingEnabledCheckbox.setChecked(True) self.drawStyle = "Current Position" self.drawStyleLabel = QtGui.QLabel("Draw style") self.drawStyleComboBox = QtGui.QComboBox(self) self.drawStyleComboBox.addItems(["Current Position", "Last Episode", "Online (All)"]) self.connect(self.drawStyleComboBox, QtCore.SIGNAL('activated (const QString&)'), self._drawStyleChanged) self.colorCriterion = "Action" self.colorCriterionLabel = QtGui.QLabel("Coloring of trajectory") self.colorCriterionComboBox = QtGui.QComboBox(self) self.colorCriterionComboBox.addItems(["Action", "Reward", "Q-Value"]) self.connect(self.colorCriterionComboBox, QtCore.SIGNAL('activated (const QString&)'), self._colorCriterionChanged) # Legend of plot self.legendLabel = QtGui.QLabel("Legend:") self.legendWidget = QtGui.QListWidget(self) # Create layout self.hlayout = QtGui.QHBoxLayout() self.hlayout.addWidget(plotWidget) self.vlayout = QtGui.QVBoxLayout() self.vlayout.addWidget(self.drawingEnabledCheckbox) self.drawStyleLayout = QtGui.QHBoxLayout() self.drawStyleLayout.addWidget(self.drawStyleLabel) self.drawStyleLayout.addWidget(self.drawStyleComboBox) self.vlayout.addLayout(self.drawStyleLayout) self.coloringLayout = QtGui.QHBoxLayout() self.coloringLayout.addWidget(self.colorCriterionLabel) self.coloringLayout.addWidget(self.colorCriterionComboBox) self.vlayout.addLayout(self.coloringLayout) self.vlayout.addWidget(self.legendLabel) self.vlayout.addWidget(self.legendWidget) self.hlayout.addLayout(self.vlayout) self.setLayout(self.hlayout) # Connect to observer (has to be the last thing!!) self.trajectoryObservableCallback = \ lambda *transition: self._updateSamples(*transition) self.trajectoryObservable.addObserver(self.trajectoryObservableCallback)
def __init__(self, maze, stateSpace): super(Maze2DFunctionViewer, self).__init__() self.maze = maze self.stateSpace = stateSpace self.updateCounter = 0 self.updatePlotNow = False self.evalFunction = None self.lock = threading.Lock() # Create matplotlib widgets plotWidget = QtGui.QWidget(self) plotWidget.setMinimumSize(600, 500) plotWidget.setWindowTitle("Maze2D") self.fig = Figure((6.0, 5.0), dpi=100) self.axis = self.fig.gca() self.axis.clear() self.maze.drawIntoAxis(self.axis) self.canvas = FigureCanvas(self.fig) self.canvas.setParent(plotWidget) self.canvas.draw() self.plottedPatches = [] # Add a combobox for selecting the function over state space that is observed self.selectedFunctionObservable = None self.functionObservableLabel = QtGui.QLabel( "Function over State Space") self.functionObservableComboBox = QtGui.QComboBox(self) functionObservables = OBSERVABLES.getAllObservablesOfType( FunctionOverStateSpaceObservable) stateActionValueObservables = \ OBSERVABLES.getAllObservablesOfType(StateActionValuesObservable) functionObservables.extend(stateActionValueObservables) self.functionObservableComboBox.addItems([ functionObservable.title for functionObservable in functionObservables ]) if len(functionObservables) > 0: self.selectedFunctionObservable = functionObservables[0].title self.connect(self.functionObservableComboBox, QtCore.SIGNAL('activated (const QString&)'), self._functionObservableChanged) # Automatically update funtion observable combobox when new observables # are created during runtime def updateFunctionObservableBox(viewer, action): self.functionObservableComboBox.clear() functionObservables = \ OBSERVABLES.getAllObservablesOfType(FunctionOverStateSpaceObservable) stateActionValueObservables = \ OBSERVABLES.getAllObservablesOfType(StateActionValuesObservable) functionObservables.extend(stateActionValueObservables) self.functionObservableComboBox.addItems([ functionObservable.title for functionObservable in functionObservables ]) if self.selectedFunctionObservable is None \ and len(functionObservables) > 0: self.selectedFunctionObservable = functionObservables[0].title else: # Let combobox still show the selected observable index = self.functionObservableComboBox.findText( self.selectedFunctionObservable) if index != -1: self.functionObservableComboBox.setCurrentIndex(index) OBSERVABLES.addObserver(updateFunctionObservableBox) # Add a combobox for for selecting the suboption that is used when # a StateActionValuesObservable is observed self.selectedSuboption = None self.suboptionLabel = QtGui.QLabel("Suboption") self.suboptionComboBox = QtGui.QComboBox(self) # Slider that controls the frequency of update the plot self.updateFrequency = 0.0 self.updateFrequencySlider = QtGui.QSlider(QtCore.Qt.Horizontal, self) self.updateFrequencySlider.setValue(int(self.updateFrequency * 100)) self.updateFrequencySlider.setMinimum(0) self.updateFrequencySlider.setMaximum(100) self.updateFrequencySlider.setTickInterval(0.1) self.updateFrequencySlider.setTickPosition(QtGui.QSlider.TicksBelow) self.connect(self.updateFrequencySlider, QtCore.SIGNAL('sliderReleased()'), self._changeUpdateFrequency) self.updateFrequencyLabel = QtGui.QLabel("UpdateFrequency: %s" % self.updateFrequency) # Button to enforce update of plot self.updatePlotButton = QtGui.QPushButton("Update Plot") self.connect(self.updatePlotButton, QtCore.SIGNAL('clicked()'), self._updatePlot) # Legend of plot self.legendLabel = QtGui.QLabel("Legend:") self.legendWidget = QtGui.QListWidget(self) self.hlayout = QtGui.QHBoxLayout() self.hlayout.addWidget(plotWidget) self.vlayout = QtGui.QVBoxLayout() self.functionObservableLayout = QtGui.QHBoxLayout() self.functionObservableLayout.addWidget(self.functionObservableLabel) self.functionObservableLayout.addWidget( self.functionObservableComboBox) self.vlayout.addLayout(self.functionObservableLayout) self.suboptionLayout = QtGui.QHBoxLayout() self.suboptionLayout.addWidget(self.suboptionLabel) self.suboptionLayout.addWidget(self.suboptionComboBox) self.vlayout.addLayout(self.suboptionLayout) self.updateFrequencyLayout = QtGui.QHBoxLayout() self.updateFrequencyLayout.addWidget(self.updateFrequencyLabel) self.updateFrequencyLayout.addWidget(self.updateFrequencySlider) self.vlayout.addLayout(self.updateFrequencyLayout) self.vlayout.addWidget(self.updatePlotButton) self.vlayout.addWidget(self.legendLabel) self.vlayout.addWidget(self.legendWidget) self.hlayout.addLayout(self.vlayout) self.setLayout(self.hlayout) # Connect to observer (has to be the last thing!!) self.functionObservable = None if self.selectedFunctionObservable is not None: self._functionObservableChanged(self.selectedFunctionObservable)
def __init__(self, pinballMazeEnv, stateSpace): super(PinballMazeFunctionViewer, self).__init__() self.pinballMazeEnv = pinballMazeEnv self.stateSpace = stateSpace self.actions = [] self.updateCounter = 0 self.updatePlotNow = False self.evalFunction = None self.lock = threading.Lock() # Create matplotlib widgets plotWidget = QtGui.QWidget(self) plotWidget.setMinimumSize(600, 500) plotWidget.setWindowTitle("Pinball Maze") self.fig = Figure((6.0, 5.0), dpi=100) self.axis = self.fig.gca() self.pinballMazeEnv.plotStateSpaceStructure(self.axis) self.canvas = FigureCanvas(self.fig) self.canvas.setParent(plotWidget) self.canvas.draw() self.plottedPatches = [] # Add a combobox for selecting the function over state space that is observed self.selectedFunctionObservable = None self.functionObservableLabel = QtGui.QLabel( "Function over State Space") self.functionObservableComboBox = QtGui.QComboBox(self) functionObservables = OBSERVABLES.getAllObservablesOfType( FunctionOverStateSpaceObservable) self.functionObservableComboBox.addItems([ functionObservable.title for functionObservable in functionObservables ]) if len(functionObservables) > 0: self.selectedFunctionObservable = functionObservables[0].title self.connect(self.functionObservableComboBox, QtCore.SIGNAL('activated (const QString&)'), self._functionObservableChanged) # Automatically update funtion observable combobox when new observables # are created during runtime def updateFunctionObservableBox(viewer, action): self.functionObservableComboBox.clear() functionObservables = \ OBSERVABLES.getAllObservablesOfType(FunctionOverStateSpaceObservable) self.functionObservableComboBox.addItems([ functionObservable.title for functionObservable in functionObservables ]) if len(functionObservables) > 0: self.selectedFunctionObservable = functionObservables[0].title OBSERVABLES.addObserver(updateFunctionObservableBox) # Slider that controls the granularity of the plot-grid self.gridNodesPerDim = 50 self.gridNodesSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self) self.gridNodesSlider.setValue(self.gridNodesPerDim) self.gridNodesSlider.setMinimum(0) self.gridNodesSlider.setMaximum(100) self.gridNodesSlider.setTickInterval(10) self.gridNodesSlider.setTickPosition(QtGui.QSlider.TicksBelow) self.connect(self.gridNodesSlider, QtCore.SIGNAL('sliderReleased()'), self._changeGridNodes) self.gridNodesLabel = QtGui.QLabel("Grid Nodes Per Dimension: %s" % self.gridNodesPerDim) # Slider that controls the frequency of update the plot self.updateFrequency = 0.0 self.updateFrequencySlider = QtGui.QSlider(QtCore.Qt.Horizontal, self) self.updateFrequencySlider.setValue(int(self.updateFrequency * 100)) self.updateFrequencySlider.setMinimum(0) self.updateFrequencySlider.setMaximum(100) self.updateFrequencySlider.setTickInterval(0.1) self.updateFrequencySlider.setTickPosition(QtGui.QSlider.TicksBelow) self.connect(self.updateFrequencySlider, QtCore.SIGNAL('sliderReleased()'), self._changeUpdateFrequency) self.updateFrequencyLabel = QtGui.QLabel("UpdateFrequency: %s" % self.updateFrequency) # Button to enforce update of plot self.updatePlotButton = QtGui.QPushButton("Update Plot") self.connect(self.updatePlotButton, QtCore.SIGNAL('clicked()'), self._updatePlot) # Chosen xvel and yvel values self.xVel = 0.5 self.xVelSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self) self.xVelSlider.setValue(int(self.xVel * 10)) self.xVelSlider.setMinimum(0) self.xVelSlider.setMaximum(10) self.xVelSlider.setTickInterval(1) self.xVelSlider.setTickPosition(QtGui.QSlider.TicksBelow) self.connect(self.xVelSlider, QtCore.SIGNAL('sliderReleased()'), self._changeXVel) self.xVelLabel = QtGui.QLabel("xvel value: %s" % self.xVel) self.yVel = 0.5 self.yVelSlider = QtGui.QSlider(QtCore.Qt.Horizontal, self) self.yVelSlider.setValue(int(self.yVel * 10)) self.yVelSlider.setMinimum(0) self.yVelSlider.setMaximum(10) self.yVelSlider.setTickInterval(1) self.yVelSlider.setTickPosition(QtGui.QSlider.TicksBelow) self.connect(self.yVelSlider, QtCore.SIGNAL('sliderReleased()'), self._changeYVel) self.yVelLabel = QtGui.QLabel("yvel value: %s" % self.xVel) # Legend of plot self.legendLabel = QtGui.QLabel("Legend:") self.legendWidget = QtGui.QListWidget(self) self.hlayout = QtGui.QHBoxLayout() self.hlayout.addWidget(plotWidget) self.vlayout = QtGui.QVBoxLayout() self.functionObservableLayout = QtGui.QHBoxLayout() self.functionObservableLayout.addWidget(self.functionObservableLabel) self.functionObservableLayout.addWidget( self.functionObservableComboBox) self.vlayout.addLayout(self.functionObservableLayout) self.gridNodesLayout = QtGui.QHBoxLayout() self.gridNodesLayout.addWidget(self.gridNodesLabel) self.gridNodesLayout.addWidget(self.gridNodesSlider) self.vlayout.addLayout(self.gridNodesLayout) self.updateFrequencyLayout = QtGui.QHBoxLayout() self.updateFrequencyLayout.addWidget(self.updateFrequencyLabel) self.updateFrequencyLayout.addWidget(self.updateFrequencySlider) self.vlayout.addLayout(self.updateFrequencyLayout) self.vlayout.addWidget(self.updatePlotButton) self.xVelLayout = QtGui.QHBoxLayout() self.xVelLayout.addWidget(self.xVelLabel) self.xVelLayout.addWidget(self.xVelSlider) self.vlayout.addLayout(self.xVelLayout) self.yVelLayout = QtGui.QHBoxLayout() self.yVelLayout.addWidget(self.yVelLabel) self.yVelLayout.addWidget(self.yVelSlider) self.vlayout.addLayout(self.yVelLayout) self.vlayout.addWidget(self.legendLabel) self.vlayout.addWidget(self.legendWidget) self.hlayout.addLayout(self.vlayout) self.setLayout(self.hlayout) # Connect to observer (has to be the last thing!!) self.functionObservable = None if self.selectedFunctionObservable: self.functionObservable = \ OBSERVABLES.getObservable(self.selectedFunctionObservable, FunctionOverStateSpaceObservable) self.functionObservableCallback = \ lambda evalFunction: self._updateFunction(evalFunction) self.functionObservable.addObserver( self.functionObservableCallback)