Beispiel #1
0
def train_network():
    x, y = read_train_data()
    x_train, x_test, y_train, y_test = train_test_split(x,
                                                        y,
                                                        test_size=0.1,
                                                        random_state=5)

    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255  # normalize training data
    x_test /= 255
    y_train = y_train.astype('float32')
    y_test = y_test.astype('float32')

    model = build_model()

    epochs = 50
    batch_size = 5

    model.fit(x_train,
              y_train,
              batch_size=batch_size,
              epochs=epochs,
              verbose=1,
              validation_data=(x_test, y_test))

    model.save_weights('rotation_center_net_weights')

    score = model.evaluate(x_test, y_test, verbose=0)
    print('Test loss:', score)

    predictions = model.predict(x_test)
    # Plotting predicted center as red circle overlayed on image for test data
    plt.ion()
    fig, ax = plt.subplots(1)
    plt.show()
    for i in range(predictions.shape[0]):
        ax.imshow(np.squeeze(x_test[i]), cmap='gray')
        circ = Circle((predictions[i, 0], predictions[i, 1]),
                      radius=50,
                      fill=False,
                      color='red')
        ax.add_patch(circ)
        print('Prediction: ', predictions[i, :])
        print('Actual: ', y_test[i, :])
        plt.draw()
        input("Press [enter] to continue.")
        circ.remove()
Beispiel #2
0
class DraggablePoint(object):

    # http://stackoverflow.com/questions/21654008/matplotlib-drag-overlapping-points-interactively

    lock = None

    def __init__(self, parent, x=0.1, y=0.1, size=0.2, color=None):

        color_edge = [20, 20, 20]
        self.parent = parent
        self.point = Circle((x, y), size, facecolor=Util.rgb_array_to_hex_str(color), edgecolor=Util.rgb_array_to_hex_str(color_edge), zorder=99)
        self.x = x
        self.y = y
        self.parent.axes[0].add_patch(self.point)
        self.press = None
        self.background = None
        self.connect()

    def connect(self):
        self.cidpress = self.point.figure.canvas.mpl_connect('button_press_event', self.on_press)
        self.cidrelease = self.point.figure.canvas.mpl_connect('button_release_event', self.on_release)
        self.cidmotion = self.point.figure.canvas.mpl_connect('motion_notify_event', self.on_motion)

    def on_press(self, event):
        if event.inaxes != self.point.axes: return
        if DraggablePoint.lock is not None: return
        contains, attrd = self.point.contains(event)
        if not contains: return
        self.press = (self.point.center), event.xdata, event.ydata
        DraggablePoint.lock = self
        self.refresh()

    def on_motion(self, event):
        if DraggablePoint.lock is not self:
            return
        if event.inaxes != self.point.axes: return
        self.point.center, xpress, ypress = self.press
        dx = event.xdata - xpress
        dy = event.ydata - ypress
        x = self.point.center[0]+dx
        y = self.point.center[1]+dy
        self.set_point_pose(x, y)

    def set_point_pose(self, x, y):
        self.point.center = (x, y)
        self.refresh()

        self.pose.position.x = self.x
        self.pose.position.y = self.y

    def on_release(self, event):
        if DraggablePoint.lock is not self:
            return
        self.press = None
        DraggablePoint.lock = None
        self.point.set_animated(False)

        self.background = None
        self.point.figure.canvas.draw()

        self.x = self.point.center[0]
        self.y = self.point.center[1]

    def refresh(self):
        self.point.set_animated(True)
        canvas = self.point.figure.canvas
        canvas.draw()
        self.background = canvas.copy_from_bbox(self.point.axes.bbox)
        axes = self.point.axes
        canvas.restore_region(self.background)
        axes.draw_artist(self.point)
        canvas.blit(axes.bbox)
        self.x = self.point.center[0]
        self.y = self.point.center[1]
        self.point.set_animated(False)
        self.background = None

    def disconnect(self):
        self.point.figure.canvas.mpl_disconnect(self.cidpress)
        self.point.figure.canvas.mpl_disconnect(self.cidrelease)
        self.point.figure.canvas.mpl_disconnect(self.cidmotion)

    def remove(self):
        self.point.remove()
        self.disconnect()
