示例#1
0
def test_init():
    box = [(0, 1), (0, 1)]
    aabb = AABB(box)

    assert len(box) == len(aabb)

    for lims_box, lims_aabb in zip(box, aabb):
        assert lims_box[0] == lims_aabb[0]
        assert lims_box[1] == lims_aabb[1]

    assert AABB().limits is None

    box2 = [(-4, -10)]
    box3 = [(1, 2, 3)]
    for bad_box in (box2, box3):
        with pytest.raises(ValueError):
            AABB(bad_box)
示例#2
0
def f(tr, tree, i):

    min_vals = np.min(tr, axis=0)
    max_vals = np.max(tr, axis=0)

    extremes = np.array([min_vals, max_vals]).T + np.array([-1, 1]) * 0.1
    aabb = AABB([tuple(x) for x in extremes])
    tree.add(aabb, i)
示例#3
0
def test_add():
    tree = AABBTree()
    aabb = AABB([(3, 4), (5, 6), (-3, 5)])

    tree.add(aabb)
    tree2 = AABBTree(aabb)
    assert tree == tree2
    assert AABBTree() != tree
示例#4
0
def test_does_overlap():
    aabb5 = AABB([(-3, 3), (-3, 3)])
    aabb6 = AABB([(0, 1), (5, 6)])
    aabb7 = AABB([(6.5, 6.5), (5.5, 5.5)])

    for aabb in (aabb5, aabb6, aabb7):
        for m in ('DFS', 'BFS'):
            assert not AABBTree().does_overlap(aabb, method=m)

    aabbs = standard_aabbs()
    for indices in itertools.permutations(range(4)):
        tree = AABBTree()
        for i in indices:
            tree.add(aabbs[i])

        for m in ('DFS', 'BFS'):
            assert tree.does_overlap(aabb5, method=m)
            assert not tree.does_overlap(aabb6, method=m)
            assert not tree.does_overlap(aabb7, method=m)
    def find_near_candidates(self, x, d_max):
        if not math.isfinite(x[0]) or not math.isfinite(x[1]):
            return []

        # the point x and a square environment
        bb = AABB([(x[0] - d_max, x[0] + d_max), (x[1] - d_max, x[1] + d_max)])

        candidates = self.data.overlap_values(bb)

        return candidates
示例#6
0
def test_corners():
    lims = [(0, 10), (5, 10)]
    aabb_corners = [[lims[0][0], lims[1][0]], [lims[0][1], lims[1][0]],
                    [lims[0][0], lims[1][1]], [lims[0][1], lims[1][1]]]

    out_corners = AABB(lims).corners
    for c in aabb_corners:
        assert c in out_corners

    for c in out_corners:
        assert c in aabb_corners
示例#7
0
def test_overlap_values():
    aabbs = standard_aabbs()
    values = ['value 1', 3.14, None, None]

    aabb5 = AABB([(-3, 3.1), (-3, 3)])
    aabb6 = AABB([(0, 1), (5, 6)])
    aabb7 = AABB([(6.5, 6.5), (5.5, 5.5)])

    for indices in itertools.permutations(range(4)):
        tree = AABBTree()
        for i in indices:
            tree.add(aabbs[i], values[i])

        vals5 = tree.overlap_values(aabb5)
        assert len(vals5) == 2
        for val in ('value 1', 3.14):
            assert val in vals5

        assert tree.overlap_values(aabb6) == []
        assert tree.overlap_values(aabb7) == []

    assert AABBTree(aabb5).overlap_values(aabb7) == []
示例#8
0
def test_eq():
    tree = AABBTree()
    tree.add(AABB([(2, 3)]))
    tree.add(AABB([(4, 5)]))
    tree.add(AABB([(-2, 2)]))
    tree2 = AABBTree(tree.aabb)

    assert tree == tree
    assert AABBTree() == AABBTree()
    assert tree != AABBTree()
    assert AABBTree() != tree
    assert AABBTree() != AABB()
    assert tree != tree2
    assert tree2 != tree

    assert not tree != tree
    assert not AABBTree() != AABBTree()
    assert not tree == AABBTree()
    assert not AABBTree() == tree
    assert not AABBTree() == AABB()
    assert not tree == tree2
    assert not tree2 == tree
