Esempio n. 1
0
 def show_domain(self, a=0):
     s = self.state
     # Draw the environment
     fig = plt.figure("FiftyChain")
     if self.circles is None:
         self.domain_fig = plt.subplot(3, 1, 1)
         plt.figure(1, (self.chain_size * 2 / 10.0, 2))
         self.domain_fig.set_xlim(0, self.chain_size * 2 / 10.0)
         self.domain_fig.set_ylim(0, 2)
         # Make the last one double circle
         self.domain_fig.add_patch(
             mpatches.Circle(
                 (0.2 + 2 / 10.0 * (self.chain_size - 1), self.Y),
                 self.RADIUS * 1.1,
                 fc="w",
             ))
         self.domain_fig.xaxis.set_visible(False)
         self.domain_fig.yaxis.set_visible(False)
         self.circles = [
             mpatches.Circle((0.2 + 2 / 10.0 * i, self.Y),
                             self.RADIUS,
                             fc="w") for i in range(self.chain_size)
         ]
         for i in range(self.chain_size):
             self.domain_fig.add_patch(self.circles[i])
             plt.show()
     for p in self.circles:
         p.set_facecolor("w")
     for p in self.GOAL_STATES:
         self.circles[p].set_facecolor("g")
     self.circles[s].set_facecolor("k")
     fig.canvas.draw()
     fig.canvas.flush_events()
Esempio n. 2
0
    def _plot_valfun(self, VMat, xlim=None, ylim=None):
        """
        :returns: handle to the figure
        """
        plt.figure("Value Function")
        # pl.xticks(self.xTicks,self.xTicksLabels, fontsize=12)
        # pl.yticks(self.yTicks,self.yTicksLabels, fontsize=12)
        # pl.xlabel(r"$\theta$ (degree)")
        # pl.ylabel(r"$\dot{\theta}$ (degree/sec)")
        plt.title("Value Function")
        if xlim is not None and ylim is not None:
            extent = [xlim[0], xlim[1], ylim[0], ylim[1]]
        else:
            extent = [0, 1, 0, 1]
        self.valueFunction_fig = plt.imshow(
            VMat,
            cmap="ValueFunction",
            interpolation="nearest",
            origin="lower",
            extent=extent,
        )

        norm = colors.Normalize(vmin=VMat.min(), vmax=VMat.max())
        self.valueFunction_fig.set_data(VMat)
        self.valueFunction_fig.set_norm(norm)
        plt.draw()
Esempio n. 3
0
 def _plot_impl(self,
                y="return",
                x="learning_steps",
                save=False,
                show=True):
     labels = rlpy.tools.results.default_labels
     performance_fig = plt.figure("Performance")
     res = self.result
     plt.plot(res[x], res[y], lw=2, markersize=4, marker=MARKERS[0])
     plt.xlim(0, res[x][-1] * 1.01)
     y_arr = np.array(res[y])
     m = y_arr.min()
     M = y_arr.max()
     delta = M - m
     if delta > 0:
         plt.ylim(m - 0.1 * delta - 0.1, M + 0.1 * delta + 0.1)
     xlabel = labels[x] if x in labels else x
     ylabel = labels[y] if y in labels else y
     plt.xlabel(xlabel, fontsize=16)
     plt.ylabel(ylabel, fontsize=16)
     if save:
         path = os.path.join(self.full_path,
                             "{:03}-performance.pdf".format(self.exp_id))
         performance_fig.savefig(path, transparent=True, pad_inches=0.1)
     if show:
         plt.ioff()
         plt.show()
Esempio n. 4
0
    def show_domain(self, a):
        s = self.state
        # Plot the car
        x, y, speed, heading = s
        car_xmin = x - self.REAR_WHEEL_RELATIVE_LOC
        car_ymin = y - self.CAR_WIDTH / 2
        if self.domain_fig is None:  # Need to initialize the figure
            self.domain_fig = plt.figure()
            # Goal
            plt.gca().add_patch(
                plt.Circle(self.GOAL,
                           radius=self.GOAL_RADIUS,
                           color="g",
                           alpha=0.4))
            plt.xlim([self.XMIN, self.XMAX])
            plt.ylim([self.YMIN, self.YMAX])
            plt.gca().set_aspect("1")
        # Car
        if self.car_fig is not None:
            plt.gca().patches.remove(self.car_fig)

        self.car_fig = mpatches.Rectangle([car_xmin, car_ymin],
                                          self.CAR_LENGTH,
                                          self.CAR_WIDTH,
                                          alpha=0.4)
        rotation = (mpl.transforms.Affine2D().rotate_deg_around(
            x, y, heading * 180 / np.pi) + plt.gca().transData)
        self.car_fig.set_transform(rotation)
        plt.gca().add_patch(self.car_fig)

        plt.draw()
