示例#1
0
    def _plot_valfun(self, VMat, xlim=None, ylim=None):
        """
        :returns: handle to the figure
        """
        plt.figure("Value Function")
        #pl.xticks(self.xTicks,self.xTicksLabels, fontsize=12)
        #pl.yticks(self.yTicks,self.yTicksLabels, fontsize=12)
        #pl.xlabel(r"$\theta$ (degree)")
        #pl.ylabel(r"$\dot{\theta}$ (degree/sec)")
        plt.title('Value Function')
        if xlim is not None and ylim is not None:
            extent = [xlim[0], xlim[1], ylim[0], ylim[1]]
        else:
            extent = [0, 1, 0, 1]
        self.valueFunction_fig = plt.imshow(
            VMat,
            cmap='ValueFunction',
            interpolation='nearest',
            origin='lower',
            extent=extent)

        norm = colors.Normalize(vmin=VMat.min(), vmax=VMat.max())
        self.valueFunction_fig.set_data(VMat)
        self.valueFunction_fig.set_norm(norm)
        plt.draw()
示例#2
0
    def showDomain(self, a=0, s=None):
        if s is None:
            s = self.state

        # Draw the environment
        if self.domain_fig is None:
            self.agent_fig = plt.figure("Domain")
            self.domain_fig = plt.imshow(self.map,
                                         cmap='GridWorld',
                                         interpolation='nearest',
                                         vmin=0,
                                         vmax=5)
            plt.xticks(np.arange(self.COLS), fontsize=FONTSIZE)
            plt.yticks(np.arange(self.ROWS), fontsize=FONTSIZE)
            # pl.tight_layout()
            self.agent_fig = plt.gca().plot(s[1],
                                            s[0],
                                            'kd',
                                            markersize=20.0 - self.COLS)
            plt.show()
        self.agent_fig.pop(0).remove()
        self.agent_fig = plt.figure("Domain")
        #mapcopy = copy(self.map)
        #mapcopy[s[0],s[1]] = self.AGENT
        # self.domain_fig.set_data(mapcopy)
        # Instead of '>' you can use 'D', 'o'
        self.agent_fig = plt.gca().plot(s[1],
                                        s[0],
                                        'k>',
                                        markersize=20.0 - self.COLS)
        plt.draw()
示例#3
0
 def showDomain(self, a=0):
     s = self.state
     # Draw the environment
     if self.circles is None:
         self.domain_fig = plt.subplot(3, 1, 1)
         plt.figure(1, (self.chainSize * 2 / 10.0, 2))
         self.domain_fig.set_xlim(0, self.chainSize * 2 / 10.0)
         self.domain_fig.set_ylim(0, 2)
         self.domain_fig.add_patch(
             mpatches.Circle(
                 (1 / 5.0 + 2 / 10.0 * (self.chainSize - 1), self.Y),
                 self.RADIUS * 1.1,
                 fc="w"))  # Make the last one double circle
         self.domain_fig.xaxis.set_visible(False)
         self.domain_fig.yaxis.set_visible(False)
         self.circles = [
             mpatches.Circle((1 / 5.0 + 2 / 10.0 * i, self.Y),
                             self.RADIUS,
                             fc="w") for i in range(self.chainSize)
         ]
         for i in range(self.chainSize):
             self.domain_fig.add_patch(self.circles[i])
             plt.show()
     for p in self.circles:
         p.set_facecolor('w')
     for p in self.GOAL_STATES:
         self.circles[p].set_facecolor('g')
     self.circles[s].set_facecolor('k')
     plt.draw()
示例#4
0
    def showDomain(self, a=0, s=None):
        if s is None:
            s = self.state

        # Draw the environment
        if self.domain_fig is None:
            self.agent_fig = plt.figure("Domain")
            self.domain_fig = plt.imshow(
                self.map,
                cmap='GridWorld',
                interpolation='nearest',
                vmin=0,
                vmax=5)
            plt.xticks(np.arange(self.COLS), fontsize=FONTSIZE)
            plt.yticks(np.arange(self.ROWS), fontsize=FONTSIZE)
            # pl.tight_layout()
            self.agent_fig = plt.gca(
            ).plot(s[1],
                   s[0],
                   'kd',
                   markersize=20.0 - self.COLS)
            plt.show()
        self.agent_fig.pop(0).remove()
        self.agent_fig = plt.figure("Domain")
        #mapcopy = copy(self.map)
        #mapcopy[s[0],s[1]] = self.AGENT
        # self.domain_fig.set_data(mapcopy)
        # Instead of '>' you can use 'D', 'o'
        self.agent_fig = plt.gca(
        ).plot(s[1],
               s[0],
               'k>',
               markersize=20.0 - self.COLS)
        plt.draw()
示例#5
0
 def showDomain(self, a=None):
     if a is not None:
         a = self.actions[a]
     T = np.empty((self.d, 2))
     T[:, 0] = np.cos(self.theta)
     T[:, 1] = np.sin(self.theta)
     R = np.dot(self.P, T)
     R1 = R - .5 * self.lengths[:, None] * T
     R2 = R + .5 * self.lengths[:, None] * T
     Rx = np.hstack([R1[:, 0], R2[:, 0]]) + self.pos_cm[0]
     Ry = np.hstack([R1[:, 1], R2[:, 1]]) + self.pos_cm[1]
     print(Rx)
     print(Ry)
     f = plt.figure("Swimmer Domain")
     if not hasattr(self, "swimmer_lines"):
         plt.plot(0., 0., "ro")
         self.swimmer_lines = plt.plot(Rx, Ry)[0]
         self.action_text = plt.text(-2, -8, str(a))
         plt.xlim(-5, 15)
         plt.ylim(-10, 10)
     else:
         self.swimmer_lines.set_data(Rx, Ry)
         self.action_text.set_text(str(a))
     plt.figure("Swimmer Domain").canvas.draw()
     plt.figure("Swimmer Domain").canvas.flush_events()
示例#6
0
文件: FiftyChain.py 项目: MLDL/rlpy
 def showDomain(self, a=0):
     s = self.state
     # Draw the environment
     if self.circles is None:
         self.domain_fig = plt.subplot(3, 1, 1)
         plt.figure(1, (self.chainSize * 2 / 10.0, 2))
         self.domain_fig.set_xlim(0, self.chainSize * 2 / 10.0)
         self.domain_fig.set_ylim(0, 2)
         self.domain_fig.add_patch(
             mpatches.Circle((1 / 5.0 + 2 / 10.0 * (self.chainSize - 1),
                              self.Y),
                             self.RADIUS * 1.1,
                             fc="w"))  # Make the last one double circle
         self.domain_fig.xaxis.set_visible(False)
         self.domain_fig.yaxis.set_visible(False)
         self.circles = [mpatches.Circle((1 / 5.0 + 2 / 10.0 * i, self.Y), self.RADIUS, fc="w")
                         for i in range(self.chainSize)]
         for i in range(self.chainSize):
             self.domain_fig.add_patch(self.circles[i])
             plt.show()
     for p in self.circles:
         p.set_facecolor('w')
     for p in self.GOAL_STATES:
         self.circles[p].set_facecolor('g')
     self.circles[s].set_facecolor('k')
     plt.draw()
示例#7
0
 def showExploration(self):
     plt.figure()
     plt.scatter([i[0] for i in self.visited_states],
                 [i[1] for i in self.visited_states],
                 color='k')
     plt.scatter([self.src['x']], [self.src['y']], color='r')
     plt.scatter([self.target['x']], [self.target['y']], color='b')
     plt.show()
示例#8
0
    def showLearning(self, representation):
        pi = np.zeros(
            (self.X_discretization,
             self.XDot_discretization),
            'uint8')
        V = np.zeros((self.X_discretization, self.XDot_discretization))

        if self.valueFunction_fig is None:
            self.valueFunction_fig = plt.figure("Value Function")
            self.valueFunction_im = plt.imshow(
                V,
                cmap='ValueFunction',
                interpolation='nearest',
                origin='lower',
                vmin=self.MIN_RETURN,
                vmax=self.MAX_RETURN)

            plt.xticks(self.xTicks, self.xTicksLabels, fontsize=12)
            plt.yticks(self.yTicks, self.yTicksLabels, fontsize=12)
            plt.xlabel(r"$x$")
            plt.ylabel(r"$\dot x$")

            self.policy_fig = plt.figure("Policy")
            self.policy_im = plt.imshow(
                pi,
                cmap='MountainCarActions',
                interpolation='nearest',
                origin='lower',
                vmin=0,
                vmax=self.actions_num)

            plt.xticks(self.xTicks, self.xTicksLabels, fontsize=12)
            plt.yticks(self.yTicks, self.yTicksLabels, fontsize=12)
            plt.xlabel(r"$x$")
            plt.ylabel(r"$\dot x$")
            plt.show()

        for row, xDot in enumerate(np.linspace(self.XDOTMIN, self.XDOTMAX, self.XDot_discretization)):
            for col, x in enumerate(np.linspace(self.XMIN, self.XMAX, self.X_discretization)):
                s = [x, xDot]
                Qs = representation.Qs(s, False)
                As = self.possibleActions()
                pi[row, col] = representation.bestAction(s, False, As)
                V[row, col] = max(Qs)
        self.valueFunction_im.set_data(V)
        self.policy_im.set_data(pi)

        self.valueFunction_fig = plt.figure("Value Function")
        plt.draw()
        self.policy_fig = plt.figure("Policy")
        plt.draw()