示例#9
0
def test_overlap_aabbs():
    aabbs = standard_aabbs()
    values = ['value 1', 3.14, None, None]

    aabb5 = AABB([(-3, 3.1), (-3, 3)])
    aabb6 = AABB([(0, 1), (5, 6)])
    aabb7 = AABB([(6.5, 6.5), (5.5, 5.5)])

    for indices in itertools.permutations(range(4)):
        tree = AABBTree()
        for i in indices:
            tree.add(aabbs[i], values[i])

        for m in ('DFS', 'BFS'):
            aabbs5 = tree.overlap_aabbs(aabb5, method=m)
            assert len(aabbs5) == 2
            for aabb in aabbs5:
                assert aabb in aabbs[:2]

            assert tree.overlap_aabbs(aabb6) == []
            assert tree.overlap_aabbs(aabb7) == []

    for m in ('DFS', 'BFS'):
        assert AABBTree(aabb5).overlap_aabbs(aabb7, method=m) == []
    def find_near_candidates(self, lat_lon, d_max):
        if not math.isfinite(lat_lon[0]) or not math.isfinite(lat_lon[1]):
            return []

        # transfer bounding box of +/- d_max (in meter) to a +/- d_lon and d_lat
        # (approximate, but very good for d_max << circumference earth)
        d_lat, d_lon = LocalMap.get_scale_at(lat_lon[0], lat_lon[1])
        d_lat *= d_max
        d_lon *= d_max

        # define an axis-aligned bounding box (in lat/lon) around the queried point lat_lon
        bb = AABB([(lat_lon[0] - d_lat, lat_lon[0] + d_lat),
                   (lat_lon[1] - d_lon, lat_lon[1] + d_lon)])

        # and query all overlapping bounding boxes of ways
        candidates = self.data.overlap_values(bb)

        return candidates
示例#11
0
def aabb_merge(tree):
    if not tree.is_leaf:
        assert tree.aabb == AABB.merge(tree.left.aabb, tree.right.aabb)
        aabb_merge(tree.left)
        aabb_merge(tree.right)
示例#12
0
def test_merge():
    aabb1 = AABB([(0, 1)])
    aabb2 = AABB([(-1, 2)])
    assert aabb2 == AABB.merge(aabb1, aabb2)

    aabb3 = AABB([(0.5, 3)])
    assert AABB([(0, 3)]) == AABB.merge(aabb1, aabb3)

    assert aabb1 == AABB.merge(aabb1, AABB())
    assert aabb2 == AABB.merge(AABB(), aabb2)
    assert AABB() == AABB.merge(AABB(), AABB())

    aabb3 = AABB([(-1, 0), (2, 3), (1, 5)])
    with pytest.raises(ValueError):
        AABB.merge(aabb1, aabb3)
示例#13
0
def test_repr():
    line = [(2, 3)]
    aabb = AABB(line)
    assert repr(aabb) == 'AABB(' + repr(line) + ')'