Esempio n. 5
0
 def _init_domain_vis(self):
     self.domain_fig = plt.figure("MountainCar")
     self.domain_ax = self.domain_fig.add_subplot(111)
     # plot mountain
     mountain_x = np.linspace(self.X_MIN, self.X_MAX, 1000)
     mountain_y = np.sin(3 * mountain_x)
     self.domain_ax.fill_between(mountain_x,
                                 min(mountain_y) - self.CAR_HEIGHT * 2,
                                 mountain_y,
                                 color="g")
     self.domain_ax.set_xlim([self.X_MIN - 0.2, self.X_MAX])
     self.domain_ax.set_ylim([
         min(mountain_y) - self.CAR_HEIGHT * 2,
         max(mountain_y) + self.CAR_HEIGHT * 2,
     ])
     # plot car
     self.car = lines.Line2D([], [], linewidth=20, color="b", alpha=0.8)
     self.domain_ax.add_line(self.car)
     # Goal
     self.domain_ax.plot(self.GOAL,
                         np.sin(3 * self.GOAL),
                         "yd",
                         markersize=10.0)
     self.domain_ax.axis("off")
     self.domain_fig.show()
Esempio n. 6
0
    def _plot_valfun(self, VMat):
        """
        :returns: handle to the figure

        .. warning::
            The calling function MUST call plt.draw() or the figures
            will not be updated.

        """
        if self.value_fn_fig is None or self.value_fn_img is None:
            maxV = VMat.max()
            minV = VMat.min()
            self.value_fn_fig = plt.figure("CartPole Value Function")
            self.value_fn_ax = self.value_fn_fig.add_subplot(111)
            self.value_fn_img = self.value_fn_ax.imshow(
                VMat,
                cmap="ValueFunction",
                interpolation="nearest",
                origin="lower",
                vmin=minV,
                vmax=maxV,
            )
            self._init_ticks_common(self.value_fn_ax)
            self.value_fn_ax.set_title("CartPole Value Function")

        norm = colors.Normalize(vmin=VMat.min(), vmax=VMat.max())
        self.value_fn_img.set_data(VMat)
        self.value_fn_img.set_norm(norm)
        self.value_fn_fig.canvas.draw()
Esempio n. 7
0
    def _plot_policy(self, piMat):
        """
        :returns: handle to the figure

        .. warning::
            The calling function MUST call plt.draw() or the figures
            will not be updated.

        """
        if self.policy_fig is None:
            self.policy_fig = plt.figure("CartPole Policy")
            self.policy_ax = self.policy_fig.add_subplot(1, 1, 1)
            self.policy_img = self.policy_ax.imshow(
                piMat,
                cmap="InvertedPendulumActions",
                interpolation="nearest",
                origin="lower",
                vmin=0,
                vmax=self.num_actions,
            )
            self._init_ticks_common(self.policy_ax)
            self.policy_ax.set_title("CartPole Policy")

        self.policy_img.set_data(piMat)
        self.policy_fig.canvas.draw()
Esempio n. 8
0
    def show_domain(self, a=0, s=None):
        """
        shows a live graph of each concentration
        """
        # only update the graph every couple of steps, otherwise it is
        # extremely slow
        if self.t % self.show_domain_every != 0 and not self.t >= self.episode_cap:
            return

        n = self.state_space_dims + 1
        names = list(self.state_names) + ["Action"]
        colors = ["b", "b", "b", "b", "r", "g", "k"]
        handles = getattr(self, "_state_graph_handles", None)
        fig = plt.figure("HIVTreatment", figsize=(12, 10))
        if handles is None:
            handles = []
            f, axes = plt.subplots(n, sharex=True, num="HIVTreatment", figsize=(12, 10))
            f.subplots_adjust(hspace=0.1)
            for i in range(n):
                ax = axes[i]
                d = np.arange(self.episode_cap + 1) * 5
                ax.set_ylabel(names[i])
                ax.locator_params(tight=True, nbins=4)
                handles.append(ax.plot(d, self.episode_data[i], color=colors[i])[0])
            self._state_graph_handles = handles
            ax.set_xlabel("Days")
        for i in range(n):
            handles[i].set_ydata(self.episode_data[i])
            ax = handles[i].axes
            ax.relim()
            ax.autoscale_view()
        fig.canvas.draw()
        fig.canvas.flush_events()