Beispiel #3
0
class _draggable_circles:
    def __init__(self, ax, position, radius, color, linestyle):
        self.ax = ax
        self.canvas = ax.figure.canvas
        self.position = position
        self.radius = radius
        self.circle = Circle(position,
                             radius,
                             color=color,
                             linestyle=linestyle,
                             fill=False)

        delta = min([
            self.ax.get_xlim()[1] - self.ax.get_xlim()[0],
            self.ax.get_ylim()[1] - self.ax.get_ylim()[0]
        ])
        self.currently_selected = False

        self.center_dot = Circle(position, delta / 200, color=color)
        self.circle_artist = self.ax.add_artist(self.circle)
        self.center_dot_artist = self.ax.add_artist(self.center_dot)
        self.center_dot_artist.set_visible(False)

        self.canvas.draw_idle()

    def circle_picker(self, mouseevent):
        if (mouseevent.xdata is None) or (mouseevent.ydata is None):
            return False, dict()
        center_xdata, center_ydata = self.circle.get_center()
        radius = self.circle.get_radius()
        tolerance = 0.05
        d = np.sqrt((center_xdata - mouseevent.xdata)**2 +
                    (center_ydata - mouseevent.ydata)**2)

        if d >= radius * (1 - tolerance) and d <= radius * (1 + tolerance):
            pickx = center_xdata
            picky = center_ydata
            props = dict(pickx=pickx, picky=picky)
            return True, props
        else:
            return False, dict()

    def click_position_finder(self, event):
        self.initial_click_position = (event.xdata, event.ydata)

    def drag_circle(self, event):
        if event.xdata and event.ydata:
            self.canvas.restore_region(self.background)
            centervector = (self.position[0] - self.initial_click_position[0],
                            self.position[1] - self.initial_click_position[1])
            newcenter = (centervector[0] + event.xdata,
                         centervector[1] + event.ydata)
            self.center_dot.set_center(newcenter)
            self.circle.set_center(newcenter)
            self.ax.draw_artist(self.circle_artist)
            self.ax.draw_artist(self.center_dot_artist)
            self.canvas.blit(self.ax.bbox)

    def change_circle_size(self, event):
        if event.xdata and event.ydata:
            self.canvas.restore_region(self.background)
            newradius = ((self.position[0] - event.xdata)**2 +
                         (self.position[1] - event.ydata)**2)**0.5
            self.circle.set_radius(newradius)
            self.ax.draw_artist(self.circle_artist)
            self.ax.draw_artist(self.center_dot_artist)
            self.canvas.blit(self.ax.bbox)

    def start_event(self, event):
        if self.currently_selected:
            return

        self.currently_selected = True
        self.center_dot_artist.set_visible(False)
        self.circle_artist.set_visible(False)
        self.canvas.draw()
        self.background = self.canvas.copy_from_bbox(self.ax.bbox)
        self.circle_artist.set_visible(True)
        self.releaser = self.canvas.mpl_connect("button_press_event",
                                                self.releaseonclick)

        if event.button == 1:
            self.canvas.draw_idle()
            self.follower = self.canvas.mpl_connect("motion_notify_event",
                                                    self.change_circle_size)

        if event.button == 3:
            self.click_position_finder(event)
            self.center_dot_artist.set_visible(True)
            self.canvas.draw_idle()
            self.follower = self.canvas.mpl_connect("motion_notify_event",
                                                    self.drag_circle)

    def releaseonclick(self, event):
        self.radius = self.circle.get_radius()
        self.position = self.circle.get_center()
        self.center_dot_artist.set_visible(False)
        self.canvas.mpl_disconnect(self.follower)
        self.canvas.mpl_disconnect(self.releaser)
        self.canvas.draw_idle()
        self.currently_selected = False

    def clear(self):
        self.circle.remove()
        self.canvas.draw()
        return self.radius
Beispiel #4
0
class Enviroment(SimulationEnviroment):
    '''
    Creates a map of the vehicle and its surroding enviroment. This map is used
    by the planner to avoid collision.

    This might be expanded later with a SLAM.
    '''
    def __init__(self, vehicle, mapRadius=80):
        self.bodies = []
        self.mapRadius = mapRadius
        self.centerBody = Body.fromVehicle(vehicle)
        self.centerBodyLine = None
        self.centerBodyArrow = None
        self.vehicle = vehicle
        self.radiusCircle = None
        self.bodyLines = []  # might be better changing this to a dict
        self.bodyArrows = []  # as to avoid having to keep bodies on track

    def addBody(self, body):
        if (body not in self.bodies
                and self.centerBody.getDistance(body) <= self.mapRadius):
            self.bodies.append(body)

    def lidarScan(self):
        for lidar in self.vehicle.lidarArray:
            lidar.updateRays()
            lidar.rayTracing(self.bodies)

    def update(self, t):
        bodiesToRemove = []
        for idx in range(len(self.bodies)):
            body = self.bodies[idx]
            body.updateStates(t)
            if (self.centerBody.getDistance(body) > self.mapRadius):
                bodiesToRemove.append(idx)

