def test_realistic_valid_examples(example, use_gurobi): penalty_attr = "penalty" location_attr = "location" example_dir = Path(__file__).parent / "mouselight_examples" / "valid" / example consensus = PointsKey("CONSENSUS") skeletonization = PointsKey("SKELETONIZATION") matched = PointsKey("MATCHED") matched_with_fallback = PointsKey("MATCHED_WITH_FALLBACK") inf_roi = Roi(Coordinate((None,) * 3), Coordinate((None,) * 3)) request = BatchRequest() request[matched] = PointsSpec(roi=inf_roi) request[matched_with_fallback] = PointsSpec(roi=inf_roi) request[consensus] = PointsSpec(roi=inf_roi) pipeline = ( ( GraphSource(example_dir / "graph.obj", [skeletonization]), GraphSource(example_dir / "tree.obj", [consensus]), ) + MergeProvider() + TopologicalMatcher( skeletonization, consensus, matched, expected_edge_len=10, match_distance_threshold=76, max_gap_crossing=48, use_gurobi=use_gurobi, location_attr=location_attr, penalty_attr=penalty_attr, ) + TopologicalMatcher( skeletonization, consensus, matched_with_fallback, expected_edge_len=10, match_distance_threshold=76, max_gap_crossing=48, use_gurobi=use_gurobi, location_attr=location_attr, penalty_attr=penalty_attr, with_fallback=True, ) ) with build(pipeline): batch = pipeline.request_batch(request) consensus_ccs = list(batch[consensus].connected_components) matched_with_fallback_ccs = list(batch[matched_with_fallback].connected_components) matched_ccs = list(batch[matched].connected_components) assert len(matched_ccs) == len(consensus_ccs)
def test_relabel_components(self): path = Path(self.path_to("test_swc_source.swc")) # write test swc self._write_swc(path, self._toy_swc_points()) # read arrays swc = PointsKey("SWC") source = SwcFileSource(path, swc) with build(source): batch = source.request_batch( BatchRequest( {swc: PointsSpec(roi=Roi((0, 1, 0), (11, 10, 1)))})) temp_g = nx.DiGraph() for point_id, point in batch.points[swc].data.items(): temp_g.add_node(point.point_id, label_id=point.label_id) if (point.parent_id != -1 and point.parent_id != point.point_id and point.parent_id in batch.points[swc].data): temp_g.add_edge(point.point_id, point.parent_id) previous_label = None ccs = list(nx.weakly_connected_components(temp_g)) self.assertEqual(len(ccs), 3) for cc in ccs: self.assertEqual(len(cc), 10) label = None for point_id in cc: if label is None: label = temp_g.nodes[point_id]["label_id"] self.assertNotEqual(label, previous_label) self.assertEqual(temp_g.nodes[point_id]["label_id"], label) previous_label = label
def test_without_placeholder(self): test_labels = ArrayKey("TEST_LABELS") test_points = PointsKey("TEST_POINTS") pipeline = ( PointTestSource3D() + RandomLocation(ensure_nonempty=test_points) + ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) + Snapshot( {test_labels: "volumes/labels"}, output_dir=self.path_to(), output_filename="elastic_augment_test{id}-{iteration}.hdf", ) ) with build(pipeline): for i in range(100): request_size = Coordinate((40, 40, 40)) request_a = BatchRequest(random_seed=i) request_a.add(test_points, request_size) request_b = BatchRequest(random_seed=i) request_b.add(test_points, request_size) request_b.add(test_labels, request_size) # No array to provide a voxel size to ElasticAugment with pytest.raises(PipelineRequestError): pipeline.request_batch(request_a) batch_b = pipeline.request_batch(request_b) self.assertIn(test_labels, batch_b)
def test_ensure_center_non_zero(self): path = Path(self.path_to("test_swc_source.swc")) # write test swc self._write_swc(path, self._toy_swc_points().to_nx_graph()) # read arrays swc = PointsKey("SWC") img = ArrayKey("IMG") pipeline = (SwcFileSource( path, [swc], [PointsSpec(roi=Roi((0, 0, 0), (11, 11, 11)))]) + RandomLocation(ensure_nonempty=swc, ensure_centered=True) + RasterizeSkeleton( points=swc, array=img, array_spec=ArraySpec( interpolatable=False, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1)), ), )) request = BatchRequest() request.add(img, Coordinate((5, 5, 5))) request.add(swc, Coordinate((5, 5, 5))) with build(pipeline): batch = pipeline.request_batch(request) data = batch[img].data g = batch[swc] assert g.num_vertices() > 0 self.assertNotEqual(data[tuple(np.array(data.shape) // 2)], 0)
def test_read_single_swc(self): path = Path(self.path_to("test_swc_source.swc")) # write test swc self._write_swc(path, self._toy_swc_points()) # read arrays swc = PointsKey("SWC") source = SwcFileSource(path, swc) with build(source): batch = source.request_batch( BatchRequest( {swc: PointsSpec(roi=Roi((0, 0, 0), (11, 11, 1)))})) for point in self._toy_swc_points(): self.assertCountEqual( point.location, batch.points[swc].data[point.point_id].location)
def test_create_boundary_nodes(self): path = Path(self.path_to("test_swc_source.swc")) # write test swc self._write_swc(path, self._toy_swc_points(), {"resolution": np.array([2, 2, 2])}) # read arrays swc = PointsKey("SWC") source = SwcFileSource(path, swc) with build(source): batch = source.request_batch( BatchRequest({swc: PointsSpec(roi=Roi((0, 5, 0), (1, 3, 1)))})) temp_g = nx.DiGraph() for point_id, point in batch.points[swc].data.items(): temp_g.add_node(point.point_id, label_id=point.label_id, location=point.location) if (point.parent_id != -1 and point.parent_id != point.point_id and point.parent_id in batch.points[swc].data): temp_g.add_edge(point.point_id, point.parent_id) else: root = point.point_id current = root expected_path = [ tuple(np.array([0.0, 5.0, 0.0])), tuple(np.array([0.0, 6.0, 0.0])), tuple(np.array([0.0, 7.0, 0.0])), ] expected_node_ids = [0, 1, 2] path = [] node_ids = [] while current is not None: node_ids.append(current) path.append(tuple(temp_g.nodes[current]["location"])) predecessors = list(temp_g._pred[current].keys()) current = predecessors[0] if len(predecessors) == 1 else None self.assertCountEqual(path, expected_path) self.assertCountEqual(node_ids, expected_node_ids)
def test_grow_labels_speed(self): bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256])) voxel_size = Coordinate([1, 1, 1]) swc_file = "test_swc.swc" swc_path = Path(self.path_to(swc_file)) swc_points = self._get_points(np.array([1, 1, 1]), np.array([1, 1, 1]), bb) self._write_swc(swc_path, swc_points.graph) # create swc sources swc_key = PointsKey("SWC") labels_key = ArrayKey("LABELS") # add request request = BatchRequest() request.add(labels_key, bb.get_shape()) request.add(swc_key, bb.get_shape()) # data source for swc a data_source = tuple() data_source = (data_source + SwcFileSource( swc_path, [swc_key], [PointsSpec(roi=bb)]) + RasterizeSkeleton( points=swc_key, array=labels_key, array_spec=ArraySpec(interpolatable=False, dtype=np.uint32, voxel_size=voxel_size), ) + GrowLabels(array=labels_key, radius=3)) pipeline = data_source num_repeats = 10 t1 = time.time() with build(pipeline): for i in range(num_repeats): pipeline.request_batch(request) t2 = time.time() self.assertLess((t2 - t1) / num_repeats, 0.1)
def test_overlap(self): path = Path(self.path_to("test_swc_sources")) path.mkdir(parents=True, exist_ok=True) # write test swc for i in range(3): self._write_swc( path / "{}.swc".format(i), self._toy_swc_points(), {"offset": np.array([0, 0, 0])}, ) # read arrays swc = PointsKey("SWC") source = SwcFileSource(path, swc) with build(source): batch = source.request_batch( BatchRequest( {swc: PointsSpec(roi=Roi((0, 0, 0), (11, 11, 1)))})) temp_g = nx.DiGraph() for point_id, point in batch.points[swc].data.items(): temp_g.add_node(point.point_id, label_id=point.label_id) if (point.parent_id != -1 and point.parent_id != point.point_id and point.parent_id in batch.points[swc].data): temp_g.add_edge(point.point_id, point.parent_id) previous_label = None ccs = list(nx.weakly_connected_components(temp_g)) self.assertEqual(len(ccs), 3) for cc in ccs: self.assertEqual(len(cc), 41) label = None for point_id in cc: if label is None: label = temp_g.nodes[point_id]["label_id"] self.assertNotEqual(label, previous_label) self.assertEqual(temp_g.nodes[point_id]["label_id"], label) previous_label = label
def test_placeholder(self): test_labels = ArrayKey("TEST_LABELS") test_points = PointsKey("TEST_POINTS") pipeline = ( PointTestSource3D() + RandomLocation(ensure_nonempty=test_points) + ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) + Snapshot( {test_labels: "volumes/labels"}, output_dir=self.path_to(), output_filename="elastic_augment_test{id}-{iteration}.hdf", ) ) with build(pipeline): for i in range(100): request_size = Coordinate((40, 40, 40)) request_a = BatchRequest(random_seed=i) request_a.add(test_points, request_size) request_a.add(test_labels, request_size, placeholder=True) request_b = BatchRequest(random_seed=i) request_b.add(test_points, request_size) request_b.add(test_labels, request_size) batch_a = pipeline.request_batch(request_a) batch_b = pipeline.request_batch(request_b) points_a = batch_a[test_points].nodes points_b = batch_b[test_points].nodes for a, b in zip(points_a, points_b): assert all(np.isclose(a.location, b.location))
def test_two_disjoint_lines_intensity(self): LABEL_RADIUS = 3 RAW_RADIUS = 3 BLEND_SMOOTHNESS = 3 bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256])) voxel_size = Coordinate([1, 1, 1]) swc_files = ("test_line_a.swc", "test_line_b.swc") swc_paths = tuple(Path(self.path_to(file_name)) for file_name in swc_files) # create two lines seperated by a given distance and write them to swc files intercepts, slopes = self._get_line_pair(roi=bb, dist=3 * LABEL_RADIUS) for intercept, slope, swc_path in zip(intercepts, slopes, swc_paths): swc_points = self._get_points(intercept, slope, bb) self._write_swc(swc_path, swc_points.to_nx_graph()) # create swc sources fused = ArrayKey("FUSED") fused_labels = ArrayKey("FUSED_LABELS") fused_swc = PointsKey("FUSED_SWC") swc_key_names = ("SWC_A", "SWC_B") labels_key_names = ("LABELS_A", "LABELS_B") raw_key_names = ("RAW_A", "RAW_B") swc_keys = tuple(PointsKey(name) for name in swc_key_names) labels_keys = tuple(ArrayKey(name) for name in labels_key_names) raw_keys = tuple(ArrayKey(name) for name in raw_key_names) # add request request = BatchRequest() request.add(fused, bb.get_shape()) request.add(fused_labels, bb.get_shape()) request.add(fused_swc, bb.get_shape()) request.add(labels_keys[0], bb.get_shape()) request.add(labels_keys[1], bb.get_shape()) request.add(raw_keys[0], bb.get_shape()) request.add(raw_keys[1], bb.get_shape()) request.add(swc_keys[0], bb.get_shape()) request.add(swc_keys[1], bb.get_shape()) # data source for swc a data_sources_a = tuple() data_sources_a = ( data_sources_a + SwcFileSource(swc_paths[0], [swc_keys[0]], [PointsSpec(roi=bb)]) + RasterizeSkeleton( points=swc_keys[0], array=labels_keys[0], array_spec=ArraySpec( interpolatable=False, dtype=np.uint16, voxel_size=voxel_size ), ) + GrowLabels(array=labels_keys[0], radii=[LABEL_RADIUS]) + RasterizeSkeleton( points=swc_keys[0], array=raw_keys[0], array_spec=ArraySpec( interpolatable=False, dtype=np.uint16, voxel_size=voxel_size ), ) + GrowLabels(array=raw_keys[0], radii=[RAW_RADIUS]) + Normalize(raw_keys[0]) ) # data source for swc b data_sources_b = tuple() data_sources_b = ( data_sources_b + SwcFileSource(swc_paths[1], [swc_keys[1]], [PointsSpec(roi=bb)]) + RasterizeSkeleton( points=swc_keys[1], array=labels_keys[1], array_spec=ArraySpec( interpolatable=False, dtype=np.uint16, voxel_size=voxel_size ), ) + GrowLabels(array=labels_keys[1], radii=[LABEL_RADIUS]) + RasterizeSkeleton( points=swc_keys[1], array=raw_keys[1], array_spec=ArraySpec( interpolatable=False, dtype=np.uint16, voxel_size=voxel_size ), ) + GrowLabels(array=raw_keys[1], radii=[RAW_RADIUS]) + Normalize(raw_keys[1]) ) data_sources = tuple([data_sources_a, data_sources_b]) + MergeProvider() pipeline = data_sources + FusionAugment( raw_keys[0], raw_keys[1], labels_keys[0], labels_keys[1], swc_keys[0], swc_keys[1], fused, fused_labels, fused_swc, blend_mode="intensity", blend_smoothness=BLEND_SMOOTHNESS, num_blended_objects=0, ) with build(pipeline): batch = pipeline.request_batch(request) fused_data = batch[fused].data fused_data = np.pad(fused_data, (1,), "constant", constant_values=(0,)) a_data = batch[raw_keys[0]].data a_data = np.pad(a_data, (1,), "constant", constant_values=(0,)) b_data = batch[raw_keys[1]].data b_data = np.pad(b_data, (1,), "constant", constant_values=(0,)) diff = np.linalg.norm(fused_data - a_data - b_data) self.assertAlmostEqual(diff, 0)
def test_recenter(): path = Path(self.path_to("test_swc_source.swc")) # write test swc self._write_swc(path, self._toy_swc_points()) # read arrays swc_source = PointsKey("SWC_SOURCE") labels_source = ArrayKey("LABELS_SOURCE") img_source = ArrayKey("IMG_SOURCE") img_swc = PointsKey("IMG_SWC") label_swc = PointsKey("LABEL_SWC") imgs = ArrayKey("IMGS") labels = ArrayKey("LABELS") points_a = PointsKey("SKELETON_A") points_b = PointsKey("SKELETON_B") img_a = ArrayKey("VOLUME_A") img_b = ArrayKey("VOLUME_B") labels_a = ArrayKey("LABELS_A") labels_b = ArrayKey("LABELS_B") # Get points from test swc swc_file_source = SwcFileSource( path, [swc_source], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))] ) # Create an artificial image source by rasterizing the points image_source = ( SwcFileSource( path, [img_swc], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))] ) + RasterizeSkeleton( points=img_swc, array=img_source, array_spec=ArraySpec( interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1)) ), ) + BinarizeLabels(labels=img_source, labels_binary=imgs) + GrowLabels(array=imgs, radius=0) ) # Create an artificial label source by rasterizing the points label_source = ( SwcFileSource( path, [label_swc], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))] ) + RasterizeSkeleton( points=label_swc, array=labels_source, array_spec=ArraySpec( interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1)) ), ) + BinarizeLabels(labels=labels_source, labels_binary=labels) + GrowLabels(array=labels, radius=1) ) skeleton = tuple() skeleton += ( (swc_file_source, image_source, label_source) + MergeProvider() + RandomLocation(ensure_nonempty=swc_source, ensure_centered=True) ) pipeline = ( skeleton + GetNeuronPair( point_source=swc_source, array_source=imgs, label_source=labels, points=(points_a, points_b), arrays=(img_a, img_b), labels=(labels_a, labels_b), seperate_by=4, shift_attempts=100, ) + Recenter(points_a, img_a, max_offset=4) + Recenter(points_b, img_b, max_offset=4) ) request = BatchRequest() data_shape = 9 request.add(points_a, Coordinate((data_shape, data_shape, data_shape))) request.add(points_b, Coordinate((data_shape, data_shape, data_shape))) request.add(img_a, Coordinate((data_shape, data_shape, data_shape))) request.add(img_b, Coordinate((data_shape, data_shape, data_shape))) request.add(labels_a, Coordinate((data_shape, data_shape, data_shape))) request.add(labels_b, Coordinate((data_shape, data_shape, data_shape))) with build(pipeline): batch = pipeline.request_batch(request) data_a = batch[img_a].data assert data_a[tuple(np.array(data_a.shape) // 2)] == 1 data_a = np.pad(data_a, (1,), "constant", constant_values=(0,)) data_b = batch[img_b].data assert data_b[tuple(np.array(data_b.shape) // 2)] == 1 data_b = np.pad(data_b, (1,), "constant", constant_values=(0,)) data_c = data_a + data_b data = np.array((data_a, data_b, data_c)) for _, point in batch[points_a].data.items(): assert ( data[(0,) + tuple(int(x) + 1 for x in point.location)] == 1 ), "data at {} is not 1, its {}".format( point.location, data[(0,) + tuple(int(x) for x in point.location)] ) for _, point in batch[points_b].data.items(): assert ( data[(1,) + tuple(int(x) + 1 for x in point.location)] == 1 ), "data at {} is not 1".format(point.location)
def test_output_min_distance(self): voxel_size = Coordinate((20, 2, 2)) ArrayKey("GT_VECTORS_MAP_PRESYN") PointsKey("PRESYN") PointsKey("POSTSYN") arraytypes_to_source_target_pointstypes = { ArrayKeys.GT_VECTORS_MAP_PRESYN: (GraphKeys.PRESYN, GraphKeys.POSTSYN) } arraytypes_to_stayinside_arraytypes = { ArrayKeys.GT_VECTORS_MAP_PRESYN: ArrayKeys.GT_LABELS } # test for partner criterion 'min_distance' radius_phys = 30 pipeline_min_distance = AddVectorMapTestSource() + AddVectorMap( src_and_trg_points=arraytypes_to_source_target_pointstypes, voxel_sizes={ArrayKeys.GT_VECTORS_MAP_PRESYN: voxel_size}, radius_phys=radius_phys, partner_criterion="min_distance", stayinside_array_keys=arraytypes_to_stayinside_arraytypes, pad_for_partners=(0, 0, 0), ) with build(pipeline_min_distance): request = BatchRequest() raw_roi = pipeline_min_distance.spec[ArrayKeys.RAW].roi gt_labels_roi = pipeline_min_distance.spec[ArrayKeys.GT_LABELS].roi presyn_roi = pipeline_min_distance.spec[GraphKeys.PRESYN].roi request.add(ArrayKeys.RAW, raw_roi.get_shape()) request.add(ArrayKeys.GT_LABELS, gt_labels_roi.get_shape()) request.add(GraphKeys.PRESYN, presyn_roi.get_shape()) request.add(GraphKeys.POSTSYN, presyn_roi.get_shape()) request.add(ArrayKeys.GT_VECTORS_MAP_PRESYN, presyn_roi.get_shape()) for identifier, spec in request.items(): spec.roi = spec.roi.shift((1000, 1000, 1000)) batch = pipeline_min_distance.request_batch(request) presyn_locs = {n.id: n for n in batch.graphs[GraphKeys.PRESYN].nodes} postsyn_locs = {n.id: n for n in batch.graphs[GraphKeys.POSTSYN].nodes} vector_map_presyn = batch.arrays[ArrayKeys.GT_VECTORS_MAP_PRESYN].data offset_vector_map_presyn = request[ ArrayKeys.GT_VECTORS_MAP_PRESYN].roi.get_offset() self.assertTrue(len(presyn_locs) > 0) self.assertTrue(len(postsyn_locs) > 0) for loc_id, point in presyn_locs.items(): if request[ArrayKeys.GT_VECTORS_MAP_PRESYN].roi.contains( Coordinate(point.location)): self.assertTrue(batch.arrays[ ArrayKeys.GT_VECTORS_MAP_PRESYN].spec.roi.contains( Coordinate(point.location))) dist_to_loc = {} for partner_id in point.attrs["partner_ids"]: if partner_id in postsyn_locs.keys(): partner_location = postsyn_locs[partner_id].location dist_to_loc[ np.linalg.norm(partner_location - point.location)] = partner_location min_dist = np.min(list(dist_to_loc.keys())) relevant_partner_loc = dist_to_loc[min_dist] presyn_loc_shifted_vx = ( point.location - offset_vector_map_presyn) // voxel_size radius_vx = [(radius_phys // vx_dim) for vx_dim in voxel_size] region_to_check = np.clip( [ (presyn_loc_shifted_vx - radius_vx), (presyn_loc_shifted_vx + radius_vx), ], a_min=(0, 0, 0), a_max=vector_map_presyn.shape[-3:], ) for x, y, z in itertools.product( range(int(region_to_check[0][0]), int(region_to_check[1][0])), range(int(region_to_check[0][1]), int(region_to_check[1][1])), range(int(region_to_check[0][2]), int(region_to_check[1][2])), ): if (np.linalg.norm((np.array( (x, y, z)) - np.asarray(point.location))) < radius_phys): vector = [ vector_map_presyn[dim][x, y, z] for dim in range(vector_map_presyn.shape[0]) ] if not np.sum(vector) == 0: trg_loc_of_vector_phys = ( np.asarray(offset_vector_map_presyn) + (voxel_size * np.array([x, y, z])) + np.asarray(vector)) self.assertTrue( np.array_equal(trg_loc_of_vector_phys, relevant_partner_loc)) # test for partner criterion 'all' pipeline_all = AddVectorMapTestSource() + AddVectorMap( src_and_trg_points=arraytypes_to_source_target_pointstypes, voxel_sizes={ArrayKeys.GT_VECTORS_MAP_PRESYN: voxel_size}, radius_phys=radius_phys, partner_criterion="all", stayinside_array_keys=arraytypes_to_stayinside_arraytypes, pad_for_partners=(0, 0, 0), ) with build(pipeline_all): batch = pipeline_all.request_batch(request) presyn_locs = {n.id: n for n in batch.graphs[GraphKeys.PRESYN].nodes} postsyn_locs = {n.id: n for n in batch.graphs[GraphKeys.POSTSYN].nodes} vector_map_presyn = batch.arrays[ArrayKeys.GT_VECTORS_MAP_PRESYN].data offset_vector_map_presyn = request[ ArrayKeys.GT_VECTORS_MAP_PRESYN].roi.get_offset() self.assertTrue(len(presyn_locs) > 0) self.assertTrue(len(postsyn_locs) > 0) for loc_id, point in presyn_locs.items(): if request[ArrayKeys.GT_VECTORS_MAP_PRESYN].roi.contains( Coordinate(point.location)): self.assertTrue(batch.arrays[ ArrayKeys.GT_VECTORS_MAP_PRESYN].spec.roi.contains( Coordinate(point.location))) partner_ids_to_locs_per_src, count_vectors_per_partner = {}, {} for partner_id in point.attrs["partner_ids"]: if partner_id in postsyn_locs.keys(): partner_ids_to_locs_per_src[partner_id] = postsyn_locs[ partner_id].location.tolist() count_vectors_per_partner[partner_id] = 0 presyn_loc_shifted_vx = ( point.location - offset_vector_map_presyn) // voxel_size radius_vx = [(radius_phys // vx_dim) for vx_dim in voxel_size] region_to_check = np.clip( [ (presyn_loc_shifted_vx - radius_vx), (presyn_loc_shifted_vx + radius_vx), ], a_min=(0, 0, 0), a_max=vector_map_presyn.shape[-3:], ) for x, y, z in itertools.product( range(int(region_to_check[0][0]), int(region_to_check[1][0])), range(int(region_to_check[0][1]), int(region_to_check[1][1])), range(int(region_to_check[0][2]), int(region_to_check[1][2])), ): if (np.linalg.norm((np.array( (x, y, z)) - np.asarray(point.location))) < radius_phys): vector = [ vector_map_presyn[dim][x, y, z] for dim in range(vector_map_presyn.shape[0]) ] if not np.sum(vector) == 0: trg_loc_of_vector_phys = ( np.asarray(offset_vector_map_presyn) + (voxel_size * np.array([x, y, z])) + np.asarray(vector)) self.assertTrue(trg_loc_of_vector_phys.tolist( ) in partner_ids_to_locs_per_src.values()) for ( partner_id, partner_loc, ) in partner_ids_to_locs_per_src.items(): if np.array_equal( np.asarray(trg_loc_of_vector_phys), partner_loc): count_vectors_per_partner[partner_id] += 1 self.assertTrue( (list(count_vectors_per_partner.values()) - np.min(list(count_vectors_per_partner.values())) <= len( count_vectors_per_partner.keys())).all())
def test_two_disjoint_lines_softmask(self): LABEL_RADIUS = 3 RAW_RADIUS = 3 # exagerated to show problem BLEND_SMOOTHNESS = 10 bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256])) voxel_size = Coordinate([1, 1, 1]) swc_files = ("test_line_a.swc", "test_line_b.swc") swc_paths = tuple( Path(self.path_to(file_name)) for file_name in swc_files) # create two lines seperated by a given distance and write them to swc files intercepts, slopes = self._get_line_pair(roi=bb, dist=3 * LABEL_RADIUS) for intercept, slope, swc_path in zip(intercepts, slopes, swc_paths): swc_points = self._get_points(intercept, slope, bb) self._write_swc(swc_path, swc_points) # create swc sources fused = ArrayKey("FUSED") fused_labels = ArrayKey("FUSED_LABELS") swc_key_names = ("SWC_A", "SWC_B") labels_key_names = ("LABELS_A", "LABELS_B") raw_key_names = ("RAW_A", "RAW_B") swc_keys = tuple(PointsKey(name) for name in swc_key_names) labels_keys = tuple(ArrayKey(name) for name in labels_key_names) raw_keys = tuple(ArrayKey(name) for name in raw_key_names) # add request request = BatchRequest() request.add(fused, bb.get_shape()) request.add(fused_labels, bb.get_shape()) request.add(labels_keys[0], bb.get_shape()) request.add(labels_keys[1], bb.get_shape()) request.add(raw_keys[0], bb.get_shape()) request.add(raw_keys[1], bb.get_shape()) request.add(swc_keys[0], bb.get_shape()) request.add(swc_keys[1], bb.get_shape()) # data source for swc a data_sources_a = tuple() data_sources_a = (data_sources_a + SwcFileSource( swc_paths[0], swc_keys[0], PointsSpec(roi=bb)) + RasterizeSkeleton( points=swc_keys[0], array=labels_keys[0], array_spec=ArraySpec(interpolatable=False, dtype=np.uint32, voxel_size=voxel_size), radius=LABEL_RADIUS, ) + RasterizeSkeleton( points=swc_keys[0], array=raw_keys[0], array_spec=ArraySpec(interpolatable=False, dtype=np.uint32, voxel_size=voxel_size), radius=RAW_RADIUS, )) # data source for swc b data_sources_b = tuple() data_sources_b = (data_sources_b + SwcFileSource( swc_paths[1], swc_keys[1], PointsSpec(roi=bb)) + RasterizeSkeleton( points=swc_keys[1], array=labels_keys[1], array_spec=ArraySpec(interpolatable=False, dtype=np.uint32, voxel_size=voxel_size), radius=LABEL_RADIUS, ) + RasterizeSkeleton( points=swc_keys[1], array=raw_keys[1], array_spec=ArraySpec(interpolatable=False, dtype=np.uint32, voxel_size=voxel_size), radius=RAW_RADIUS, )) data_sources = tuple([data_sources_a, data_sources_b ]) + MergeProvider() pipeline = data_sources + FusionAugment( raw_keys[0], raw_keys[1], labels_keys[0], labels_keys[1], fused, fused_labels, blend_mode="labels_mask", blend_smoothness=BLEND_SMOOTHNESS, num_blended_objects=0, ) with build(pipeline): batch = pipeline.request_batch(request) fused_data = batch[fused].data fused_data = np.pad(fused_data, (1, ), "constant", constant_values=(0, )) a_data = batch[raw_keys[0]].data a_data = np.pad(a_data, (1, ), "constant", constant_values=(0, )) b_data = batch[raw_keys[1]].data b_data = np.pad(b_data, (1, ), "constant", constant_values=(0, )) all_data = np.zeros((5, ) + fused_data.shape) all_data[0, :, :, :] = fused_data all_data[1, :, :, :] = a_data + b_data all_data[2, :, :, :] = fused_data - a_data - b_data all_data[3, :, :, :] = a_data all_data[4, :, :, :] = b_data # Uncomment to visualize problem if imported_volshow: volshow(all_data) # input("Press enter when you are done viewing the data: ") diff = np.linalg.norm(fused_data - a_data - b_data) self.assertAlmostEqual(diff, 0)
def test_rasterize_speed(self): # This is worryingly slow for such a small volume (256**3) and only 2 # straight lines for skeletons. LABEL_RADIUS = 3 bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256])) voxel_size = Coordinate([1, 1, 1]) swc_files = ("test_line_a.swc", "test_line_b.swc") swc_paths = tuple(Path(self.path_to(file_name)) for file_name in swc_files) # create two lines seperated by a given distance and write them to swc files intercepts, slopes = self._get_line_pair(roi=bb, dist=3 * LABEL_RADIUS) for intercept, slope, swc_path in zip(intercepts, slopes, swc_paths): swc_points = self._get_points(intercept, slope, bb) self._write_swc(swc_path, swc_points) # create swc sources swc_key_names = ("SWC_A", "SWC_B") labels_key_names = ("LABELS_A", "LABELS_B") swc_keys = tuple(PointsKey(name) for name in swc_key_names) labels_keys = tuple(ArrayKey(name) for name in labels_key_names) # add request request = BatchRequest() request.add(labels_keys[0], bb.get_shape()) request.add(labels_keys[1], bb.get_shape()) request.add(swc_keys[0], bb.get_shape()) request.add(swc_keys[1], bb.get_shape()) # data source for swc a data_sources_a = tuple() data_sources_a = ( data_sources_a + SwcFileSource(swc_paths[0], swc_keys[0], PointsSpec(roi=bb)) + RasterizeSkeleton( points=swc_keys[0], array=labels_keys[0], array_spec=ArraySpec( interpolatable=False, dtype=np.uint32, voxel_size=voxel_size ), radius=LABEL_RADIUS, ) ) # data source for swc b data_sources_b = tuple() data_sources_b = ( data_sources_b + SwcFileSource(swc_paths[1], swc_keys[1], PointsSpec(roi=bb)) + RasterizeSkeleton( points=swc_keys[1], array=labels_keys[1], array_spec=ArraySpec( interpolatable=False, dtype=np.uint32, voxel_size=voxel_size ), radius=LABEL_RADIUS, ) ) data_sources = tuple([data_sources_a, data_sources_b]) + MergeProvider() pipeline = data_sources t1 = time.time() with build(pipeline): batch = pipeline.request_batch(request) a_data = batch[labels_keys[0]].data a_data = np.pad(a_data, (1,), "constant", constant_values=(0,)) b_data = batch[labels_keys[1]].data b_data = np.pad(b_data, (1,), "constant", constant_values=(0,)) t2 = time.time() self.assertLess(t2 - t1, 0.1)
def test_get_neuron_pair(self): path = Path(self.path_to("test_swc_source.swc")) # write test swc self._write_swc(path, self._toy_swc_points().to_nx_graph()) # read arrays swc_source = PointsKey("SWC_SOURCE") ensure_nonempty = PointsKey("ENSURE_NONEMPTY") labels_source = ArrayKey("LABELS_SOURCE") img_source = ArrayKey("IMG_SOURCE") img_swc = PointsKey("IMG_SWC") label_swc = PointsKey("LABEL_SWC") imgs = ArrayKey("IMGS") labels = ArrayKey("LABELS") points_a = PointsKey("SKELETON_A") points_b = PointsKey("SKELETON_B") img_a = ArrayKey("VOLUME_A") img_b = ArrayKey("VOLUME_B") labels_a = ArrayKey("LABELS_A") labels_b = ArrayKey("LABELS_B") data_shape = 5 output_shape = Coordinate((data_shape, data_shape, data_shape)) # Get points from test swc swc_file_source = SwcFileSource( path, [swc_source, ensure_nonempty], [ PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31))), PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31))), ], ) # Create an artificial image source by rasterizing the points image_source = (SwcFileSource( path, [img_swc], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]) + RasterizeSkeleton( points=img_swc, array=img_source, array_spec=ArraySpec( interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1)), ), ) + BinarizeLabels(labels=img_source, labels_binary=imgs) + GrowLabels(array=imgs, radius=0)) # Create an artificial label source by rasterizing the points label_source = (SwcFileSource(path, [label_swc], [ PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31))) ]) + RasterizeSkeleton( points=label_swc, array=labels_source, array_spec=ArraySpec( interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1)), ), ) + BinarizeLabels(labels=labels_source, labels_binary=labels) + GrowLabels(array=labels, radius=1)) skeleton = tuple() skeleton += ((swc_file_source, image_source, label_source) + MergeProvider() + RandomLocation(ensure_nonempty=ensure_nonempty, ensure_centered=True)) pipeline = skeleton + GetNeuronPair( point_source=swc_source, nonempty_placeholder=ensure_nonempty, array_source=imgs, label_source=labels, points=(points_a, points_b), arrays=(img_a, img_b), labels=(labels_a, labels_b), seperate_by=(1, 3), shift_attempts=100, request_attempts=10, output_shape=output_shape, ) request = BatchRequest() request.add(points_a, output_shape) request.add(points_b, output_shape) request.add(img_a, output_shape) request.add(img_b, output_shape) request.add(labels_a, output_shape) request.add(labels_b, output_shape) with build(pipeline): for i in range(10): batch = pipeline.request_batch(request) assert all([ x in batch for x in [points_a, points_b, img_a, img_b, labels_a, labels_b] ]) min_dist = 5 for a, b in itertools.product( batch[points_a].nodes, batch[points_b].nodes, ): min_dist = min( min_dist, np.linalg.norm(a.location - b.location), ) self.assertLessEqual(min_dist, 3) self.assertGreaterEqual(min_dist, 1)
def get_test_data_sources(setup_config): input_shape = Coordinate(setup_config["INPUT_SHAPE"]) voxel_size = Coordinate(setup_config["VOXEL_SIZE"]) input_size = input_shape * voxel_size micron_scale = voxel_size[0] # New array keys # Note: These are intended to be requested with size input_size raw = ArrayKey("RAW") consensus = PointsKey("CONSENSUS") skeletonization = PointsKey("SKELETONIZATION") matched = PointsKey("MATCHED") nonempty_placeholder = PointsKey("NONEMPTY") labels = ArrayKey("LABELS") if setup_config["FUSION_PIPELINE"]: ensure_nonempty = nonempty_placeholder else: ensure_nonempty = consensus data_sources = (( TestImageSource( array=raw, array_specs={ raw: ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, size=input_size * 3, voxel_size=voxel_size, ), TestPointSource( points=[consensus, nonempty_placeholder], directed=True, size=input_size * 3, num_points=30, ), TestPointSource( points=[skeletonization], directed=False, size=input_size * 3, num_points=333, ), ) + MergeProvider() + RandomLocation( ensure_nonempty=ensure_nonempty, ensure_centered=True, point_balance_radius=10 * micron_scale, ) + TopologicalMatcher( skeletonization, consensus, matched, match_distance_threshold=50 * micron_scale, max_gap_crossing=30 * micron_scale, try_complete=False, use_gurobi=True, ) + RasterizeSkeleton( points=matched, array=labels, array_spec=ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint64), ) + Normalize(raw)) return ( data_sources, raw, labels, consensus, nonempty_placeholder, skeletonization, matched, )