Esempio n. 9
0
    def _plot_policy(self,
                     piMat,
                     title="Policy",
                     var="policy_fig",
                     xlim=None,
                     ylim=None):
        """
        :returns: handle to the figure
        """

        if getattr(self, var, None) is None:
            plt.figure(title)
            # define the colormap
            cmap = plt.cm.jet
            # extract all colors from the .jet map
            cmaplist = [cmap(i) for i in range(cmap.N)]
            # force the first color entry to be grey
            cmaplist[0] = (0.5, 0.5, 0.5, 1.0)
            # create the new map
            cmap = cmap.from_list("Custom cmap", cmaplist, cmap.N)

            # define the bins and normalize
            bounds = np.linspace(0, self.num_actions, self.num_actions + 1)
            norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
            if xlim is not None and ylim is not None:
                extent = [xlim[0], xlim[1], ylim[0], ylim[1]]
            else:
                extent = [0, 1, 0, 1]
            self.__dict__[var] = plt.imshow(
                piMat,
                interpolation="nearest",
                origin="lower",
                cmap=cmap,
                norm=norm,
                extent=extent,
            )
            # pl.xticks(self.xTicks,self.xTicksLabels, fontsize=12)
            # pl.yticks(self.yTicks,self.yTicksLabels, fontsize=12)
            # pl.xlabel(r"$\theta$ (degree)")
            # pl.ylabel(r"$\dot{\theta}$ (degree/sec)")
            plt.title(title)

            plt.colorbar()
        plt.figure(title)
        self.__dict__[var].set_data(piMat)
        plt.draw()
Esempio n. 10
0
 def _init_vf_vis(self):
     fig = plt.figure("Value Function")
     self.vf_ax = fig.add_subplot(111, projection="3d")
     x_space = np.linspace(self.X_MIN, self.X_MAX, self.X_DISCR)
     xdot_space = np.linspace(self.XDOT_MIN, self.XDOT_MAX, self.XDOT_DISCR)
     self.vf_x, self.vf_xdot = np.meshgrid(x_space, xdot_space)
     self.vf_ax.set_xlabel(r"$x$")
     self.vf_ax.set_ylabel(r"$\dot x$")
     return fig
Esempio n. 11
0
 def show_domain(self, a=0):
     # Draw the environment
     s = self.state
     world = np.zeros((self.blocks, self.blocks), "uint8")
     undrawn_blocks = np.arange(self.blocks)
     while len(undrawn_blocks):
         A = undrawn_blocks[0]
         B = s[A]
         undrawn_blocks = undrawn_blocks[1:]
         if B == A:  # => A is on Table
             world[0, A] = A + 1  # 0 is white thats why!
         else:
             # See if B is already drawn
             i, j = findElemArray2D(B + 1, world)
             if len(i):
                 world[i + 1, j] = A + 1  # 0 is white thats why!
             else:
                 # Put it in the back of the list
                 undrawn_blocks = np.hstack((undrawn_blocks, [A]))
     if self.domain_fig is None:
         plt.figure("Domain")
         self.domain_fig = plt.imshow(
             world,
             cmap="BlocksWorld",
             origin="lower",
             interpolation="nearest")  # ,vmin=0,vmax=self.blocks)
         plt.xticks(np.arange(self.blocks), fontsize=FONTSIZE)
         plt.yticks(np.arange(self.blocks), fontsize=FONTSIZE)
         # pl.tight_layout()
         plt.axis("off")
         plt.show()
     else:
         self.domain_fig.set_data(world)
         plt.figure("Domain").canvas.draw()
         plt.figure("Domain").canvas.flush_events()