#        for removeIdx in bodiesToRemove:
#            del self.bodies[removeIdx]

        orientation = self.vehicle.getOrientation()
        pos = self.vehicle.getPos()
        vel = self.vehicle.getVelocity()
        acc = self.vehicle.getAcc()
        omega = self.vehicle.getOmega()

        self.centerBody.setStates(pos[0], pos[1], vel, orientation, t)
        self.centerBody.setAcceleration(acc)
        self.centerBody.setOmega(omega)

    def createPlot(self, fig=None, ax=None):
        self.timeOfLastPlot = -1

        if (fig is None):
            self.fig = plt.figure()
        else:
            self.fig = fig

        if (ax is None):
            self.ax = self.fig.add_subplot(111)
        else:
            self.ax = ax

#        for lidar in self.vehicle.lidarArray:
        pm = [
            pm.coords if pm is not None else None
            for pm in self.vehicle.lidar.read()
        ]
        bodyLine, = self.ax.plot(pm, 'r-', linewidth=0.1)
        self.bodyLines.append(bodyLine)

        (verticesX, verticesY) = self.centerBody.getDrawingVertex()
        self.centerBodyLine, = self.ax.plot(verticesX, verticesY, 'r')
        self.centerBodyArrow = Arrow(
            self.centerBody.x,
            self.centerBody.y,
            0.1 * self.centerBody.v * cos(self.centerBody.orientation),
            0.1 * self.centerBody.v * sin(self.centerBody.orientation),
            color='c')
        self.ax.add_patch(self.centerBodyArrow)

        self.radiusCircle = Circle((self.centerBody.x, self.centerBody.y),
                                   self.mapRadius,
                                   color='k',
                                   linestyle=':',
                                   fill=False)
        self.ax.add_patch(self.radiusCircle)

        self.ax.set_xlim([
            self.centerBody.x - self.mapRadius * 1.1,
            self.centerBody.x + self.mapRadius * 1.1
        ])
        self.ax.set_ylim([
            self.centerBody.y - self.mapRadius * 1.1,
            self.centerBody.y + self.mapRadius * 1.1
        ])
        self.ax.set_xlabel('Distance X')
        self.ax.set_ylabel('Distance Y')

    def plot(self, draw=True):
        #        for idx in range(len(self.vehicle.lidarArray)):

        idx = 0
        lidar = self.vehicle.lidar
        pm_x = [(self.centerBody.x, pm.xy[0][0]) if pm is not None else None
                for pm in lidar.read()]
        pm_y = [(self.centerBody.y, pm.xy[1][0]) if pm is not None else None
                for pm in lidar.read()]

        bodyLine = self.bodyLines[idx]
        bodyLine.set_ydata(pm_y)
        bodyLine.set_xdata(pm_x)

        (verticesX, verticesY) = self.centerBody.getDrawingVertex()
        self.centerBodyLine.set_ydata(verticesY)
        self.centerBodyLine.set_xdata(verticesX)

        self.centerBodyArrow.remove()
        self.centerBodyArrow = Arrow(
            self.centerBody.x,
            self.centerBody.y,
            0.1 * self.centerBody.v * cos(self.centerBody.orientation),
            0.1 * self.centerBody.v * sin(self.centerBody.orientation),
            color='c')
        self.ax.add_patch(self.centerBodyArrow)

        self.radiusCircle.remove()
        self.radiusCircle = Circle((self.centerBody.x, self.centerBody.y),
                                   self.mapRadius,
                                   color='k',
                                   linestyle=':',
                                   fill=False)
        self.ax.add_patch(self.radiusCircle)

        self.ax.set_xlim([
            self.centerBody.x - self.mapRadius * 1.1,
            self.centerBody.x + self.mapRadius * 1.1
        ])
        self.ax.set_ylim([
            self.centerBody.y - self.mapRadius * 1.1,
            self.centerBody.y + self.mapRadius * 1.1
        ])

        if (draw):
            self.fig.canvas.draw()
            plt.pause(0.001)