示例#9
0
文件: RCCar.py 项目: zhexiaozhe/rlpy
    def showDomain(self, a):
        s = self.state
        # Plot the car
        x, y, speed, heading = s
        car_xmin = x - self.REAR_WHEEL_RELATIVE_LOC
        car_ymin = y - self.CAR_WIDTH / 2.
        if self.domain_fig is None:  # Need to initialize the figure
            self.domain_fig = plt.figure()
            # Goal
            plt.gca().add_patch(
                plt.Circle(self.GOAL,
                           radius=self.GOAL_RADIUS,
                           color='g',
                           alpha=.4))
            plt.xlim([self.XMIN, self.XMAX])
            plt.ylim([self.YMIN, self.YMAX])
            plt.gca().set_aspect('1')
        # Car
        if self.car_fig is not None:
            plt.gca().patches.remove(self.car_fig)

        self.car_fig = mpatches.Rectangle([car_xmin, car_ymin],
                                          self.CAR_LENGTH,
                                          self.CAR_WIDTH,
                                          alpha=.4)
        rotation = mpl.transforms.Affine2D().rotate_deg_around(
            x, y, heading * 180 / np.pi) + plt.gca().transData
        self.car_fig.set_transform(rotation)
        plt.gca().add_patch(self.car_fig)

        plt.draw()
示例#10
0
文件: Experiment.py 项目: MLDL/rlpy
 def plot(self, y="return", x="learning_steps", save=False):
     """Plots the performance of the experiment
     This function has only limited capabilities.
     For more advanced plotting of results consider
     :py:class:`Tools.Merger.Merger`.
     """
     labels = rlpy.Tools.results.default_labels
     performance_fig = plt.figure("Performance")
     res = self.result
     plt.plot(res[x], res[y], '-bo', lw=3, markersize=10)
     plt.xlim(0, res[x][-1] * 1.01)
     y_arr = np.array(res[y])
     m = y_arr.min()
     M = y_arr.max()
     delta = M - m
     if delta > 0:
         plt.ylim(m - .1 * delta - .1, M + .1 * delta + .1)
     xlabel = labels[x] if x in labels else x
     ylabel = labels[y] if y in labels else y
     plt.xlabel(xlabel, fontsize=16)
     plt.ylabel(ylabel, fontsize=16)
     if save:
         path = os.path.join(
             self.full_path,
             "{:3}-performance.pdf".format(self.exp_id))
         performance_fig.savefig(path, transparent=True, pad_inches=.1)
     plt.ioff()
     plt.show()
 def plot(self, y="return", x="learning_steps", save=False):
     """Plots the performance of the experiment
     This function has only limited capabilities.
     For more advanced plotting of results consider
     :py:class:`Tools.Merger.Merger`.
     """
     labels = rlpy.Tools.results.default_labels
     performance_fig = plt.figure("Performance")
     res = self.result
     plt.plot(res[x], res[y], '-bo', lw=3, markersize=10)
     plt.xlim(0, res[x][-1] * 1.01)
     y_arr = np.array(res[y])
     m = y_arr.min()
     M = y_arr.max()
     delta = M - m
     if delta > 0:
         plt.ylim(m - .1 * delta - .1, M + .1 * delta + .1)
     xlabel = labels[x] if x in labels else x
     ylabel = labels[y] if y in labels else y
     plt.xlabel(xlabel, fontsize=16)
     plt.ylabel(ylabel, fontsize=16)
     if save:
         path = os.path.join(self.full_path,
                             "{:3}-performance.pdf".format(self.exp_id))
         performance_fig.savefig(path, transparent=True, pad_inches=.1)
     plt.ioff()
     plt.show()
示例#12
0
文件: RCCar.py 项目: okkhoy/rlpy
    def showDomain(self, a):
        s = self.state
        # Plot the car
        x, y, speed, heading = s
        car_xmin = x - self.REAR_WHEEL_RELATIVE_LOC
        car_ymin = y - old_div(self.CAR_WIDTH, 2.)
        if self.domain_fig is None:  # Need to initialize the figure
            self.domain_fig = plt.figure()
            # Goal
            plt.gca(
            ).add_patch(
                plt.Circle(
                    self.GOAL,
                    radius=self.GOAL_RADIUS,
                    color='g',
                    alpha=.4))
            plt.xlim([self.XMIN, self.XMAX])
            plt.ylim([self.YMIN, self.YMAX])
            plt.gca().set_aspect('1')
        # Car
        if self.car_fig is not None:
            plt.gca().patches.remove(self.car_fig)

        self.car_fig = mpatches.Rectangle(
            [car_xmin,
             car_ymin],
            self.CAR_LENGTH,
            self.CAR_WIDTH,
            alpha=.4)
        rotation = mpl.transforms.Affine2D().rotate_deg_around(
            x, y, heading * 180 / np.pi) + plt.gca().transData
        self.car_fig.set_transform(rotation)
        plt.gca().add_patch(self.car_fig)

        plt.draw()
示例#13
0
    def _plot_policy(self,
                     piMat,
                     title="Policy",
                     var="policy_fig",
                     xlim=None,
                     ylim=None):
        """
        :returns: handle to the figure
        """

        if getattr(self, var, None) is None:
            plt.figure(title)
            # define the colormap
            cmap = plt.cm.jet
            # extract all colors from the .jet map
            cmaplist = [cmap(i) for i in range(cmap.N)]
            # force the first color entry to be grey
            cmaplist[0] = (.5, .5, .5, 1.0)
            # create the new map
            cmap = cmap.from_list('Custom cmap', cmaplist, cmap.N)

            # define the bins and normalize
            bounds = np.linspace(0, self.actions_num, self.actions_num + 1)
            norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
            if xlim is not None and ylim is not None:
                extent = [xlim[0], xlim[1], ylim[0], ylim[1]]
            else:
                extent = [0, 1, 0, 1]
            self.__dict__[var] = plt.imshow(piMat,
                                            interpolation='nearest',
                                            origin='lower',
                                            cmap=cmap,
                                            norm=norm,
                                            extent=extent)
            #pl.xticks(self.xTicks,self.xTicksLabels, fontsize=12)
            #pl.yticks(self.yTicks,self.yTicksLabels, fontsize=12)
            #pl.xlabel(r"$\theta$ (degree)")
            #pl.ylabel(r"$\dot{\theta}$ (degree/sec)")
            plt.title(title)

            plt.colorbar()
        plt.figure(title)
        self.__dict__[var].set_data(piMat)
        plt.draw()
示例#14
0
    def gridworld_showlearning(self, representation):
        dom = self.actual_domain
        if self.valueFunction_fig is None:
            plt.figure("Value Function")
            self.valueFunction_fig = plt.imshow(dom.map,
                                                cmap='ValueFunction',
                                                interpolation='nearest',
                                                vmin=dom.MIN_RETURN,
                                                vmax=dom.MAX_RETURN)
            plt.xticks(np.arange(dom.COLS), fontsize=12)
            plt.yticks(np.arange(dom.ROWS), fontsize=12)
            # Create quivers for each action. 4 in total
            plt.show()
        plt.figure("Value Function")

        V = self.get_value_function(representation)
        # print Acts
        # Show Value Function
        self.valueFunction_fig.set_data(V)
        plt.draw()
示例#15
0
文件: Swimmer.py 项目: okkhoy/rlpy
    def _plot_policy(self, piMat, title="Policy",
                     var="policy_fig", xlim=None, ylim=None):
        """
        :returns: handle to the figure
        """

        if getattr(self, var, None) is None:
            plt.figure(title)
            # define the colormap
            cmap = plt.cm.jet
            # extract all colors from the .jet map
            cmaplist = [cmap(i) for i in range(cmap.N)]
            # force the first color entry to be grey
            cmaplist[0] = (.5, .5, .5, 1.0)
            # create the new map
            cmap = cmap.from_list('Custom cmap', cmaplist, cmap.N)

            # define the bins and normalize
            bounds = np.linspace(0, self.actions_num, self.actions_num + 1)
            norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
            if xlim is not None and ylim is not None:
                extent = [xlim[0], xlim[1], ylim[0], ylim[1]]
            else:
                extent = [0, 1, 0, 1]
            self.__dict__[var] = plt.imshow(
                piMat,
                interpolation='nearest',
                origin='lower',
                cmap=cmap,
                norm=norm,
                extent=extent)
            #pl.xticks(self.xTicks,self.xTicksLabels, fontsize=12)
            #pl.yticks(self.yTicks,self.yTicksLabels, fontsize=12)
            #pl.xlabel(r"$\theta$ (degree)")
            #pl.ylabel(r"$\dot{\theta}$ (degree/sec)")
            plt.title(title)

            plt.colorbar()
        plt.figure(title)
        self.__dict__[var].set_data(piMat)
        plt.draw()