示例#14
0
def init():
    glfw.init()
    m = glfw.get_primary_monitor()
    mode = glfw.get_video_mode(m)
    width = mode.size.width // 4
    height = mode.size.height // 4
    window = glfw.create_window(width, height, 'Ray Tracer', None, None)
    glfw.make_context_current(window)

    window_dict = {
        'width': width,
        'height': height,
        'W': np.array([0, 0, 1]),
        'E': np.array([0, 0.15, -1]),
        'b': np.array([0, 1, 0]),
        'd': 1.5,
        'mouse': None,
        'yaw': 0,
        'pitch': 0,
        'up': False,
        'down': False,
        'left': False,
        'right': False,
        'show_cursor': False,
        'num_triangles': 0,
        'num_objects': 2,
        'fov': 90,
        'delta_t': 0,
        'sprint': False,
        'num_aabb': 0
    }
    window_dict = struct(window_dict)
    glfw.set_window_size_callback(
        window, lambda *args: window_size_callback(window_dict, *args))
    glfw.set_cursor_pos_callback(
        window, lambda *args: cursor_position_callback(window_dict, *args))
    glfw.set_key_callback(window,
                          lambda *args: key_callback(window_dict, *args))
    glfw.set_mouse_button_callback(
        window, lambda *args: mouse_button_callback(window_dict, *args))
    set_cursor(window, window_dict)

    quad = pyglet.graphics.vertex_list(
        4, ('v2f', (-1.0, -1.0, -1.0, 1.0, 1.0, -1.0, 1.0, 1.0)))

    try:
        shader = from_files_names("quad_vshader.glsl", "quad_fshader.glsl")
    except ShaderCompilationError as e:
        print(e.logs)
        exit()

    GL_SHADER_STORAGE_BUFFER = GLuint(0x90D2)

    f_obj = '(1f)[type](3f)[pos](1f)[size1](1f)[size2](3f)[color](3f)[direction]'
    object_buffer = Buffer.array(format=f_obj, usage=GL_DYNAMIC_DRAW)
    object_buffer.bind(GL_SHADER_STORAGE_BUFFER)
    #object_buffer.reserve(3)
    #attach_uniform_buffer(object_buffer, shader, b'object_data', GLuint(1))
    glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, object_buffer.bid)

    f_tr = '(3f)[vertex1](3f)[vertex2](3f)[vertex3](3f)[color](3f)[normal]'
    tr_buffer = Buffer.array(format=f_tr, usage=GL_DYNAMIC_DRAW)
    tr_buffer.bind(GL_SHADER_STORAGE_BUFFER)
    #attach_uniform_buffer(tr_buffer, shader, b'triangle_data', GLuint(2))
    glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 1, tr_buffer.bid)

    #f_aabb = '(2f)[xlim](2f)[ylim](2f)[zlim](1i)[left](1i)[right]'

    tree = AABBTree()

    data = get_obj_data('bunny.obj', tree, window_dict.num_objects)
    print(tree.depth)
    aabb = AABB([(-1, 1), (-1, 1), (9, 11)])
    tree.add(aabb, 0)
    aabb = AABB([(-15, 15), (-1.1, 0.9), (-15, 15)])
    tree.add(aabb, 1)
    # print(tree)

    tree_arr = []

    def f(tree_arr, tree):
        left = tree.left.num if tree.left else -1
        right = tree.right.num if tree.right else tree.value
        limits = [list(x) for x in tree.aabb.limits]
        l = [[2], limits[0] + [limits[1][0]], [left], [right], [0, 0, 0],
             [limits[1][1]] + limits[2]]
        tree_arr += [l]

    def g(tree, num):
        tree.num = num['num']
        num['num'] += 1

    num = {'num': 0}
    traverse(tree, lambda x: g(x, num))
    traverse(tree, lambda x: f(tree_arr, x))

    window_dict.num_aabb = len(tree_arr)
    object_data = []
    object_data += [[[0], [0, 0, 10], [1], [0], [1, 0, 0], [0, 0, 0]]]
    object_data += [[[1], [0, -1, 1], [0], [0], [1, 0, 1], [0, 1, 0]]]
    object_data += tree_arr

    object_buffer.init(object_data)
    print(tree.depth)

    #for e, i in zip(object_data, range(len(object_data))):
    #    print(i, e)

    #aabb_buf.init(tree_arr)

    # aabb_buf.reserve(1)
    # aabb_buf[0] = [[1, 0], [0, 0], [0, 0], [0], [0]]

    window_dict.num_triangles = len(data)
    tr_buffer.init(data)

    return window, window_dict, object_buffer, tr_buffer, shader, quad
示例#15
0
def test_overlaps_closed():
    aabb1 = AABB([(0, 0)])
    aabb2 = AABB([(-1, 0)])
    aabb3 = AABB([(1, 2)])
    aabb4 = AABB([(-9, -8)])

    assert aabb1.overlaps(aabb2, True)
    assert aabb2.overlaps(aabb1, True)
    assert not aabb1.overlaps(aabb3, True)
    assert not aabb2.overlaps(aabb3, True)
    assert not aabb1.overlaps(aabb4, True)
    assert not aabb2.overlaps(aabb4, True)
    assert not aabb1.overlaps(AABB(), True)
示例#16
0
 def get_aabb(self, point, bound):
     return AABB([(point.x - bound[0], point.x + bound[0]),
                  (point.y - bound[1], point.y + bound[1]),
                  (point.z - bound[2], point.z + bound[2])])