Beispiel #5
0
class Plot_canvas(FigureCanvas):
    def __init__(self):
        fig = Figure(figsize=(4, 4), dpi=100)
        self.ax = fig.add_subplot(111)
        super().__init__(fig)
        FigureCanvas.setSizePolicy(self, QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)
        self.radars = []

    def init_walls(self, pos_angle, ends, fin_pos):
        self.ax.cla()
        self.ax.plot(*zip(*ends),
                     color='k')  # zip equals to tranpose, * eauals to depack
        finish_area = Rectangle(xy=(fin_pos[0][0], fin_pos[1][1]),
                                width=fin_pos[1][0] - fin_pos[0][0],
                                height=fin_pos[1][1] - fin_pos[0][1],
                                color='red')
        self.ax.add_artist(finish_area)
        pos = tuple(pos_angle[:2])
        angle = math.radians(pos_angle[2])
        self.car = Circle(pos, radius=3, color='mediumblue')
        self.head = Arrow(pos[0],
                          pos[1],
                          dx=5 * math.cos(angle),
                          dy=5 * math.sin(angle),
                          facecolor='gold',
                          edgecolor='k')
        self.ax.add_artist(self.car)
        self.ax.add_artist(self.head)
        self.draw()

    def update_car(self, pos_angle, inters):
        self.car.remove()
        self.head.remove()

        pos = tuple(pos_angle[:2])
        angle = math.radians(pos_angle[2])
        self.car = Circle(pos, radius=3, color='mediumblue')
        self.head = Arrow(pos[0],
                          pos[1],
                          dx=5 * math.cos(angle),
                          dy=5 * math.sin(angle),
                          facecolor='gold',
                          edgecolor='k')
        if self.radars:
            for radar in self.radars:
                radar.remove()
        self.radars = [
            Line2D(*zip(pos, inter), linestyle='-', color='gray')
            for inter in inters
        ]
        for radar in self.radars:
            self.ax.add_line(radar)
        self.ax.add_artist(self.car)
        self.ax.add_artist(self.head)
        self.draw()

    def collide(self):
        self.car.set_color('darkred')
        self.draw()

    def show_path(self, path_x, path_y):
        self.path = Line2D(path_x,
                           path_y,
                           linewidth=6,
                           solid_capstyle='round',
                           solid_joinstyle='round',
                           alpha=0.75,
                           color='gray')
        self.ax.add_line(self.path)
        self.draw()
Beispiel #6
0
class SimulationEnviroment(object):
    '''
    Creates a map of the vehicle and its surroding enviroment. This map is used
    by the planner to avoid collision.

    This might be expanded later with a SLAM.
    '''
    def __init__(self, centerBody, mapRadius=80):
        self.bodies = []
        self.mapRadius = mapRadius
        self.centerBody = centerBody
        self.centerBodyLine = None
        self.centerBodyArrow = None
        self.radiusCircle = None
        self.bodyLines = []  # might be better changin g this to a dict
        self.bodyArrows = []  # as to avoid having to keep bodies on track
        self.lidarArray = []

    def addLIDAR(self, lidar):
        self.lidarArray.append(lidar)

    def getPointMap(self):
        pm = []
        for lidar in self.lidarArray:
            pm.extend(lidar.pointMap)

        return pm

    def lidarScan(self):
        for lidar in self.lidarArray:
            lidar.updateRays()
            lidar.rayTracing(self.bodies)

    def addBody(self, body):
        if (body not in self.bodies):
            self.bodies.append(body)

    def update(self, t):
        for b in self.bodies:
            b.updateStates(t)

#        prevPos = self.centerBody.getPosition()
#        prevOrientation = self.centerBody.getOrientation()
#        self.centerBody.updateStates(t)
        pos = (self.centerBody.x, self.centerBody.y)
        #        delta_x, delta_y = (currPos[0] - prevPos[0], currPos[1] - prevPos[1])
        #        delta_theta = self.centerBody.getOrientation() - prevOrientation
        for lidar in self.lidarArray:
            lidar.setPose(pos, 0)

        self.lidarScan()