示例#16
0
 def showDomain(self, a=0):
     # Draw the environment
     s = self.state
     world = np.zeros((self.blocks, self.blocks), 'uint8')
     undrawn_blocks = np.arange(self.blocks)
     while len(undrawn_blocks):
         A = undrawn_blocks[0]
         B = s[A]
         undrawn_blocks = undrawn_blocks[1:]
         if B == A:  # => A is on Table
             world[0, A] = A + 1  # 0 is white thats why!
         else:
             # See if B is already drawn
             i, j = findElemArray2D(B + 1, world)
             if len(i):
                 world[i + 1, j] = A + 1  # 0 is white thats why!
             else:
                 # Put it in the back of the list
                 undrawn_blocks = np.hstack((undrawn_blocks, [A]))
     if self.domain_fig is None:
         plt.figure("Domain")
         self.domain_fig = plt.imshow(
             world,
             cmap='BlocksWorld',
             origin='lower',
             interpolation='nearest')  # ,vmin=0,vmax=self.blocks)
         plt.xticks(np.arange(self.blocks), fontsize=FONTSIZE)
         plt.yticks(np.arange(self.blocks), fontsize=FONTSIZE)
         # pl.tight_layout()
         plt.axis('off')
         plt.show()
     else:
         self.domain_fig.set_data(world)
         plt.figure("Domain").canvas.draw()
         plt.figure("Domain").canvas.flush_events()
示例#17
0
 def showDomain(self, a=0):
     # Draw the environment
     s = self.state
     world = np.zeros((self.blocks, self.blocks), 'uint8')
     undrawn_blocks = np.arange(self.blocks)
     while len(undrawn_blocks):
         A = undrawn_blocks[0]
         B = s[A]
         undrawn_blocks = undrawn_blocks[1:]
         if B == A:  # => A is on Table
             world[0, A] = A + 1  # 0 is white thats why!
         else:
             # See if B is already drawn
             i, j = findElemArray2D(B + 1, world)
             if len(i):
                 world[i + 1, j] = A + 1  # 0 is white thats why!
             else:
                 # Put it in the back of the list
                 undrawn_blocks = np.hstack((undrawn_blocks, [A]))
     if self.domain_fig is None:
         plt.figure("Domain")
         self.domain_fig = plt.imshow(
             world,
             cmap='BlocksWorld',
             origin='lower',
             interpolation='nearest')  # ,vmin=0,vmax=self.blocks)
         plt.xticks(np.arange(self.blocks), fontsize=FONTSIZE)
         plt.yticks(np.arange(self.blocks), fontsize=FONTSIZE)
         # pl.tight_layout()
         plt.axis('off')
         plt.show()
     else:
         self.domain_fig.set_data(world)
         plt.figure("Domain").canvas.draw()
         plt.figure("Domain").canvas.flush_events()
示例#18
0
    def showDomain(self, a):
        if self.gcf is None:
            self.gcf = plt.gcf()

        s = self.state
        # Plot the car
        x, y, speed, heading = s
        car_xmin = x - self.REAR_WHEEL_RELATIVE_LOC
        car_ymin = y - self.CAR_WIDTH / 2.
        if self.domain_fig is None:  # Need to initialize the figure
            self.domain_fig = plt.figure()
            # Goal
            plt.gca(
            ).add_patch(
                plt.Circle(
                    self.GOAL,
                    radius=self.GOAL_RADIUS,
                    color='g',
                    alpha=.4))
            plt.xlim([self.XMIN, self.XMAX])
            plt.ylim([self.YMIN, self.YMAX])
            plt.gca().set_aspect('1')
        # Car
        if self.car_fig is not None:
            plt.gca().patches.remove(self.car_fig)

        if self.slips:            
            slip_x, slip_y = zip(*self.slips)
            try:
                line = plt.axes().lines[0]
                if len(line.get_xdata()) != len(slip_x): # if plot has discrepancy from data
                    line.set_xdata(slip_x)
                    line.set_ydata(slip_y)
            except IndexError:
                plt.plot(slip_x, slip_y, 'x', color='b')

        self.car_fig = mpatches.Rectangle(
            [car_xmin,
             car_ymin],
            self.CAR_LENGTH,
            self.CAR_WIDTH,
            alpha=.4)
        rotation = mpl.transforms.Affine2D().rotate_deg_around(
            x, y, heading * 180 / np.pi) + plt.gca().transData
        self.car_fig.set_transform(rotation)
        plt.gca().add_patch(self.car_fig)

        plt.draw()
        # self.gcf.canvas.draw()
        plt.pause(0.001)
示例#19
0
文件: HIVTreatment.py 项目: MLDL/rlpy
    def showDomain(self, a=0, s=None):
        """
        shows a live graph of each concentration
        """
        # only update the graph every couple of steps, otherwise it is
        # extremely slow
        if self.t % self.show_domain_every != 0 and not self.t >= self.episodeCap:
            return

        n = self.state_space_dims + 1
        names = list(self.state_names) + ["Action"]
        colors = ["b", "b", "b", "b", "r", "g", "k"]
        handles = getattr(self, "_state_graph_handles", None)
        plt.figure("Domain", figsize=(12, 10))
        if handles is None:
            handles = []
            f, axes = plt.subplots(
                n, sharex=True, num="Domain", figsize=(12, 10))
            f.subplots_adjust(hspace=0.1)
            for i in range(n):
                ax = axes[i]
                d = np.arange(self.episodeCap + 1) * 5
                ax.set_ylabel(names[i])
                ax.locator_params(tight=True, nbins=4)
                handles.append(
                    ax.plot(d,
                            self.episode_data[i],
                            color=colors[i])[0])
            self._state_graph_handles = handles
            ax.set_xlabel("Days")
        for i in range(n):
            handles[i].set_ydata(self.episode_data[i])
            ax = handles[i].get_axes()
            ax.relim()
            ax.autoscale_view()
        plt.draw()
示例#20
0
    def showDomain(self, a=0, s=None):
        """
        shows a live graph of each concentration
        """
        # only update the graph every couple of steps, otherwise it is
        # extremely slow
        if self.t % self.show_domain_every != 0 and not self.t >= self.episodeCap:
            return

        n = self.state_space_dims + 1
        names = list(self.state_names) + ["Action"]
        colors = ["b", "b", "b", "b", "r", "g", "k"]
        handles = getattr(self, "_state_graph_handles", None)
        plt.figure("Domain", figsize=(12, 10))
        if handles is None:
            handles = []
            f, axes = plt.subplots(n,
                                   sharex=True,
                                   num="Domain",
                                   figsize=(12, 10))
            f.subplots_adjust(hspace=0.1)
            for i in range(n):
                ax = axes[i]
                d = np.arange(self.episodeCap + 1) * 5
                ax.set_ylabel(names[i])
                ax.locator_params(tight=True, nbins=4)
                handles.append(
                    ax.plot(d, self.episode_data[i], color=colors[i])[0])
            self._state_graph_handles = handles
            ax.set_xlabel("Days")
        for i in range(n):
            handles[i].set_ydata(self.episode_data[i])
            ax = handles[i].get_axes()
            ax.relim()
            ax.autoscale_view()
        plt.draw()
示例#21
0
    def showDomain(self, a=0):
        # Draw the environment
        s = self.state
        s = s[0]
        if self.circles is None:
            fig = plt.figure(1, (self.chainSize * 2, 2))
            ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect=1.)
            ax.set_xlim(0, self.chainSize * 2)
            ax.set_ylim(0, 2)
            # Make the last one double circle
            ax.add_patch(
                mpatches.Circle((1 + 2 * (self.chainSize - 1), self.Y), self.RADIUS * 1.1, fc="w"))
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
            self.circles = [mpatches.Circle((1 + 2 * i, self.Y), self.RADIUS, fc="w")
                            for i in range(self.chainSize)]
            for i in range(self.chainSize):
                ax.add_patch(self.circles[i])
                if i != self.chainSize - 1:
                    fromAtoB(
                        1 + 2 * i + self.SHIFT,
                        self.Y + self.SHIFT,
                        1 + 2 * (i + 1) - self.SHIFT,
                        self.Y + self.SHIFT)
                    if i != self.chainSize - 2:
                        fromAtoB(
                            1 + 2 * (i + 1) - self.SHIFT,
                            self.Y - self.SHIFT,
                            1 + 2 * i + self.SHIFT,
                            self.Y - self.SHIFT,
                            'r')
                fromAtoB(
                    .75,
                    self.Y -
                    1.5 *
                    self.SHIFT,
                    .75,
                    self.Y +
                    1.5 *
                    self.SHIFT,
                    'r',
                    connectionstyle='arc3,rad=-1.2')
                plt.show()

        [p.set_facecolor('w') for p in self.circles]
        self.circles[s].set_facecolor('k')
        plt.draw()
