Пример #1
0
    def showDomain(self, a=0):
        s = self.state
        if self.domain_fig is None:
            self.domain_fig = plt.figure(
                1, (UAVLocation.SIZE * self.dist_between_locations + 1,
                    self.NUM_UAV + 1))
            plt.show()
        plt.clf()
        # Draw the environment
        # Allocate horizontal 'lanes' for UAVs to traverse

        # Formerly, we checked if this was the first time plotting; wedge shapes cannot be removed from
        # matplotlib environment, nor can their properties be changed, without clearing the figure
        # Thus, we must redraw the figure on each timestep
        #        if self.location_rect_vis is None:
        # Figure with x width corresponding to number of location states, UAVLocation.SIZE
        # and rows (lanes) set aside in y for each UAV (NUM_UAV total lanes).
        # Add buffer of 1
        self.subplot_axes = self.domain_fig.add_axes([0, 0, 1, 1],
                                                     frameon=False,
                                                     aspect=1.)
        crashLocationX = 2 * \
            (self.dist_between_locations) * (UAVLocation.SIZE - 1)
        self.subplot_axes.set_xlim(0, 1 + crashLocationX + self.RECT_GAP)
        self.subplot_axes.set_ylim(0, 1 + self.NUM_UAV)
        self.subplot_axes.xaxis.set_visible(False)
        self.subplot_axes.yaxis.set_visible(False)

        # Assign coordinates of each possible uav location on figure
        self.location_coord = [
            0.5 + (self.LOCATION_WIDTH / 2) + (self.dist_between_locations) * i
            for i in range(UAVLocation.SIZE - 1)
        ]
        self.location_coord.append(crashLocationX + self.LOCATION_WIDTH / 2)

        # Create rectangular patches at each of those locations
        self.location_rect_vis = [
            mpatches.Rectangle([0.5 + (self.dist_between_locations) * i, 0],
                               self.LOCATION_WIDTH,
                               self.NUM_UAV * 2,
                               fc='w') for i in range(UAVLocation.SIZE - 1)
        ]
        self.location_rect_vis.append(
            mpatches.Rectangle([crashLocationX, 0],
                               self.LOCATION_WIDTH,
                               self.NUM_UAV * 2,
                               fc='w'))
        [
            self.subplot_axes.add_patch(self.location_rect_vis[i])
            for i in range(4)
        ]
        self.comms_line = [
            lines.Line2D([
                0.5 + self.LOCATION_WIDTH +
                (self.dist_between_locations) * i, 0.5 + self.LOCATION_WIDTH +
                (self.dist_between_locations) * i + self.RECT_GAP
            ], [self.NUM_UAV * 0.5 + 0.5, self.NUM_UAV * 0.5 + 0.5],
                         linewidth=3,
                         color='black',
                         visible=False) for i in range(UAVLocation.SIZE - 2)
        ]
        self.comms_line.append(
            lines.Line2D([
                0.5 + self.LOCATION_WIDTH +
                (self.dist_between_locations) * 2, crashLocationX
            ], [self.NUM_UAV * 0.5 + 0.5, self.NUM_UAV * 0.5 + 0.5],
                         linewidth=3,
                         color='black',
                         visible=False))

        # Create location text below rectangles
        locText = ["Base", "Refuel", "Communication", "Surveillance"]
        self.location_rect_txt = [
            plt.text(0.5 + self.dist_between_locations * i +
                     0.5 * self.LOCATION_WIDTH,
                     -0.3,
                     locText[i],
                     ha='center') for i in range(UAVLocation.SIZE - 1)
        ]
        self.location_rect_txt.append(
            plt.text(crashLocationX + 0.5 * self.LOCATION_WIDTH,
                     -0.3,
                     locText[UAVLocation.SIZE - 1],
                     ha='center'))

        # Initialize list of circle objects

        uav_x = self.location_coord[UAVLocation.BASE]

        # Update the member variables storing all the figure objects
        self.uav_circ_vis = [
            mpatches.Circle((uav_x, 1 + uav_id), self.UAV_RADIUS, fc="w")
            for uav_id in range(0, self.NUM_UAV)
        ]
        self.uav_text_vis = [None for uav_id in range(0, self.NUM_UAV)]  # f**k
        self.uav_sensor_vis = [
            mpatches.Wedge((uav_x + self.SENSOR_REL_X, 1 + uav_id),
                           self.SENSOR_LENGTH, -30, 30)
            for uav_id in range(0, self.NUM_UAV)
        ]
        self.uav_actuator_vis = [
            mpatches.Wedge((uav_x, 1 + uav_id + self.ACTUATOR_REL_Y),
                           self.ACTUATOR_HEIGHT, 60, 120)
            for uav_id in range(0, self.NUM_UAV)
        ]

        # The following was executed when we used to check if the environment needed re-drawing: see above.
        # Remove all UAV circle objects from visualization
        #        else:
        #            [self.uav_circ_vis[uav_id].remove() for uav_id in range(0,self.NUM_UAV)]
        #            [self.uav_text_vis[uav_id].remove() for uav_id in range(0,self.NUM_UAV)]
        #            [self.uav_sensor_vis[uav_id].remove() for uav_id in range(0,self.NUM_UAV)]

        # For each UAV:
        # Draw a circle, with text inside = amt fuel remaining
        # Triangle on top of UAV for comms, black = good, red = bad
        # Triangle in front of UAV for surveillance
        sStruct = self.state2Struct(s)

        for uav_id in range(0, self.NUM_UAV):
            # Assign all the variables corresponding to this UAV for this iteration;
            # this could alternately be done with a UAV class whose objects keep track
            # of these variables.  Elect to use lists here since ultimately the state
            # must be a vector anyway.
            # State index corresponding to the location of this uav
            uav_location = sStruct.locations[uav_id]
            uav_fuel = sStruct.fuel[uav_id]
            uav_sensor = sStruct.sensor[uav_id]
            uav_actuator = sStruct.actuator[uav_id]

            # Assign coordinates on figure where UAV should be drawn
            uav_x = self.location_coord[uav_location]
            uav_y = 1 + uav_id

            # Update plot wit this UAV
            self.uav_circ_vis[uav_id] = mpatches.Circle((uav_x, uav_y),
                                                        self.UAV_RADIUS,
                                                        fc="w")
            self.uav_text_vis[uav_id] = plt.text(uav_x - 0.05, uav_y - 0.05,
                                                 uav_fuel)
            if uav_sensor == SensorState.RUNNING:
                objColor = 'black'
            else:
                objColor = 'red'
            self.uav_sensor_vis[uav_id] = mpatches.Wedge(
                (uav_x + self.SENSOR_REL_X, uav_y),
                self.SENSOR_LENGTH,
                -30,
                30,
                color=objColor)

            if uav_actuator == ActuatorState.RUNNING:
                objColor = 'black'
            else:
                objColor = 'red'
            self.uav_actuator_vis[uav_id] = mpatches.Wedge(
                (uav_x, uav_y + self.ACTUATOR_REL_Y),
                self.ACTUATOR_HEIGHT,
                60,
                120,
                color=objColor)

            self.subplot_axes.add_patch(self.uav_circ_vis[uav_id])
            self.subplot_axes.add_patch(self.uav_sensor_vis[uav_id])
            self.subplot_axes.add_patch(self.uav_actuator_vis[uav_id])

        numHealthySurveil = np.sum(
            np.logical_and(sStruct.locations == UAVLocation.SURVEIL,
                           sStruct.sensor))
        # We have comms coverage: draw a line between comms states to show this
        if (any(sStruct.locations == UAVLocation.COMMS)):
            for i in xrange(len(self.comms_line)):
                self.comms_line[i].set_visible(True)
                self.comms_line[i].set_color('black')
                self.subplot_axes.add_line(self.comms_line[i])
            # We also have UAVs in surveillance; color the comms line black
            if numHealthySurveil > 0:
                self.location_rect_vis[len(self.location_rect_vis) -
                                       1].set_color('green')
        plt.draw()
        sleep(0.5)