#        self.centerBody.update(t)

    def createPlot(self, fig=None, ax=None):
        self.timeOfLastPlot = -1

        if (fig is None):
            self.fig = plt.figure()
        else:
            self.fig = fig

        if (ax is None):
            self.ax = self.fig.add_subplot(111)
        else:
            self.ax = ax

        for b in self.bodies:
            (verticesX, verticesY) = b.getDrawingVertex()
            bodyLine, = self.ax.plot(verticesX, verticesY, 'k')
            self.bodyLines.append(bodyLine)
            directionArrow = Arrow(b.x,
                                   b.y,
                                   0.1 * b.v * cos(b.orientation),
                                   0.1 * b.v * sin(b.orientation),
                                   color='c')
            self.bodyArrows.append(directionArrow)
            self.ax.add_patch(directionArrow)

        (verticesX, verticesY) = self.centerBody.getDrawingVertex()
        self.centerBodyLine, = self.ax.plot(verticesX, verticesY, 'r')
        self.centerBodyArrow = Arrow(
            self.centerBody.x,
            self.centerBody.y,
            0.1 * b.v * cos(self.centerBody.orientation),
            0.1 * b.v * sin(self.centerBody.orientation),
            color='c')
        self.ax.add_patch(self.centerBodyArrow)

        self.radiusCircle = Circle((self.centerBody.x, self.centerBody.y),
                                   self.mapRadius,
                                   color='k',
                                   linestyle=':',
                                   fill=False)
        self.ax.add_patch(self.radiusCircle)

        self.ax.set_xlim([
            self.centerBody.x - self.mapRadius * 1.1,
            self.centerBody.x + self.mapRadius * 1.1
        ])
        self.ax.set_ylim([
            self.centerBody.y - self.mapRadius * 1.1,
            self.centerBody.y + self.mapRadius * 1.1
        ])
        self.ax.set_xlabel('Distance X')
        self.ax.set_ylabel('Distance Y')

    def plot(self, draw=True):
        for idx in range(len(self.bodies)):
            b = self.bodies[idx]
            (verticesX, verticesY) = b.getDrawingVertex()
            bodyLine = self.bodyLines[idx]
            bodyLine.set_ydata(verticesY)
            bodyLine.set_xdata(verticesX)

            directionArrow = self.bodyArrows[idx]
            directionArrow.remove()

            directionArrow = Arrow(b.x,
                                   b.y,
                                   0.1 * b.v * cos(b.orientation),
                                   0.1 * b.v * sin(b.orientation),
                                   color='c')
            self.ax.add_patch(directionArrow)
            self.bodyArrows[idx] = directionArrow

        (verticesX, verticesY) = self.centerBody.getDrawingVertex()
        self.centerBodyLine.set_ydata(verticesY)
        self.centerBodyLine.set_xdata(verticesX)

        self.centerBodyArrow.remove()
        self.centerBodyArrow = Arrow(
            self.centerBody.x,
            self.centerBody.y,
            0.1 * self.centerBody.v * cos(self.centerBody.orientation),
            0.1 * self.centerBody.v * sin(self.centerBody.orientation),
            color='c')
        self.ax.add_patch(self.centerBodyArrow)

        self.radiusCircle.remove()
        self.radiusCircle = Circle((self.centerBody.x, self.centerBody.y),
                                   self.mapRadius,
                                   color='k',
                                   linestyle=':',
                                   fill=False)
        self.ax.add_patch(self.radiusCircle)

        self.ax.set_xlim([
            self.centerBody.x - self.mapRadius * 1.1,
            self.centerBody.x + self.mapRadius * 1.1
        ])
        self.ax.set_ylim([
            self.centerBody.y - self.mapRadius * 1.1,
            self.centerBody.y + self.mapRadius * 1.1
        ])

        if (draw):
            self.fig.canvas.draw()
            plt.pause(0.001)