示例#22
0
    def plot_trials(self,
                    y="eps_return",
                    x="learning_steps",
                    average=10,
                    save=False):
        """Plots the performance of the experiment
        This function has only limited capabilities.
        For more advanced plotting of results consider
        :py:class:`Tools.Merger.Merger`.
        """
        def movingaverage(interval, window_size):
            window = np.ones(int(window_size)) / float(window_size)
            return np.convolve(interval, window, 'same')

        labels = rlpy.Tools.results.default_labels
        performance_fig = plt.figure("Performance")
        trials = self.trials
        y_arr = np.array(trials[y])
        if average:
            assert type(average) is int, "Filter length is not an integer!"
            y_arr = movingaverage(y_arr, average)
        plt.plot(trials[x], y_arr, '-bo', lw=3, markersize=10)
        plt.xlim(0, trials[x][-1] * 1.01)
        m = y_arr.min()
        M = y_arr.max()
        delta = M - m
        if delta > 0:
            plt.ylim(m - .1 * delta - .1, M + .1 * delta + .1)
        xlabel = labels[x] if x in labels else x
        ylabel = labels[y] if y in labels else y
        plt.xlabel(xlabel, fontsize=16)
        plt.ylabel(ylabel, fontsize=16)
        if save:
            path = os.path.join(self.full_path,
                                "{:3}-trials.pdf".format(self.exp_id))
            performance_fig.savefig(path, transparent=True, pad_inches=.1)
        plt.ioff()
        plt.show()
示例#23
0
    def showDomain(self, a):
        s = self.state
        # Draw the environment
        if self.domain_fig is None:
            plt.figure("Domain")
            self.domain_fig = plt.imshow(
                self.map,
                cmap='IntruderMonitoring',
                interpolation='nearest',
                vmin=0,
                vmax=3)
            plt.xticks(np.arange(self.COLS), fontsize=FONTSIZE)
            plt.yticks(np.arange(self.ROWS), fontsize=FONTSIZE)
            plt.show()
        if self.ally_fig is not None:
            self.ally_fig.pop(0).remove()
            self.intruder_fig.pop(0).remove()

        s_ally = s[0:self.NUMBER_OF_AGENTS * 2].reshape((-1, 2))
        s_intruder = s[self.NUMBER_OF_AGENTS * 2:].reshape((-1, 2))
        self.ally_fig = plt.plot(
            s_ally[:,
                   1],
            s_ally[:,
                   0],
            'bo',
            markersize=30.0,
            alpha=.7,
            markeredgecolor='k',
            markeredgewidth=2)
        self.intruder_fig = plt.plot(
            s_intruder[:,
                       1],
            s_intruder[:,
                       0],
            'g>',
            color='gray',
            markersize=30.0,
            alpha=.7,
            markeredgecolor='k',
            markeredgewidth=2)
        plt.figure("Domain").canvas.draw()
        plt.figure("Domain").canvas.flush_events()
示例#24
0
    def showDomain(self, a):
        s = self.state
        # Draw the environment
        if self.domain_fig is None:
            plt.figure("Domain")
            self.domain_fig = plt.imshow(self.map,
                                         cmap='IntruderMonitoring',
                                         interpolation='nearest',
                                         vmin=0,
                                         vmax=3)
            plt.xticks(np.arange(self.COLS), fontsize=FONTSIZE)
            plt.yticks(np.arange(self.ROWS), fontsize=FONTSIZE)
            plt.show()
        if self.ally_fig is not None:
            self.ally_fig.pop(0).remove()
            self.intruder_fig.pop(0).remove()

        s_ally = s[0:self.NUMBER_OF_AGENTS * 2].reshape((-1, 2))
        s_intruder = s[self.NUMBER_OF_AGENTS * 2:].reshape((-1, 2))
        self.ally_fig = plt.plot(s_ally[:, 1],
                                 s_ally[:, 0],
                                 'bo',
                                 markersize=30.0,
                                 alpha=.7,
                                 markeredgecolor='k',
                                 markeredgewidth=2)
        self.intruder_fig = plt.plot(s_intruder[:, 1],
                                     s_intruder[:, 0],
                                     'g>',
                                     color='gray',
                                     markersize=30.0,
                                     alpha=.7,
                                     markeredgecolor='k',
                                     markeredgewidth=2)
        plt.figure("Domain").canvas.draw()
        plt.figure("Domain").canvas.flush_events()
示例#25
0
    def showLearning(self, representation):
        if self.valueFunction_fig is None:
            plt.figure("Value Function")
            self.valueFunction_fig = plt.imshow(self.map,
                                                cmap='ValueFunction',
                                                interpolation='nearest',
                                                vmin=self.MIN_RETURN,
                                                vmax=self.MAX_RETURN)
            plt.xticks(np.arange(self.COLS), fontsize=12)
            plt.yticks(np.arange(self.ROWS), fontsize=12)
            # Create quivers for each action. 4 in total
            X = np.arange(self.ROWS) - self.SHIFT
            Y = np.arange(self.COLS)
            X, Y = np.meshgrid(X, Y)
            DX = DY = np.ones(X.shape)
            C = np.zeros(X.shape)
            C[0, 0] = 1  # Making sure C has both 0 and 1
            # length of arrow/width of bax. Less then 0.5 because each arrow is
            # offset, 0.4 looks nice but could be better/auto generated
            arrow_ratio = 0.4
            Max_Ratio_ArrowHead_to_ArrowLength = 0.25
            ARROW_WIDTH = 0.5 * Max_Ratio_ArrowHead_to_ArrowLength / 5.0
            self.upArrows_fig = plt.quiver(Y,
                                           X,
                                           DY,
                                           DX,
                                           C,
                                           units='y',
                                           cmap='Actions',
                                           scale_units="height",
                                           scale=self.ROWS / arrow_ratio,
                                           width=-1 * ARROW_WIDTH)
            self.upArrows_fig.set_clim(vmin=0, vmax=1)
            X = np.arange(self.ROWS) + self.SHIFT
            Y = np.arange(self.COLS)
            X, Y = np.meshgrid(X, Y)
            self.downArrows_fig = plt.quiver(Y,
                                             X,
                                             DY,
                                             DX,
                                             C,
                                             units='y',
                                             cmap='Actions',
                                             scale_units="height",
                                             scale=self.ROWS / arrow_ratio,
                                             width=-1 * ARROW_WIDTH)
            self.downArrows_fig.set_clim(vmin=0, vmax=1)
            X = np.arange(self.ROWS)
            Y = np.arange(self.COLS) - self.SHIFT
            X, Y = np.meshgrid(X, Y)
            self.leftArrows_fig = plt.quiver(Y,
                                             X,
                                             DY,
                                             DX,
                                             C,
                                             units='x',
                                             cmap='Actions',
                                             scale_units="width",
                                             scale=self.COLS / arrow_ratio,
                                             width=ARROW_WIDTH)
            self.leftArrows_fig.set_clim(vmin=0, vmax=1)
            X = np.arange(self.ROWS)
            Y = np.arange(self.COLS) + self.SHIFT
            X, Y = np.meshgrid(X, Y)
            self.rightArrows_fig = plt.quiver(Y,
                                              X,
                                              DY,
                                              DX,
                                              C,
                                              units='x',
                                              cmap='Actions',
                                              scale_units="width",
                                              scale=self.COLS / arrow_ratio,
                                              width=ARROW_WIDTH)
            self.rightArrows_fig.set_clim(vmin=0, vmax=1)
            plt.show()
        plt.figure("Value Function")
        V = np.zeros((self.ROWS, self.COLS))
        # Boolean 3 dimensional array. The third array highlights the action.
        # Thie mask is used to see in which cells what actions should exist
        Mask = np.ones((self.COLS, self.ROWS, self.actions_num), dtype='bool')
        arrowSize = np.zeros((self.COLS, self.ROWS, self.actions_num),
                             dtype='float')
        # 0 = suboptimal action, 1 = optimal action
        arrowColors = np.zeros((self.COLS, self.ROWS, self.actions_num),
                               dtype='uint8')
        for r in xrange(self.ROWS):
            for c in xrange(self.COLS):
                if self.map[r, c] == self.BLOCKED:
                    V[r, c] = 0
                if self.map[r, c] == self.GOAL:
                    V[r, c] = self.MAX_RETURN
                if self.map[r, c] == self.PIT:
                    V[r, c] = self.MIN_RETURN
                if self.map[r, c] == self.EMPTY or self.map[r,
                                                            c] == self.START:
                    s = np.array([r, c])
                    As = self.possibleActions(s)
                    terminal = self.isTerminal(s)
                    Qs = representation.Qs(s, terminal)
                    bestA = representation.bestActions(s, terminal, As)
                    V[r, c] = max(Qs[As])
                    Mask[c, r, As] = False
                    arrowColors[c, r, bestA] = 1

                    for i in xrange(len(As)):
                        a = As[i]
                        Q = Qs[i]
                        value = linearMap(Q, self.MIN_RETURN, self.MAX_RETURN,
                                          0, 1)
                        arrowSize[c, r, a] = value
        # Show Value Function
        self.valueFunction_fig.set_data(V)
        # Show Policy Up Arrows
        DX = arrowSize[:, :, 0]
        DY = np.zeros((self.ROWS, self.COLS))
        DX = np.ma.masked_array(DX, mask=Mask[:, :, 0])
        DY = np.ma.masked_array(DY, mask=Mask[:, :, 0])
        C = np.ma.masked_array(arrowColors[:, :, 0], mask=Mask[:, :, 0])
        self.upArrows_fig.set_UVC(DY, DX, C)
        # Show Policy Down Arrows
        DX = -arrowSize[:, :, 1]
        DY = np.zeros((self.ROWS, self.COLS))
        DX = np.ma.masked_array(DX, mask=Mask[:, :, 1])
        DY = np.ma.masked_array(DY, mask=Mask[:, :, 1])
        C = np.ma.masked_array(arrowColors[:, :, 1], mask=Mask[:, :, 1])
        self.downArrows_fig.set_UVC(DY, DX, C)
        # Show Policy Left Arrows
        DX = np.zeros((self.ROWS, self.COLS))
        DY = -arrowSize[:, :, 2]
        DX = np.ma.masked_array(DX, mask=Mask[:, :, 2])
        DY = np.ma.masked_array(DY, mask=Mask[:, :, 2])
        C = np.ma.masked_array(arrowColors[:, :, 2], mask=Mask[:, :, 2])
        self.leftArrows_fig.set_UVC(DY, DX, C)
        # Show Policy Right Arrows
        DX = np.zeros((self.ROWS, self.COLS))
        DY = arrowSize[:, :, 3]
        DX = np.ma.masked_array(DX, mask=Mask[:, :, 3])
        DY = np.ma.masked_array(DY, mask=Mask[:, :, 3])
        C = np.ma.masked_array(arrowColors[:, :, 3], mask=Mask[:, :, 3])
        self.rightArrows_fig.set_UVC(DY, DX, C)
        plt.draw()