示例#17
0
def endpoint_statistics(path_to_svg):
    '''
    Given:
        path_to_svg: A path to an SVG file.
    
    Normalizes by the svg's long edge as defined by its viewBox.
    
    Ignores <svg> width or height attributes.

    '''

    global_scale = 1.0

    try:
        doc = Document(path_to_svg)
        flatpaths = doc.flatten_all_paths()
        paths = [path for (path, _, _) in flatpaths]
    except:
        global_scale = get_global_scale(doc.tree)
        ## Let's truly fail if there are transform nodes we can't handle.
        # try: global_scale = get_global_scale( doc.tree )
        # except: print( "WARNING: There are transforms, but flatten_all_paths() failed. Falling back to unflattened paths and ignoring transforms.", file = sys.stderr )

        paths, _ = svg2paths(path_to_svg)

    ## First pass: Gather all endpoints, path index, segment index, t value
    endpoints = []  # a copy of point coordinations
    endpoints_p = [
    ]  # real points, we will do the snapping by changing points in this list
    endpoint_addresses = []

    for path_index, path in enumerate(paths):
        for seg_index, seg in enumerate(path):
            for t in (0, 1):
                pt = seg.point(t)
                endpoints.append((pt.real, pt.imag))
                endpoint_addresses.append((path_index, seg_index, t))

    print("Creating spatial data structures:")
    ## Point-point queries.
    dist_finder = scipy.spatial.cKDTree(endpoints)
    ## Build an axis-aligned bounding box tree for the segments.
    bbtree = AABBTree()  # but, why?
    # for path_index, path in tqdm( enumerate( paths ), total = len( paths ), ncols = 50 ):
    for path_index, path in enumerate(paths):
        for seg_index, seg in enumerate(path):
            xmin, xmax, ymin, ymax = seg.bbox(
            )  # record bbox of each segmentation?
            bbtree.add(AABB([(xmin, xmax), (ymin, ymax)]),
                       (path_index, seg_index, seg))

    # Second pass: Gather all minimum distances
    print("Finding minimum distances:")

    minimum_distances = []
    for i, (pt, (path_index, seg_index,
                 t)) in enumerate(zip(endpoints, endpoint_addresses)):
        ## 1. Find the minimum distance to any other endpoints

        ## Find two closest points, since the point itself is in dist_finder with distance 0.
        mindist, closest_pt_indices = dist_finder.query([pt], k=2)

        ## These come back as 1-by-2 matrices.
        mindist = mindist[0]
        closest_pt_indices = closest_pt_indices[0]
        ## If we didn't find 2 points, then pt is the only point in this file.
        ## There is no point element in SVG, so that should never happen.

        assert len(closest_pt_indices) == 2
        ## If there are two or more other points identical to pt,
        ## then pt might not actually be one of the two returned, but both distances
        ## should be zero.
        assert i in closest_pt_indices or (mindist < eps).all()
        assert min(mindist) <= eps

        ## The larger distance corresponds to the point that is not pt.
        mindist = max(mindist)

        ## If we already found the minimum distance is 0, then there's no point also
        ## searching for T-junctions.
        if mindist < eps:
            minimum_distances.append(mindist)
            continue

        ## 2. Find the closest point on any other paths (T-junction).
        ## We are looking for any segments closer than mindist to pt.
        # why? why mindist?
        query = AABB([(pt[0] - mindist, pt[0] + mindist),
                      (pt[1] - mindist, pt[1] + mindist)])

        for other_path_index, other_seg_index, seg in bbtree.overlap_values(
                query):
            ## Don't compare the point with its own segment.
            if other_path_index == path_index and other_seg_index == seg_index:
                continue

            ## Optimization: If the distance to the bounding box is larger
            ## than mindist, skip it.
            ## This is still relevant, because mindist will shrink as we iterate over
            ## the results of our AABB tree query.
            # why? this is also not reasonable to me
            xmin, xmax, ymin, ymax = seg.bbox()
            if (pt[0] < xmin - mindist or pt[0] > xmax + mindist
                    or pt[1] < ymin - mindist or pt[1] > ymin + mindist):
                continue

            ## Get the point to segment distance.
            dist_to_other_path = distance_point_to_segment(pt, seg)
            ## Keep it if it's smaller.
            if mindist is None or dist_to_other_path < mindist:
                mindist = dist_to_other_path

            ## Terminate early if the minimum distance found already is 0
            if mindist < eps: break

        ## Accumulate the minimum distance
        minimum_distances.append(mindist)

    minimum_distances = global_scale * asfarray(minimum_distances)

    ## Divide by long edge.
    if 'viewBox' in doc.root.attrib:
        import re
        _, _, width, height = [
            float(v)
            for v in re.split('[ ,]+', doc.root.attrib['viewBox'].strip())
        ]
        long_edge = max(width, height)
        print("Normalizing by long edge:", long_edge)
        minimum_distances /= long_edge
    elif "width" in doc.root.attrib and "height" in doc.root.attrib:
        width = doc.root.attrib["width"].strip().strip("px")
        height = doc.root.attrib["height"].strip().strip("px")
        long_edge = max(float(width), float(height))
        print("Normalizing by long edge:", long_edge)
        minimum_distances /= long_edge
    else:
        print(
            "WARNING: No viewBox found in <svg>. Not normalizing by long edge."
        )
    print("Done")
    return minimum_distances
