コード例 #1
0
ファイル: RCCar.py プロジェクト: zhexiaozhe/rlpy
    def showDomain(self, a):
        s = self.state
        # Plot the car
        x, y, speed, heading = s
        car_xmin = x - self.REAR_WHEEL_RELATIVE_LOC
        car_ymin = y - self.CAR_WIDTH / 2.
        if self.domain_fig is None:  # Need to initialize the figure
            self.domain_fig = plt.figure()
            # Goal
            plt.gca().add_patch(
                plt.Circle(self.GOAL,
                           radius=self.GOAL_RADIUS,
                           color='g',
                           alpha=.4))
            plt.xlim([self.XMIN, self.XMAX])
            plt.ylim([self.YMIN, self.YMAX])
            plt.gca().set_aspect('1')
        # Car
        if self.car_fig is not None:
            plt.gca().patches.remove(self.car_fig)

        self.car_fig = mpatches.Rectangle([car_xmin, car_ymin],
                                          self.CAR_LENGTH,
                                          self.CAR_WIDTH,
                                          alpha=.4)
        rotation = mpl.transforms.Affine2D().rotate_deg_around(
            x, y, heading * 180 / np.pi) + plt.gca().transData
        self.car_fig.set_transform(rotation)
        plt.gca().add_patch(self.car_fig)

        plt.draw()
コード例 #2
0
    def showDomain(self, a):
        if self.gcf is None:
            self.gcf = plt.gcf()

        s = self.state
        # Plot the car
        x, y, speed, heading = s
        car_xmin = x - self.REAR_WHEEL_RELATIVE_LOC
        car_ymin = y - self.CAR_WIDTH / 2.
        if self.domain_fig is None:  # Need to initialize the figure
            self.domain_fig = plt.figure()
            # Goal
            plt.gca(
            ).add_patch(
                plt.Circle(
                    self.GOAL,
                    radius=self.GOAL_RADIUS,
                    color='g',
                    alpha=.4))
            plt.xlim([self.XMIN, self.XMAX])
            plt.ylim([self.YMIN, self.YMAX])
            plt.gca().set_aspect('1')
        # Car
        if self.car_fig is not None:
            plt.gca().patches.remove(self.car_fig)

        if self.slips:            
            slip_x, slip_y = zip(*self.slips)
            try:
                line = plt.axes().lines[0]
                if len(line.get_xdata()) != len(slip_x): # if plot has discrepancy from data
                    line.set_xdata(slip_x)
                    line.set_ydata(slip_y)
            except IndexError:
                plt.plot(slip_x, slip_y, 'x', color='b')

        self.car_fig = mpatches.Rectangle(
            [car_xmin,
             car_ymin],
            self.CAR_LENGTH,
            self.CAR_WIDTH,
            alpha=.4)
        rotation = mpl.transforms.Affine2D().rotate_deg_around(
            x, y, heading * 180 / np.pi) + plt.gca().transData
        self.car_fig.set_transform(rotation)
        plt.gca().add_patch(self.car_fig)

        plt.draw()
        # self.gcf.canvas.draw()
        plt.pause(0.001)