示例#26
0
    def showDomain(self, a=0):
        s = self.state
        if self.domain_fig is None:
            self.domain_fig = plt.figure(
                1, (UAVLocation.SIZE * self.dist_between_locations + 1,
                    self.NUM_UAV + 1))
            plt.show()
        plt.clf()
        # Draw the environment
        # Allocate horizontal 'lanes' for UAVs to traverse

        # Formerly, we checked if this was the first time plotting; wedge shapes cannot be removed from
        # matplotlib environment, nor can their properties be changed, without clearing the figure
        # Thus, we must redraw the figure on each timestep
        #        if self.location_rect_vis is None:
        # Figure with x width corresponding to number of location states, UAVLocation.SIZE
        # and rows (lanes) set aside in y for each UAV (NUM_UAV total lanes).
        # Add buffer of 1
        self.subplot_axes = self.domain_fig.add_axes([0, 0, 1, 1],
                                                     frameon=False,
                                                     aspect=1.)
        crashLocationX = 2 * \
            (self.dist_between_locations) * (UAVLocation.SIZE - 1)
        self.subplot_axes.set_xlim(0, 1 + crashLocationX + self.RECT_GAP)
        self.subplot_axes.set_ylim(0, 1 + self.NUM_UAV)
        self.subplot_axes.xaxis.set_visible(False)
        self.subplot_axes.yaxis.set_visible(False)

        # Assign coordinates of each possible uav location on figure
        self.location_coord = [
            0.5 + (self.LOCATION_WIDTH / 2) + (self.dist_between_locations) * i
            for i in range(UAVLocation.SIZE - 1)
        ]
        self.location_coord.append(crashLocationX + self.LOCATION_WIDTH / 2)

        # Create rectangular patches at each of those locations
        self.location_rect_vis = [
            mpatches.Rectangle([0.5 + (self.dist_between_locations) * i, 0],
                               self.LOCATION_WIDTH,
                               self.NUM_UAV * 2,
                               fc='w') for i in range(UAVLocation.SIZE - 1)
        ]
        self.location_rect_vis.append(
            mpatches.Rectangle([crashLocationX, 0],
                               self.LOCATION_WIDTH,
                               self.NUM_UAV * 2,
                               fc='w'))
        [
            self.subplot_axes.add_patch(self.location_rect_vis[i])
            for i in range(4)
        ]
        self.comms_line = [
            lines.Line2D([
                0.5 + self.LOCATION_WIDTH +
                (self.dist_between_locations) * i, 0.5 + self.LOCATION_WIDTH +
                (self.dist_between_locations) * i + self.RECT_GAP
            ], [self.NUM_UAV * 0.5 + 0.5, self.NUM_UAV * 0.5 + 0.5],
                         linewidth=3,
                         color='black',
                         visible=False) for i in range(UAVLocation.SIZE - 2)
        ]
        self.comms_line.append(
            lines.Line2D([
                0.5 + self.LOCATION_WIDTH +
                (self.dist_between_locations) * 2, crashLocationX
            ], [self.NUM_UAV * 0.5 + 0.5, self.NUM_UAV * 0.5 + 0.5],
                         linewidth=3,
                         color='black',
                         visible=False))

        # Create location text below rectangles
        locText = ["Base", "Refuel", "Communication", "Surveillance"]
        self.location_rect_txt = [
            plt.text(0.5 + self.dist_between_locations * i +
                     0.5 * self.LOCATION_WIDTH,
                     -0.3,
                     locText[i],
                     ha='center') for i in range(UAVLocation.SIZE - 1)
        ]
        self.location_rect_txt.append(
            plt.text(crashLocationX + 0.5 * self.LOCATION_WIDTH,
                     -0.3,
                     locText[UAVLocation.SIZE - 1],
                     ha='center'))

        # Initialize list of circle objects

        uav_x = self.location_coord[UAVLocation.BASE]

        # Update the member variables storing all the figure objects
        self.uav_circ_vis = [
            mpatches.Circle((uav_x, 1 + uav_id), self.UAV_RADIUS, fc="w")
            for uav_id in range(0, self.NUM_UAV)
        ]
        self.uav_text_vis = [None for uav_id in range(0, self.NUM_UAV)]  # f**k
        self.uav_sensor_vis = [
            mpatches.Wedge((uav_x + self.SENSOR_REL_X, 1 + uav_id),
                           self.SENSOR_LENGTH, -30, 30)
            for uav_id in range(0, self.NUM_UAV)
        ]
        self.uav_actuator_vis = [
            mpatches.Wedge((uav_x, 1 + uav_id + self.ACTUATOR_REL_Y),
                           self.ACTUATOR_HEIGHT, 60, 120)
            for uav_id in range(0, self.NUM_UAV)
        ]

        # The following was executed when we used to check if the environment needed re-drawing: see above.
        # Remove all UAV circle objects from visualization
        #        else:
        #            [self.uav_circ_vis[uav_id].remove() for uav_id in range(0,self.NUM_UAV)]
        #            [self.uav_text_vis[uav_id].remove() for uav_id in range(0,self.NUM_UAV)]
        #            [self.uav_sensor_vis[uav_id].remove() for uav_id in range(0,self.NUM_UAV)]

        # For each UAV:
        # Draw a circle, with text inside = amt fuel remaining
        # Triangle on top of UAV for comms, black = good, red = bad
        # Triangle in front of UAV for surveillance
        sStruct = self.state2Struct(s)

        for uav_id in range(0, self.NUM_UAV):
            # Assign all the variables corresponding to this UAV for this iteration;
            # this could alternately be done with a UAV class whose objects keep track
            # of these variables.  Elect to use lists here since ultimately the state
            # must be a vector anyway.
            # State index corresponding to the location of this uav
            uav_location = sStruct.locations[uav_id]
            uav_fuel = sStruct.fuel[uav_id]
            uav_sensor = sStruct.sensor[uav_id]
            uav_actuator = sStruct.actuator[uav_id]

            # Assign coordinates on figure where UAV should be drawn
            uav_x = self.location_coord[uav_location]
            uav_y = 1 + uav_id

            # Update plot wit this UAV
            self.uav_circ_vis[uav_id] = mpatches.Circle((uav_x, uav_y),
                                                        self.UAV_RADIUS,
                                                        fc="w")
            self.uav_text_vis[uav_id] = plt.text(uav_x - 0.05, uav_y - 0.05,
                                                 uav_fuel)
            if uav_sensor == SensorState.RUNNING:
                objColor = 'black'
            else:
                objColor = 'red'
            self.uav_sensor_vis[uav_id] = mpatches.Wedge(
                (uav_x + self.SENSOR_REL_X, uav_y),
                self.SENSOR_LENGTH,
                -30,
                30,
                color=objColor)

            if uav_actuator == ActuatorState.RUNNING:
                objColor = 'black'
            else:
                objColor = 'red'
            self.uav_actuator_vis[uav_id] = mpatches.Wedge(
                (uav_x, uav_y + self.ACTUATOR_REL_Y),
                self.ACTUATOR_HEIGHT,
                60,
                120,
                color=objColor)

            self.subplot_axes.add_patch(self.uav_circ_vis[uav_id])
            self.subplot_axes.add_patch(self.uav_sensor_vis[uav_id])
            self.subplot_axes.add_patch(self.uav_actuator_vis[uav_id])

        numHealthySurveil = np.sum(
            np.logical_and(sStruct.locations == UAVLocation.SURVEIL,
                           sStruct.sensor))
        # We have comms coverage: draw a line between comms states to show this
        if (any(sStruct.locations == UAVLocation.COMMS)):
            for i in xrange(len(self.comms_line)):
                self.comms_line[i].set_visible(True)
                self.comms_line[i].set_color('black')
                self.subplot_axes.add_line(self.comms_line[i])
            # We also have UAVs in surveillance; color the comms line black
            if numHealthySurveil > 0:
                self.location_rect_vis[len(self.location_rect_vis) -
                                       1].set_color('green')
        plt.draw()
        sleep(0.5)