Пример #2
0
    def showDomain(self, a=0):
        """
        Plot the 2 links + action arrows
        """
        s = self.state
        if self.domain_fig is None:  # Need to initialize the figure
            self.domain_fig = plt.gcf()
            self.domain_ax = self.domain_fig.add_axes([0, 0, 1, 1],
                                                      frameon=True,
                                                      aspect=1.)
            ax = self.domain_ax
            self.link1 = lines.Line2D([], [], linewidth=2, color='black')
            self.link2 = lines.Line2D([], [], linewidth=2, color='blue')
            ax.add_line(self.link1)
            ax.add_line(self.link2)

            # Allow room for pendulum to swing without getting cut off on graph
            viewable_distance = self.LINK_LENGTH_1 + self.LINK_LENGTH_2 + 0.5
            ax.set_xlim(-viewable_distance, +viewable_distance)
            ax.set_ylim(-viewable_distance, viewable_distance)
            # add bar
            bar = lines.Line2D([-viewable_distance, viewable_distance],
                               [self.LINK_LENGTH_1, self.LINK_LENGTH_1],
                               linewidth=1,
                               color='red')
            ax.add_line(bar)
            # ax.set_aspect('equal')

            plt.show()

        if self.action_arrow is not None:
            self.action_arrow.remove()
            self.action_arrow = None

        torque = self.AVAIL_TORQUE[a]
        SHIFT = .5
        if torque > 0:  # counterclockwise torque
            self.action_arrow = fromAtoB(SHIFT / 2.0,
                                         .5 * SHIFT,
                                         -SHIFT / 2.0,
                                         -.5 * SHIFT,
                                         'k',
                                         connectionstyle="arc3,rad=+1.2",
                                         ax=self.domain_ax)
        elif torque < 0:  # clockwise torque
            self.action_arrow = fromAtoB(-SHIFT / 2.0,
                                         .5 * SHIFT,
                                         +SHIFT / 2.0,
                                         -.5 * SHIFT,
                                         'r',
                                         connectionstyle="arc3,rad=-1.2",
                                         ax=self.domain_ax)

        # update pendulum arm on figure
        p1 = [
            -self.LINK_LENGTH_1 * np.cos(s[0]),
            self.LINK_LENGTH_1 * np.sin(s[0])
        ]

        self.link1.set_data([0., p1[1]], [0., p1[0]])
        p2 = [
            p1[0] - self.LINK_LENGTH_2 * np.cos(s[0] + s[1]),
            p1[1] + self.LINK_LENGTH_2 * np.sin(s[0] + s[1])
        ]
        self.link2.set_data([p1[1], p2[1]], [p1[0], p2[0]])
        plt.draw()
