コード例 #1
0
def test_channel_segment_set_nodes(nodes):
    segment = ChannelSegment([0, 1])

    segment.nodes = nodes
    assert segment.downstream_node == nodes[0]
    assert segment.upstream_node == nodes[-1]
    assert_array_equal(segment.nodes, nodes)
    assert len(segment) == len(nodes)
コード例 #2
0
def test_connector_add_downstream():
    segment_1 = ChannelSegment([0, 1])
    segment_2 = ChannelSegment([2, 0])
    connector = ChannelSegmentConnector(segment_1)

    connector.add(segment_2)
    assert connector.root is segment_2
    assert connector.root.count_segments(direction="upstream") == 1
    assert connector.root.downstream is None
コード例 #3
0
def test_channel_segment_for_each(nodes):
    all_nodes = []

    def collect_nodes(segment):
        all_nodes.extend(list(segment.nodes))

    segment = ChannelSegment(nodes)
    segment.for_each(collect_nodes)

    assert_array_equal(all_nodes, segment.nodes)
コード例 #4
0
def test_channel_segment_add_downstream_node():
    segment = ChannelSegment([0, 1])
    downstream = ChannelSegment([5, 6])
    assert segment.downstream is None
    assert len(downstream.upstream) == 0

    segment.downstream = downstream

    assert segment.downstream is downstream
    assert segment in downstream.upstream
コード例 #5
0
def test_reindex_segment_nodes_with_downstream(nodes, last_node):
    root = ChannelSegment([0, 1])
    segment = ChannelSegment(nodes)
    root.add_upstream(segment)

    reindex = SegmentNodeReindexer(nodes=[last_node])
    reindex(segment)

    assert segment.nodes[0] == root.nodes[-1]
    assert segment.nodes[1:] == list(range(last_node + 1, last_node + len(nodes)))
コード例 #6
0
def test_channel_segment_add_upstream_node():
    segment = ChannelSegment([0, 1])
    upstream = ChannelSegment([5, 6])
    assert len(segment.upstream) == 0
    assert upstream.downstream is None

    segment.add_upstream(upstream)

    assert upstream in segment.upstream
    assert upstream.downstream is segment
コード例 #7
0
def test_connector_add_upstream():
    segment = ChannelSegment([0, 1])
    connector = ChannelSegmentConnector(segment)

    assert connector.root is segment
    assert len(connector.orphans) == 0

    connector.add(ChannelSegment([1, 2]))
    assert connector.root is segment
    assert connector.root.count_segments(direction="upstream") == 1
    assert connector.root.downstream is None
コード例 #8
0
def test_create_links_with_downstream(nodes):
    root = ChannelSegment([0, 1])
    segment = ChannelSegment(nodes)
    root.add_upstream(segment)

    collect_links = SegmentLinkCollector()
    collect_links(segment)
    links = collect_links.links

    assert len(links) == len(segment) - 1
    assert links[0] == (root.nodes[-1], segment.nodes[1])

    if len(links) > 1:
        heads, tails = zip(*links[1:])
        assert list(heads) == nodes[1:-1]
        assert list(tails) == nodes[2:]
コード例 #9
0
def test_connector_add_orphan():
    segment_1 = ChannelSegment([0, 1])
    segment_2 = ChannelSegment([2, 3])
    connector = ChannelSegmentConnector(segment_1)

    connector.add(segment_2)

    assert connector.root is segment_1
    assert connector.root.count_segments(direction="upstream") == 0
    assert connector.root.downstream is None

    assert len(connector.orphans) == 1
    assert connector.orphans == (segment_2,)

    connector.add(ChannelSegment([1, 2]))
    assert connector.root.count_segments(direction="upstream") == 2
    assert connector.orphans == ()
コード例 #10
0
def test_channel_segment_many_upstream(segments):
    segments = [ChannelSegment(segment) for segment in segments]
    root = segments[0]
    for current, next in pairwise(segments):
        current.add_upstream(next)

    assert root.count_segments(direction="upstream") == len(segments) - 1
    assert root.count_segments(direction="downstream") == 0
コード例 #11
0
def test_channel_segment_many_flat_upstream(segments):
    segments = [ChannelSegment(segment) for segment in segments]
    root = segments[0]
    for segment in segments[1:]:
        root.add_upstream(segment)
    assert root.downstream is None
    assert len(root.upstream) == len(segments) - 1

    assert root.count_segments(direction="upstream") == len(segments) - 1
    assert root.count_segments(direction="downstream") == 0
コード例 #12
0
def test_create_links(nodes):
    segment = ChannelSegment(nodes)

    collect_links = SegmentLinkCollector()
    collect_links(segment)
    links = collect_links.links

    assert len(links) == len(segment) - 1
    heads, tails = zip(*links)
    assert list(heads) == nodes[:-1]
    assert list(tails) == nodes[1:]
コード例 #13
0
def test_create_links_with_existing(nodes):
    segment = ChannelSegment(nodes)

    collect_links = SegmentLinkCollector(links=[(1, 2), (3, 4)])
    collect_links(segment)
    links = collect_links.links

    assert links[:2] == [(1, 2), (3, 4)]
    assert len(links[2:]) == len(segment) - 1
    heads, tails = zip(*links[2:])
    assert list(heads) == nodes[:-1]
    assert list(tails) == nodes[1:]
コード例 #14
0
def test_channel_segment_many_downstream(segments):
    segments = [ChannelSegment(segment) for segment in segments]
    root = segments[0]
    for current, next in pairwise(segments):
        current.downstream = next
    root = segments[0]
    leaf = segments[-1]

    assert root.count_segments(direction="upstream") == 0
    assert root.count_segments(direction="downstream") == len(segments) - 1

    assert leaf.count_segments(direction="upstream") == len(segments) - 1
    assert leaf.count_segments(direction="downstream") == 0
コード例 #15
0
def test_construct_xy_of_node(shape_and_segment):
    shape, segment = shape_and_segment
    grid = RasterModelGrid(shape)

    collect_coordinates = SegmentNodeCoordinateCollector(grid)
    collect_coordinates(ChannelSegment(segment))
    xy_of_node = collect_coordinates.xy_of_node

    assert len(xy_of_node) == len(segment)

    x_of_node, y_of_node = zip(*xy_of_node)
    assert (x_of_node == grid.x_of_node[segment]).all()
    assert (y_of_node == grid.y_of_node[segment]).all()
コード例 #16
0
def test_reindex_segment_nodes_orphan(nodes):
    segment = ChannelSegment(nodes)
    reindex = SegmentNodeReindexer()
    reindex(segment)
    assert segment.nodes == list(range(len(nodes)))
コード例 #17
0
def test_reindex_segment_nodes_with_last_node(nodes, last_node):
    segment = ChannelSegment(nodes)
    reindex = SegmentNodeReindexer(nodes=[last_node])
    reindex(segment)
    assert segment.nodes == list(range(last_node + 1, last_node + 1 + len(nodes)))