def get_test_data_sources(setup_config): input_shape = Coordinate(setup_config["INPUT_SHAPE"]) voxel_size = Coordinate(setup_config["VOXEL_SIZE"]) input_size = input_shape * voxel_size micron_scale = voxel_size[0] # New array keys # Note: These are intended to be requested with size input_size raw = ArrayKey("RAW") matched = GraphKey("MATCHED") nonempty_placeholder = GraphKey("NONEMPTY") labels = ArrayKey("LABELS") ensure_nonempty = matched data_sources = (( TestImageSource( array=raw, array_specs={ raw: ArraySpec(interpolatable=True, voxel_size=voxel_size, dtype=np.uint16) }, size=input_size * 3, voxel_size=voxel_size, ), TestPointSource( points=[matched, nonempty_placeholder], directed=False, size=input_size * 3, num_points=333, ), ) + MergeProvider() + RandomLocation( ensure_nonempty=ensure_nonempty, ensure_centered=True, point_balance_radius=10 * micron_scale, ) + RasterizeSkeleton( points=matched, array=labels, array_spec=ArraySpec( interpolatable=False, voxel_size=voxel_size, dtype=np.uint64), ) + Normalize(raw)) return ( data_sources, raw, labels, nonempty_placeholder, matched, )
def test_merge_basics(self): voxel_size = (1, 1, 1) GraphKey("PRESYN") ArrayKey("GT_LABELS") graphsource = GraphTestSource(voxel_size) arraysource = ArrayTestSoure(voxel_size) pipeline = (graphsource, arraysource) + MergeProvider() + RandomLocation() window_request = Coordinate((50, 50, 50)) with build(pipeline): # Check basic merging. request = BatchRequest() request.add((GraphKeys.PRESYN), window_request) request.add((ArrayKeys.GT_LABELS), window_request) batch_res = pipeline.request_batch(request) self.assertTrue(ArrayKeys.GT_LABELS in batch_res.arrays) self.assertTrue(GraphKeys.PRESYN in batch_res.graphs) # Check that request of only one source also works. request = BatchRequest() request.add((GraphKeys.PRESYN), window_request) batch_res = pipeline.request_batch(request) self.assertFalse(ArrayKeys.GT_LABELS in batch_res.arrays) self.assertTrue(GraphKeys.PRESYN in batch_res.graphs) # Check that it fails, when having two sources that provide the same type. arraysource2 = ArrayTestSoure(voxel_size) pipeline_fail = (arraysource, arraysource2) + MergeProvider() + RandomLocation() with self.assertRaises(PipelineSetupError): with build(pipeline_fail): pass
def test_6_neighborhood(): # array keys graph = GraphKey("GRAPH") neighborhood = ArrayKey("NEIGHBORHOOD") neighborhood_mask = ArrayKey("NEIGHBORHOOD_MASK") distance = 1 pipeline = TestSource(graph) + Neighborhood( graph, neighborhood, neighborhood_mask, distance, array_specs={ neighborhood: ArraySpec(voxel_size=Coordinate((1, 1, 1))), neighborhood_mask: ArraySpec(voxel_size=Coordinate((1, 1, 1))), }, k=6, ) request = BatchRequest() request[neighborhood] = ArraySpec(roi=Roi((0, 0, 0), (10, 10, 10))) request[neighborhood_mask] = ArraySpec(roi=Roi((0, 0, 0), (10, 10, 10))) with build(pipeline): batch = pipeline.request_batch(request) n_data = batch[neighborhood].data n_mask = batch[neighborhood_mask].data masked_ind = list( set([(0, i, 0) for i in range(10) if i not in [0, 4]] + [(i, 5, 0) for i in range(10)] + [(i, 4, 0) for i in range(10) if i not in [0]])) assert all(n_mask[tuple(zip(*masked_ind))] ), f"expected {masked_ind} but saw {np.where(n_mask==1)}"
def test_mirror(self): test_graph = GraphKey("TEST_GRAPH") pipeline = TestSource() + SimpleAugment( mirror_only=[0, 1, 2], transpose_only=[] ) request = BatchRequest() request[GraphKeys.TEST_GRAPH] = GraphSpec(roi=Roi((0, 20, 33), (100, 100, 120))) possible_loc = [[50, 49], [70, 29], [100, 86]] with build(pipeline): seen_mirrored = False for i in range(100): batch = pipeline.request_batch(request) assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1 node = list(batch[GraphKeys.TEST_GRAPH].nodes)[0] logging.debug(node.location) assert all( [ node.location[dim] in possible_loc[dim] for dim in range(3) ] ) seen_mirrored = seen_mirrored or any( [node.location[dim] == possible_loc[dim][1] for dim in range(3)] ) assert Roi((0, 20, 33), (100, 100, 120)).contains(batch[GraphKeys.TEST_GRAPH].spec.roi) assert batch[GraphKeys.TEST_GRAPH].spec.roi.contains(node.location) assert seen_mirrored
def test_relabel_components(self): path = Path(self.path_to("test_swc_source.swc")) # write test swc self._write_swc(path, self._toy_swc_points().to_nx_graph()) # read arrays swc = GraphKey("SWC") source = SwcFileSource(path, [swc]) with build(source): batch = source.request_batch( BatchRequest({swc: GraphSpec(roi=Roi((0, 1, 5), (11, 10, 1)))})) temp_g = batch.points[swc] temp_g.relabel_connected_components() previous_label = None ccs = list(temp_g.connected_components) self.assertEqual(len(ccs), 3) for cc in ccs: self.assertEqual(len(cc), 10) label = None for point_id in cc: if label is None: label = temp_g.node(point_id).attrs["component"] self.assertNotEqual(label, previous_label) self.assertEqual( temp_g.node(point_id).attrs["component"], label) previous_label = label
def test_without_placeholder(self): test_labels = ArrayKey("TEST_LABELS") test_points = GraphKey("TEST_POINTS") pipeline = ( PointTestSource3D() + RandomLocation(ensure_nonempty=test_points) + ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) + Snapshot( {test_labels: "volumes/labels"}, output_dir=self.path_to(), output_filename="elastic_augment_test{id}-{iteration}.hdf", )) with build(pipeline): for i in range(2): request_size = Coordinate((40, 40, 40)) request_a = BatchRequest(random_seed=i) request_a.add(test_points, request_size) request_b = BatchRequest(random_seed=i) request_b.add(test_points, request_size) request_b.add(test_labels, request_size) # No array to provide a voxel size to ElasticAugment with pytest.raises(PipelineRequestError): pipeline.request_batch(request_a) batch_b = pipeline.request_batch(request_b) self.assertIn(test_labels, batch_b)
def test_output(self): cropped_roi_raw = Roi((400, 40, 40), (1000, 100, 100)) cropped_roi_presyn = Roi((800, 80, 80), (800, 80, 80)) GraphKey("PRESYN") pipeline = ( ExampleSourceCrop() + Crop(ArrayKeys.RAW, cropped_roi_raw) + Crop(GraphKeys.PRESYN, cropped_roi_presyn) ) with build(pipeline): self.assertTrue(pipeline.spec[ArrayKeys.RAW].roi == cropped_roi_raw) self.assertTrue(pipeline.spec[GraphKeys.PRESYN].roi == cropped_roi_presyn) pipeline = ExampleSourceCrop() + Crop( ArrayKeys.RAW, fraction_negative=(0.25, 0, 0), fraction_positive=(0.25, 0, 0), ) expected_roi_raw = Roi((650, 20, 20), (900, 180, 180)) with build(pipeline): logger.info(pipeline.spec[ArrayKeys.RAW].roi) logger.info(expected_roi_raw) self.assertTrue(pipeline.spec[ArrayKeys.RAW].roi == expected_roi_raw)
def test_square(self): test_graph = GraphKey("TEST_GRAPH") test_array1 = ArrayKey("TEST_ARRAY1") test_array2 = ArrayKey("TEST_ARRAY2") pipeline = ((ArrayTestSource(), TestSource()) + MergeProvider() + Pad(test_array1, None) + Pad(test_array2, None) + Pad(test_graph, None) + SimpleAugment( mirror_only=[1,2], transpose_only=[1,2] )) request = BatchRequest() request[GraphKeys.TEST_GRAPH] = GraphSpec(roi=Roi((0, 50, 65), (100, 100, 100))) request[ArrayKeys.TEST_ARRAY1] = ArraySpec(roi=Roi((0, 0, 15), (100, 200, 200))) request[ArrayKeys.TEST_ARRAY2] = ArraySpec(roi=Roi((0, 50, 65), (100, 100, 100))) with build(pipeline): for i in range(100): batch = pipeline.request_batch(request) assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1 for (array_key, array) in batch.arrays.items(): assert batch.arrays[array_key].data.shape == batch.arrays[array_key].spec.roi.get_shape()
def test_placeholder(self): test_labels = ArrayKey("TEST_LABELS") test_points = GraphKey("TEST_POINTS") pipeline = ( PointTestSource3D() + RandomLocation(ensure_nonempty=test_points) + ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) + Snapshot( {test_labels: "volumes/labels"}, output_dir=self.path_to(), output_filename="elastic_augment_test{id}-{iteration}.hdf", )) with build(pipeline): for i in range(2): request_size = Coordinate((40, 40, 40)) request_a = BatchRequest(random_seed=i) request_a.add(test_points, request_size) request_a.add(test_labels, request_size, placeholder=True) request_b = BatchRequest(random_seed=i) request_b.add(test_points, request_size) request_b.add(test_labels, request_size) batch_a = pipeline.request_batch(request_a) batch_b = pipeline.request_batch(request_b) points_a = batch_a[test_points].nodes points_b = batch_b[test_points].nodes for a, b in zip(points_a, points_b): assert all(np.isclose(a.location, b.location))
def test_transpose(): voxel_size = Coordinate((20, 20)) graph_key = GraphKey("GRAPH") array_key = ArrayKey("ARRAY") graph = Graph( [Node(id=1, location=np.array([450, 550]))], [], GraphSpec(roi=Roi((100, 200), (800, 600))), ) data = np.zeros([40, 30]) data[17, 17] = 1 array = Array( data, ArraySpec(roi=Roi((100, 200), (800, 600)), voxel_size=voxel_size)) default_pipeline = ( (GraphSource(graph_key, graph), ArraySource(array_key, array)) + MergeProvider() + SimpleAugment( mirror_only=[], transpose_only=[0, 1], transpose_probs=[0, 0])) transpose_pipeline = ( (GraphSource(graph_key, graph), ArraySource(array_key, array)) + MergeProvider() + SimpleAugment( mirror_only=[], transpose_only=[0, 1], transpose_probs=[1, 1])) request = BatchRequest() request[graph_key] = GraphSpec(roi=Roi((400, 500), (200, 300))) request[array_key] = ArraySpec(roi=Roi((400, 500), (200, 300))) with build(default_pipeline): expected_location = [450, 550] batch = default_pipeline.request_batch(request) assert len(list(batch[graph_key].nodes)) == 1 node = list(batch[graph_key].nodes)[0] assert all(np.isclose(node.location, expected_location)) node_voxel_index = Coordinate( (node.location - batch[array_key].spec.roi.get_offset()) / voxel_size) assert ( batch[array_key].data[node_voxel_index] == 1 ), f"Node at {np.where(batch[array_key].data == 1)} not {node_voxel_index}" with build(transpose_pipeline): expected_location = [410, 590] batch = transpose_pipeline.request_batch(request) assert len(list(batch[graph_key].nodes)) == 1 node = list(batch[graph_key].nodes)[0] assert all(np.isclose(node.location, expected_location)) node_voxel_index = Coordinate( (node.location - batch[array_key].spec.roi.get_offset()) / voxel_size) assert ( batch[array_key].data[node_voxel_index] == 1 ), f"Node at {np.where(batch[array_key].data == 1)} not {node_voxel_index}"
def test_3d_basics(self): test_labels = ArrayKey("TEST_LABELS") test_points = GraphKey("TEST_POINTS") test_raster = ArrayKey("TEST_RASTER") pipeline = ( PointTestSource3D() + ElasticAugment( [10, 10, 10], [0.1, 0.1, 0.1], # [0, 0, 0], # no jitter [0, 2.0 * math.pi], ) + RasterizeGraph( test_points, test_raster, settings=RasterizationSettings(radius=2, mode="peak"), ) + Snapshot( { test_labels: "volumes/labels", test_raster: "volumes/raster" }, dataset_dtypes={test_raster: np.float32}, output_dir=self.path_to(), output_filename="elastic_augment_test{id}-{iteration}.hdf", )) for _ in range(5): with build(pipeline): request_roi = Roi((-20, -20, -20), (40, 40, 40)) request = BatchRequest() request[test_labels] = ArraySpec(roi=request_roi) request[test_points] = GraphSpec(roi=request_roi) request[test_raster] = ArraySpec(roi=request_roi) batch = pipeline.request_batch(request) labels = batch[test_labels] points = batch[test_points] # the point at (0, 0, 0) should not have moved self.assertTrue(points.contains(0)) labels_data_roi = ( labels.spec.roi - labels.spec.roi.get_begin()) / labels.spec.voxel_size # points should have moved together with the voxels for point in points.nodes: loc = point.location - labels.spec.roi.get_begin() loc = loc / labels.spec.voxel_size loc = Coordinate(int(round(x)) for x in loc) if labels_data_roi.contains(loc): self.assertEqual(labels.data[loc], point.id)
def test_3d(self): test_graph = GraphKey("TEST_GRAPH") graph_spec = GraphSpec(roi=Roi((0, 0, 0), (5, 5, 5))) test_array = ArrayKey("TEST_ARRAY") array_spec = ArraySpec( roi=Roi((0, 0, 0), (5, 5, 5)), voxel_size=Coordinate((1, 1, 1)) ) test_array2 = ArrayKey("TEST_ARRAY2") array2_spec = ArraySpec( roi=Roi((0, 0, 0), (5, 5, 5)), voxel_size=Coordinate((1, 1, 1)) ) snapshot_request = BatchRequest() snapshot_request.add(test_graph, Coordinate((5, 5, 5))) pipeline = ExampleSource( [test_graph, test_array, test_array2], [graph_spec, array_spec, array2_spec] ) + Snapshot( { test_graph: "graphs/graph", test_array: "volumes/array", test_array2: "volumes/array2", }, output_dir=str(self.test_dir), every=2, additional_request=snapshot_request, output_filename="snapshot.hdf", ) snapshot_file_path = Path(self.test_dir, "snapshot.hdf") with build(pipeline): request = BatchRequest() roi = Roi((0, 0, 0), (5, 5, 5)) request[test_array] = ArraySpec(roi=roi) request[test_array2] = ArraySpec(roi=roi) pipeline.request_batch(request) assert snapshot_file_path.exists() f = h5py.File(snapshot_file_path) assert f["volumes/array"] is not None assert f["graphs/graph-ids"] is not None snapshot_file_path.unlink() pipeline.request_batch(request) assert not snapshot_file_path.exists()
def test_multi_transpose(self): test_graph = GraphKey("TEST_GRAPH") test_array1 = ArrayKey("TEST_ARRAY1") test_array2 = ArrayKey("TEST_ARRAY2") point = np.array([50, 70, 100]) transpose_dims = [0, 1, 2] pipeline = (ArrayTestSource(), ExampleSource()) + MergeProvider() + SimpleAugment( mirror_only=[], transpose_only=transpose_dims) request = BatchRequest() offset = (0, 20, 33) request[GraphKeys.TEST_GRAPH] = GraphSpec( roi=Roi(offset, (100, 100, 120))) request[ArrayKeys.TEST_ARRAY1] = ArraySpec( roi=Roi((0, 0, 0), (100, 200, 300))) request[ArrayKeys.TEST_ARRAY2] = ArraySpec( roi=Roi((0, 100, 250), (100, 100, 50))) # Create all possible permurations of our transpose dims transpose_combinations = list(permutations(transpose_dims, 3)) possible_loc = np.zeros((len(transpose_combinations), 3)) # Transpose points in all possible ways for i, comb in enumerate(transpose_combinations): possible_loc[i] = point[np.array(comb)] with build(pipeline): seen_transposed = False seen_node = True for i in range(100): batch = pipeline.request_batch(request) if len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1: seen_node = True node = list(batch[GraphKeys.TEST_GRAPH].nodes)[0] assert node.location in possible_loc seen_transposed = seen_transposed or any( [node.location[dim] != point[dim] for dim in range(3)]) assert Roi((0, 20, 33), (100, 100, 120)).contains( batch[GraphKeys.TEST_GRAPH].spec.roi) assert batch[GraphKeys.TEST_GRAPH].spec.roi.contains( node.location) for (array_key, array) in batch.arrays.items(): assert batch.arrays[array_key].data.shape == batch.arrays[ array_key].spec.roi.get_shape() assert seen_transposed assert seen_node
def test_pipeline3(self): array_key = ArrayKey("TEST_ARRAY") points_key = GraphKey("TEST_POINTS") voxel_size = Coordinate((1, 1)) spec = ArraySpec(voxel_size=voxel_size, interpolatable=True) hdf5_source = Hdf5Source(self.fake_data_file, {array_key: "testdata"}, array_specs={array_key: spec}) csv_source = CsvPointsSource( self.fake_points_file, points_key, GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))), ) request = BatchRequest() shape = Coordinate((60, 60)) request.add(array_key, shape, voxel_size=Coordinate((1, 1))) request.add(points_key, shape) shift_node = ShiftAugment(prob_slip=0.2, prob_shift=0.2, sigma=5, shift_axis=0) pipeline = ((hdf5_source, csv_source) + MergeProvider() + RandomLocation(ensure_nonempty=points_key) + shift_node) with build(pipeline) as b: request = b.request_batch(request) # print(request[points_key]) target_vals = [ self.fake_data[point[0]][point[1]] for point in self.fake_points ] result_data = request[array_key].data result_points = list(request[points_key].nodes) result_vals = [ result_data[int(point.location[0])][int(point.location[1])] for point in result_points ] for result_val in result_vals: self.assertTrue( result_val in target_vals, msg= "result value {} at points {} not in target values {} at points {}" .format( result_val, list(result_points), target_vals, self.fake_points, ), )
def test_output(self): """ Fails due to probabilities being calculated in advance, rather than after creating each roi. The new approach does not account for all possible roi's containing each point, some of which may not contain its nearest neighbors. """ GraphKey('TEST_POINTS') pipeline = (ExampleSourceRandomLocation() + RandomLocation( ensure_nonempty=GraphKeys.TEST_POINTS, point_balance_radius=100)) # count the number of times we get each point histogram = {} with build(pipeline): for i in range(5000): batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_POINTS: GraphSpec(roi=Roi((0, 0, 0), (100, 100, 100))) })) points = { node.id: node for node in batch[GraphKeys.TEST_POINTS].nodes } self.assertTrue(len(points) > 0) self.assertTrue((1 in points) != (2 in points or 3 in points), points) for node in batch[GraphKeys.TEST_POINTS].nodes: if node.id not in histogram: histogram[node.id] = 1 else: histogram[node.id] += 1 total = sum(histogram.values()) for k, v in histogram.items(): histogram[k] = float(v) / total # we should get roughly the same count for each point for i in histogram.keys(): for j in histogram.keys(): self.assertAlmostEqual(histogram[i], histogram[j], 1)
def test_read_single_swc(self): path = Path(self.path_to("test_swc_source.swc")) # write test swc self._write_swc(path, self._toy_swc_points().to_nx_graph()) # read arrays swc = GraphKey("SWC") source = SwcFileSource(path, [swc]) with build(source): batch = source.request_batch( BatchRequest({swc: GraphSpec(roi=Roi((0, 0, 5), (11, 11, 1)))})) for node in self._toy_swc_points().nodes: self.assertCountEqual(node.location, batch.points[swc].node(node.id).location)
def test_req_full_roi(self): GraphKey("TEST_GRAPH") possible_roi = Roi((0, 0, 0), (1000, 1000, 1000)) pipeline = (SourceGraphLocation() + BatchTester(possible_roi, exact=False) + RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH)) with build(pipeline): batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_GRAPH: GraphSpec(roi=Roi((0, 0, 0), (1000, 1000, 1000))) })) assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
def test_roi_one_point(self): GraphKey("TEST_GRAPH") upstream_roi = Roi((500, 500, 500), (1, 1, 1)) pipeline = (SourceGraphLocation() + BatchTester(upstream_roi, exact=True) + RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH)) with build(pipeline): for i in range(500): batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_GRAPH: GraphSpec(roi=Roi((0, 0, 0), (1, 1, 1))) })) assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
def test_dim_size_1(self): GraphKey("TEST_GRAPH") upstream_roi = Roi((500, 401, 401), (1, 200, 200)) pipeline = (SourceGraphLocation() + BatchTester(upstream_roi, exact=False) + RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH)) # count the number of times we get each node with build(pipeline): for i in range(500): batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_GRAPH: GraphSpec(roi=Roi((0, 0, 0), (1, 100, 100))) })) assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
def test_two_transpose(self): test_graph = GraphKey("TEST_GRAPH") test_array1 = ArrayKey("TEST_ARRAY1") test_array2 = ArrayKey("TEST_ARRAY2") transpose_dims = [1, 2] pipeline = (ArrayTestSource(), ExampleSource()) + MergeProvider() + SimpleAugment( mirror_only=[], transpose_only=transpose_dims) request = BatchRequest() request[GraphKeys.TEST_GRAPH] = GraphSpec( roi=Roi((0, 20, 33), (100, 100, 120))) request[ArrayKeys.TEST_ARRAY1] = ArraySpec( roi=Roi((0, 0, 0), (100, 200, 300))) request[ArrayKeys.TEST_ARRAY2] = ArraySpec( roi=Roi((0, 100, 250), (100, 100, 50))) possible_loc = [[50, 50], [70, 100], [100, 70]] with build(pipeline): seen_transposed = False for i in range(100): batch = pipeline.request_batch(request) assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1 node = list(batch[GraphKeys.TEST_GRAPH].nodes)[0] logging.debug(node.location) assert all([ node.location[dim] in possible_loc[dim] for dim in range(3) ]) seen_transposed = seen_transposed or any([ node.location[dim] != possible_loc[dim][0] for dim in range(3) ]) assert Roi((0, 20, 33), (100, 100, 120)).contains( batch[GraphKeys.TEST_GRAPH].spec.roi) assert batch[GraphKeys.TEST_GRAPH].spec.roi.contains( node.location) for (array_key, array) in batch.arrays.items(): assert batch.arrays[array_key].data.shape == batch.arrays[ array_key].spec.roi.get_shape() assert seen_transposed
def test_output(self): GraphKey("TEST_GRAPH") pipeline = TestSourceRandomLocation() + RandomLocation( ensure_nonempty=GraphKeys.TEST_GRAPH) # count the number of times we get each node histogram = {} with build(pipeline): for i in range(5000): batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_GRAPH: GraphSpec(roi=Roi((0, 0, 0), (100, 100, 100))) })) nodes = list(batch[GraphKeys.TEST_GRAPH].nodes) node_ids = [v.id for v in nodes] self.assertTrue(len(nodes) > 0) self.assertTrue( (1 in node_ids) != (2 in node_ids or 3 in node_ids), node_ids, ) for node in batch[GraphKeys.TEST_GRAPH].nodes: if node.id not in histogram: histogram[node.id] = 1 else: histogram[node.id] += 1 total = sum(histogram.values()) for k, v in histogram.items(): histogram[k] = float(v) / total # we should get roughly the same count for each point for i in histogram.keys(): for j in histogram.keys(): self.assertAlmostEqual(histogram[i], histogram[j], 1)
def test_equal_probability(self): GraphKey('TEST_POINTS') pipeline = (ExampleSourceRandomLocation() + RandomLocation(ensure_nonempty=GraphKeys.TEST_POINTS)) # count the number of times we get each point histogram = {} with build(pipeline): for i in range(5000): batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_POINTS: GraphSpec(roi=Roi((0, 0, 0), (10, 10, 10))) })) points = { node.id: node for node in batch[GraphKeys.TEST_POINTS].nodes } self.assertTrue(len(points) > 0) self.assertTrue((1 in points) != (2 in points or 3 in points), points) for point in batch[GraphKeys.TEST_POINTS].nodes: if point.id not in histogram: histogram[point.id] = 1 else: histogram[point.id] += 1 total = sum(histogram.values()) for k, v in histogram.items(): histogram[k] = float(v) / total # we should get roughly the same count for each point for i in histogram.keys(): for j in histogram.keys(): self.assertAlmostEqual(histogram[i], histogram[j], 1)
def test_filter_components(): raw = GraphKey("RAW") pipeline = TestSource() + FilterComponents(raw, 100, Coordinate((10, 10, 10))) request_no_fallback = BatchRequest() request_no_fallback[raw] = GraphSpec(roi=Roi((0, 0, 0), (20, 20, 20))) with build(pipeline): batch = pipeline.request_batch(request_no_fallback) assert raw in batch assert len(list(batch[raw].connected_components)) == 1 request_fallback = BatchRequest() request_fallback[raw] = GraphSpec(roi=Roi((20, 20, 20), (20, 20, 20))) with build(pipeline): batch = pipeline.request_batch(request_fallback) assert raw in batch assert len(list(batch[raw].connected_components)) == 0
def test_keep_node_ids(self): path = Path(self.path_to("test_swc_source.swc")) # write test swc self._write_swc( path, self._toy_swc_points().to_nx_graph(), {"resolution": np.array([2, 2, 2])}, ) # read arrays swc = GraphKey("SWC") source = SwcFileSource(path, [swc], keep_ids=True) with build(source): batch = source.request_batch( BatchRequest({swc: GraphSpec(roi=Roi((0, 5, 10), (1, 2, 1)))})) temp_g = batch.points[swc] # root is only node with in_degree 0 current = [n for n, d in temp_g.in_degree() if d == 0][0] # edge nodes can't keep the same id in case one node has multiple children # in the roi. expected_path = [ tuple(np.array([0.0, 5.0, 10.0])), tuple(np.array([0.0, 6.0, 10.0])), tuple(np.array([0.0, 7.0, 10.0])), ] path = [] while current is not None: current_node = temp_g.node(current) path.append(tuple(current_node.location)) successors = list(temp_g.successors(current_node)) current = successors[0] if len(successors) == 1 else None for a, b in zip(path, expected_path): assert all(np.isclose(a, b))
def test_output(self): graph = GraphKey("TEST_GRAPH") labels = ArrayKey("TEST_LABELS") pipeline = (TestSourcePad() + Pad(labels, Coordinate((20, 20, 20)), value=1) + Pad(graph, Coordinate((10, 10, 10)))) with build(pipeline): self.assertTrue( pipeline.spec[labels].roi == Roi((180, 0, 0), (1840, 220, 220))) self.assertTrue( pipeline.spec[graph].roi == Roi((190, 10, 10), (1820, 200, 200))) batch = pipeline.request_batch( BatchRequest( {labels: ArraySpec(Roi((180, 0, 0), (20, 20, 20)))})) self.assertEqual(np.sum(batch.arrays[labels].data), 1 * 10 * 10)
def test_overlap(self): path = Path(self.path_to("test_swc_sources")) path.mkdir(parents=True, exist_ok=True) # write test swc for i in range(3): self._write_swc( path / "{}.swc".format(i), self._toy_swc_points().to_nx_graph(), {"offset": np.array([0, i, 0])}, ) # read arrays swc = GraphKey("SWC") source = SwcFileSource(path, [swc]) with build(source): batch = source.request_batch( BatchRequest({swc: GraphSpec(roi=Roi((0, 0, 5), (11, 13, 1)))})) temp_g = batch.points[swc] temp_g.relabel_connected_components() previous_label = None ccs = list(temp_g.connected_components) self.assertEqual(len(ccs), 3) for cc in ccs: self.assertEqual(len(cc), 41) label = None for point_id in cc: if label is None: label = temp_g.node(point_id).attrs["component"] self.assertNotEqual(label, previous_label) self.assertEqual( temp_g.node(point_id).attrs["component"], label) previous_label = label
def test_output_min_distance(self): voxel_size = Coordinate((20, 2, 2)) ArrayKey("GT_VECTORS_MAP_PRESYN") GraphKey("PRESYN") GraphKey("POSTSYN") arraytypes_to_source_target_pointstypes = { ArrayKeys.GT_VECTORS_MAP_PRESYN: (GraphKeys.PRESYN, GraphKeys.POSTSYN) } arraytypes_to_stayinside_arraytypes = { ArrayKeys.GT_VECTORS_MAP_PRESYN: ArrayKeys.GT_LABELS } # test for partner criterion 'min_distance' radius_phys = 30 pipeline_min_distance = AddVectorMapTestSource() + AddVectorMap( src_and_trg_points=arraytypes_to_source_target_pointstypes, voxel_sizes={ArrayKeys.GT_VECTORS_MAP_PRESYN: voxel_size}, radius_phys=radius_phys, partner_criterion="min_distance", stayinside_array_keys=arraytypes_to_stayinside_arraytypes, pad_for_partners=(0, 0, 0), ) with build(pipeline_min_distance): request = BatchRequest() raw_roi = pipeline_min_distance.spec[ArrayKeys.RAW].roi gt_labels_roi = pipeline_min_distance.spec[ArrayKeys.GT_LABELS].roi presyn_roi = pipeline_min_distance.spec[GraphKeys.PRESYN].roi request.add(ArrayKeys.RAW, raw_roi.get_shape()) request.add(ArrayKeys.GT_LABELS, gt_labels_roi.get_shape()) request.add(GraphKeys.PRESYN, presyn_roi.get_shape()) request.add(GraphKeys.POSTSYN, presyn_roi.get_shape()) request.add(ArrayKeys.GT_VECTORS_MAP_PRESYN, presyn_roi.get_shape()) for identifier, spec in request.items(): spec.roi = spec.roi.shift((1000, 1000, 1000)) batch = pipeline_min_distance.request_batch(request) presyn_locs = {n.id: n for n in batch.graphs[GraphKeys.PRESYN].nodes} postsyn_locs = {n.id: n for n in batch.graphs[GraphKeys.POSTSYN].nodes} vector_map_presyn = batch.arrays[ArrayKeys.GT_VECTORS_MAP_PRESYN].data offset_vector_map_presyn = request[ ArrayKeys.GT_VECTORS_MAP_PRESYN ].roi.get_offset() self.assertTrue(len(presyn_locs) > 0) self.assertTrue(len(postsyn_locs) > 0) for loc_id, point in presyn_locs.items(): if request[ArrayKeys.GT_VECTORS_MAP_PRESYN].roi.contains( Coordinate(point.location) ): self.assertTrue( batch.arrays[ArrayKeys.GT_VECTORS_MAP_PRESYN].spec.roi.contains( Coordinate(point.location) ) ) dist_to_loc = {} for partner_id in point.attrs["partner_ids"]: if partner_id in postsyn_locs.keys(): partner_location = postsyn_locs[partner_id].location dist_to_loc[ np.linalg.norm(partner_location - point.location) ] = partner_location min_dist = np.min(list(dist_to_loc.keys())) relevant_partner_loc = dist_to_loc[min_dist] presyn_loc_shifted_vx = ( point.location - offset_vector_map_presyn ) // voxel_size radius_vx = [(radius_phys // vx_dim) for vx_dim in voxel_size] region_to_check = np.clip( [ (presyn_loc_shifted_vx - radius_vx), (presyn_loc_shifted_vx + radius_vx), ], a_min=(0, 0, 0), a_max=vector_map_presyn.shape[-3:], ) for x, y, z in itertools.product( range(int(region_to_check[0][0]), int(region_to_check[1][0])), range(int(region_to_check[0][1]), int(region_to_check[1][1])), range(int(region_to_check[0][2]), int(region_to_check[1][2])), ): if ( np.linalg.norm( (np.array((x, y, z)) - np.asarray(point.location)) ) < radius_phys ): vector = [ vector_map_presyn[dim][x, y, z] for dim in range(vector_map_presyn.shape[0]) ] if not np.sum(vector) == 0: trg_loc_of_vector_phys = ( np.asarray(offset_vector_map_presyn) + (voxel_size * np.array([x, y, z])) + np.asarray(vector) ) self.assertTrue( np.array_equal( trg_loc_of_vector_phys, relevant_partner_loc ) ) # test for partner criterion 'all' pipeline_all = AddVectorMapTestSource() + AddVectorMap( src_and_trg_points=arraytypes_to_source_target_pointstypes, voxel_sizes={ArrayKeys.GT_VECTORS_MAP_PRESYN: voxel_size}, radius_phys=radius_phys, partner_criterion="all", stayinside_array_keys=arraytypes_to_stayinside_arraytypes, pad_for_partners=(0, 0, 0), ) with build(pipeline_all): batch = pipeline_all.request_batch(request) presyn_locs = {n.id: n for n in batch.graphs[GraphKeys.PRESYN].nodes} postsyn_locs = {n.id: n for n in batch.graphs[GraphKeys.POSTSYN].nodes} vector_map_presyn = batch.arrays[ArrayKeys.GT_VECTORS_MAP_PRESYN].data offset_vector_map_presyn = request[ ArrayKeys.GT_VECTORS_MAP_PRESYN ].roi.get_offset() self.assertTrue(len(presyn_locs) > 0) self.assertTrue(len(postsyn_locs) > 0) for loc_id, point in presyn_locs.items(): if request[ArrayKeys.GT_VECTORS_MAP_PRESYN].roi.contains( Coordinate(point.location) ): self.assertTrue( batch.arrays[ArrayKeys.GT_VECTORS_MAP_PRESYN].spec.roi.contains( Coordinate(point.location) ) ) partner_ids_to_locs_per_src, count_vectors_per_partner = {}, {} for partner_id in point.attrs["partner_ids"]: if partner_id in postsyn_locs.keys(): partner_ids_to_locs_per_src[partner_id] = postsyn_locs[ partner_id ].location.tolist() count_vectors_per_partner[partner_id] = 0 presyn_loc_shifted_vx = ( point.location - offset_vector_map_presyn ) // voxel_size radius_vx = [(radius_phys // vx_dim) for vx_dim in voxel_size] region_to_check = np.clip( [ (presyn_loc_shifted_vx - radius_vx), (presyn_loc_shifted_vx + radius_vx), ], a_min=(0, 0, 0), a_max=vector_map_presyn.shape[-3:], ) for x, y, z in itertools.product( range(int(region_to_check[0][0]), int(region_to_check[1][0])), range(int(region_to_check[0][1]), int(region_to_check[1][1])), range(int(region_to_check[0][2]), int(region_to_check[1][2])), ): if ( np.linalg.norm( (np.array((x, y, z)) - np.asarray(point.location)) ) < radius_phys ): vector = [ vector_map_presyn[dim][x, y, z] for dim in range(vector_map_presyn.shape[0]) ] if not np.sum(vector) == 0: trg_loc_of_vector_phys = ( np.asarray(offset_vector_map_presyn) + (voxel_size * np.array([x, y, z])) + np.asarray(vector) ) self.assertTrue( trg_loc_of_vector_phys.tolist() in partner_ids_to_locs_per_src.values() ) for ( partner_id, partner_loc, ) in partner_ids_to_locs_per_src.items(): if np.array_equal( np.asarray(trg_loc_of_vector_phys), partner_loc ): count_vectors_per_partner[partner_id] += 1 self.assertTrue( ( list(count_vectors_per_partner.values()) - np.min(list(count_vectors_per_partner.values())) <= len(count_vectors_per_partner.keys()) ).all() )
def test_fast_transform_no_recompute(self): test_labels = ArrayKey("TEST_LABELS") test_points = GraphKey("TEST_POINTS") test_raster = ArrayKey("TEST_RASTER") fast_pipeline = (DensePointTestSource3D() + ElasticAugment( [10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi], use_fast_points_transform=True, recompute_missing_points=False, ) + RasterizeGraph( test_points, test_raster, settings=RasterizationSettings(radius=2, mode="peak"), )) reference_pipeline = ( DensePointTestSource3D() + ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) + RasterizeGraph( test_points, test_raster, settings=RasterizationSettings(radius=2, mode="peak"), )) timings = [] for i in range(5): points_fast = {} points_reference = {} # seed chosen specifically to make this test fail seed = i + 15 with build(fast_pipeline): request_roi = Roi((0, 0, 0), (40, 40, 40)) request = BatchRequest(random_seed=seed) request[test_labels] = ArraySpec(roi=request_roi) request[test_points] = GraphSpec(roi=request_roi) request[test_raster] = ArraySpec(roi=request_roi) t1_fast = time.time() batch = fast_pipeline.request_batch(request) t2_fast = time.time() points_fast = { node.id: node for node in batch[test_points].nodes } with build(reference_pipeline): request_roi = Roi((0, 0, 0), (40, 40, 40)) request = BatchRequest(random_seed=seed) request[test_labels] = ArraySpec(roi=request_roi) request[test_points] = GraphSpec(roi=request_roi) request[test_raster] = ArraySpec(roi=request_roi) t1_ref = time.time() batch = reference_pipeline.request_batch(request) t2_ref = time.time() points_reference = { node.id: node for node in batch[test_points].nodes } timings.append((t2_fast - t1_fast, t2_ref - t1_ref)) diffs = [] missing = 0 for point_id, point in points_reference.items(): if point_id not in points_fast: missing += 1 continue diff = point.location - points_fast[point_id].location diffs.append(tuple(diff)) self.assertAlmostEqual( np.linalg.norm(diff), 0, delta=1, msg= "fast transform returned location {} but expected {} for point {}" .format(point.location, points_fast[point_id].location, point_id), ) t_fast, t_ref = [np.mean(x) for x in zip(*timings)] self.assertLess(t_fast, t_ref) self.assertGreater(missing, 0)
def test_random_seed(self): test_labels = ArrayKey('TEST_LABELS') test_points = GraphKey('TEST_POINTS') test_raster = ArrayKey('TEST_RASTER') pipeline = ( PointTestSource3D() + ElasticAugment( [10, 10, 10], [0.1, 0.1, 0.1], # [0, 0, 0], # no jitter [0, 2.0 * math.pi]) + # rotate randomly # [math.pi/4, math.pi/4]) + # rotate by 45 deg # [0, 0]) + # no rotation RasterizeGraph(test_points, test_raster, settings=RasterizationSettings(radius=2, mode='peak')) + Snapshot( { test_labels: 'volumes/labels', test_raster: 'volumes/raster' }, dataset_dtypes={test_raster: np.float32}, output_dir=self.path_to(), output_filename='elastic_augment_test{id}-{iteration}.hdf')) batch_points = [] for _ in range(5): with build(pipeline): request_roi = Roi((-20, -20, -20), (40, 40, 40)) request = BatchRequest(random_seed=10) request[test_labels] = ArraySpec(roi=request_roi) request[test_points] = GraphSpec(roi=request_roi) request[test_raster] = ArraySpec(roi=request_roi) batch = pipeline.request_batch(request) labels = batch[test_labels] points = batch[test_points] batch_points.append( tuple((node.id, tuple(node.location)) for node in points.nodes)) # the point at (0, 0, 0) should not have moved data = {node.id: node for node in points.nodes} self.assertTrue(0 in data) labels_data_roi = ( labels.spec.roi - labels.spec.roi.get_begin()) / labels.spec.voxel_size # points should have moved together with the voxels for node in points.nodes: loc = node.location - labels.spec.roi.get_begin() loc = loc / labels.spec.voxel_size loc = Coordinate(int(round(x)) for x in loc) if labels_data_roi.contains(loc): self.assertEqual(labels.data[loc], node.id) for point_data in zip(*batch_points): self.assertEqual(len(set(point_data)), 1)
def test_ensure_centered(self): """ Expected failure due to emergent behavior of two desired rules: 1) Points on the upper bound of Roi are not considered contained 2) When considering a point as a center of a random location, scale by the number of points within some delta distance if two points are equally likely to be chosen, and centering a roi on either of them means the other is on the bounding box of the roi, then it can be the case that if the roi is centered one of them, the roi contains only that one, but if the roi is centered on the second, then both are considered contained, breaking the equal likelihood of picking each point. """ GraphKey("TEST_POINTS") pipeline = ExampleSourceRandomLocation() + RandomLocation( ensure_nonempty=GraphKeys.TEST_POINTS, ensure_centered=True) # count the number of times we get each point histogram = {} with build(pipeline): for i in range(5000): batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_POINTS: GraphSpec(roi=Roi((0, 0, 0), (100, 100, 100))) })) points = batch[GraphKeys.TEST_POINTS].data roi = batch[GraphKeys.TEST_POINTS].spec.roi locations = tuple( [Coordinate(point.location) for point in points.values()]) self.assertTrue( Coordinate([50, 50, 50]) in locations, f"locations: {tuple([point.location for point in points.values()])}" ) self.assertTrue(len(points) > 0) self.assertTrue((1 in points) != (2 in points or 3 in points), points) for point_id in batch[GraphKeys.TEST_POINTS].data.keys(): if point_id not in histogram: histogram[point_id] = 1 else: histogram[node.id] += 1 total = sum(histogram.values()) for k, v in histogram.items(): histogram[k] = float(v) / total # we should get roughly the same count for each point for i in histogram.keys(): for j in histogram.keys(): self.assertAlmostEqual(histogram[i], histogram[j], 1, histogram)