Пример #3
0
    def showDomain(self, a):
        s = self.state
        # Plot the car and an arrow indicating the direction of accelaration
        # Parts of this code was adopted from Jose Antonio Martin H.
        # <*****@*****.**> online source code
        pos, vel = s
        if self.domain_fig is None:  # Need to initialize the figure
            self.domain_fig = plt.figure("Mountain Car Domain")
            # plot mountain
            mountain_x = np.linspace(self.XMIN, self.XMAX, 1000)
            mountain_y = np.sin(3 * mountain_x)
            plt.gca(
            ).fill_between(mountain_x,
                           min(mountain_y) - self.CAR_HEIGHT * 2,
                           mountain_y,
                           color='g')
            plt.xlim([self.XMIN - .2, self.XMAX])
            plt.ylim(
                [min(mountain_y) - self.CAR_HEIGHT * 2,
                 max(mountain_y) + self.CAR_HEIGHT * 2])
            # plot car
            self.car = lines.Line2D([], [], linewidth=20, color='b', alpha=.8)
            plt.gca().add_line(self.car)
            # Goal
            plt.plot(self.GOAL, np.sin(3 * self.GOAL), 'yd', markersize=10.0)
            plt.axis('off')
            plt.gca().set_aspect('1')
        self.domain_fig = plt.figure("Mountain Car Domain")
        #pos = 0
        #a = 0
        car_middle_x = pos
        car_middle_y = np.sin(3 * pos)
        slope = np.arctan(3 * np.cos(3 * pos))
        car_back_x = car_middle_x - self.CAR_WIDTH * np.cos(slope) / 2.
        car_front_x = car_middle_x + self.CAR_WIDTH * np.cos(slope) / 2.
        car_back_y = car_middle_y - self.CAR_WIDTH * np.sin(slope) / 2.
        car_front_y = car_middle_y + self.CAR_WIDTH * np.sin(slope) / 2.
        self.car.set_data([car_back_x, car_front_x], [car_back_y, car_front_y])
        # wheels
        # plott(x(1)-0.05,sin(3*(x(1)-0.05))+0.06,'ok','markersize',12,'MarkerFaceColor',[.5 .5 .5]);
        # plot(x(1)+0.05,sin(3*(x(1)+0.05))+0.06,'ok','markersize',12,'MarkerFaceColor',[.5 .5 .5]);
        # Arrows
        if self.actionArrow is not None:
            self.actionArrow.remove()
            self.actionArrow = None

        if self.actions[a] > 0:
            self.actionArrow = fromAtoB(
                car_front_x, car_front_y,
                car_front_x + self.ARROW_LENGTH *
                np.cos(slope), car_front_y +
                self.ARROW_LENGTH * np.sin(slope),
                #car_front_x + self.CAR_WIDTH*cos(slope)/2., car_front_y + self.CAR_WIDTH*sin(slope)/2.+self.CAR_HEIGHT,
                'k', "arc3,rad=0",
                0, 0, 'simple'
            )
        if self.actions[a] < 0:
            self.actionArrow = fromAtoB(
                car_back_x, car_back_y,
                car_back_x - self.ARROW_LENGTH *
                np.cos(slope), car_back_y -
                self.ARROW_LENGTH * np.sin(slope),
                #car_front_x + self.CAR_WIDTH*cos(slope)/2., car_front_y + self.CAR_WIDTH*sin(slope)/2.+self.CAR_HEIGHT,
                'r', "arc3,rad=0",
                0, 0, 'simple'
            )
        plt.draw()