示例#27
0
    def showDomain(self, a=0):
        s = self.state
        plt.figure("Domain")

        if self.networkGraph is None:  # or self.networkPos is None:
            self.networkGraph = nx.Graph()
            # enumerate all computer_ids, simulatenously iterating through
            # neighbors list and compstatus
            for computer_id, (neighbors,
                              compstatus) in enumerate(zip(self.NEIGHBORS, s)):
                # Add a node to network for each computer
                self.networkGraph.add_node(computer_id, node_color="w")
            for uniqueEdge in self.UNIQUE_EDGES:
                self.networkGraph.add_edge(
                    uniqueEdge[0], uniqueEdge[1],
                    edge_color="k")  # Add an edge between each neighbor
            self.networkPos = nx.circular_layout(self.networkGraph)
            nx.draw_networkx_nodes(self.networkGraph,
                                   self.networkPos,
                                   node_color="w")
            nx.draw_networkx_edges(self.networkGraph,
                                   self.networkPos,
                                   edges_color="k")
            nx.draw_networkx_labels(self.networkGraph, self.networkPos)
            plt.show()
        else:
            plt.clf()
            blackEdges = []
            redEdges = []
            greenNodes = []
            redNodes = []
            for computer_id, (neighbors,
                              compstatus) in enumerate(zip(self.NEIGHBORS, s)):
                if (compstatus == self.RUNNING):
                    greenNodes.append(computer_id)
                else:
                    redNodes.append(computer_id)
            # Iterate through all unique edges
            for uniqueEdge in self.UNIQUE_EDGES:
                if (s[uniqueEdge[0]] == self.RUNNING
                        and s[uniqueEdge[1]] == self.RUNNING):
                    # Then both computers are working
                    blackEdges.append(uniqueEdge)
                else:  # If either computer is BROKEN, make the edge red
                    redEdges.append(uniqueEdge)
            # "if redNodes", etc. - only draw things in the network if these lists aren't empty / null
            if redNodes:
                nx.draw_networkx_nodes(self.networkGraph,
                                       self.networkPos,
                                       nodelist=redNodes,
                                       node_color="r",
                                       linewidths=2)
            if greenNodes:
                nx.draw_networkx_nodes(self.networkGraph,
                                       self.networkPos,
                                       nodelist=greenNodes,
                                       node_color="w",
                                       linewidths=2)
            if blackEdges:
                nx.draw_networkx_edges(self.networkGraph,
                                       self.networkPos,
                                       edgelist=blackEdges,
                                       edge_color="k",
                                       width=2,
                                       style='solid')
            if redEdges:
                nx.draw_networkx_edges(self.networkGraph,
                                       self.networkPos,
                                       edgelist=redEdges,
                                       edge_color="k",
                                       width=2,
                                       style='dotted')
        nx.draw_networkx_labels(self.networkGraph, self.networkPos)
        plt.figure("Domain").canvas.draw()
        plt.figure("Domain").canvas.flush_events()
示例#28
0
                               kernel_args=[kernel_width],
                               active_threshold=active_threshold,
                               discover_threshold=discover_threshold,
                               normalization=False,
                               max_active_base_feat=100,
                               max_base_feat_sim=max_base_feat_sim)
    policy = SwimmerPolicy(representation)
    #policy = eGreedy(representation, epsilon=0.1)
    stat_bins_per_state_dim = 20
    # agent           = SARSA(representation,policy,domain,initial_learn_rate=initial_learn_rate,
    # lambda_=.0, learn_rate_decay_mode="boyan", boyan_N0=boyan_N0)
    opt["agent"] = SARSA(
        policy, representation, discount_factor=domain.discount_factor,
        lambda_=lambda_, initial_learn_rate=initial_learn_rate,
        learn_rate_decay_mode="boyan", boyan_N0=boyan_N0)
    experiment = Experiment(**opt)
    return experiment

if __name__ == '__main__':
    from rlpy.Tools.run import run_profiled
    # run_profiled(make_experiment)
    experiment = make_experiment(1)
    experiment.run(visualize_performance=1, visualize_learning=True)
    # experiment.plot()
    # experiment.save()
    from rlpy.Tools import plt
    plt.figure()
    for i in range(9):
        plt.plot(experiment.state_counts_learn[i], label="Dim " + str(i))
    plt.legend()
 def showExploration(self):
     plt.figure()
     plt.scatter([i[0] for i in self.visited_states],[i[1] for i in self.visited_states], color='k')
     plt.scatter([self.src['x']],[self.src['y']], color='r')
     plt.scatter([self.target['x']],[self.target['y']], color='b')
     plt.show()
示例#30
0
    def showLearning(self, representation):
        if self.valueFunction_fig is None:
            plt.figure("Value Function")
            self.valueFunction_fig = plt.imshow(
                self.map,
                cmap='ValueFunction',
                interpolation='nearest',
                vmin=self.MIN_RETURN,
                vmax=self.MAX_RETURN)
            plt.xticks(np.arange(self.COLS), fontsize=12)
            plt.yticks(np.arange(self.ROWS), fontsize=12)
            # Create quivers for each action. 4 in total
            X = np.arange(self.ROWS) - self.SHIFT
            Y = np.arange(self.COLS)
            X, Y = np.meshgrid(X, Y)
            DX = DY = np.ones(X.shape)
            C = np.zeros(X.shape)
            C[0, 0] = 1  # Making sure C has both 0 and 1
            # length of arrow/width of bax. Less then 0.5 because each arrow is
            # offset, 0.4 looks nice but could be better/auto generated
            arrow_ratio = 0.4
            Max_Ratio_ArrowHead_to_ArrowLength = 0.25
            ARROW_WIDTH = 0.5 * Max_Ratio_ArrowHead_to_ArrowLength / 5.0
            self.upArrows_fig = plt.quiver(
                Y,
                X,
                DY,
                DX,
                C,
                units='y',
                cmap='Actions',
                scale_units="height",
                scale=self.ROWS /
                arrow_ratio,
                width=-
                1 *
                ARROW_WIDTH)
            self.upArrows_fig.set_clim(vmin=0, vmax=1)
            X = np.arange(self.ROWS) + self.SHIFT
            Y = np.arange(self.COLS)
            X, Y = np.meshgrid(X, Y)
            self.downArrows_fig = plt.quiver(
                Y,
                X,
                DY,
                DX,
                C,
                units='y',
                cmap='Actions',
                scale_units="height",
                scale=self.ROWS /
                arrow_ratio,
                width=-
                1 *
                ARROW_WIDTH)
            self.downArrows_fig.set_clim(vmin=0, vmax=1)
            X = np.arange(self.ROWS)
            Y = np.arange(self.COLS) - self.SHIFT
            X, Y = np.meshgrid(X, Y)
            self.leftArrows_fig = plt.quiver(
                Y,
                X,
                DY,
                DX,
                C,
                units='x',
                cmap='Actions',
                scale_units="width",
                scale=self.COLS /
                arrow_ratio,
                width=ARROW_WIDTH)
            self.leftArrows_fig.set_clim(vmin=0, vmax=1)
            X = np.arange(self.ROWS)
            Y = np.arange(self.COLS) + self.SHIFT
            X, Y = np.meshgrid(X, Y)
            self.rightArrows_fig = plt.quiver(
                Y,
                X,
                DY,
                DX,
                C,
                units='x',
                cmap='Actions',
                scale_units="width",
                scale=self.COLS /
                arrow_ratio,
                width=ARROW_WIDTH)
            self.rightArrows_fig.set_clim(vmin=0, vmax=1)
            plt.show()
        plt.figure("Value Function")
        V = np.zeros((self.ROWS, self.COLS))
        # Boolean 3 dimensional array. The third array highlights the action.
        # Thie mask is used to see in which cells what actions should exist
        Mask = np.ones(
            (self.COLS,
             self.ROWS,
             self.actions_num),
            dtype='bool')
        arrowSize = np.zeros(
            (self.COLS,
             self.ROWS,
             self.actions_num),
            dtype='float')
        # 0 = suboptimal action, 1 = optimal action
        arrowColors = np.zeros(
            (self.COLS,
             self.ROWS,
             self.actions_num),
            dtype='uint8')
        for r in xrange(self.ROWS):
            for c in xrange(self.COLS):
                if self.map[r, c] == self.BLOCKED:
                    V[r, c] = 0
                if self.map[r, c] == self.GOAL:
                    V[r, c] = self.MAX_RETURN
                if self.map[r, c] == self.PIT:
                    V[r, c] = self.MIN_RETURN
                if self.map[r, c] == self.EMPTY or self.map[r, c] == self.START:
                    s = np.array([r, c])
                    As = self.possibleActions(s)
                    terminal = self.isTerminal(s)
                    Qs = representation.Qs(s, terminal)
                    bestA = representation.bestActions(s, terminal, As)
                    V[r, c] = max(Qs[As])
                    Mask[c, r, As] = False
                    arrowColors[c, r, bestA] = 1

                    for i in xrange(len(As)):
                        a = As[i]
                        Q = Qs[i]
                        value = linearMap(
                            Q,
                            self.MIN_RETURN,
                            self.MAX_RETURN,
                            0,
                            1)
                        arrowSize[c, r, a] = value
        # Show Value Function
        self.valueFunction_fig.set_data(V)
        # Show Policy Up Arrows
        DX = arrowSize[:, :, 0]
        DY = np.zeros((self.ROWS, self.COLS))
        DX = np.ma.masked_array(DX, mask=Mask[:, :, 0])
        DY = np.ma.masked_array(DY, mask=Mask[:, :, 0])
        C  = np.ma.masked_array(arrowColors[:, :, 0], mask=Mask[:,:, 0])
        self.upArrows_fig.set_UVC(DY, DX, C)
        # Show Policy Down Arrows
        DX = -arrowSize[:, :, 1]
        DY = np.zeros((self.ROWS, self.COLS))
        DX = np.ma.masked_array(DX, mask=Mask[:, :, 1])
        DY = np.ma.masked_array(DY, mask=Mask[:, :, 1])
        C  = np.ma.masked_array(arrowColors[:, :, 1], mask=Mask[:,:, 1])
        self.downArrows_fig.set_UVC(DY, DX, C)
        # Show Policy Left Arrows
        DX = np.zeros((self.ROWS, self.COLS))
        DY = -arrowSize[:, :, 2]
        DX = np.ma.masked_array(DX, mask=Mask[:, :, 2])
        DY = np.ma.masked_array(DY, mask=Mask[:, :, 2])
        C  = np.ma.masked_array(arrowColors[:, :, 2], mask=Mask[:,:, 2])
        self.leftArrows_fig.set_UVC(DY, DX, C)
        # Show Policy Right Arrows
        DX = np.zeros((self.ROWS, self.COLS))
        DY = arrowSize[:, :, 3]
        DX = np.ma.masked_array(DX, mask=Mask[:, :, 3])
        DY = np.ma.masked_array(DY, mask=Mask[:, :, 3])
        C  = np.ma.masked_array(arrowColors[:, :, 3], mask=Mask[:,:, 3])
        self.rightArrows_fig.set_UVC(DY, DX, C)
        plt.draw()