def plot(radii, centroids, img_array, path_, chances=None, real=False):
    print(img_array)
    df_cent = pd.DataFrame(centroids, columns=['X', 'Y', 'Z'])
    df_cent['rad'] = radii
    if real:
        df_cent['chances'] = 1
    else:
        df_cent['chances'] = chances
    z_unique = sorted(df_cent['Z'].unique())
    paths = []
    sub_dfs = []
    for i in range(len(z_unique)):
        if real:
            if not os.path.exists('Real\\{}\\{}'.format(path_, i + 1)):
                os.mkdir('Real\\{}\\{}'.format(path_, i + 1))
        else:
            if not os.path.exists('Predicted\\{}\\{}'.format(path_, i + 1)):
                os.mkdir('Predicted\\{}\\{}'.format(path_, i + 1))

        slice_number = z_unique[i]
        slices = img_array[(int(z_unique[i]) - 5):(int(z_unique[i]) + 6)]
        paths_in = []
        #sub_dfs.append(df_cent[df_cent['Z'] == z_unique[i]])
        df_tmp = df_cent[df_cent['Z'] == z_unique[i]]
        df_tmp = df_tmp[df_tmp['chances'] == max(df_tmp['chances'])]
        for k in range(10):

            plt.figure(figsize=(30, 30))
            plt.imshow(slices[k], cmap=plt.cm.gray)
            #print(len(df_cent[df_cent['Z'] == z_unique[i]]))

            #for j in range(len(df_cent[df_cent['Z'] == z_unique[i]])):

            plt.axis("off")
            #plt.gca().xaxis.set_major_locator(plt.NullLocator())
            #plt.gca().yaxis.set_major_locator(plt.NullLocator())

            alphas = [0.1, 0.15, 0.2, 0.25, 0.3, 0.3, 0.25, 0.2, 0.15, 0.1]
            circ = Circle((df_tmp['Y'].values[0], df_tmp['X'].values[0]),
                          2 * df_tmp['rad'].values[0],
                          linewidth=0,
                          alpha=alphas[k],
                          color='red')
            plt.gcf().gca().add_artist(circ)
            if real:
                plt.gca().set_axis_off()
                plt.subplots_adjust(top=1,
                                    bottom=0,
                                    right=1,
                                    left=0,
                                    hspace=0,
                                    wspace=0)
                plt.margins(0, 0)
                plt.gca().xaxis.set_major_locator(plt.NullLocator())
                plt.gca().yaxis.set_major_locator(plt.NullLocator())
                plt.savefig('Real\\{}\\{}\\{}.png'.format(
                    path_, i + 1,
                    int(z_unique[i]) - 5 + k),
                            bbox_inches='tight',
                            pad_inches=0)
                paths_in.append('Real\\{}\\{}\\{}.png'.format(
                    path_, i + 1,
                    int(z_unique[i]) - 5 + k))
            else:
                plt.gca().set_axis_off()
                plt.subplots_adjust(top=1,
                                    bottom=0,
                                    right=1,
                                    left=0,
                                    hspace=0,
                                    wspace=0)
                plt.margins(0, 0)
                plt.gca().xaxis.set_major_locator(plt.NullLocator())
                plt.gca().yaxis.set_major_locator(plt.NullLocator())
                plt.savefig('Predicted\\{}\\{}\\{}.png'.format(
                    path_, i + 1,
                    int(z_unique[i]) - 5 + k),
                            bbox_inches='tight',
                            pad_inches=0)
                paths_in.append('Predicted\\{}\\{}\\{}.png'.format(
                    path_, i + 1,
                    int(z_unique[i]) - 5 + k))
            plt.show()
            circ.remove()
        sub_dfs.append(df_tmp)
        paths.append(paths_in)
    return paths, sub_dfs