Пример #4
0
    def _plot_state(self, fourDimState, a):
        """
        :param fourDimState: Four-dimensional cartpole state
            (``theta, thetaDot, x, xDot``)
        :param a: force action on the cart

        Visualizes the state of the cartpole - the force action on the cart
        is displayed as an arrow (not including noise!)
        """
        s = fourDimState
        if (self.domain_fig is None or self.pendulumArm is None) or \
           (self.cartBox is None or self.cartBlob is None):  # Need to initialize the figure
            self.domain_fig = pl.figure("Domain")
            self.domain_ax = self.domain_fig.add_axes(
                [0, 0, 1, 1], frameon=True, aspect=1.)
            self.pendulumArm = lines.Line2D(
                [],
                [],
                linewidth=self.PEND_WIDTH,
                color='black')
            self.cartBox = mpatches.Rectangle(
                [0,
                 self.PENDULUM_PIVOT_Y - old_div(self.RECT_HEIGHT, 2.0)],
                self.RECT_WIDTH,
                self.RECT_HEIGHT,
                alpha=.4)
            self.cartBlob = mpatches.Rectangle(
                [0,
                 self.PENDULUM_PIVOT_Y - old_div(self.BLOB_WIDTH, 2.0)],
                self.BLOB_WIDTH,
                self.BLOB_WIDTH,
                alpha=.4)
            self.domain_ax.add_patch(self.cartBox)
            self.domain_ax.add_line(self.pendulumArm)
            self.domain_ax.add_patch(self.cartBlob)
            # Draw Ground
            groundPath = mpath.Path(self.GROUND_VERTS)
            groundPatch = mpatches.PathPatch(groundPath, hatch="//")
            self.domain_ax.add_patch(groundPatch)
            self.timeText = self.domain_ax.text(
                self.POSITION_LIMITS[1],
                self.LENGTH,
                "")
            self.rewardText = self.domain_ax.text(
                self.POSITION_LIMITS[0],
                self.LENGTH,
                "")
            # Allow room for pendulum to swing without getting cut off on graph
            viewableDistance = self.LENGTH + 0.5
            if self.POSITION_LIMITS[0] < -100 * self.LENGTH or self.POSITION_LIMITS[1] > 100 * self.LENGTH:
                # We have huge position limits, limit the figure width so
                # cart is still visible
                self.domain_ax.set_xlim(-viewableDistance, viewableDistance)
            else:
                self.domain_ax.set_xlim(
                    self.POSITION_LIMITS[0] - viewableDistance,
                    self.POSITION_LIMITS[1] + viewableDistance)
            self.domain_ax.set_ylim(-viewableDistance, viewableDistance)
            # self.domain_ax.set_aspect('equal')

            pl.show()

        forceAction = self.AVAIL_FORCE[a]
        curX = s[StateIndex.X]
        curTheta = s[StateIndex.THETA]

        pendulumBobX = curX + self.LENGTH * np.sin(curTheta)
        pendulumBobY = self.PENDULUM_PIVOT_Y + self.LENGTH * np.cos(curTheta)

        if self.DEBUG:
            print('Pendulum Position: ', pendulumBobX, pendulumBobY)

        # update pendulum arm on figure
        self.pendulumArm.set_data(
            [curX, pendulumBobX], [self.PENDULUM_PIVOT_Y, pendulumBobY])
        self.cartBox.set_x(curX - old_div(self.RECT_WIDTH, 2.0))
        self.cartBlob.set_x(curX - old_div(self.BLOB_WIDTH, 2.0))

        if self.actionArrow is not None:
            self.actionArrow.remove()
            self.actionArrow = None

        if forceAction == 0:
            pass  # no force
        else:  # cw or ccw torque
            if forceAction > 0:  # rightward force
                self.actionArrow = fromAtoB(
                    curX - self.ACTION_ARROW_LENGTH - old_div(self.RECT_WIDTH, 2.0), 0,
                    curX - old_div(self.RECT_WIDTH, 2.0), 0,
                    'k', "arc3,rad=0",
                    0, 0, 'simple', ax=self.domain_ax
                )
            else:  # leftward force
                self.actionArrow = fromAtoB(
                    curX + self.ACTION_ARROW_LENGTH + old_div(self.RECT_WIDTH, 2.0), 0,
                    curX + old_div(self.RECT_WIDTH, 2.0), 0,
                    'r', "arc3,rad=0",
                    0, 0, 'simple', ax=self.domain_ax
                )
        self.domain_fig.canvas.draw()