def test_merge_basics(self): voxel_size = (1, 1, 1) GraphKey("PRESYN") ArrayKey("GT_LABELS") graphsource = GraphTestSource(voxel_size) arraysource = ArrayTestSoure(voxel_size) pipeline = (graphsource, arraysource) + MergeProvider() + RandomLocation() window_request = Coordinate((50, 50, 50)) with build(pipeline): # Check basic merging. request = BatchRequest() request.add((GraphKeys.PRESYN), window_request) request.add((ArrayKeys.GT_LABELS), window_request) batch_res = pipeline.request_batch(request) self.assertTrue(ArrayKeys.GT_LABELS in batch_res.arrays) self.assertTrue(GraphKeys.PRESYN in batch_res.graphs) # Check that request of only one source also works. request = BatchRequest() request.add((GraphKeys.PRESYN), window_request) batch_res = pipeline.request_batch(request) self.assertFalse(ArrayKeys.GT_LABELS in batch_res.arrays) self.assertTrue(GraphKeys.PRESYN in batch_res.graphs) # Check that it fails, when having two sources that provide the same type. arraysource2 = ArrayTestSoure(voxel_size) pipeline_fail = (arraysource, arraysource2) + MergeProvider() + RandomLocation() with self.assertRaises(PipelineSetupError): with build(pipeline_fail): pass
def test_transpose(): voxel_size = Coordinate((20, 20)) graph_key = GraphKey("GRAPH") array_key = ArrayKey("ARRAY") graph = Graph( [Node(id=1, location=np.array([450, 550]))], [], GraphSpec(roi=Roi((100, 200), (800, 600))), ) data = np.zeros([40, 30]) data[17, 17] = 1 array = Array( data, ArraySpec(roi=Roi((100, 200), (800, 600)), voxel_size=voxel_size)) default_pipeline = ( (GraphSource(graph_key, graph), ArraySource(array_key, array)) + MergeProvider() + SimpleAugment( mirror_only=[], transpose_only=[0, 1], transpose_probs=[0, 0])) transpose_pipeline = ( (GraphSource(graph_key, graph), ArraySource(array_key, array)) + MergeProvider() + SimpleAugment( mirror_only=[], transpose_only=[0, 1], transpose_probs=[1, 1])) request = BatchRequest() request[graph_key] = GraphSpec(roi=Roi((400, 500), (200, 300))) request[array_key] = ArraySpec(roi=Roi((400, 500), (200, 300))) with build(default_pipeline): expected_location = [450, 550] batch = default_pipeline.request_batch(request) assert len(list(batch[graph_key].nodes)) == 1 node = list(batch[graph_key].nodes)[0] assert all(np.isclose(node.location, expected_location)) node_voxel_index = Coordinate( (node.location - batch[array_key].spec.roi.get_offset()) / voxel_size) assert ( batch[array_key].data[node_voxel_index] == 1 ), f"Node at {np.where(batch[array_key].data == 1)} not {node_voxel_index}" with build(transpose_pipeline): expected_location = [410, 590] batch = transpose_pipeline.request_batch(request) assert len(list(batch[graph_key].nodes)) == 1 node = list(batch[graph_key].nodes)[0] assert all(np.isclose(node.location, expected_location)) node_voxel_index = Coordinate( (node.location - batch[array_key].spec.roi.get_offset()) / voxel_size) assert ( batch[array_key].data[node_voxel_index] == 1 ), f"Node at {np.where(batch[array_key].data == 1)} not {node_voxel_index}"
def test_output(self): a = ArrayKey("A") b = ArrayKey("B") source_a = TestSourceRandomLocation(a) source_b = TestSourceRandomLocation(b) pipeline = (source_a, source_b) + \ MergeProvider() + CustomRandomLocation() with build(pipeline): for i in range(10): batch = pipeline.request_batch( BatchRequest({ a: ArraySpec(roi=Roi((0, 0, 0), (20, 20, 20))), b: ArraySpec(roi=Roi((0, 0, 0), (20, 20, 20))) })) self.assertTrue(np.sum(batch.arrays[a].data) > 0) self.assertTrue(np.sum(batch.arrays[b].data) > 0) # Request a ROI with the same shape as the entire ROI full_roi_a = Roi((0, 0, 0), source_a.roi.get_shape()) full_roi_b = Roi((0, 0, 0), source_b.roi.get_shape()) batch = pipeline.request_batch( BatchRequest({ a: ArraySpec(roi=full_roi_a), b: ArraySpec(roi=full_roi_b) }))
def test_square(self): test_graph = GraphKey("TEST_GRAPH") test_array1 = ArrayKey("TEST_ARRAY1") test_array2 = ArrayKey("TEST_ARRAY2") pipeline = ((ArrayTestSource(), TestSource()) + MergeProvider() + Pad(test_array1, None) + Pad(test_array2, None) + Pad(test_graph, None) + SimpleAugment( mirror_only=[1,2], transpose_only=[1,2] )) request = BatchRequest() request[GraphKeys.TEST_GRAPH] = GraphSpec(roi=Roi((0, 50, 65), (100, 100, 100))) request[ArrayKeys.TEST_ARRAY1] = ArraySpec(roi=Roi((0, 0, 15), (100, 200, 200))) request[ArrayKeys.TEST_ARRAY2] = ArraySpec(roi=Roi((0, 50, 65), (100, 100, 100))) with build(pipeline): for i in range(100): batch = pipeline.request_batch(request) assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1 for (array_key, array) in batch.arrays.items(): assert batch.arrays[array_key].data.shape == batch.arrays[array_key].spec.roi.get_shape()
def test_realistic_valid_examples(example, use_gurobi): penalty_attr = "penalty" location_attr = "location" example_dir = Path(__file__).parent / "mouselight_examples" / "valid" / example consensus = PointsKey("CONSENSUS") skeletonization = PointsKey("SKELETONIZATION") matched = PointsKey("MATCHED") matched_with_fallback = PointsKey("MATCHED_WITH_FALLBACK") inf_roi = Roi(Coordinate((None,) * 3), Coordinate((None,) * 3)) request = BatchRequest() request[matched] = PointsSpec(roi=inf_roi) request[matched_with_fallback] = PointsSpec(roi=inf_roi) request[consensus] = PointsSpec(roi=inf_roi) pipeline = ( ( GraphSource(example_dir / "graph.obj", [skeletonization]), GraphSource(example_dir / "tree.obj", [consensus]), ) + MergeProvider() + TopologicalMatcher( skeletonization, consensus, matched, expected_edge_len=10, match_distance_threshold=76, max_gap_crossing=48, use_gurobi=use_gurobi, location_attr=location_attr, penalty_attr=penalty_attr, ) + TopologicalMatcher( skeletonization, consensus, matched_with_fallback, expected_edge_len=10, match_distance_threshold=76, max_gap_crossing=48, use_gurobi=use_gurobi, location_attr=location_attr, penalty_attr=penalty_attr, with_fallback=True, ) ) with build(pipeline): batch = pipeline.request_batch(request) consensus_ccs = list(batch[consensus].connected_components) matched_with_fallback_ccs = list(batch[matched_with_fallback].connected_components) matched_ccs = list(batch[matched].connected_components) assert len(matched_ccs) == len(consensus_ccs)
def test_multi_transpose(self): test_graph = GraphKey("TEST_GRAPH") test_array1 = ArrayKey("TEST_ARRAY1") test_array2 = ArrayKey("TEST_ARRAY2") point = np.array([50, 70, 100]) transpose_dims = [0, 1, 2] pipeline = (ArrayTestSource(), ExampleSource()) + MergeProvider() + SimpleAugment( mirror_only=[], transpose_only=transpose_dims) request = BatchRequest() offset = (0, 20, 33) request[GraphKeys.TEST_GRAPH] = GraphSpec( roi=Roi(offset, (100, 100, 120))) request[ArrayKeys.TEST_ARRAY1] = ArraySpec( roi=Roi((0, 0, 0), (100, 200, 300))) request[ArrayKeys.TEST_ARRAY2] = ArraySpec( roi=Roi((0, 100, 250), (100, 100, 50))) # Create all possible permurations of our transpose dims transpose_combinations = list(permutations(transpose_dims, 3)) possible_loc = np.zeros((len(transpose_combinations), 3)) # Transpose points in all possible ways for i, comb in enumerate(transpose_combinations): possible_loc[i] = point[np.array(comb)] with build(pipeline): seen_transposed = False seen_node = True for i in range(100): batch = pipeline.request_batch(request) if len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1: seen_node = True node = list(batch[GraphKeys.TEST_GRAPH].nodes)[0] assert node.location in possible_loc seen_transposed = seen_transposed or any( [node.location[dim] != point[dim] for dim in range(3)]) assert Roi((0, 20, 33), (100, 100, 120)).contains( batch[GraphKeys.TEST_GRAPH].spec.roi) assert batch[GraphKeys.TEST_GRAPH].spec.roi.contains( node.location) for (array_key, array) in batch.arrays.items(): assert batch.arrays[array_key].data.shape == batch.arrays[ array_key].spec.roi.get_shape() assert seen_transposed assert seen_node
def test_pipeline3(self): array_key = ArrayKey("TEST_ARRAY") points_key = GraphKey("TEST_POINTS") voxel_size = Coordinate((1, 1)) spec = ArraySpec(voxel_size=voxel_size, interpolatable=True) hdf5_source = Hdf5Source(self.fake_data_file, {array_key: "testdata"}, array_specs={array_key: spec}) csv_source = CsvPointsSource( self.fake_points_file, points_key, GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))), ) request = BatchRequest() shape = Coordinate((60, 60)) request.add(array_key, shape, voxel_size=Coordinate((1, 1))) request.add(points_key, shape) shift_node = ShiftAugment(prob_slip=0.2, prob_shift=0.2, sigma=5, shift_axis=0) pipeline = ((hdf5_source, csv_source) + MergeProvider() + RandomLocation(ensure_nonempty=points_key) + shift_node) with build(pipeline) as b: request = b.request_batch(request) # print(request[points_key]) target_vals = [ self.fake_data[point[0]][point[1]] for point in self.fake_points ] result_data = request[array_key].data result_points = list(request[points_key].nodes) result_vals = [ result_data[int(point.location[0])][int(point.location[1])] for point in result_points ] for result_val in result_vals: self.assertTrue( result_val in target_vals, msg= "result value {} at points {} not in target values {} at points {}" .format( result_val, list(result_points), target_vals, self.fake_points, ), )
def test_impossible(self): a = ArrayKey("A") b = ArrayKey("B") source_a = TestSourceRandomLocation(a) source_b = TestSourceRandomLocation(b) pipeline = (source_a, source_b) + \ MergeProvider() + CustomRandomLocation() with build(pipeline): with self.assertRaises(AssertionError): batch = pipeline.request_batch( BatchRequest({ a: ArraySpec(roi=Roi((0, 0, 0), (200, 20, 20))), b: ArraySpec(roi=Roi((1000, 100, 100), (220, 22, 22))), }))
def test_two_transpose(self): test_graph = GraphKey("TEST_GRAPH") test_array1 = ArrayKey("TEST_ARRAY1") test_array2 = ArrayKey("TEST_ARRAY2") transpose_dims = [1, 2] pipeline = (ArrayTestSource(), ExampleSource()) + MergeProvider() + SimpleAugment( mirror_only=[], transpose_only=transpose_dims) request = BatchRequest() request[GraphKeys.TEST_GRAPH] = GraphSpec( roi=Roi((0, 20, 33), (100, 100, 120))) request[ArrayKeys.TEST_ARRAY1] = ArraySpec( roi=Roi((0, 0, 0), (100, 200, 300))) request[ArrayKeys.TEST_ARRAY2] = ArraySpec( roi=Roi((0, 100, 250), (100, 100, 50))) possible_loc = [[50, 50], [70, 100], [100, 70]] with build(pipeline): seen_transposed = False for i in range(100): batch = pipeline.request_batch(request) assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1 node = list(batch[GraphKeys.TEST_GRAPH].nodes)[0] logging.debug(node.location) assert all([ node.location[dim] in possible_loc[dim] for dim in range(3) ]) seen_transposed = seen_transposed or any([ node.location[dim] != possible_loc[dim][0] for dim in range(3) ]) assert Roi((0, 20, 33), (100, 100, 120)).contains( batch[GraphKeys.TEST_GRAPH].spec.roi) assert batch[GraphKeys.TEST_GRAPH].spec.roi.contains( node.location) for (array_key, array) in batch.arrays.items(): assert batch.arrays[array_key].data.shape == batch.arrays[ array_key].spec.roi.get_shape() assert seen_transposed
def test_impossible(self): a = ArrayKey("A") b = ArrayKey("B") null_key = ArrayKey("NULL") source_a = ExampleSourceRandomLocation(a) source_b = ExampleSourceRandomLocation(b) pipeline = ((source_a, source_b) + MergeProvider() + CustomRandomLocation(null_key)) with build(pipeline): with self.assertRaises(PipelineRequestError): batch = pipeline.request_batch( BatchRequest({ a: ArraySpec(roi=Roi((0, 0, 0), (200, 20, 20))), b: ArraySpec(roi=Roi((1000, 100, 100), (220, 22, 22))), }))
def test_two_disjoint_lines_intensity(self): LABEL_RADIUS = 3 RAW_RADIUS = 3 BLEND_SMOOTHNESS = 3 bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256])) voxel_size = Coordinate([1, 1, 1]) swc_files = ("test_line_a.swc", "test_line_b.swc") swc_paths = tuple(Path(self.path_to(file_name)) for file_name in swc_files) # create two lines seperated by a given distance and write them to swc files intercepts, slopes = self._get_line_pair(roi=bb, dist=3 * LABEL_RADIUS) for intercept, slope, swc_path in zip(intercepts, slopes, swc_paths): swc_points = self._get_points(intercept, slope, bb) self._write_swc(swc_path, swc_points.to_nx_graph()) # create swc sources fused = ArrayKey("FUSED") fused_labels = ArrayKey("FUSED_LABELS") fused_swc = PointsKey("FUSED_SWC") swc_key_names = ("SWC_A", "SWC_B") labels_key_names = ("LABELS_A", "LABELS_B") raw_key_names = ("RAW_A", "RAW_B") swc_keys = tuple(PointsKey(name) for name in swc_key_names) labels_keys = tuple(ArrayKey(name) for name in labels_key_names) raw_keys = tuple(ArrayKey(name) for name in raw_key_names) # add request request = BatchRequest() request.add(fused, bb.get_shape()) request.add(fused_labels, bb.get_shape()) request.add(fused_swc, bb.get_shape()) request.add(labels_keys[0], bb.get_shape()) request.add(labels_keys[1], bb.get_shape()) request.add(raw_keys[0], bb.get_shape()) request.add(raw_keys[1], bb.get_shape()) request.add(swc_keys[0], bb.get_shape()) request.add(swc_keys[1], bb.get_shape()) # data source for swc a data_sources_a = tuple() data_sources_a = ( data_sources_a + SwcFileSource(swc_paths[0], [swc_keys[0]], [PointsSpec(roi=bb)]) + RasterizeSkeleton( points=swc_keys[0], array=labels_keys[0], array_spec=ArraySpec( interpolatable=False, dtype=np.uint16, voxel_size=voxel_size ), ) + GrowLabels(array=labels_keys[0], radii=[LABEL_RADIUS]) + RasterizeSkeleton( points=swc_keys[0], array=raw_keys[0], array_spec=ArraySpec( interpolatable=False, dtype=np.uint16, voxel_size=voxel_size ), ) + GrowLabels(array=raw_keys[0], radii=[RAW_RADIUS]) + Normalize(raw_keys[0]) ) # data source for swc b data_sources_b = tuple() data_sources_b = ( data_sources_b + SwcFileSource(swc_paths[1], [swc_keys[1]], [PointsSpec(roi=bb)]) + RasterizeSkeleton( points=swc_keys[1], array=labels_keys[1], array_spec=ArraySpec( interpolatable=False, dtype=np.uint16, voxel_size=voxel_size ), ) + GrowLabels(array=labels_keys[1], radii=[LABEL_RADIUS]) + RasterizeSkeleton( points=swc_keys[1], array=raw_keys[1], array_spec=ArraySpec( interpolatable=False, dtype=np.uint16, voxel_size=voxel_size ), ) + GrowLabels(array=raw_keys[1], radii=[RAW_RADIUS]) + Normalize(raw_keys[1]) ) data_sources = tuple([data_sources_a, data_sources_b]) + MergeProvider() pipeline = data_sources + FusionAugment( raw_keys[0], raw_keys[1], labels_keys[0], labels_keys[1], swc_keys[0], swc_keys[1], fused, fused_labels, fused_swc, blend_mode="intensity", blend_smoothness=BLEND_SMOOTHNESS, num_blended_objects=0, ) with build(pipeline): batch = pipeline.request_batch(request) fused_data = batch[fused].data fused_data = np.pad(fused_data, (1,), "constant", constant_values=(0,)) a_data = batch[raw_keys[0]].data a_data = np.pad(a_data, (1,), "constant", constant_values=(0,)) b_data = batch[raw_keys[1]].data b_data = np.pad(b_data, (1,), "constant", constant_values=(0,)) diff = np.linalg.norm(fused_data - a_data - b_data) self.assertAlmostEqual(diff, 0)
def test_recenter(): path = Path(self.path_to("test_swc_source.swc")) # write test swc self._write_swc(path, self._toy_swc_points()) # read arrays swc_source = PointsKey("SWC_SOURCE") labels_source = ArrayKey("LABELS_SOURCE") img_source = ArrayKey("IMG_SOURCE") img_swc = PointsKey("IMG_SWC") label_swc = PointsKey("LABEL_SWC") imgs = ArrayKey("IMGS") labels = ArrayKey("LABELS") points_a = PointsKey("SKELETON_A") points_b = PointsKey("SKELETON_B") img_a = ArrayKey("VOLUME_A") img_b = ArrayKey("VOLUME_B") labels_a = ArrayKey("LABELS_A") labels_b = ArrayKey("LABELS_B") # Get points from test swc swc_file_source = SwcFileSource( path, [swc_source], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))] ) # Create an artificial image source by rasterizing the points image_source = ( SwcFileSource( path, [img_swc], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))] ) + RasterizeSkeleton( points=img_swc, array=img_source, array_spec=ArraySpec( interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1)) ), ) + BinarizeLabels(labels=img_source, labels_binary=imgs) + GrowLabels(array=imgs, radius=0) ) # Create an artificial label source by rasterizing the points label_source = ( SwcFileSource( path, [label_swc], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))] ) + RasterizeSkeleton( points=label_swc, array=labels_source, array_spec=ArraySpec( interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1)) ), ) + BinarizeLabels(labels=labels_source, labels_binary=labels) + GrowLabels(array=labels, radius=1) ) skeleton = tuple() skeleton += ( (swc_file_source, image_source, label_source) + MergeProvider() + RandomLocation(ensure_nonempty=swc_source, ensure_centered=True) ) pipeline = ( skeleton + GetNeuronPair( point_source=swc_source, array_source=imgs, label_source=labels, points=(points_a, points_b), arrays=(img_a, img_b), labels=(labels_a, labels_b), seperate_by=4, shift_attempts=100, ) + Recenter(points_a, img_a, max_offset=4) + Recenter(points_b, img_b, max_offset=4) ) request = BatchRequest() data_shape = 9 request.add(points_a, Coordinate((data_shape, data_shape, data_shape))) request.add(points_b, Coordinate((data_shape, data_shape, data_shape))) request.add(img_a, Coordinate((data_shape, data_shape, data_shape))) request.add(img_b, Coordinate((data_shape, data_shape, data_shape))) request.add(labels_a, Coordinate((data_shape, data_shape, data_shape))) request.add(labels_b, Coordinate((data_shape, data_shape, data_shape))) with build(pipeline): batch = pipeline.request_batch(request) data_a = batch[img_a].data assert data_a[tuple(np.array(data_a.shape) // 2)] == 1 data_a = np.pad(data_a, (1,), "constant", constant_values=(0,)) data_b = batch[img_b].data assert data_b[tuple(np.array(data_b.shape) // 2)] == 1 data_b = np.pad(data_b, (1,), "constant", constant_values=(0,)) data_c = data_a + data_b data = np.array((data_a, data_b, data_c)) for _, point in batch[points_a].data.items(): assert ( data[(0,) + tuple(int(x) + 1 for x in point.location)] == 1 ), "data at {} is not 1, its {}".format( point.location, data[(0,) + tuple(int(x) for x in point.location)] ) for _, point in batch[points_b].data.items(): assert ( data[(1,) + tuple(int(x) + 1 for x in point.location)] == 1 ), "data at {} is not 1".format(point.location)
def test_both(self): test_graph = GraphKey("TEST_GRAPH") test_array1 = ArrayKey("TEST_ARRAY1") test_array2 = ArrayKey("TEST_ARRAY2") og_point = np.array([50, 70, 100]) transpose_dims = [0, 1, 2] mirror_dims = [0, 1, 2] pipeline = ((ArrayTestSource(), TestSource()) + MergeProvider() + Pad(test_array1, None) + Pad(test_array2, None) + Pad(test_graph, None) + SimpleAugment( mirror_only=mirror_dims, transpose_only=transpose_dims )) request = BatchRequest() offset = (0, 20, 33) request[GraphKeys.TEST_GRAPH] = GraphSpec(roi=Roi(offset, (100, 100, 120))) request[ArrayKeys.TEST_ARRAY1] = ArraySpec(roi=Roi((0, 0, 0), (100, 200, 300))) request[ArrayKeys.TEST_ARRAY2] = ArraySpec(roi=Roi((0, 100, 250), (100, 100, 50))) # Get all possble mirror locations # possible_mirror_loc = [[49, 50], [70, 29], [100, 86]] mirror_combs = [[49, 70, 100], [49, 29, 86], [49, 70, 86], [49, 29, 100], [50, 70, 100], [50, 29, 86], [50, 70, 86], [50, 29, 100]] # Create all possible permurations of our transpose dims transpose_combinations = list(permutations(transpose_dims, 3)) # Generate all possible tranposes of all possible mirrors possible_loc = np.zeros((len(mirror_combs), len(transpose_combinations), 3)) for i, point in enumerate(mirror_combs): for j, comb in enumerate(transpose_combinations): possible_loc[i, j] = np.array(point)[np.array(comb)] with build(pipeline): seen_transposed = False seen_node = True for i in range(100): batch = pipeline.request_batch(request) if len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1: seen_node = True node = list(batch[GraphKeys.TEST_GRAPH].nodes)[0] # Check if your location is possible assert node.location in possible_loc seen_transposed = seen_transposed or any( [node.location[dim] != og_point[dim] for dim in range(3)] ) assert Roi((0, 20, 33), (100, 100, 120)).contains(batch[GraphKeys.TEST_GRAPH].spec.roi) assert batch[GraphKeys.TEST_GRAPH].spec.roi.contains(node.location) for (array_key, array) in batch.arrays.items(): assert batch.arrays[array_key].data.shape == batch.arrays[array_key].spec.roi.get_shape() assert seen_transposed assert seen_node
def test_two_disjoint_lines_softmask(self): LABEL_RADIUS = 3 RAW_RADIUS = 3 # exagerated to show problem BLEND_SMOOTHNESS = 10 bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256])) voxel_size = Coordinate([1, 1, 1]) swc_files = ("test_line_a.swc", "test_line_b.swc") swc_paths = tuple( Path(self.path_to(file_name)) for file_name in swc_files) # create two lines seperated by a given distance and write them to swc files intercepts, slopes = self._get_line_pair(roi=bb, dist=3 * LABEL_RADIUS) for intercept, slope, swc_path in zip(intercepts, slopes, swc_paths): swc_points = self._get_points(intercept, slope, bb) self._write_swc(swc_path, swc_points) # create swc sources fused = ArrayKey("FUSED") fused_labels = ArrayKey("FUSED_LABELS") swc_key_names = ("SWC_A", "SWC_B") labels_key_names = ("LABELS_A", "LABELS_B") raw_key_names = ("RAW_A", "RAW_B") swc_keys = tuple(PointsKey(name) for name in swc_key_names) labels_keys = tuple(ArrayKey(name) for name in labels_key_names) raw_keys = tuple(ArrayKey(name) for name in raw_key_names) # add request request = BatchRequest() request.add(fused, bb.get_shape()) request.add(fused_labels, bb.get_shape()) request.add(labels_keys[0], bb.get_shape()) request.add(labels_keys[1], bb.get_shape()) request.add(raw_keys[0], bb.get_shape()) request.add(raw_keys[1], bb.get_shape()) request.add(swc_keys[0], bb.get_shape()) request.add(swc_keys[1], bb.get_shape()) # data source for swc a data_sources_a = tuple() data_sources_a = (data_sources_a + SwcFileSource( swc_paths[0], swc_keys[0], PointsSpec(roi=bb)) + RasterizeSkeleton( points=swc_keys[0], array=labels_keys[0], array_spec=ArraySpec(interpolatable=False, dtype=np.uint32, voxel_size=voxel_size), radius=LABEL_RADIUS, ) + RasterizeSkeleton( points=swc_keys[0], array=raw_keys[0], array_spec=ArraySpec(interpolatable=False, dtype=np.uint32, voxel_size=voxel_size), radius=RAW_RADIUS, )) # data source for swc b data_sources_b = tuple() data_sources_b = (data_sources_b + SwcFileSource( swc_paths[1], swc_keys[1], PointsSpec(roi=bb)) + RasterizeSkeleton( points=swc_keys[1], array=labels_keys[1], array_spec=ArraySpec(interpolatable=False, dtype=np.uint32, voxel_size=voxel_size), radius=LABEL_RADIUS, ) + RasterizeSkeleton( points=swc_keys[1], array=raw_keys[1], array_spec=ArraySpec(interpolatable=False, dtype=np.uint32, voxel_size=voxel_size), radius=RAW_RADIUS, )) data_sources = tuple([data_sources_a, data_sources_b ]) + MergeProvider() pipeline = data_sources + FusionAugment( raw_keys[0], raw_keys[1], labels_keys[0], labels_keys[1], fused, fused_labels, blend_mode="labels_mask", blend_smoothness=BLEND_SMOOTHNESS, num_blended_objects=0, ) with build(pipeline): batch = pipeline.request_batch(request) fused_data = batch[fused].data fused_data = np.pad(fused_data, (1, ), "constant", constant_values=(0, )) a_data = batch[raw_keys[0]].data a_data = np.pad(a_data, (1, ), "constant", constant_values=(0, )) b_data = batch[raw_keys[1]].data b_data = np.pad(b_data, (1, ), "constant", constant_values=(0, )) all_data = np.zeros((5, ) + fused_data.shape) all_data[0, :, :, :] = fused_data all_data[1, :, :, :] = a_data + b_data all_data[2, :, :, :] = fused_data - a_data - b_data all_data[3, :, :, :] = a_data all_data[4, :, :, :] = b_data # Uncomment to visualize problem if imported_volshow: volshow(all_data) # input("Press enter when you are done viewing the data: ") diff = np.linalg.norm(fused_data - a_data - b_data) self.assertAlmostEqual(diff, 0)
def test_rasterize_speed(self): # This is worryingly slow for such a small volume (256**3) and only 2 # straight lines for skeletons. LABEL_RADIUS = 3 bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256])) voxel_size = Coordinate([1, 1, 1]) swc_files = ("test_line_a.swc", "test_line_b.swc") swc_paths = tuple(Path(self.path_to(file_name)) for file_name in swc_files) # create two lines seperated by a given distance and write them to swc files intercepts, slopes = self._get_line_pair(roi=bb, dist=3 * LABEL_RADIUS) for intercept, slope, swc_path in zip(intercepts, slopes, swc_paths): swc_points = self._get_points(intercept, slope, bb) self._write_swc(swc_path, swc_points) # create swc sources swc_key_names = ("SWC_A", "SWC_B") labels_key_names = ("LABELS_A", "LABELS_B") swc_keys = tuple(PointsKey(name) for name in swc_key_names) labels_keys = tuple(ArrayKey(name) for name in labels_key_names) # add request request = BatchRequest() request.add(labels_keys[0], bb.get_shape()) request.add(labels_keys[1], bb.get_shape()) request.add(swc_keys[0], bb.get_shape()) request.add(swc_keys[1], bb.get_shape()) # data source for swc a data_sources_a = tuple() data_sources_a = ( data_sources_a + SwcFileSource(swc_paths[0], swc_keys[0], PointsSpec(roi=bb)) + RasterizeSkeleton( points=swc_keys[0], array=labels_keys[0], array_spec=ArraySpec( interpolatable=False, dtype=np.uint32, voxel_size=voxel_size ), radius=LABEL_RADIUS, ) ) # data source for swc b data_sources_b = tuple() data_sources_b = ( data_sources_b + SwcFileSource(swc_paths[1], swc_keys[1], PointsSpec(roi=bb)) + RasterizeSkeleton( points=swc_keys[1], array=labels_keys[1], array_spec=ArraySpec( interpolatable=False, dtype=np.uint32, voxel_size=voxel_size ), radius=LABEL_RADIUS, ) ) data_sources = tuple([data_sources_a, data_sources_b]) + MergeProvider() pipeline = data_sources t1 = time.time() with build(pipeline): batch = pipeline.request_batch(request) a_data = batch[labels_keys[0]].data a_data = np.pad(a_data, (1,), "constant", constant_values=(0,)) b_data = batch[labels_keys[1]].data b_data = np.pad(b_data, (1,), "constant", constant_values=(0,)) t2 = time.time() self.assertLess(t2 - t1, 0.1)
def test_get_neuron_pair(self): path = Path(self.path_to("test_swc_source.swc")) # write test swc self._write_swc(path, self._toy_swc_points().to_nx_graph()) # read arrays swc_source = PointsKey("SWC_SOURCE") ensure_nonempty = PointsKey("ENSURE_NONEMPTY") labels_source = ArrayKey("LABELS_SOURCE") img_source = ArrayKey("IMG_SOURCE") img_swc = PointsKey("IMG_SWC") label_swc = PointsKey("LABEL_SWC") imgs = ArrayKey("IMGS") labels = ArrayKey("LABELS") points_a = PointsKey("SKELETON_A") points_b = PointsKey("SKELETON_B") img_a = ArrayKey("VOLUME_A") img_b = ArrayKey("VOLUME_B") labels_a = ArrayKey("LABELS_A") labels_b = ArrayKey("LABELS_B") data_shape = 5 output_shape = Coordinate((data_shape, data_shape, data_shape)) # Get points from test swc swc_file_source = SwcFileSource( path, [swc_source, ensure_nonempty], [ PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31))), PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31))), ], ) # Create an artificial image source by rasterizing the points image_source = (SwcFileSource( path, [img_swc], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]) + RasterizeSkeleton( points=img_swc, array=img_source, array_spec=ArraySpec( interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1)), ), ) + BinarizeLabels(labels=img_source, labels_binary=imgs) + GrowLabels(array=imgs, radius=0)) # Create an artificial label source by rasterizing the points label_source = (SwcFileSource(path, [label_swc], [ PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31))) ]) + RasterizeSkeleton( points=label_swc, array=labels_source, array_spec=ArraySpec( interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1)), ), ) + BinarizeLabels(labels=labels_source, labels_binary=labels) + GrowLabels(array=labels, radius=1)) skeleton = tuple() skeleton += ((swc_file_source, image_source, label_source) + MergeProvider() + RandomLocation(ensure_nonempty=ensure_nonempty, ensure_centered=True)) pipeline = skeleton + GetNeuronPair( point_source=swc_source, nonempty_placeholder=ensure_nonempty, array_source=imgs, label_source=labels, points=(points_a, points_b), arrays=(img_a, img_b), labels=(labels_a, labels_b), seperate_by=(1, 3), shift_attempts=100, request_attempts=10, output_shape=output_shape, ) request = BatchRequest() request.add(points_a, output_shape) request.add(points_b, output_shape) request.add(img_a, output_shape) request.add(img_b, output_shape) request.add(labels_a, output_shape) request.add(labels_b, output_shape) with build(pipeline): for i in range(10): batch = pipeline.request_batch(request) assert all([ x in batch for x in [points_a, points_b, img_a, img_b, labels_a, labels_b] ]) min_dist = 5 for a, b in itertools.product( batch[points_a].nodes, batch[points_b].nodes, ): min_dist = min( min_dist, np.linalg.norm(a.location - b.location), ) self.assertLessEqual(min_dist, 3) self.assertGreaterEqual(min_dist, 1)