Esempio n. 12
0
    def show_domain(self, a=0):
        # Draw the environment
        if self.circles is None:
            self.fig = plt.figure(1, (self.chain_size * 2, 2))
            ax = self.fig.add_axes([0, 0, 1, 1], frameon=False, aspect=1.0)
            ax.set_xlim(0, self.chain_size * 2)
            ax.set_ylim(0, 2)
            # Make the last one double circle
            ax.add_patch(
                mpatches.Circle((1 + 2 * (self.chain_size - 1), self.Y),
                                self.RADIUS * 1.1,
                                fc="w"))
            self.circles = [
                mpatches.Circle((1 + 2 * i, self.Y), self.RADIUS, fc="w")
                for i in range(self.chain_size)
            ]
            for i in range(self.chain_size):
                ax.add_patch(self.circles[i])
                if i < self.chain_size - 1:
                    from_a_to_b(
                        1 + 2 * i + self.SHIFT,
                        self.Y + self.SHIFT,
                        1 + 2 * (i + 1) - self.SHIFT,
                        self.Y + self.SHIFT,
                    )
                if i < self.chain_size - 2:
                    from_a_to_b(
                        1 + 2 * (i + 1) - self.SHIFT,
                        self.Y - self.SHIFT,
                        1 + 2 * i + self.SHIFT,
                        self.Y - self.SHIFT,
                        "r",
                    )
            from_a_to_b(
                0.75,
                self.Y - 1.5 * self.SHIFT,
                0.75,
                self.Y + 1.5 * self.SHIFT,
                "r",
                connectionstyle="arc3,rad=-1.2",
            )
            self.fig.show()

        for i, p in enumerate(self.circles):
            if self.state[0] == i:
                p.set_facecolor("k")
            else:
                p.set_facecolor("w")
        self.fig.canvas.draw()
Esempio n. 13
0
    def _init_domain_figure(self):
        # Initialize the figure
        self.domain_fig = plt.figure("CartPole {}".format(self.NAME))
        self.domain_ax = self.domain_fig.add_axes([0, 0, 1, 1],
                                                  frameon=True,
                                                  aspect=1.0)
        self.pendulum_arm = lines.Line2D([], [],
                                         linewidth=self.PEND_WIDTH,
                                         color="black")
        self.cart_box = mpatches.Rectangle(
            [0, self.PENDULUM_PIVOT_Y - self.RECT_HEIGHT / 2],
            self.RECT_WIDTH,
            self.RECT_HEIGHT,
            alpha=0.4,
        )
        self.cart_blob = mpatches.Rectangle(
            [0, self.PENDULUM_PIVOT_Y - self.BLOB_WIDTH / 2],
            self.BLOB_WIDTH,
            self.BLOB_WIDTH,
            alpha=0.4,
        )
        self.domain_ax.add_patch(self.cart_box)
        self.domain_ax.add_line(self.pendulum_arm)
        self.domain_ax.add_patch(self.cart_blob)
        # Draw Ground
        groundPath = mpath.Path(self.GROUND_VERTS)
        groundPatch = mpatches.PathPatch(groundPath, hatch="//")
        self.domain_ax.add_patch(groundPatch)
        self.time_text = self.domain_ax.text(self.POSITION_LIMITS[1],
                                             self.LENGTH, "")
        self.reward_text = self.domain_ax.text(self.POSITION_LIMITS[0],
                                               self.LENGTH, "")
        # Allow room for pendulum to swing without getting cut off on graph
        viewable_dist = self.LENGTH + 0.5
        if (self.POSITION_LIMITS[0] < -100 * self.LENGTH
                or self.POSITION_LIMITS[1] > 100 * self.LENGTH):
            # We have huge position limits, limit the figure width so
            # cart is still visible
            self.domain_ax.set_xlim(-viewable_dist, viewable_dist)
        else:
            self.domain_ax.set_xlim(
                self.POSITION_LIMITS[0] - viewable_dist,
                self.POSITION_LIMITS[1] + viewable_dist,
            )
        self.domain_ax.set_ylim(-viewable_dist, viewable_dist)
        self.domain_ax.set_aspect("equal")

        plt.show()