示例#31
0
文件: PST.py 项目: okkhoy/rlpy
    def showDomain(self, a=0):
        s = self.state
        if self.domain_fig is None:
            plt.figure("Domain")
            self.domain_fig = plt.figure(
                1, (UAVLocation.SIZE * self.dist_between_locations + 1, self.NUM_UAV + 1))
            plt.show()
        plt.clf()
         # Draw the environment
         # Allocate horizontal 'lanes' for UAVs to traverse

        # Formerly, we checked if this was the first time plotting; wedge shapes cannot be removed from
        # matplotlib environment, nor can their properties be changed, without clearing the figure
        # Thus, we must redraw the figure on each timestep
        #        if self.location_rect_vis is None:
        # Figure with x width corresponding to number of location states, UAVLocation.SIZE
        # and rows (lanes) set aside in y for each UAV (NUM_UAV total lanes).
        # Add buffer of 1
        self.subplot_axes = self.domain_fig.add_axes(
            [0, 0, 1, 1], frameon=False, aspect=1.)
        crashLocationX = 2 * \
            (self.dist_between_locations) * (UAVLocation.SIZE - 1)
        self.subplot_axes.set_xlim(0, 1 + crashLocationX + self.RECT_GAP)
        self.subplot_axes.set_ylim(0, 1 + self.NUM_UAV)
        self.subplot_axes.xaxis.set_visible(False)
        self.subplot_axes.yaxis.set_visible(False)

        # Assign coordinates of each possible uav location on figure
        self.location_coord = [0.5 + (old_div(self.LOCATION_WIDTH, 2)) +
                               (self.dist_between_locations) * i for i in range(UAVLocation.SIZE - 1)]
        self.location_coord.append(crashLocationX + old_div(self.LOCATION_WIDTH, 2))

         # Create rectangular patches at each of those locations
        self.location_rect_vis = [mpatches.Rectangle(
            [0.5 + (self.dist_between_locations) * i,
             0],
            self.LOCATION_WIDTH,
            self.NUM_UAV * 2,
            fc='w') for i in range(UAVLocation.SIZE - 1)]
        self.location_rect_vis.append(
            mpatches.Rectangle([crashLocationX,
                                0],
                               self.LOCATION_WIDTH,
                               self.NUM_UAV * 2,
                               fc='w'))
        [self.subplot_axes.add_patch(self.location_rect_vis[i])
         for i in range(4)]
        self.comms_line = [lines.Line2D(
            [0.5 + self.LOCATION_WIDTH + (self.dist_between_locations) * i,
             0.5 + self.LOCATION_WIDTH + (
                 self.dist_between_locations) * i + self.RECT_GAP],
            [self.NUM_UAV * 0.5 + 0.5,
             self.NUM_UAV * 0.5 + 0.5],
            linewidth=3,
            color='black',
            visible=False) for i in range(UAVLocation.SIZE - 2)]
        self.comms_line.append(
            lines.Line2D(
                [0.5 + self.LOCATION_WIDTH + (self.dist_between_locations) * 2,
                 crashLocationX],
                [self.NUM_UAV * 0.5 + 0.5,
                 self.NUM_UAV * 0.5 + 0.5],
                linewidth=3,
                color='black',
                visible=False))

        # Create location text below rectangles
        locText = ["Base", "Refuel", "Communication", "Surveillance"]
        self.location_rect_txt = [plt.text(
            0.5 + self.dist_between_locations * i + 0.5 * self.LOCATION_WIDTH,
            -0.3,
            locText[i],
            ha='center') for i in range(UAVLocation.SIZE - 1)]
        self.location_rect_txt.append(
            plt.text(crashLocationX + 0.5 * self.LOCATION_WIDTH, -0.3,
                     locText[UAVLocation.SIZE - 1], ha='center'))

        # Initialize list of circle objects

        uav_x = self.location_coord[UAVLocation.BASE]

        # Update the member variables storing all the figure objects
        self.uav_circ_vis = [mpatches.Circle(
            (uav_x,
             1 + uav_id),
            self.UAV_RADIUS,
            fc="w") for uav_id in range(0,
                                        self.NUM_UAV)]
        self.uav_text_vis = [None for uav_id in range(0, self.NUM_UAV)]  # f**k
        self.uav_sensor_vis = [mpatches.Wedge(
            (uav_x + self.SENSOR_REL_X,
             1 + uav_id),
            self.SENSOR_LENGTH,
            -30,
            30) for uav_id in range(0,
                                    self.NUM_UAV)]
        self.uav_actuator_vis = [mpatches.Wedge(
            (uav_x,
             1 + uav_id + self.ACTUATOR_REL_Y),
            self.ACTUATOR_HEIGHT,
            60,
            120) for uav_id in range(0,
                                     self.NUM_UAV)]

        # The following was executed when we used to check if the environment needed re-drawing: see above.
                # Remove all UAV circle objects from visualization
        #        else:
        #            [self.uav_circ_vis[uav_id].remove() for uav_id in range(0,self.NUM_UAV)]
        #            [self.uav_text_vis[uav_id].remove() for uav_id in range(0,self.NUM_UAV)]
        #            [self.uav_sensor_vis[uav_id].remove() for uav_id in range(0,self.NUM_UAV)]

        # For each UAV:
        # Draw a circle, with text inside = amt fuel remaining
        # Triangle on top of UAV for comms, black = good, red = bad
        # Triangle in front of UAV for surveillance
        sStruct = self.state2Struct(s)

        for uav_id in range(0, self.NUM_UAV):
            # Assign all the variables corresponding to this UAV for this iteration;
            # this could alternately be done with a UAV class whose objects keep track
            # of these variables.  Elect to use lists here since ultimately the state
            # must be a vector anyway.
            # State index corresponding to the location of this uav
            uav_location = sStruct.locations[uav_id]
            uav_fuel = sStruct.fuel[uav_id]
            uav_sensor = sStruct.sensor[uav_id]
            uav_actuator = sStruct.actuator[uav_id]

            # Assign coordinates on figure where UAV should be drawn
            uav_x = self.location_coord[uav_location]
            uav_y = 1 + uav_id

            # Update plot wit this UAV
            self.uav_circ_vis[uav_id] = mpatches.Circle(
                (uav_x, uav_y), self.UAV_RADIUS, fc="w")
            self.uav_text_vis[uav_id] = plt.text(
                uav_x - 0.05,
                uav_y - 0.05,
                uav_fuel)
            if uav_sensor == SensorState.RUNNING:
                objColor = 'black'
            else:
                objColor = 'red'
            self.uav_sensor_vis[uav_id] = mpatches.Wedge(
                (uav_x + self.SENSOR_REL_X,
                 uav_y),
                self.SENSOR_LENGTH,
                -30,
                30,
                color=objColor)

            if uav_actuator == ActuatorState.RUNNING:
                objColor = 'black'
            else:
                objColor = 'red'
            self.uav_actuator_vis[uav_id] = mpatches.Wedge(
                (uav_x,
                 uav_y + self.ACTUATOR_REL_Y),
                self.ACTUATOR_HEIGHT,
                60,
                120,
                color=objColor)

            self.subplot_axes.add_patch(self.uav_circ_vis[uav_id])
            self.subplot_axes.add_patch(self.uav_sensor_vis[uav_id])
            self.subplot_axes.add_patch(self.uav_actuator_vis[uav_id])

        numHealthySurveil = np.sum(
            np.logical_and(
                sStruct.locations == UAVLocation.SURVEIL,
                sStruct.sensor))
        # We have comms coverage: draw a line between comms states to show this
        if (any(sStruct.locations == UAVLocation.COMMS)):
            for i in range(len(self.comms_line)):
                self.comms_line[i].set_visible(True)
                self.comms_line[i].set_color('black')
                self.subplot_axes.add_line(self.comms_line[i])
            # We also have UAVs in surveillance; color the comms line black
            if numHealthySurveil > 0:
                self.location_rect_vis[
                    len(self.location_rect_vis) - 1].set_color('green')
        plt.figure("Domain").canvas.draw()
        plt.figure("Domain").canvas.flush_events()
        sleep(0.5)
