def test_channel_segment_set_nodes(nodes): segment = ChannelSegment([0, 1]) segment.nodes = nodes assert segment.downstream_node == nodes[0] assert segment.upstream_node == nodes[-1] assert_array_equal(segment.nodes, nodes) assert len(segment) == len(nodes)
def test_connector_add_downstream(): segment_1 = ChannelSegment([0, 1]) segment_2 = ChannelSegment([2, 0]) connector = ChannelSegmentConnector(segment_1) connector.add(segment_2) assert connector.root is segment_2 assert connector.root.count_segments(direction="upstream") == 1 assert connector.root.downstream is None
def test_channel_segment_for_each(nodes): all_nodes = [] def collect_nodes(segment): all_nodes.extend(list(segment.nodes)) segment = ChannelSegment(nodes) segment.for_each(collect_nodes) assert_array_equal(all_nodes, segment.nodes)
def test_channel_segment_add_downstream_node(): segment = ChannelSegment([0, 1]) downstream = ChannelSegment([5, 6]) assert segment.downstream is None assert len(downstream.upstream) == 0 segment.downstream = downstream assert segment.downstream is downstream assert segment in downstream.upstream
def test_reindex_segment_nodes_with_downstream(nodes, last_node): root = ChannelSegment([0, 1]) segment = ChannelSegment(nodes) root.add_upstream(segment) reindex = SegmentNodeReindexer(nodes=[last_node]) reindex(segment) assert segment.nodes[0] == root.nodes[-1] assert segment.nodes[1:] == list(range(last_node + 1, last_node + len(nodes)))
def test_channel_segment_add_upstream_node(): segment = ChannelSegment([0, 1]) upstream = ChannelSegment([5, 6]) assert len(segment.upstream) == 0 assert upstream.downstream is None segment.add_upstream(upstream) assert upstream in segment.upstream assert upstream.downstream is segment
def test_connector_add_upstream(): segment = ChannelSegment([0, 1]) connector = ChannelSegmentConnector(segment) assert connector.root is segment assert len(connector.orphans) == 0 connector.add(ChannelSegment([1, 2])) assert connector.root is segment assert connector.root.count_segments(direction="upstream") == 1 assert connector.root.downstream is None
def test_create_links_with_downstream(nodes): root = ChannelSegment([0, 1]) segment = ChannelSegment(nodes) root.add_upstream(segment) collect_links = SegmentLinkCollector() collect_links(segment) links = collect_links.links assert len(links) == len(segment) - 1 assert links[0] == (root.nodes[-1], segment.nodes[1]) if len(links) > 1: heads, tails = zip(*links[1:]) assert list(heads) == nodes[1:-1] assert list(tails) == nodes[2:]
def test_connector_add_orphan(): segment_1 = ChannelSegment([0, 1]) segment_2 = ChannelSegment([2, 3]) connector = ChannelSegmentConnector(segment_1) connector.add(segment_2) assert connector.root is segment_1 assert connector.root.count_segments(direction="upstream") == 0 assert connector.root.downstream is None assert len(connector.orphans) == 1 assert connector.orphans == (segment_2,) connector.add(ChannelSegment([1, 2])) assert connector.root.count_segments(direction="upstream") == 2 assert connector.orphans == ()
def test_channel_segment_many_upstream(segments): segments = [ChannelSegment(segment) for segment in segments] root = segments[0] for current, next in pairwise(segments): current.add_upstream(next) assert root.count_segments(direction="upstream") == len(segments) - 1 assert root.count_segments(direction="downstream") == 0
def test_channel_segment_many_flat_upstream(segments): segments = [ChannelSegment(segment) for segment in segments] root = segments[0] for segment in segments[1:]: root.add_upstream(segment) assert root.downstream is None assert len(root.upstream) == len(segments) - 1 assert root.count_segments(direction="upstream") == len(segments) - 1 assert root.count_segments(direction="downstream") == 0
def test_create_links(nodes): segment = ChannelSegment(nodes) collect_links = SegmentLinkCollector() collect_links(segment) links = collect_links.links assert len(links) == len(segment) - 1 heads, tails = zip(*links) assert list(heads) == nodes[:-1] assert list(tails) == nodes[1:]
def test_create_links_with_existing(nodes): segment = ChannelSegment(nodes) collect_links = SegmentLinkCollector(links=[(1, 2), (3, 4)]) collect_links(segment) links = collect_links.links assert links[:2] == [(1, 2), (3, 4)] assert len(links[2:]) == len(segment) - 1 heads, tails = zip(*links[2:]) assert list(heads) == nodes[:-1] assert list(tails) == nodes[1:]
def test_channel_segment_many_downstream(segments): segments = [ChannelSegment(segment) for segment in segments] root = segments[0] for current, next in pairwise(segments): current.downstream = next root = segments[0] leaf = segments[-1] assert root.count_segments(direction="upstream") == 0 assert root.count_segments(direction="downstream") == len(segments) - 1 assert leaf.count_segments(direction="upstream") == len(segments) - 1 assert leaf.count_segments(direction="downstream") == 0
def test_construct_xy_of_node(shape_and_segment): shape, segment = shape_and_segment grid = RasterModelGrid(shape) collect_coordinates = SegmentNodeCoordinateCollector(grid) collect_coordinates(ChannelSegment(segment)) xy_of_node = collect_coordinates.xy_of_node assert len(xy_of_node) == len(segment) x_of_node, y_of_node = zip(*xy_of_node) assert (x_of_node == grid.x_of_node[segment]).all() assert (y_of_node == grid.y_of_node[segment]).all()
def test_reindex_segment_nodes_orphan(nodes): segment = ChannelSegment(nodes) reindex = SegmentNodeReindexer() reindex(segment) assert segment.nodes == list(range(len(nodes)))
def test_reindex_segment_nodes_with_last_node(nodes, last_node): segment = ChannelSegment(nodes) reindex = SegmentNodeReindexer(nodes=[last_node]) reindex(segment) assert segment.nodes == list(range(last_node + 1, last_node + 1 + len(nodes)))