Esempio n. 14
0
    def show_domain(self, a):
        s = self.state
        # Draw the environment
        fig = plt.figure("IntruderMonitoring")
        if self.domain_fig is None:
            self.domain_fig = plt.imshow(
                self.map,
                cmap="IntruderMonitoring",
                interpolation="nearest",
                vmin=0,
                vmax=3,
            )
            plt.xticks(np.arange(self.COLS), fontsize=FONTSIZE)
            plt.yticks(np.arange(self.ROWS), fontsize=FONTSIZE)
            plt.show()
        if self.ally_fig is not None:
            self.ally_fig.pop(0).remove()
            self.intruder_fig.pop(0).remove()

        s_ally = s[0:self.NUMBER_OF_AGENTS * 2].reshape((-1, 2))
        s_intruder = s[self.NUMBER_OF_AGENTS * 2:].reshape((-1, 2))
        self.ally_fig = plt.plot(
            s_ally[:, 1],
            s_ally[:, 0],
            "bo",
            markersize=30.0,
            alpha=0.7,
            markeredgecolor="k",
            markeredgewidth=2,
        )
        self.intruder_fig = plt.plot(
            s_intruder[:, 1],
            s_intruder[:, 0],
            "g>",
            color="gray",
            markersize=30.0,
            alpha=0.7,
            markeredgecolor="k",
            markeredgewidth=2,
        )
        fig.canvas.draw()
        fig.canvas.flush_events()
Esempio n. 15
0
 def show_domain(self, a=None):
     if a is not None:
         a = self.actions[a]
     T = np.empty((self.d, 2))
     T[:, 0] = np.cos(self.theta)
     T[:, 1] = np.sin(self.theta)
     R = np.dot(self.P, T)
     R1 = R - 0.5 * self.lengths[:, None] * T
     R2 = R + 0.5 * self.lengths[:, None] * T
     Rx = np.hstack([R1[:, 0], R2[:, 0]]) + self.pos_cm[0]
     Ry = np.hstack([R1[:, 1], R2[:, 1]]) + self.pos_cm[1]
     fig = plt.figure("Swimmer")
     if self.swimmer_lines is None:
         plt.plot(0.0, 0.0, "ro")
         self.swimmer_lines = plt.plot(Rx, Ry)[0]
         self.action_text = plt.text(-2, -8, str(a))
         plt.xlim(-5, 15)
         plt.ylim(-10, 10)
     else:
         self.swimmer_lines.set_data(Rx, Ry)
         self.action_text.set_text(str(a))
     fig.canvas.draw()
     fig.canvas.flush_events()
Esempio n. 16
0
    def show_domain(self, a=0):
        s = self.state
        plt.figure("Domain")

        if self.networkGraph is None:  # or self.networkPos is None:
            self.networkGraph = nx.Graph()
            # enumerate all computer_ids, simulatenously iterating through
            # neighbors list and compstatus
            for computer_id, (neighbors,
                              compstatus) in enumerate(zip(self.NEIGHBORS, s)):
                # Add a node to network for each computer
                self.networkGraph.add_node(computer_id, node_color="w")
            for uniqueEdge in self.UNIQUE_EDGES:
                self.networkGraph.add_edge(
                    uniqueEdge[0], uniqueEdge[1],
                    edge_color="k")  # Add an edge between each neighbor
            self.networkPos = nx.circular_layout(self.networkGraph)
            nx.draw_networkx_nodes(self.networkGraph,
                                   self.networkPos,
                                   node_color="w")
            nx.draw_networkx_edges(self.networkGraph,
                                   self.networkPos,
                                   edge_color="k")
            nx.draw_networkx_labels(self.networkGraph, self.networkPos)
            plt.show()
        else:
            plt.clf()
            blackEdges = []
            redEdges = []
            greenNodes = []
            redNodes = []
            for computer_id, (neighbors,
                              compstatus) in enumerate(zip(self.NEIGHBORS, s)):
                if compstatus == self.RUNNING:
                    greenNodes.append(computer_id)
                else:
                    redNodes.append(computer_id)
            # Iterate through all unique edges
            for uniqueEdge in self.UNIQUE_EDGES:
                if (s[uniqueEdge[0]] == self.RUNNING
                        and s[uniqueEdge[1]] == self.RUNNING):
                    # Then both computers are working
                    blackEdges.append(uniqueEdge)
                else:  # If either computer is BROKEN, make the edge red
                    redEdges.append(uniqueEdge)
            # "if redNodes", etc. - only draw things in the network if these lists aren't empty / null
            if redNodes:
                nx.draw_networkx_nodes(
                    self.networkGraph,
                    self.networkPos,
                    nodelist=redNodes,
                    node_color="r",
                    linewidths=2,
                )
            if greenNodes:
                nx.draw_networkx_nodes(
                    self.networkGraph,
                    self.networkPos,
                    nodelist=greenNodes,
                    node_color="w",
                    linewidths=2,
                )
            if blackEdges:
                nx.draw_networkx_edges(
                    self.networkGraph,
                    self.networkPos,
                    edgelist=blackEdges,
                    edge_color="k",
                    width=2,
                    style="solid",
                )
            if redEdges:
                nx.draw_networkx_edges(
                    self.networkGraph,
                    self.networkPos,
                    edgelist=redEdges,
                    edge_color="k",
                    width=2,
                    style="dotted",
                )
        nx.draw_networkx_labels(self.networkGraph, self.networkPos)
        plt.figure("Domain").canvas.draw()
        plt.figure("Domain").canvas.flush_events()