示例#18
0
    def t_junction_close_recursive(pt_with_index, distance, pre_seg_with_index,
                                   paths, bbtree, depth, snapped):
        # Search in AABB tree for overlap bboxes
        if depth == 0: return paths
        if pt_with_index[1] in snapped: return paths

        depth = depth - 1
        bbox_edge = 2 * distance
        pt = pt_with_index[0]
        path_index, seg_index, t = pt_with_index[1]
        query = AABB([(pt[0] - bbox_edge, pt[0] + bbox_edge),
                      (pt[1] - bbox_edge, pt[1] + bbox_edge)])
        min_j_dist = float('inf')
        min_t = None
        min_path_index = None
        min_seg_index = None
        min_seg = None
        for other_path_index, other_seg_index, seg in bbtree.overlap_values(
                query):
            if other_path_index == path_index and other_seg_index == seg_index:
                continue
            try:
                j_dist, j_t = seg.radialrange(complex(*pt))[0]
            except Exception as e:
                print(str(e))
                continue
            if min_j_dist > j_dist:
                min_j_dist = j_dist
                min_t = j_t
                min_path_index = other_path_index
                min_seg_index = other_seg_index
                min_seg = seg

        # if find target segment
        if min_j_dist < distance and min_j_dist > eps:
            # if the fixment of current pt(seg) depends on pre_seg, then fixment of pre_seg also depends on current pt
            # in other word, seg and pre_seg need to be fixed that the same time
            if (min_path_index, min_seg_index) == pre_seg_with_index[1]:
                # find closest point of two segmet endpoints to each other
                # cloeset point on min_seg(pre_seg) to cur_seg endpoint
                point1 = min_seg.point(min_t)
                # cloeset point on cur_seg to min_seg(pre_seg) endpoint
                t1 = 0 if min_t < 0.5 else 1
                dist2, t2 = paths[path_index][seg_index].radialrange(
                    min_seg.point(t1))[0]
                point2 = paths[path_index][seg_index].point(t2)

                # point2 should also satisfy the distance requirement
                assert (dist2 < distance and dist2 > eps)
                # fix both segments
                avg_point = (point1 + point2) / 2
                # set current segment
                if t == 0:
                    new_seg_1 = set_segment_by_point(
                        paths[path_index][seg_index], start=avg_point)
                elif t == 1:
                    new_seg_1 = set_segment_by_point(
                        paths[path_index][seg_index], end=avg_point)
                else:
                    raise ValueError("Invalid point value %f" % t)
                paths[path_index][seg_index] = new_seg_1

                # set previous segment
                if t1 == 0:
                    new_seg_2 = set_segment_by_point(min_seg, start=avg_point)
                elif t1 == 1:
                    new_seg_2 = set_segment_by_point(min_seg, end=avg_point)
                else:
                    raise ValueError("Invalid point value %f" % t1)
                paths[min_path_index][min_seg_index] = new_seg_2

                # return result
                return paths

            else:
                org_seg = paths[path_index][seg_index]
                # call it self recursively by two endpoints of min_seg to find if there is addtional dependency
                pt_start = ((min_seg.start.real, min_seg.start.imag),
                            (min_path_index, min_seg_index, 0))
                pre_seg = (paths[path_index][seg_index], (path_index,
                                                          seg_index))
                paths = t_junction_close_recursive(pt_start, distance, pre_seg,
                                                   paths, bbtree, depth,
                                                   snapped)

                pt_end = ((min_seg.end.real, min_seg.end.imag),
                          (min_path_index, min_seg_index, 1))
                paths = t_junction_close_recursive(pt_end, distance, pre_seg,
                                                   paths, bbtree, depth,
                                                   snapped)

                # generate current new segments after all denpent segments are fixed
                if org_seg == paths[path_index][seg_index]:
                    t_point = paths[min_path_index][min_seg_index].point(min_t)
                    if t == 0:
                        new_seg = set_segment_by_point(
                            paths[path_index][seg_index], start=t_point)
                    elif t == 1:
                        new_seg = set_segment_by_point(
                            paths[path_index][seg_index], end=t_point)
                    else:
                        raise ValueError("Invalid point value %f" % t)
                    paths[path_index][seg_index] = new_seg

                return paths
        else:
            # nothing need to change, return paths directly
            return paths