Beispiel #8
0
class PlotMap:
    '''  class contains method to plot the pss data, swath boundary, map and
         active receivers
    '''
    def __init__(self, start_date, maptype=None, swaths_selected=None):
        self.date = start_date
        self.maptype = maptype
        self.swaths_selected = swaths_selected
        self.pss_dataframes = [None, None, None]
        self.init_pss_dataframes()

        self.fig, self.ax = self.setup_map(figsize=FIGSIZE)

        connect = self.fig.canvas.mpl_connect
        connect('button_press_event', self.on_click)
        connect('key_press_event', self.on_key)
        connect('resize_event', self.on_resize)

        self.resize_timer = self.fig.canvas.new_timer(interval=250)
        self.resize_timer.add_callback(self.blit)

        # start event loop
        self.artists_on_stage = False
        self.background = None
        self.show(block=False)
        plt.pause(0.1)
        self.blit()

    def setup_map(self, figsize):
        ''' setup the map and background '''
        fig, ax = plt.subplots(figsize=figsize)

        # plot the swath boundary
        _, _, _, swaths_bnd_gpd = GeoData().filter_geo_data_by_swaths(
            swaths_selected=self.swaths_selected,
            swaths_only=True,
            source_boundary=True,
        )
        swaths_bnd_gpd = self.convert_to_map(swaths_bnd_gpd)
        swaths_bnd_gpd.plot(ax=ax, facecolor='none', edgecolor=EDGECOLOR)
        # obtain the extent of the data based on swaths_bnd_gdf
        extent_map = ax.axis()
        logger.info(f'extent data swaths: {extent_map}')

        # plot the selected basemap background
        if self.maptype == maptypes[0]:
            add_basemap_local(ax)

        elif self.maptype == maptypes[1]:
            add_basemap_osm(ax)

        else:
            pass  # no basemap background

        # restore original x/y limits
        ax.axis(extent_map)

        return fig, ax

    def setup_artists(self):
        date_text_x, date_text_y = 0.80, 0.95
        self.vib_artists = {}
        for force_level, force_attr in force_attrs.items():
            self.vib_artists[force_level] = self.ax.scatter(
                [
                    0,
                ],
                [
                    0,
                ],
                s=MARKERSIZE,
                marker='o',
                facecolor=force_attr[0],
            )
        self.date_artist = self.ax.text(
            date_text_x,
            date_text_y,
            '',
            transform=self.ax.transAxes,
        )
        self.actrecv_artist = Polygon(np.array([[0, 0]]),
                                      closed=True,
                                      edgecolor='red',
                                      fill=False)
        self.ax.add_patch(self.actrecv_artist)
        self.cp_artist = Circle((0, 0), radius=SOURCE_CENTER, fc=SOURCE_COLOR)
        self.ax.add_patch(self.cp_artist)
        self.artists_on_stage = True

    def remove_artists(self):
        if self.artists_on_stage:
            for _, vib_artist in self.vib_artists.items():
                vib_artist.remove()
            self.date_artist.remove()
            self.actrecv_artist.remove()
            self.cp_artist.remove()
            self.artists_on_stage = False

        else:
            pass

    def init_pss_dataframes(self):
        dates = [self.date - timedelta(1), self.date, self.date + timedelta(1)]
        for i, _date in enumerate(dates):
            _pss_gpd = get_vps_force_for_date_range(_date, _date, MEDIUM_FORCE,
                                                    HIGH_FORCE)
            _pss_gpd = self.convert_to_map(_pss_gpd)
            self.pss_dataframes[i] = _pss_gpd

    def update_right_pss_dataframes(self):
        self.pss_dataframes[0] = self.pss_dataframes[1]
        self.pss_dataframes[1] = self.pss_dataframes[2]
        _date = self.date + timedelta(1)
        _pss_gpd = get_vps_force_for_date_range(_date, _date, MEDIUM_FORCE,
                                                HIGH_FORCE)
        _pss_gpd = self.convert_to_map(_pss_gpd)
        self.pss_dataframes[2] = _pss_gpd

    def update_left_pss_dataframes(self):
        self.pss_dataframes[2] = self.pss_dataframes[1]
        self.pss_dataframes[1] = self.pss_dataframes[0]
        _date = self.date - timedelta(1)
        _pss_gpd = get_vps_force_for_date_range(_date, _date, MEDIUM_FORCE,
                                                HIGH_FORCE)
        _pss_gpd = self.convert_to_map(_pss_gpd)
        self.pss_dataframes[0] = _pss_gpd

    def plot_pss_data(self, index):
        '''  plot pss force data in three ranges LOW, MEDIUM, HIGH '''
        vib_pss_gpd = self.pss_dataframes[index]
        self.date_artist.set_text(self.date.strftime("%d %m %y"))

        # plot the VP grouped by force_level
        if not vib_pss_gpd.empty:
            for force_level, vib_pss in vib_pss_gpd.groupby('force_level'):
                if pts := [(xy.x, xy.y)
                           for xy in vib_pss['geometry'].to_list()]:
                    self.vib_artists[force_level].set_offsets(pts)

                else:
                    self.vib_artists[force_level].set_offsets([[0, 0]])

        else:
Beispiel #9
0
class Particle:
    """A class representing a two-dimensional particle."""
    def __init__(self, x, y, vx, vy, radius, status, color):
        """Initialize the particle's position, velocity, and radius.

        Any key-value pairs passed in the styles dictionary will be passed
        as arguments to Matplotlib's Circle patch constructor.

        """

        self.r = np.array((x, y))
        self.v = np.array((vx, vy))
        self.radius = radius
        self.status = status
        self.color = color
        self.infect_time = 0
        self.hospital_time = 0
        self.time = 0
        self.wander_step_duration = int(np.random.random() * 10)
        self.last_change_time = 0

    # For convenience, map the components of the particle's position and
    # velocity vector onto the attributes x, y, vx and vy.
    @property
    def x(self):
        return self.r[0]

    @x.setter
    def x(self, value):
        self.r[0] = value

    @property
    def y(self):
        return self.r[1]

    @y.setter
    def y(self, value):
        self.r[1] = value

    @property
    def vx(self):
        return self.v[0]

    @vx.setter
    def vx(self, value):
        self.v[0] = value

    @property
    def vy(self):
        return self.v[1]

    @vy.setter
    def vy(self, value):
        self.v[1] = value

    def contact(self, other):
        """Does the circle of this agent contact that of other one?"""

        return np.hypot(*(
            self.r - other.r)) < self.radius + other.radius + contact_distance

    def draw(self, ax):
        """Add this Particle's Circle patch to the Matplotlib Axes ax."""

        self.circle = Circle(xy=self.r, radius=self.radius, color=self.color)
        ax.add_patch(self.circle)
        return self.circle

    def remove(self):
        self.circle.remove()

    def advance(self, dt):  #move function
        """Advance the Particle's position forward in time by dt."""
        if self.status == STATUS[3] or self.status == STATUS[4]:
            return

        self.time += 1
        # it = int(np.random.random() * 10)
        if (self.time - self.last_change_time) > self.wander_step_duration:
            # self.vx, self.vy = random_vel()
            self.vx = int(np.random.randn() * 7) / 100
            self.vy = int(np.random.randn() * 7) / 100
            self.last_change_time = self.time

        self.r += self.v * dt
Beispiel #10
0
class Crosshair:
    """Class that corms a crosshair of given color
    and size on a given canvas"""
    def __init__(self):
        self.canvas = []
        self.size = []
        self.color = []
        self.zorder = []
        self.x = []
        self.y = []
        self.visible = False

    def setup(self,
              canvas,
              size,
              x,
              y,
              text='',
              zorder=2,
              color='blue',
              circle=False):

        # get input and save to internal namespace
        self.canvas = canvas
        self.size = size
        self.color = color
        self.zorder = zorder
        self.x = x
        self.y = y
        self.visible = False
        self.circle = circle  # should a circle be drawn around crosshair?
        self.text = text

        # create lines
        self.horizontalLine = Line2D([self.x - self.size, self.x + self.size],
                                     [self.y, self.y],
                                     linestyle='-',
                                     alpha=0.5,
                                     linewidth=1.5,
                                     color=color,
                                     zorder=zorder)
        self.verticalLine = Line2D([self.x, self.x],
                                   [self.y - self.size, self.y + self.size],
                                   linestyle='-',
                                   alpha=0.5,
                                   linewidth=1.5,
                                   color=color,
                                   zorder=zorder)

        if self.circle:
            self.circularLine = Circle((self.x, self.y),
                                       self.size,
                                       fill=False,
                                       linestyle='-',
                                       alpha=0.5,
                                       linewidth=1.5,
                                       zorder=zorder,
                                       color=color)

    def toggle(self):
        "makes cross (in)visible, depending on previous state"
        try:
            # if crosshair is invisible
            if not self.visible:
                self.canvas.axes.add_line(self.horizontalLine)
                self.canvas.axes.add_line(self.verticalLine)
                if self.circle: self.canvas.axes.add_patch(self.circularLine)
                self.annotation = self.canvas.axes.text(self.x,
                                                        self.y,
                                                        self.text,
                                                        color=self.color,
                                                        zorder=self.zorder,
                                                        fontsize=10)
                self.canvas.draw()
                self.visible = True

            # If crosshair is currently visible
            else:
                self.horizontalLine.remove()
                self.verticalLine.remove()
                self.annotation.remove()
                if self.circle: self.circularLine.remove()
                self.canvas.draw()
                self.visible = False
        except Exception:
            print(traceback.print_exc())

    def move(self, x, y):
        " moves crosshair to new location"
        self.x += x
        self.y += y

        for canvas in self.canvases:

            if self.visible:
                self.horizontalLine.set_data(
                    [self.x - self.size, self.x + self.size], [self.y, self.y])
                self.verticalLine.set_data(
                    [self.x, self.x], [self.y - self.size, self.y + self.size])
                if self.circle:
                    self.circularLine.center = (x, y)
                    canvas.axes.add_patch(self.circularLine)

    def wipe(self):
        " Removes crosshair from wherever"
        if self.visible:
            self.toggle()

        self.x = []
        self.y = []
        self.horizontalLine = []
        self.verticalLine = []
        if self.circle:
            self.circularLine = []
        self.visible = False