Esempio n. 17
0
File: pst.py Progetto: kngwyu/rlpy3
    def show_domain(self, a=0):
        s = self.state
        if self.domain_fig is None:
            plt.figure("Domain")
            self.domain_fig = plt.figure(
                1,
                (UAVLocation.SIZE * self.dist_between_locations + 1,
                 self.NUM_UAV + 1),
            )
            plt.show()
        plt.clf()
        # Draw the environment
        # Allocate horizontal 'lanes' for UAVs to traverse

        # Formerly, we checked if this was the first time plotting; wedge shapes cannot be removed from
        # matplotlib environment, nor can their properties be changed, without clearing the figure
        # Thus, we must redraw the figure on each timestep
        #        if self.location_rect_vis is None:
        # Figure with x width corresponding to number of location states, UAVLocation.SIZE
        # and rows (lanes) set aside in y for each UAV (NUM_UAV total lanes).
        # Add buffer of 1
        self.subplot_axes = self.domain_fig.add_axes([0, 0, 1, 1],
                                                     frameon=False,
                                                     aspect=1.0)
        crashLocationX = 2 * (self.dist_between_locations) * (
            UAVLocation.SIZE - 1)
        self.subplot_axes.set_xlim(0, 1 + crashLocationX + self.RECT_GAP)
        self.subplot_axes.set_ylim(0, 1 + self.NUM_UAV)
        self.subplot_axes.xaxis.set_visible(False)
        self.subplot_axes.yaxis.set_visible(False)

        # Assign coordinates of each possible uav location on figure
        self.location_coord = [
            0.5 + self.LOCATION_WIDTH / 2 + (self.dist_between_locations) * i
            for i in range(UAVLocation.SIZE - 1)
        ]
        self.location_coord.append(crashLocationX + self.LOCATION_WIDTH / 2)

        # Create rectangular patches at each of those locations
        self.location_rect_vis = [
            mpatches.Rectangle(
                [0.5 + (self.dist_between_locations) * i, 0],
                self.LOCATION_WIDTH,
                self.NUM_UAV * 2,
                fc="w",
            ) for i in range(UAVLocation.SIZE - 1)
        ]
        self.location_rect_vis.append(
            mpatches.Rectangle([crashLocationX, 0],
                               self.LOCATION_WIDTH,
                               self.NUM_UAV * 2,
                               fc="w"))
        [
            self.subplot_axes.add_patch(self.location_rect_vis[i])
            for i in range(4)
        ]
        self.comms_line = [
            lines.Line2D(
                [
                    0.5 + self.LOCATION_WIDTH +
                    (self.dist_between_locations) * i,
                    0.5 + self.LOCATION_WIDTH +
                    (self.dist_between_locations) * i + self.RECT_GAP,
                ],
                [self.NUM_UAV * 0.5 + 0.5, self.NUM_UAV * 0.5 + 0.5],
                linewidth=3,
                color="black",
                visible=False,
            ) for i in range(UAVLocation.SIZE - 2)
        ]
        self.comms_line.append(
            lines.Line2D(
                [
                    0.5 + self.LOCATION_WIDTH +
                    (self.dist_between_locations) * 2,
                    crashLocationX,
                ],
                [self.NUM_UAV * 0.5 + 0.5, self.NUM_UAV * 0.5 + 0.5],
                linewidth=3,
                color="black",
                visible=False,
            ))

        # Create location text below rectangles
        locText = ["Base", "Refuel", "Communication", "Surveillance"]
        self.location_rect_txt = [
            plt.text(
                0.5 + self.dist_between_locations * i +
                0.5 * self.LOCATION_WIDTH,
                -0.3,
                locText[i],
                ha="center",
            ) for i in range(UAVLocation.SIZE - 1)
        ]
        self.location_rect_txt.append(
            plt.text(
                crashLocationX + 0.5 * self.LOCATION_WIDTH,
                -0.3,
                locText[UAVLocation.SIZE - 1],
                ha="center",
            ))

        # Initialize list of circle objects

        uav_x = self.location_coord[UAVLocation.BASE]

        # Update the member variables storing all the figure objects
        self.uav_circ_vis = [
            mpatches.Circle((uav_x, 1 + uav_id), self.UAV_RADIUS, fc="w")
            for uav_id in range(0, self.NUM_UAV)
        ]
        self.uav_text_vis = [None for uav_id in range(0, self.NUM_UAV)]  # f**k
        self.uav_sensor_vis = [
            mpatches.Wedge((uav_x + self.SENSOR_REL_X, 1 + uav_id),
                           self.SENSOR_LENGTH, -30, 30)
            for uav_id in range(0, self.NUM_UAV)
        ]
        self.uav_actuator_vis = [
            mpatches.Wedge((uav_x, 1 + uav_id + self.ACTUATOR_REL_Y),
                           self.ACTUATOR_HEIGHT, 60, 120)
            for uav_id in range(0, self.NUM_UAV)
        ]

        # For each UAV:
        # Draw a circle, with text inside = amt fuel remaining
        # Triangle on top of UAV for comms, black = good, red = bad
        # Triangle in front of UAV for surveillance
        sStruct = self.state2Struct(s)

        for uav_id in range(0, self.NUM_UAV):
            # Assign all the variables corresponding to this UAV for this iteration;
            # this could alternately be done with a UAV class whose objects keep track
            # of these variables.  Elect to use lists here since ultimately the state
            # must be a vector anyway.
            # State index corresponding to the location of this uav
            uav_location = sStruct.locations[uav_id]
            uav_fuel = sStruct.fuel[uav_id]
            uav_sensor = sStruct.sensor[uav_id]
            uav_actuator = sStruct.actuator[uav_id]

            # Assign coordinates on figure where UAV should be drawn
            uav_x = self.location_coord[uav_location]
            uav_y = 1 + uav_id

            # Update plot wit this UAV
            self.uav_circ_vis[uav_id] = mpatches.Circle((uav_x, uav_y),
                                                        self.UAV_RADIUS,
                                                        fc="w")
            self.uav_text_vis[uav_id] = plt.text(uav_x - 0.05, uav_y - 0.05,
                                                 uav_fuel)
            if uav_sensor == SensorState.RUNNING:
                objColor = "black"
            else:
                objColor = "red"
            self.uav_sensor_vis[uav_id] = mpatches.Wedge(
                (uav_x + self.SENSOR_REL_X, uav_y),
                self.SENSOR_LENGTH,
                -30,
                30,
                color=objColor,
            )

            if uav_actuator == ActuatorState.RUNNING:
                objColor = "black"
            else:
                objColor = "red"
            self.uav_actuator_vis[uav_id] = mpatches.Wedge(
                (uav_x, uav_y + self.ACTUATOR_REL_Y),
                self.ACTUATOR_HEIGHT,
                60,
                120,
                color=objColor,
            )

            self.subplot_axes.add_patch(self.uav_circ_vis[uav_id])
            self.subplot_axes.add_patch(self.uav_sensor_vis[uav_id])
            self.subplot_axes.add_patch(self.uav_actuator_vis[uav_id])

        numHealthySurveil = np.sum(
            np.logical_and(sStruct.locations == UAVLocation.SURVEIL,
                           sStruct.sensor))
        # We have comms coverage: draw a line between comms states to show this
        if any(sStruct.locations == UAVLocation.COMMS):
            for i in range(len(self.comms_line)):
                self.comms_line[i].set_visible(True)
                self.comms_line[i].set_color("black")
                self.subplot_axes.add_line(self.comms_line[i])
            # We also have UAVs in surveillance; color the comms line black
            if numHealthySurveil > 0:
                self.location_rect_vis[len(self.location_rect_vis) -
                                       1].set_color("green")
        plt.figure("Domain").canvas.draw()
        plt.figure("Domain").canvas.flush_events()
        sleep(0.5)