示例#32
0
                               kernel_args=[kernel_width],
                               active_threshold=active_threshold,
                               discover_threshold=discover_threshold,
                               normalization=False,
                               max_active_base_feat=100,
                               max_base_feat_sim=max_base_feat_sim)
    policy = SwimmerPolicy(representation)
    #policy = eGreedy(representation, epsilon=0.1)
    stat_bins_per_state_dim = 20
    # agent           = SARSA(representation,policy,domain,initial_learn_rate=initial_learn_rate,
    # lambda_=.0, learn_rate_decay_mode="boyan", boyan_N0=boyan_N0)
    opt["agent"] = SARSA(
        policy, representation, discount_factor=domain.discount_factor,
        lambda_=lambda_, initial_learn_rate=initial_learn_rate,
        learn_rate_decay_mode="boyan", boyan_N0=boyan_N0)
    experiment = Experiment(**opt)
    return experiment

if __name__ == '__main__':
    from rlpy.Tools.run import run_profiled
    # run_profiled(make_experiment)
    experiment = make_experiment(1)
    experiment.run(visualize_performance=1, visualize_learning=True)
    # experiment.plot()
    # experiment.save()
    from rlpy.Tools import plt
    plt.figure()
    for i in range(9):
        plt.plot(experiment.state_counts_learn[i], label="Dim " + str(i))
    plt.legend()
示例#33
0
    def showDomain(self, a=0):
        s = self.state
        plt.figure("Domain")

        if self.networkGraph is None:  # or self.networkPos is None:
            self.networkGraph = nx.Graph()
            # enumerate all computer_ids, simulatenously iterating through
            # neighbors list and compstatus
            for computer_id, (neighbors, compstatus) in enumerate(zip(self.NEIGHBORS, s)):
                # Add a node to network for each computer
                self.networkGraph.add_node(computer_id, node_color="w")
            for uniqueEdge in self.UNIQUE_EDGES:
                    self.networkGraph.add_edge(
                        uniqueEdge[0],
                        uniqueEdge[1],
                        edge_color="k")  # Add an edge between each neighbor
            self.networkPos = nx.circular_layout(self.networkGraph)
            nx.draw_networkx_nodes(
                self.networkGraph,
                self.networkPos,
                node_color="w")
            nx.draw_networkx_edges(
                self.networkGraph,
                self.networkPos,
                edges_color="k")
            nx.draw_networkx_labels(self.networkGraph, self.networkPos)
            plt.show()
        else:
            plt.clf()
            blackEdges = []
            redEdges = []
            greenNodes = []
            redNodes = []
            for computer_id, (neighbors, compstatus) in enumerate(zip(self.NEIGHBORS, s)):
                if(compstatus == self.RUNNING):
                    greenNodes.append(computer_id)
                else:
                    redNodes.append(computer_id)
            # Iterate through all unique edges
            for uniqueEdge in self.UNIQUE_EDGES:
                if(s[uniqueEdge[0]] == self.RUNNING and s[uniqueEdge[1]] == self.RUNNING):
                    # Then both computers are working
                    blackEdges.append(uniqueEdge)
                else:  # If either computer is BROKEN, make the edge red
                    redEdges.append(uniqueEdge)
            # "if redNodes", etc. - only draw things in the network if these lists aren't empty / null
            if redNodes:
                nx.draw_networkx_nodes(
                    self.networkGraph,
                    self.networkPos,
                    nodelist=redNodes,
                    node_color="r",
                    linewidths=2)
            if greenNodes:
                nx.draw_networkx_nodes(
                    self.networkGraph,
                    self.networkPos,
                    nodelist=greenNodes,
                    node_color="w",
                    linewidths=2)
            if blackEdges:
                nx.draw_networkx_edges(
                    self.networkGraph,
                    self.networkPos,
                    edgelist=blackEdges,
                    edge_color="k",
                    width=2,
                    style='solid')
            if redEdges:
                nx.draw_networkx_edges(
                    self.networkGraph,
                    self.networkPos,
                    edgelist=redEdges,
                    edge_color="k",
                    width=2,
                    style='dotted')
        nx.draw_networkx_labels(self.networkGraph, self.networkPos)
        plt.figure("Domain").canvas.draw()
        plt.figure("Domain").canvas.flush_events()
示例#34
0
    def showDomain(self, a):
        s = self.state
        # Plot the car and an arrow indicating the direction of accelaration
        # Parts of this code was adopted from Jose Antonio Martin H.
        # <*****@*****.**> online source code
        pos, vel = s
        if self.domain_fig is None:  # Need to initialize the figure
            self.domain_fig = plt.figure("Mountain Car Domain")
            # plot mountain
            mountain_x = np.linspace(self.XMIN, self.XMAX, 1000)
            mountain_y = np.sin(3 * mountain_x)
            plt.gca(
            ).fill_between(mountain_x,
                           min(mountain_y) - self.CAR_HEIGHT * 2,
                           mountain_y,
                           color='g')
            plt.xlim([self.XMIN - .2, self.XMAX])
            plt.ylim(
                [min(mountain_y) - self.CAR_HEIGHT * 2,
                 max(mountain_y) + self.CAR_HEIGHT * 2])
            # plot car
            self.car = lines.Line2D([], [], linewidth=20, color='b', alpha=.8)
            plt.gca().add_line(self.car)
            # Goal
            plt.plot(self.GOAL, np.sin(3 * self.GOAL), 'yd', markersize=10.0)
            plt.axis('off')
            plt.gca().set_aspect('1')
        self.domain_fig = plt.figure("Mountain Car Domain")
        #pos = 0
        #a = 0
        car_middle_x = pos
        car_middle_y = np.sin(3 * pos)
        slope = np.arctan(3 * np.cos(3 * pos))
        car_back_x = car_middle_x - self.CAR_WIDTH * np.cos(slope) / 2.
        car_front_x = car_middle_x + self.CAR_WIDTH * np.cos(slope) / 2.
        car_back_y = car_middle_y - self.CAR_WIDTH * np.sin(slope) / 2.
        car_front_y = car_middle_y + self.CAR_WIDTH * np.sin(slope) / 2.
        self.car.set_data([car_back_x, car_front_x], [car_back_y, car_front_y])
        # wheels
        # plott(x(1)-0.05,sin(3*(x(1)-0.05))+0.06,'ok','markersize',12,'MarkerFaceColor',[.5 .5 .5]);
        # plot(x(1)+0.05,sin(3*(x(1)+0.05))+0.06,'ok','markersize',12,'MarkerFaceColor',[.5 .5 .5]);
        # Arrows
        if self.actionArrow is not None:
            self.actionArrow.remove()
            self.actionArrow = None

        if self.actions[a] > 0:
            self.actionArrow = fromAtoB(
                car_front_x, car_front_y,
                car_front_x + self.ARROW_LENGTH *
                np.cos(slope), car_front_y +
                self.ARROW_LENGTH * np.sin(slope),
                #car_front_x + self.CAR_WIDTH*cos(slope)/2., car_front_y + self.CAR_WIDTH*sin(slope)/2.+self.CAR_HEIGHT,
                'k', "arc3,rad=0",
                0, 0, 'simple'
            )
        if self.actions[a] < 0:
            self.actionArrow = fromAtoB(
                car_back_x, car_back_y,
                car_back_x - self.ARROW_LENGTH *
                np.cos(slope), car_back_y -
                self.ARROW_LENGTH * np.sin(slope),
                #car_front_x + self.CAR_WIDTH*cos(slope)/2., car_front_y + self.CAR_WIDTH*sin(slope)/2.+self.CAR_HEIGHT,
                'r', "arc3,rad=0",
                0, 0, 'simple'
            )
        plt.draw()