def test_mirror(): voxel_size = Coordinate((20, 20)) graph_key = GraphKey("GRAPH") array_key = ArrayKey("ARRAY") graph = Graph( [Node(id=1, location=np.array([450, 550]))], [], GraphSpec(roi=Roi((100, 200), (800, 600))), ) data = np.zeros([40, 30]) data[17, 17] = 1 array = Array( data, ArraySpec(roi=Roi((100, 200), (800, 600)), voxel_size=voxel_size)) default_pipeline = ( (GraphSource(graph_key, graph), ArraySource(array_key, array)) + MergeProvider() + SimpleAugment( mirror_only=[0, 1], transpose_only=[], mirror_probs=[0, 0])) mirror_pipeline = ( (GraphSource(graph_key, graph), ArraySource(array_key, array)) + MergeProvider() + SimpleAugment( mirror_only=[0, 1], transpose_only=[], mirror_probs=[1, 1])) request = BatchRequest() request[graph_key] = GraphSpec(roi=Roi((400, 500), (200, 300))) request[array_key] = ArraySpec(roi=Roi((400, 500), (200, 300))) with build(default_pipeline): expected_location = [450, 550] batch = default_pipeline.request_batch(request) assert len(list(batch[graph_key].nodes)) == 1 node = list(batch[graph_key].nodes)[0] assert all(np.isclose(node.location, expected_location)) node_voxel_index = Coordinate( (node.location - batch[array_key].spec.roi.get_offset()) / voxel_size) assert batch[array_key].data[node_voxel_index] == 1 with build(mirror_pipeline): expected_location = [550, 750] batch = mirror_pipeline.request_batch(request) assert len(list(batch[graph_key].nodes)) == 1 node = list(batch[graph_key].nodes)[0] assert all(np.isclose(node.location, expected_location)) node_voxel_index = Coordinate( (node.location - batch[array_key].spec.roi.get_offset()) / voxel_size) assert ( batch[array_key].data[node_voxel_index] == 1 ), f"Node at {np.where(batch[array_key].data == 1)} not {node_voxel_index}"
def test_filter_components(): raw = GraphKey("RAW") pipeline = TestSource() + FilterComponents(raw, 100, Coordinate((10, 10, 10))) request_no_fallback = BatchRequest() request_no_fallback[raw] = GraphSpec(roi=Roi((0, 0, 0), (20, 20, 20))) with build(pipeline): batch = pipeline.request_batch(request_no_fallback) assert raw in batch assert len(list(batch[raw].connected_components)) == 1 request_fallback = BatchRequest() request_fallback[raw] = GraphSpec(roi=Roi((20, 20, 20), (20, 20, 20))) with build(pipeline): batch = pipeline.request_batch(request_fallback) assert raw in batch assert len(list(batch[raw].connected_components)) == 0
def test_multi_transpose(self): test_graph = GraphKey("TEST_GRAPH") test_array1 = ArrayKey("TEST_ARRAY1") test_array2 = ArrayKey("TEST_ARRAY2") point = np.array([50, 70, 100]) transpose_dims = [0, 1, 2] pipeline = (ArrayTestSource(), ExampleSource()) + MergeProvider() + SimpleAugment( mirror_only=[], transpose_only=transpose_dims) request = BatchRequest() offset = (0, 20, 33) request[GraphKeys.TEST_GRAPH] = GraphSpec( roi=Roi(offset, (100, 100, 120))) request[ArrayKeys.TEST_ARRAY1] = ArraySpec( roi=Roi((0, 0, 0), (100, 200, 300))) request[ArrayKeys.TEST_ARRAY2] = ArraySpec( roi=Roi((0, 100, 250), (100, 100, 50))) # Create all possible permurations of our transpose dims transpose_combinations = list(permutations(transpose_dims, 3)) possible_loc = np.zeros((len(transpose_combinations), 3)) # Transpose points in all possible ways for i, comb in enumerate(transpose_combinations): possible_loc[i] = point[np.array(comb)] with build(pipeline): seen_transposed = False seen_node = True for i in range(100): batch = pipeline.request_batch(request) if len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1: seen_node = True node = list(batch[GraphKeys.TEST_GRAPH].nodes)[0] assert node.location in possible_loc seen_transposed = seen_transposed or any( [node.location[dim] != point[dim] for dim in range(3)]) assert Roi((0, 20, 33), (100, 100, 120)).contains( batch[GraphKeys.TEST_GRAPH].spec.roi) assert batch[GraphKeys.TEST_GRAPH].spec.roi.contains( node.location) for (array_key, array) in batch.arrays.items(): assert batch.arrays[array_key].data.shape == batch.arrays[ array_key].spec.roi.get_shape() assert seen_transposed assert seen_node
def test_realistic_invalid_examples(example, use_gurobi): penalty_attr = "penalty" location_attr = "location" example_dir = Path(__file__).parent / "mouselight_examples" / "invalid" / example consensus = PointsKey("CONSENSUS") skeletonization = PointsKey("SKELETONIZATION") matched = PointsKey("MATCHED") matched_with_fallback = PointsKey("MATCHED_WITH_FALLBACK") inf_roi = Roi(Coordinate((None,) * 3), Coordinate((None,) * 3)) request = BatchRequest() request[matched] = PointsSpec(roi=inf_roi) request[matched_with_fallback] = PointsSpec(roi=inf_roi) pipeline = ( ( GraphSource(example_dir / "graph.obj", [skeletonization]), GraphSource(example_dir / "tree.obj", [consensus]), ) + MergeProvider() + TopologicalMatcher( skeletonization, consensus, matched, expected_edge_len=10, match_distance_threshold=76, max_gap_crossing=48, use_gurobi=use_gurobi, location_attr=location_attr, penalty_attr=penalty_attr, ) + TopologicalMatcher( skeletonization, consensus, matched_with_fallback, expected_edge_len=10, match_distance_threshold=76, max_gap_crossing=48, use_gurobi=use_gurobi, location_attr=location_attr, penalty_attr=penalty_attr, with_fallback=True, ) ) with build(pipeline): batch = pipeline.request_batch(request) assert matched in batch assert len(list(batch[matched].nodes)) == 0 assert len(list(batch[matched_with_fallback].nodes)) > 0
def test_pipeline3(self): array_key = ArrayKey("TEST_ARRAY") points_key = GraphKey("TEST_POINTS") voxel_size = Coordinate((1, 1)) spec = ArraySpec(voxel_size=voxel_size, interpolatable=True) hdf5_source = Hdf5Source(self.fake_data_file, {array_key: "testdata"}, array_specs={array_key: spec}) csv_source = CsvPointsSource( self.fake_points_file, points_key, GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))), ) request = BatchRequest() shape = Coordinate((60, 60)) request.add(array_key, shape, voxel_size=Coordinate((1, 1))) request.add(points_key, shape) shift_node = ShiftAugment(prob_slip=0.2, prob_shift=0.2, sigma=5, shift_axis=0) pipeline = ((hdf5_source, csv_source) + MergeProvider() + RandomLocation(ensure_nonempty=points_key) + shift_node) with build(pipeline) as b: request = b.request_batch(request) # print(request[points_key]) target_vals = [ self.fake_data[point[0]][point[1]] for point in self.fake_points ] result_data = request[array_key].data result_points = list(request[points_key].nodes) result_vals = [ result_data[int(point.location[0])][int(point.location[1])] for point in result_points ] for result_val in result_vals: self.assertTrue( result_val in target_vals, msg= "result value {} at points {} not in target values {} at points {}" .format( result_val, list(result_points), target_vals, self.fake_points, ), )
def test_mismatched_voxel_multiples(): """ Ensure we don't shift by half a voxel when transposing 2 axes. If voxel_size = [2, 2], and we transpose array of shape [4, 6]: center = total_roi.get_center() -> [2, 3] # Get distance from center, then transpose dist_to_center = center - roi.get_offset() -> [2, 3] dist_to_center = transpose(dist_to_center) -> [3, 2] # Using the transposed distance to center, get the offset. new_offset = center - dist_to_center -> [-1, 1] shape = transpose(shape) -> [6, 4] original = ((0, 0), (4, 6)) transposed = ((-1, 1), (6, 4)) This result is what we would expect from tranposing, but no longer fits the voxel grid. dist_to_center should be limited to multiples of the lcm_voxel_size. instead we should get: original = ((0, 0), (4, 6)) transposed = ((0, 0), (6, 4)) """ test_array = ArrayKey("TEST_ARRAY") data = np.zeros([3, 3]) data[ 2, 1] = 1 # voxel has Roi((4, 2) (2, 2)). Contained in Roi((0, 0), (6, 4)). at 2, 1 source = ArraySource( test_array, Array( data, ArraySpec(roi=Roi((0, 0), (6, 6)), voxel_size=(2, 2)), ), ) pipeline = source + SimpleAugment( mirror_only=[], transpose_only=[0, 1], transpose_probs={(1, 0): 1}) with build(pipeline): request = BatchRequest() request[test_array] = ArraySpec(roi=Roi((0, 0), (4, 6))) batch = pipeline.request_batch(request) data = batch[test_array].data assert data[1, 2] == 1, f"{data}"
def test_precache(self): if torch.cuda.is_initialized(): raise RuntimeError( "Cuda is already initialized in the main process! Will not be able " "to reinitialize in forked subprocesses.") logging.getLogger("gunpowder.torch.nodes.predict").setLevel( logging.INFO) a = ArrayKey("A") pred = ArrayKey("PRED") model = TestModel() reference_request = BatchRequest() reference_request[a] = ArraySpec(roi=Roi((0, 0), (7, 7))) reference_request[pred] = ArraySpec(roi=Roi((1, 1), (5, 5))) source = TestTorchTrain2DSource() predict = Predict( model=model, inputs={"a": a}, outputs={0: pred}, array_specs={pred: ArraySpec()}, ) pipeline = source + predict + PreCache(cache_size=3, num_workers=2) request = BatchRequest({ a: ArraySpec(roi=Roi((0, 0), (17, 17))), pred: ArraySpec(roi=Roi((0, 0), (15, 15))), }) # train for a couple of iterations with build(pipeline): batch = pipeline.request_batch(request) assert pred in batch
def test_placeholder(self): test_labels = ArrayKey("TEST_LABELS") test_points = PointsKey("TEST_POINTS") pipeline = ( PointTestSource3D() + RandomLocation(ensure_nonempty=test_points) + ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) + Snapshot( {test_labels: "volumes/labels"}, output_dir=self.path_to(), output_filename="elastic_augment_test{id}-{iteration}.hdf", ) ) with build(pipeline): for i in range(100): request_size = Coordinate((40, 40, 40)) request_a = BatchRequest(random_seed=i) request_a.add(test_points, request_size) request_a.add(test_labels, request_size, placeholder=True) request_b = BatchRequest(random_seed=i) request_b.add(test_points, request_size) request_b.add(test_labels, request_size) batch_a = pipeline.request_batch(request_a) batch_b = pipeline.request_batch(request_b) points_a = batch_a[test_points].nodes points_b = batch_b[test_points].nodes for a, b in zip(points_a, points_b): assert all(np.isclose(a.location, b.location))
def setUp(self): super(ProviderTest, self).setUp() # create some common array keys to be used by concrete tests ArrayKey("RAW") ArrayKey("GT_LABELS") ArrayKey("GT_AFFINITIES") ArrayKey("GT_AFFINITIES_MASK") ArrayKey("GT_MASK") ArrayKey("GT_IGNORE") ArrayKey("LOSS_SCALE") self.test_source = TestSource() self.test_request = BatchRequest() self.test_request[ArrayKeys.RAW] = ArraySpec( roi=Roi((20, 20, 20), (10, 10, 10)))
def prepare(self, request): deps = BatchRequest() request[ self. neighborhood_mask].roi, f"Requested {self.neighborhood} and {self.neighborhood_mask} with different roi's" request_roi = request[self.neighborhood].roi grow_distance = Coordinate( (np.ceil(self.distance), ) * len(request_roi.get_shape())) request_roi = request_roi.grow(grow_distance, grow_distance) deps[self.gt] = GraphSpec(roi=request_roi) return deps
def visualize_embedding_pipeline(fusion_pipeline, train_embedding): setup_config = DEFAULT_CONFIG setup_config["FUSION_PIPELINE"] = fusion_pipeline setup_config["TRAIN_EMBEDDING"] = train_embedding voxel_size = Coordinate(setup_config["VOXEL_SIZE"]) output_size = Coordinate(setup_config["OUTPUT_SHAPE"]) * voxel_size input_size = Coordinate(setup_config["INPUT_SHAPE"]) * voxel_size pipeline, raw, output = embedding_pipeline(setup_config, get_test_data_sources) request = BatchRequest() request.add(raw, input_size) request.add(output, output_size) with build(pipeline): pipeline.request_batch(request) visualize_hdf5(Path("snapshots/snapshot_1.hdf"), tuple(voxel_size))
def prepare(self, request): # add "base" and "add" volume to request deps = BatchRequest() deps[self.raw_base] = request[self.raw_fused] deps[self.raw_add] = request[self.raw_fused] # enlarge roi for labels to be the same size as the raw data for mask generation deps[self.labels_base] = request[self.raw_fused] deps[self.labels_add] = request[self.raw_fused] # make points optional if self.points_fused in request: deps[self.points_base] = PointsSpec(roi=request[self.raw_fused].roi) deps[self.points_add] = PointsSpec(roi=request[self.raw_fused].roi) return deps
def prepare(self, request: BatchRequest): deps = BatchRequest() upstream_dependencies = { self.embeddings: self.spec[self.embeddings], self.mask: self.spec[self.mask], } downstream_request = {self.mst: request[self.mst]} upstream_dependencies = ProviderSpec(array_specs=upstream_dependencies, graph_specs=downstream_request) upstream_roi = upstream_dependencies.get_common_roi() deps[self.embeddings] = ArraySpec(roi=upstream_roi) deps[self.mask] = ArraySpec(roi=upstream_roi) return deps
def test_output(self): """ Fails due to probabilities being calculated in advance, rather than after creating each roi. The new approach does not account for all possible roi's containing each point, some of which may not contain its nearest neighbors. """ GraphKey('TEST_POINTS') pipeline = (ExampleSourceRandomLocation() + RandomLocation( ensure_nonempty=GraphKeys.TEST_POINTS, point_balance_radius=100)) # count the number of times we get each point histogram = {} with build(pipeline): for i in range(5000): batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_POINTS: GraphSpec(roi=Roi((0, 0, 0), (100, 100, 100))) })) points = { node.id: node for node in batch[GraphKeys.TEST_POINTS].nodes } self.assertTrue(len(points) > 0) self.assertTrue((1 in points) != (2 in points or 3 in points), points) for node in batch[GraphKeys.TEST_POINTS].nodes: if node.id not in histogram: histogram[node.id] = 1 else: histogram[node.id] += 1 total = sum(histogram.values()) for k, v in histogram.items(): histogram[k] = float(v) / total # we should get roughly the same count for each point for i in histogram.keys(): for j in histogram.keys(): self.assertAlmostEqual(histogram[i], histogram[j], 1)
def test_prepare1(self): key = ArrayKey("TEST_ARRAY") spec = ArraySpec(voxel_size=Coordinate((1, 1)), interpolatable=True) hdf5_source = Hdf5Source(self.fake_data_file, {key: "testdata"}, array_specs={key: spec}) request = BatchRequest() shape = Coordinate((3, 3)) request.add(key, shape, voxel_size=Coordinate((1, 1))) shift_node = ShiftAugment(sigma=1, shift_axis=0) with build((hdf5_source + shift_node)): shift_node.prepare(request) self.assertTrue(shift_node.ndim == 2) self.assertTrue(shift_node.shift_sigmas == tuple([0.0, 1.0]))
def test_read_single_swc(self): path = Path(self.path_to("test_swc_source.swc")) # write test swc self._write_swc(path, self._toy_swc_points().to_nx_graph()) # read arrays swc = GraphKey("SWC") source = SwcFileSource(path, [swc]) with build(source): batch = source.request_batch( BatchRequest({swc: GraphSpec(roi=Roi((0, 0, 5), (11, 11, 1)))})) for node in self._toy_swc_points().nodes: self.assertCountEqual(node.location, batch.points[swc].node(node.id).location)
def test_impossible(self): a = ArrayKey("A") b = ArrayKey("B") source_a = TestSourceRandomLocation(a) source_b = TestSourceRandomLocation(b) pipeline = (source_a, source_b) + \ MergeProvider() + CustomRandomLocation() with build(pipeline): with self.assertRaises(AssertionError): batch = pipeline.request_batch( BatchRequest({ a: ArraySpec(roi=Roi((0, 0, 0), (200, 20, 20))), b: ArraySpec(roi=Roi((1000, 100, 100), (220, 22, 22))), }))
def test_roi_one_point(self): GraphKey("TEST_GRAPH") upstream_roi = Roi((500, 500, 500), (1, 1, 1)) pipeline = (SourceGraphLocation() + BatchTester(upstream_roi, exact=True) + RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH)) with build(pipeline): for i in range(500): batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_GRAPH: GraphSpec(roi=Roi((0, 0, 0), (1, 1, 1))) })) assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
def test_pipeline2(self): key = ArrayKey("TEST_ARRAY") spec = ArraySpec(voxel_size=Coordinate((3, 1)), interpolatable=True) hdf5_source = Hdf5Source(self.fake_data_file, {key: "testdata"}, array_specs={key: spec}) request = BatchRequest() shape = Coordinate((3, 3)) request.add(key, shape, voxel_size=Coordinate((3, 1))) shift_node = ShiftAugment(prob_slip=0.2, prob_shift=0.2, sigma=1, shift_axis=0) with build((hdf5_source + shift_node)) as b: b.request_batch(request)
def test_req_full_roi(self): GraphKey("TEST_GRAPH") possible_roi = Roi((0, 0, 0), (1000, 1000, 1000)) pipeline = (SourceGraphLocation() + BatchTester(possible_roi, exact=False) + RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH)) with build(pipeline): batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_GRAPH: GraphSpec(roi=Roi((0, 0, 0), (1000, 1000, 1000))) })) assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
def test_dim_size_1(self): GraphKey("TEST_GRAPH") upstream_roi = Roi((500, 401, 401), (1, 200, 200)) pipeline = (SourceGraphLocation() + BatchTester(upstream_roi, exact=False) + RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH)) # count the number of times we get each node with build(pipeline): for i in range(500): batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_GRAPH: GraphSpec(roi=Roi((0, 0, 0), (1, 100, 100))) })) assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
def test_create_boundary_nodes(self): path = Path(self.path_to("test_swc_source.swc")) # write test swc self._write_swc(path, self._toy_swc_points(), {"resolution": np.array([2, 2, 2])}) # read arrays swc = PointsKey("SWC") source = SwcFileSource(path, swc) with build(source): batch = source.request_batch( BatchRequest({swc: PointsSpec(roi=Roi((0, 5, 0), (1, 3, 1)))})) temp_g = nx.DiGraph() for point_id, point in batch.points[swc].data.items(): temp_g.add_node(point.point_id, label_id=point.label_id, location=point.location) if (point.parent_id != -1 and point.parent_id != point.point_id and point.parent_id in batch.points[swc].data): temp_g.add_edge(point.point_id, point.parent_id) else: root = point.point_id current = root expected_path = [ tuple(np.array([0.0, 5.0, 0.0])), tuple(np.array([0.0, 6.0, 0.0])), tuple(np.array([0.0, 7.0, 0.0])), ] expected_node_ids = [0, 1, 2] path = [] node_ids = [] while current is not None: node_ids.append(current) path.append(tuple(temp_g.nodes[current]["location"])) predecessors = list(temp_g._pred[current].keys()) current = predecessors[0] if len(predecessors) == 1 else None self.assertCountEqual(path, expected_path) self.assertCountEqual(node_ids, expected_node_ids)
def test_read_single_swc(self): path = Path(self.path_to("test_swc_source.swc")) # write test swc self._write_swc(path, self._toy_swc_points()) # read arrays swc = PointsKey("SWC") source = SwcFileSource(path, swc) with build(source): batch = source.request_batch( BatchRequest( {swc: PointsSpec(roi=Roi((0, 0, 0), (11, 11, 1)))})) for point in self._toy_swc_points(): self.assertCountEqual( point.location, batch.points[swc].data[point.point_id].location)
def prepare(self, request: BatchRequest, seed: int, direction: Coordinate) -> Tuple[BatchRequest, int]: """ Only request everything with the given seed """ dps = BatchRequest(random_seed=seed) if self.nonempty_placeholder is not None: # request nonempty placeholder of size request total roi # grow such that it can be cropped down to two different locations growth = self._get_growth() total_roi = request.get_total_roi() grown_roi = total_roi.grow(growth, growth) dps[self.nonempty_placeholder] = GraphSpec(roi=grown_roi, placeholder=True) # handle smaller requests array_keys = list(request.array_specs.keys()) voxel_size = self.spec.get_lcm_voxel_size(array_keys) direction = Coordinate(direction) direction -= Coordinate( tuple(np.array(direction) % np.array(voxel_size))) if any([points in request for points in self.points]): dps[self.point_source] = copy.deepcopy(request[self.points[0]]) dps[self.point_source].roi = dps[self.point_source].roi.shift( direction) if any([array in request for array in self.arrays]): dps[self.array_source] = copy.deepcopy(request[self.arrays[0]]) dps[self.array_source].roi = dps[self.array_source].roi.shift( direction) if any([labels in request for labels in self.labels]): dps[self.label_source] = copy.deepcopy(request[self.labels[0]]) dps[self.label_source].roi = dps[self.label_source].roi.shift( direction) for source, targets in self.extra_keys.items(): if targets[0] in request: dps[source] = copy.deepcopy(request[targets[0]]) dps[source].roi = dps[source].roi.shift(direction) return dps
def test_impossible(self): a = ArrayKey("A") b = ArrayKey("B") null_key = ArrayKey("NULL") source_a = ExampleSourceRandomLocation(a) source_b = ExampleSourceRandomLocation(b) pipeline = ((source_a, source_b) + MergeProvider() + CustomRandomLocation(null_key)) with build(pipeline): with self.assertRaises(PipelineRequestError): batch = pipeline.request_batch( BatchRequest({ a: ArraySpec(roi=Roi((0, 0, 0), (200, 20, 20))), b: ArraySpec(roi=Roi((1000, 100, 100), (220, 22, 22))), }))
def test_mismatched_voxel_multiples(): """ Ensure we don't shift by half a voxel when transposing 2 axes. If voxel_size = [2, 2], and we transpose array of shape [4, 6]: center = total_roi.get_center() -> [2, 3] # Get distance from center, then transpose dist_to_center = center - roi.get_offset() -> [2, 3] dist_to_center = transpose(dist_to_center) -> [3, 2] # Using the tranposed distance to center, get the correct offset. new_offset = center - dist_to_center -> [-1, 1] shape = transpose(shape) -> [6, 4] original = ((0, 0), (4, 6)) transposed = ((-1, 1), (6, 4)) This result is what we would expect from tranposing, but no longer fits the voxel grid. dist_to_center should be limited to multiples of the lcm_voxel_size. """ test_array = ArrayKey("TEST_ARRAY") pipeline = (CornerSource(test_array, voxel_size=(2, 2)) + SimpleAugment(transpose_only=[0, 1])) request = BatchRequest() request[test_array] = ArraySpec(roi=Roi((0, 0), (4, 6))) with build(pipeline): loop = 100 while loop > 0: loop -= 1 batch = pipeline.request_batch(request) data = batch[test_array].data if data.sum(axis=1)[0] == 1: loop = -1 assert loop < 0, "Data was never transposed!"
def test_two_transpose(self): test_graph = GraphKey("TEST_GRAPH") test_array1 = ArrayKey("TEST_ARRAY1") test_array2 = ArrayKey("TEST_ARRAY2") transpose_dims = [1, 2] pipeline = (ArrayTestSource(), ExampleSource()) + MergeProvider() + SimpleAugment( mirror_only=[], transpose_only=transpose_dims) request = BatchRequest() request[GraphKeys.TEST_GRAPH] = GraphSpec( roi=Roi((0, 20, 33), (100, 100, 120))) request[ArrayKeys.TEST_ARRAY1] = ArraySpec( roi=Roi((0, 0, 0), (100, 200, 300))) request[ArrayKeys.TEST_ARRAY2] = ArraySpec( roi=Roi((0, 100, 250), (100, 100, 50))) possible_loc = [[50, 50], [70, 100], [100, 70]] with build(pipeline): seen_transposed = False for i in range(100): batch = pipeline.request_batch(request) assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1 node = list(batch[GraphKeys.TEST_GRAPH].nodes)[0] logging.debug(node.location) assert all([ node.location[dim] in possible_loc[dim] for dim in range(3) ]) seen_transposed = seen_transposed or any([ node.location[dim] != possible_loc[dim][0] for dim in range(3) ]) assert Roi((0, 20, 33), (100, 100, 120)).contains( batch[GraphKeys.TEST_GRAPH].spec.roi) assert batch[GraphKeys.TEST_GRAPH].spec.roi.contains( node.location) for (array_key, array) in batch.arrays.items(): assert batch.arrays[array_key].data.shape == batch.arrays[ array_key].spec.roi.get_shape() assert seen_transposed
def test_output(self): GraphKey("TEST_GRAPH") pipeline = TestSourceRandomLocation() + RandomLocation( ensure_nonempty=GraphKeys.TEST_GRAPH) # count the number of times we get each node histogram = {} with build(pipeline): for i in range(5000): batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_GRAPH: GraphSpec(roi=Roi((0, 0, 0), (100, 100, 100))) })) nodes = list(batch[GraphKeys.TEST_GRAPH].nodes) node_ids = [v.id for v in nodes] self.assertTrue(len(nodes) > 0) self.assertTrue( (1 in node_ids) != (2 in node_ids or 3 in node_ids), node_ids, ) for node in batch[GraphKeys.TEST_GRAPH].nodes: if node.id not in histogram: histogram[node.id] = 1 else: histogram[node.id] += 1 total = sum(histogram.values()) for k, v in histogram.items(): histogram[k] = float(v) / total # we should get roughly the same count for each point for i in histogram.keys(): for j in histogram.keys(): self.assertAlmostEqual(histogram[i], histogram[j], 1)
def test_equal_probability(self): GraphKey('TEST_POINTS') pipeline = (ExampleSourceRandomLocation() + RandomLocation(ensure_nonempty=GraphKeys.TEST_POINTS)) # count the number of times we get each point histogram = {} with build(pipeline): for i in range(5000): batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_POINTS: GraphSpec(roi=Roi((0, 0, 0), (10, 10, 10))) })) points = { node.id: node for node in batch[GraphKeys.TEST_POINTS].nodes } self.assertTrue(len(points) > 0) self.assertTrue((1 in points) != (2 in points or 3 in points), points) for point in batch[GraphKeys.TEST_POINTS].nodes: if point.id not in histogram: histogram[point.id] = 1 else: histogram[point.id] += 1 total = sum(histogram.values()) for k, v in histogram.items(): histogram[k] = float(v) / total # we should get roughly the same count for each point for i in histogram.keys(): for j in histogram.keys(): self.assertAlmostEqual(histogram[i], histogram[j], 1)
def test_grow_labels_speed(self): bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256])) voxel_size = Coordinate([1, 1, 1]) swc_file = "test_swc.swc" swc_path = Path(self.path_to(swc_file)) swc_points = self._get_points(np.array([1, 1, 1]), np.array([1, 1, 1]), bb) self._write_swc(swc_path, swc_points.graph) # create swc sources swc_key = PointsKey("SWC") labels_key = ArrayKey("LABELS") # add request request = BatchRequest() request.add(labels_key, bb.get_shape()) request.add(swc_key, bb.get_shape()) # data source for swc a data_source = tuple() data_source = (data_source + SwcFileSource( swc_path, [swc_key], [PointsSpec(roi=bb)]) + RasterizeSkeleton( points=swc_key, array=labels_key, array_spec=ArraySpec(interpolatable=False, dtype=np.uint32, voxel_size=voxel_size), ) + GrowLabels(array=labels_key, radius=3)) pipeline = data_source num_repeats = 10 t1 = time.time() with build(pipeline): for i in range(num_repeats): pipeline.request_batch(request) t2 = time.time() self.assertLess((t2 - t1) / num_repeats, 0.1)