コード例 #3
0
ファイル: PST.py プロジェクト: amoliu/consumable-irl
    def showDomain(self, a=0):
        s = self.state
        if self.domain_fig is None:
            self.domain_fig = plt.figure(
                1, (UAVLocation.SIZE * self.dist_between_locations + 1,
                    self.NUM_UAV + 1))
            plt.show()
        plt.clf()
        # Draw the environment
        # Allocate horizontal 'lanes' for UAVs to traverse

        # Formerly, we checked if this was the first time plotting; wedge shapes cannot be removed from
        # matplotlib environment, nor can their properties be changed, without clearing the figure
        # Thus, we must redraw the figure on each timestep
        #        if self.location_rect_vis is None:
        # Figure with x width corresponding to number of location states, UAVLocation.SIZE
        # and rows (lanes) set aside in y for each UAV (NUM_UAV total lanes).
        # Add buffer of 1
        self.subplot_axes = self.domain_fig.add_axes([0, 0, 1, 1],
                                                     frameon=False,
                                                     aspect=1.)
        crashLocationX = 2 * \
            (self.dist_between_locations) * (UAVLocation.SIZE - 1)
        self.subplot_axes.set_xlim(0, 1 + crashLocationX + self.RECT_GAP)
        self.subplot_axes.set_ylim(0, 1 + self.NUM_UAV)
        self.subplot_axes.xaxis.set_visible(False)
        self.subplot_axes.yaxis.set_visible(False)

        # Assign coordinates of each possible uav location on figure
        self.location_coord = [
            0.5 + (self.LOCATION_WIDTH / 2) + (self.dist_between_locations) * i
            for i in range(UAVLocation.SIZE - 1)
        ]
        self.location_coord.append(crashLocationX + self.LOCATION_WIDTH / 2)

        # Create rectangular patches at each of those locations
        self.location_rect_vis = [
            mpatches.Rectangle([0.5 + (self.dist_between_locations) * i, 0],
                               self.LOCATION_WIDTH,
                               self.NUM_UAV * 2,
                               fc='w') for i in range(UAVLocation.SIZE - 1)
        ]
        self.location_rect_vis.append(
            mpatches.Rectangle([crashLocationX, 0],
                               self.LOCATION_WIDTH,
                               self.NUM_UAV * 2,
                               fc='w'))
        [
            self.subplot_axes.add_patch(self.location_rect_vis[i])
            for i in range(4)
        ]
        self.comms_line = [
            lines.Line2D([
                0.5 + self.LOCATION_WIDTH +
                (self.dist_between_locations) * i, 0.5 + self.LOCATION_WIDTH +
                (self.dist_between_locations) * i + self.RECT_GAP
            ], [self.NUM_UAV * 0.5 + 0.5, self.NUM_UAV * 0.5 + 0.5],
                         linewidth=3,
                         color='black',
                         visible=False) for i in range(UAVLocation.SIZE - 2)
        ]
        self.comms_line.append(
            lines.Line2D([
                0.5 + self.LOCATION_WIDTH +
                (self.dist_between_locations) * 2, crashLocationX
            ], [self.NUM_UAV * 0.5 + 0.5, self.NUM_UAV * 0.5 + 0.5],
                         linewidth=3,
                         color='black',
                         visible=False))

        # Create location text below rectangles
        locText = ["Base", "Refuel", "Communication", "Surveillance"]
        self.location_rect_txt = [
            plt.text(0.5 + self.dist_between_locations * i +
                     0.5 * self.LOCATION_WIDTH,
                     -0.3,
                     locText[i],
                     ha='center') for i in range(UAVLocation.SIZE - 1)
        ]
        self.location_rect_txt.append(
            plt.text(crashLocationX + 0.5 * self.LOCATION_WIDTH,
                     -0.3,
                     locText[UAVLocation.SIZE - 1],
                     ha='center'))

        # Initialize list of circle objects

        uav_x = self.location_coord[UAVLocation.BASE]

        # Update the member variables storing all the figure objects
        self.uav_circ_vis = [
            mpatches.Circle((uav_x, 1 + uav_id), self.UAV_RADIUS, fc="w")
            for uav_id in range(0, self.NUM_UAV)
        ]
        self.uav_text_vis = [None for uav_id in range(0, self.NUM_UAV)]  # f**k
        self.uav_sensor_vis = [
            mpatches.Wedge((uav_x + self.SENSOR_REL_X, 1 + uav_id),
                           self.SENSOR_LENGTH, -30, 30)
            for uav_id in range(0, self.NUM_UAV)
        ]
        self.uav_actuator_vis = [
            mpatches.Wedge((uav_x, 1 + uav_id + self.ACTUATOR_REL_Y),
                           self.ACTUATOR_HEIGHT, 60, 120)
            for uav_id in range(0, self.NUM_UAV)
        ]

        # The following was executed when we used to check if the environment needed re-drawing: see above.
        # Remove all UAV circle objects from visualization
        #        else:
        #            [self.uav_circ_vis[uav_id].remove() for uav_id in range(0,self.NUM_UAV)]
        #            [self.uav_text_vis[uav_id].remove() for uav_id in range(0,self.NUM_UAV)]
        #            [self.uav_sensor_vis[uav_id].remove() for uav_id in range(0,self.NUM_UAV)]

        # For each UAV:
        # Draw a circle, with text inside = amt fuel remaining
        # Triangle on top of UAV for comms, black = good, red = bad
        # Triangle in front of UAV for surveillance
        sStruct = self.state2Struct(s)

        for uav_id in range(0, self.NUM_UAV):
            # Assign all the variables corresponding to this UAV for this iteration;
            # this could alternately be done with a UAV class whose objects keep track
            # of these variables.  Elect to use lists here since ultimately the state
            # must be a vector anyway.
            # State index corresponding to the location of this uav
            uav_location = sStruct.locations[uav_id]
            uav_fuel = sStruct.fuel[uav_id]
            uav_sensor = sStruct.sensor[uav_id]
            uav_actuator = sStruct.actuator[uav_id]

            # Assign coordinates on figure where UAV should be drawn
            uav_x = self.location_coord[uav_location]
            uav_y = 1 + uav_id

            # Update plot wit this UAV
            self.uav_circ_vis[uav_id] = mpatches.Circle((uav_x, uav_y),
                                                        self.UAV_RADIUS,
                                                        fc="w")
            self.uav_text_vis[uav_id] = plt.text(uav_x - 0.05, uav_y - 0.05,
                                                 uav_fuel)
            if uav_sensor == SensorState.RUNNING:
                objColor = 'black'
            else:
                objColor = 'red'
            self.uav_sensor_vis[uav_id] = mpatches.Wedge(
                (uav_x + self.SENSOR_REL_X, uav_y),
                self.SENSOR_LENGTH,
                -30,
                30,
                color=objColor)

            if uav_actuator == ActuatorState.RUNNING:
                objColor = 'black'
            else:
                objColor = 'red'
            self.uav_actuator_vis[uav_id] = mpatches.Wedge(
                (uav_x, uav_y + self.ACTUATOR_REL_Y),
                self.ACTUATOR_HEIGHT,
                60,
                120,
                color=objColor)

            self.subplot_axes.add_patch(self.uav_circ_vis[uav_id])
            self.subplot_axes.add_patch(self.uav_sensor_vis[uav_id])
            self.subplot_axes.add_patch(self.uav_actuator_vis[uav_id])

        numHealthySurveil = np.sum(
            np.logical_and(sStruct.locations == UAVLocation.SURVEIL,
                           sStruct.sensor))
        # We have comms coverage: draw a line between comms states to show this
        if (any(sStruct.locations == UAVLocation.COMMS)):
            for i in xrange(len(self.comms_line)):
                self.comms_line[i].set_visible(True)
                self.comms_line[i].set_color('black')
                self.subplot_axes.add_line(self.comms_line[i])
            # We also have UAVs in surveillance; color the comms line black
            if numHealthySurveil > 0:
                self.location_rect_vis[len(self.location_rect_vis) -
                                       1].set_color('green')
        plt.draw()
        sleep(0.5)