示例#19
0
def standard_aabbs():
    aabb1 = AABB([(0, 1), (0, 1)])
    aabb2 = AABB([(3, 4), (0, 1)])
    aabb3 = AABB([(5, 6), (5, 6)])
    aabb4 = AABB([(7, 8), (5, 6)])
    return [aabb1, aabb2, aabb3, aabb4]
示例#20
0
def test_next():
    box = [(0, 1), (0, 1)]
    aabb = AABB(box)
    aabb._i = 2 + 1
    with pytest.raises(StopIteration):
        aabb.__next__()
 def insert(self, element):
     a, b = element.get_axis_aligned_bounding_box()
     aabb = AABB([(a[0], b[0]), (a[1], b[1])])
     self.data.add(aabb, element)
示例#22
0
def test_eq():
    limits1 = [(-2, 3), (1, 2.3)]
    limits2 = [(-2, 3), (2, 2.3)]

    aabb1 = AABB(limits1)
    aabb2 = AABB(limits2)
    aabb3 = AABB(limits2)

    assert aabb1 == aabb1
    assert aabb1 != aabb2
    assert aabb2 != aabb1
    assert aabb2 == aabb3

    assert aabb1 != limits1
    assert AABB([(2, 3)]) != aabb1
    assert AABB() == AABB()
    assert aabb1 != AABB()
    assert AABB() != aabb1

    assert not aabb1 != aabb1
    assert not aabb1 == aabb2
    assert not aabb2 == aabb1
    assert not aabb2 != aabb3

    assert not aabb1 == limits1
    assert not AABB([(2, 3)]) == aabb1
    assert not AABB() != AABB()
    assert not aabb1 == AABB()
    assert not AABB() == aabb1
