def _determineColor(self, value): # Choose the color for the value if self.colorCriterion in ["Action", "Reward"]: # Finite number of values if value not in self.valueToColorMapping: color = self.colorsCycle.next() self.valueToColorMapping[value] = color # Add to legend item = QtGui.QListWidgetItem(str(value), self.legendWidget) qColor = QtGui.QColor(int(color[0]*255), int(color[1]*255), int(color[2]*255)) item.setTextColor(qColor) self.legendWidget.addItem(item) return self.valueToColorMapping[value] else: if self.maxValue != self.minValue: alpha = numpy.log10(value - self.minValue + 1) \ / numpy.log10(self.maxValue - self.minValue + 1) else: alpha = 0.5 return (alpha, 0, 1-alpha)
def _plotFunction(self): if self.evalFunction is None: return self.lock.acquire() # Clean up old plot for patch in self.plottedPatches: patch.remove() self.plottedPatches = [] self.colorMapping = dict() self.colors = cycle(["b", "g", "r", "c", "m", "y"]) cmap = pylab.get_cmap("jet") # Check if the observed function returns discrete or continuous value discreteFunction = isinstance(self.functionObservable, FunctionOverStateSpaceObservable) \ and self.functionObservable.discreteValues if not discreteFunction: # The values of the observed function over the 2d state space values = numpy.ma.array(numpy.zeros( (self.maze.getColumns(), self.maze.getRows())), mask=numpy.zeros((self.maze.getColumns(), self.maze.getRows()))) # Iterate over all states and compute the value of the observed function dimensions = [ self.stateSpace[dimName] for dimName in ["column", "row"] ] for column in range(self.maze.getColumns()): for row in range(self.maze.getRows()): # Create state object state = State((column, row), dimensions) # Evaluate function for this state if isinstance(self.functionObservable, FunctionOverStateSpaceObservable): functionValue = self.evalFunction(state) else: # StateActionValuesObservable # Determine chosen option first selectedOption = None for option in self.actions: selectedOptionName = str( self.suboptionComboBox.currentText()) if str(option) == selectedOptionName: selectedOption = option break assert selectedOption is not None functionValue = self.evalFunction(state, option) # Map function value onto color value if discreteFunction: # Deal with situations where the function is only defined over # part of the state space if functionValue == None or functionValue in [ numpy.nan, numpy.inf, -numpy.inf ]: continue # Determine color value for function value if not functionValue in self.colorMapping: # Choose value for function value that occurrs for the # first time self.colorMapping[functionValue] = self.colors.next() patch = self.maze.plotSquare( self.axis, (column, row), self.colorMapping[functionValue]) self.plottedPatches.append(patch[0]) else: # Remember values since we have to know the min and max value # before we can plot values[column, row] = functionValue if functionValue == None or functionValue in [ numpy.nan, numpy.inf, -numpy.inf ]: values.mask[column, row] = True # Do the actual plotting for functions with continuous values if not discreteFunction: minValue = values.min() maxValue = values.max() for column in range(self.maze.getColumns()): for row in range(self.maze.getRows()): if values.mask[column, row]: continue value = (values[column, row] - minValue) / (maxValue - minValue) patch = self.maze.plotSquare(self.axis, (column, row), cmap(value), zorder=0) self.plottedPatches.append(patch[0]) # Set limits self.axis.set_xlim(0, len(self.maze.structure[0]) - 1) self.axis.set_ylim(0, len(self.maze.structure) - 1) # Update legend self.legendWidget.clear() if discreteFunction: for functionValue, colorValue in self.colorMapping.items(): if isinstance(functionValue, tuple): functionValue = functionValue[0] # deal with '(action,)' rgbaColor = matplotlib.colors.ColorConverter().to_rgba( colorValue) item = QtGui.QListWidgetItem(str(functionValue), self.legendWidget) color = QtGui.QColor(int(rgbaColor[0] * 255), int(rgbaColor[1] * 255), int(rgbaColor[2] * 255)) item.setTextColor(color) self.legendWidget.addItem(item) else: for value in numpy.linspace(values.min(), values.max(), 10): rgbaColor = cmap( (value - values.min()) / (values.max() - values.min())) item = QtGui.QListWidgetItem(str(value), self.legendWidget) color = QtGui.QColor(int(rgbaColor[0] * 255), int(rgbaColor[1] * 255), int(rgbaColor[2] * 255)) item.setTextColor(color) self.legendWidget.addItem(item) self.canvas.draw() self.lock.release()
def _updateSamples(self, state, action, reward, succState, episodeTerminated): # Determine color if self.colorCriterion == "Action": value = action elif self.colorCriterion == "Reward": value = reward elif self.colorCriterion == "Q-Value": if self.evalFunction is None: return queryState = State((succState['x'], succState['xdot'], succState['y'], succState['ydot']), self.dimensions) value = self.evalFunction(queryState) self.minValue = min(value, self.minValue) self.maxValue = max(value, self.maxValue) if self.drawingEnabledCheckbox.checkState(): # Immediate drawing # Remove ball patch if it is drawn currently if self.ballPatch != None: self.ballPatch.remove() self.ballPatch = None if self.drawStyle == "Current Position": # Remove old trajectory self._removeTrajectory() self.rememberedSegments = [] # Plot ball self.ballPatch = Circle([state["x"], state["y"]], self.pinballMazeEnv.maze.ballRadius, facecolor='k') self.axis.add_patch(self.ballPatch) self.canvas.draw() elif self.drawStyle == "Online (All)": # If drawing was just reactivated self._drawRememberedSegments() # Draw current transition lines = self.axis.plot([state["x"], succState["x"]], [state["y"], succState["y"]], '-', color=self._determineColor(value)) self.linePatches.extend(lines) self.canvas.draw() else: # "Last Episode" # Remember state trajectory, it will be drawn at the end # of the episode self.rememberedSegments.append((state["x"], succState["x"], state["y"], succState["y"], value)) if episodeTerminated: # Remove last trajectory, draw this episode's trajectory self._removeTrajectory() self._drawRememberedSegments() self.canvas.draw() # When coloring trajectory based on real valued criteria, # we have to update the legend now if self.colorCriterion == "Q-Value": self.legendWidget.clear() for value in numpy.logspace(0, numpy.log10(self.maxValue - self.minValue + 1), 10): value = value - 1 + self.minValue color = self._determineColor(value) item = QtGui.QListWidgetItem(str(value), self.legendWidget) qColor = QtGui.QColor(int(color[0]*255), int(color[1]*255), int(color[2]*255)) item.setTextColor(qColor) self.legendWidget.addItem(item) else: if self.drawStyle != "Current Position": # Remember state trajectory, it will be drawn once drawing is # reenabled self.rememberedSegments.append((state["x"], succState["x"], state["y"], succState["y"], value))
def _plotFunction(self): if self.evalFunction is None: return self.lock.acquire() # Clean up old plot for patch in self.plottedPatches: patch.remove() self.plottedPatches = [] # Check if the observed function returns discrete or continuous value discreteFunction = self.functionObservable.discreteValues # Generate 2d state slice defaultDimValues = { "x": 0, "xdot": self.xVel, "y": 0, "ydot": self.yVel } stateSlice = generate2dStateSlice(["x", "y"], self.stateSpace, defaultDimValues, gridNodesPerDim=self.gridNodesPerDim) # Compute values that should be plotted values, colorMapping = \ generate2dPlotArray(self.evalFunction, stateSlice, not discreteFunction, shape=(self.gridNodesPerDim,self.gridNodesPerDim)) # If all value are None, we cannot draw anything useful if values.mask.all(): self.lock.release() return polyCollection = self.axis.pcolor( numpy.linspace(0.0, 1.0, self.gridNodesPerDim), numpy.linspace(0.0, 1.0, self.gridNodesPerDim), values.T) self.plottedPatches.append(polyCollection) # Set axis limits self.axis.set_xlim(0, 1) self.axis.set_ylim(0, 1) # Update legend self.legendWidget.clear() linearSegmentedColorbar = polyCollection.get_cmap() if discreteFunction: for functionValue, colorValue in colorMapping.items(): if isinstance(functionValue, tuple): functionValue = functionValue[0] # deal with '(action,)' normValue = polyCollection.norm(colorValue) if isinstance(normValue, numpy.ndarray): normValue = normValue[ 0] # happens when function is constant rgbaColor = linearSegmentedColorbar(normValue) item = QtGui.QListWidgetItem(str(functionValue), self.legendWidget) color = QtGui.QColor(int(rgbaColor[0] * 255), int(rgbaColor[1] * 255), int(rgbaColor[2] * 255)) item.setTextColor(color) self.legendWidget.addItem(item) else: for value in numpy.linspace(polyCollection.norm.vmin, polyCollection.norm.vmax, 10): normValue = polyCollection.norm(value) if isinstance(normValue, numpy.ndarray): normValue = normValue[ 0] # happens when function is constant rgbaColor = linearSegmentedColorbar(normValue) item = QtGui.QListWidgetItem(str(value), self.legendWidget) color = QtGui.QColor(int(rgbaColor[0] * 255), int(rgbaColor[1] * 255), int(rgbaColor[2] * 255)) item.setTextColor(color) self.legendWidget.addItem(item) self.canvas.draw() self.lock.release()