コード例 #4
0
    def _plot_state(self, fourDimState, a):
        """
        :param fourDimState: Four-dimensional cartpole state
            (``theta, thetaDot, x, xDot``)
        :param a: force action on the cart

        Visualizes the state of the cartpole - the force action on the cart
        is displayed as an arrow (not including noise!)
        """
        s = fourDimState
        if (self.domain_fig is None or self.pendulumArm is None) or \
           (self.cartBox is None or self.cartBlob is None):  # Need to initialize the figure
            self.domain_fig = pl.figure("Domain")
            self.domain_ax = self.domain_fig.add_axes(
                [0, 0, 1, 1], frameon=True, aspect=1.)
            self.pendulumArm = lines.Line2D(
                [],
                [],
                linewidth=self.PEND_WIDTH,
                color='black')
            self.cartBox = mpatches.Rectangle(
                [0,
                 self.PENDULUM_PIVOT_Y - old_div(self.RECT_HEIGHT, 2.0)],
                self.RECT_WIDTH,
                self.RECT_HEIGHT,
                alpha=.4)
            self.cartBlob = mpatches.Rectangle(
                [0,
                 self.PENDULUM_PIVOT_Y - old_div(self.BLOB_WIDTH, 2.0)],
                self.BLOB_WIDTH,
                self.BLOB_WIDTH,
                alpha=.4)
            self.domain_ax.add_patch(self.cartBox)
            self.domain_ax.add_line(self.pendulumArm)
            self.domain_ax.add_patch(self.cartBlob)
            # Draw Ground
            groundPath = mpath.Path(self.GROUND_VERTS)
            groundPatch = mpatches.PathPatch(groundPath, hatch="//")
            self.domain_ax.add_patch(groundPatch)
            self.timeText = self.domain_ax.text(
                self.POSITION_LIMITS[1],
                self.LENGTH,
                "")
            self.rewardText = self.domain_ax.text(
                self.POSITION_LIMITS[0],
                self.LENGTH,
                "")
            # Allow room for pendulum to swing without getting cut off on graph
            viewableDistance = self.LENGTH + 0.5
            if self.POSITION_LIMITS[0] < -100 * self.LENGTH or self.POSITION_LIMITS[1] > 100 * self.LENGTH:
                # We have huge position limits, limit the figure width so
                # cart is still visible
                self.domain_ax.set_xlim(-viewableDistance, viewableDistance)
            else:
                self.domain_ax.set_xlim(
                    self.POSITION_LIMITS[0] - viewableDistance,
                    self.POSITION_LIMITS[1] + viewableDistance)
            self.domain_ax.set_ylim(-viewableDistance, viewableDistance)
            # self.domain_ax.set_aspect('equal')

            pl.show()

        forceAction = self.AVAIL_FORCE[a]
        curX = s[StateIndex.X]
        curTheta = s[StateIndex.THETA]

        pendulumBobX = curX + self.LENGTH * np.sin(curTheta)
        pendulumBobY = self.PENDULUM_PIVOT_Y + self.LENGTH * np.cos(curTheta)

        if self.DEBUG:
            print('Pendulum Position: ', pendulumBobX, pendulumBobY)

        # update pendulum arm on figure
        self.pendulumArm.set_data(
            [curX, pendulumBobX], [self.PENDULUM_PIVOT_Y, pendulumBobY])
        self.cartBox.set_x(curX - old_div(self.RECT_WIDTH, 2.0))
        self.cartBlob.set_x(curX - old_div(self.BLOB_WIDTH, 2.0))

        if self.actionArrow is not None:
            self.actionArrow.remove()
            self.actionArrow = None

        if forceAction == 0:
            pass  # no force
        else:  # cw or ccw torque
            if forceAction > 0:  # rightward force
                self.actionArrow = fromAtoB(
                    curX - self.ACTION_ARROW_LENGTH - old_div(self.RECT_WIDTH, 2.0), 0,
                    curX - old_div(self.RECT_WIDTH, 2.0), 0,
                    'k', "arc3,rad=0",
                    0, 0, 'simple', ax=self.domain_ax
                )
            else:  # leftward force
                self.actionArrow = fromAtoB(
                    curX + self.ACTION_ARROW_LENGTH + old_div(self.RECT_WIDTH, 2.0), 0,
                    curX + old_div(self.RECT_WIDTH, 2.0), 0,
                    'r', "arc3,rad=0",
                    0, 0, 'simple', ax=self.domain_ax
                )
        self.domain_fig.canvas.draw()