示例#23
0
def virus(bSaveGif=False):
    def scatterLegend(personlist, x, y):
        type0 = []
        type1 = []
        type2 = []
        type3 = []
        type4 = []
        type5 = []
        for aperson in personlist:
            if aperson.status == 0:
                type0.append(np.array((aperson.posionX, aperson.posionY)))
            elif aperson.status == 1:
                type1.append(np.array((aperson.posionX, aperson.posionY)))
            elif aperson.status == 2:
                type2.append(np.array((aperson.posionX, aperson.posionY)))
            elif aperson.status == 3:
                type3.append(np.array((aperson.posionX, aperson.posionY)))
            elif aperson.status == 4:
                type4.append(np.array((aperson.posionX, aperson.posionY)))
            elif aperson.status == 5:
                type5.append(np.array((aperson.posionX, aperson.posionY)))

        type0 = np.array(type0)
        type1 = np.array(type1)
        type2 = np.array(type2)
        type3 = np.array(type3)
        type4 = np.array(type4)
        type5 = np.array(type5)

        # 会有空的情况,就要做一下
        handles = []
        labels = []
        if (len(type0) != 0):
            g0 = plt.scatter(type0[:, x],
                             type0[:, y],
                             color='darkblue',
                             marker='.')
            handles.append(g0)
            labels.append('易感者')
        if (len(type1) != 0):
            g1 = plt.scatter(type1[:, x], type1[:, y], c='red', marker='.')
            handles.append(g1)
            labels.append('感染者')
        if (len(type2) != 0):
            g2 = plt.scatter(type2[:, x], type2[:, y], c='orange', marker='.')
            handles.append(g2)
            labels.append('潜伏者')
        if (len(type3) != 0):
            g3 = plt.scatter(type3[:, x], type3[:, y], c='green', marker='.')
            handles.append(g3)
            labels.append('康复者')
        if (len(type4) != 0):
            g4 = plt.scatter(type4[:, x],
                             type4[:, y],
                             c='lightgreen',
                             marker='.')
            handles.append(g4)
            labels.append('自愈者')
        if (len(type5) != 0):
            g5 = plt.scatter(type5[:, x], type5[:, y], c='gray', marker='x')
            handles.append(g5)
            labels.append('死亡者')
        #plt.legend(handles=[g0, g1, g2,  g3], labels=['Susceptible', 'Infection', 'Exposed', 'Recovery'])
        plt.legend(handles=handles, labels=labels, loc=1)

    def drawFig():
        plt.cla()
        plt.plot(nSuscept, color='darkblue', label='Susceptible', marker='.')
        plt.plot(nInfect, color='red', label='Infection', marker='.')
        plt.plot(nExposed, color='orange', label='Exposed', marker='.')
        plt.plot(nRecovery, color='green', label='Recovery', marker='.')
        plt.plot(nSelfRecovery,
                 color='lightgreen',
                 label='SelfRecovery',
                 marker='.')
        plt.plot(nDeath, color='grey', label='Death', marker='.')
        #plt.title('SEIR Model')
        plt.legend(loc=1)
        plt.xlabel('Day')
        plt.ylabel('Number')

    def DoStep(day, personlist):
        nDaySuscept = 0
        nDayExposed = 0
        nDayInfect = 0
        nDayRecovery = 0
        nDaySelfRecovery = 0
        nDayDeath = 0

        global tree
        for aperson in personlist:
            if aperson.status == 5:  # has death
                pass
            else:
                aperson.move(MoveStep)

        if gUsePP:
            parts = 4
            start = 0
            end = len(personlist)
            step = int((end - start) / parts + 1)

            jobs = []

            for index in range(parts):
                starti = start + index * step
                endi = min(start + (index + 1) * step, end)
                # Submit a job which will calculate partial sum
                # part_sum - the function
                # (starti, endi) - tuple with arguments for part_sum
                # () - tuple with functions on which function part_sum depends
                # () - tuple with module names which must be
                #      imported before part_sum execution
                jobs.append(
                    job_server.submit(infectOtherPP,
                                      (day, personlist, starti, endi)))
            #job_server.print_stats()

        else:
            for aperson in personlist:
                aperson.infectOther(day, personlist)
        for aperson in personlist:
            aperson.update(day)
            if aperson.status == 0:
                nDaySuscept += 1
            if aperson.status == 1:
                nDayInfect += 1
            if aperson.status == 2:
                nDayExposed += 1
            if aperson.status == 3:
                nDayRecovery += 1
            if aperson.status == 4:
                nDaySelfRecovery += 1
            if aperson.status == 5:
                nDayDeath += 1

        nSuscept.append(nDaySuscept)
        nInfect.append(nDayInfect)
        nExposed.append(nDayExposed)
        nRecovery.append(nDayRecovery)
        nSelfRecovery.append(nDaySelfRecovery)
        nDeath.append(nDayDeath)


#

    nSuscept = []
    nExposed = []
    nInfect = []
    nRecovery = []
    nSelfRecovery = []
    nDeath = []

    # 初始化每个人的位置,随机放置
    for i in range(0, N - 1):
        aperson = Person(i)
        personList.append(aperson)
        aabb1 = AABB([(aperson.posionX - gInfectDis // 2,
                       aperson.posionX + gInfectDis // 2),
                      (aperson.posionY - gInfectDis // 2,
                       aperson.posionY + gInfectDis // 2)])
        tree.add(aabb1, aperson.id)

    # 初始感染
    for i in range(0, gN0 - 1):
        personList[i].status = 1
    for i in range(0, MaxTime):
        DoStep(i, personList)
        if (i % int(MaxTime / frames) == 0):
            plt.subplot(2, 1, 1)  #两行一列的子图,目前在第一个位置画。
            plt.cla()
            label = 'Day {0}'.format(i)
            plt.title(label)
            scatterLegend(personList, 0, 1)
            plt.subplot(2, 1, 2)  #两行一列的子图,目前在第二个位置画。
            drawFig()
            if (bSaveGif):
                filename = gtmpFolder + r"outbreak" + str(i) + ".png"
                plt.savefig(filename, dpi=150, bbox_inches='tight')
                keyFrames.append(filename)
            else:
                plt.pause(0.01)
示例#24
0
def test_str():
    limits = [(2, 3), (-20, 24), (2.3, 6.71)]
    assert str(limits) == str(AABB(limits))