def provide(self, request): timing = Timing(self) timing.start() spec = self.get_spec() batch = Batch() with h5py.File(self.filename, 'r') as f: for (volume_type, roi) in request.volumes.items(): if volume_type not in spec.volumes: raise RuntimeError("Asked for %s which this source does not provide"%volume_type) if not spec.volumes[volume_type].contains(roi): raise RuntimeError("%s's ROI %s outside of my ROI %s"%(volume_type,roi,spec.volumes[volume_type])) logger.debug("Reading %s in %s..."%(volume_type,roi)) # shift request roi into dataset dataset_roi = roi.shift(-spec.volumes[volume_type].get_offset()) batch.volumes[volume_type] = Volume( self.__read(f, self.datasets[volume_type], dataset_roi), roi=roi, resolution=self.resolutions[volume_type]) logger.debug("done") timing.stop() batch.profiling_stats.add(timing) return batch
def __setup_batch(self, batch_spec, chunk): '''Allocate a batch matching the sizes of ``batch_spec``, using ``chunk`` as template.''' batch = Batch() for (array_key, spec) in batch_spec.array_specs.items(): roi = spec.roi voxel_size = self.spec[array_key].voxel_size # get the 'non-spatial' shape of the chunk-batch # and append the shape of the request to it array = chunk.arrays[array_key] shape = array.data.shape[:-roi.dims()] shape += (roi.get_shape() // voxel_size) spec = self.spec[array_key].copy() spec.roi = roi logger.info("allocating array of shape %s for %s", shape, array_key) batch.arrays[array_key] = Array(data=np.zeros(shape), spec=spec) for (points_key, spec) in batch_spec.points_specs.items(): roi = spec.roi spec = self.spec[points_key].copy() spec.roi = roi batch.points[points_key] = Points(data={}, spec=spec) logger.debug("setup batch to fill %s", batch) return batch
def provide(self, request): # create upstream requests upstream_requests = {} for key, spec in request.items(): provider = self.key_to_provider[key] if provider not in upstream_requests: # use new random seeds per upstream request. # seeds picked by random should be deterministic since # the provided request already has a random seed. seed = random.randint(0, 2**32) upstream_requests[provider] = BatchRequest(random_seed=seed) upstream_requests[provider][key] = spec # execute requests, merge batches merged_batch = Batch() for provider, upstream_request in upstream_requests.items(): batch = provider.request_batch(upstream_request) for key, array in batch.arrays.items(): merged_batch.arrays[key] = array for key, graph in batch.graphs.items(): merged_batch.graphs[key] = graph merged_batch.profiling_stats.merge_with(batch.profiling_stats) return merged_batch
def provide(self, request): timing = Timing(self) timing.start() batch = Batch() for (array_key, request_spec) in request.array_specs.items(): logger.debug("Reading %s in %s...", array_key, request_spec.roi) voxel_size = self.spec[array_key].voxel_size # scale request roi to voxel units dataset_roi = request_spec.roi / voxel_size # shift request roi into dataset dataset_roi = dataset_roi - self.spec[array_key].roi.get_offset( ) / voxel_size # create array spec array_spec = self.spec[array_key].copy() array_spec.roi = request_spec.roi # add array to batch batch.arrays[array_key] = Array( self.__read(self.datasets[array_key], dataset_roi), array_spec) logger.debug("done") timing.stop() batch.profiling_stats.add(timing) return batch
def provide(self, request): timing = Timing(self) timing.start() batch = Batch() with self._open_file(self.filename) as data_file: for (array_key, request_spec) in request.array_specs.items(): voxel_size = self.spec[array_key].voxel_size # scale request roi to voxel units dataset_roi = request_spec.roi / voxel_size # shift request roi into dataset dataset_roi = dataset_roi - self.spec[ array_key].roi.get_offset() / voxel_size # create array spec array_spec = self.spec[array_key].copy() array_spec.roi = request_spec.roi # add array to batch batch.arrays[array_key] = Array( self.__read(data_file, self.datasets[array_key], dataset_roi, self.channel_ids[array_key]), array_spec) timing.stop() batch.profiling_stats.add(timing) return batch
def provide(self, request: BatchRequest) -> Batch: timing = Timing(self, "provide") timing.start() batch = Batch() for points_key in self.points: if points_key not in request: continue # Retrieve all points in the requested region using a kdtree for speed point_ids = self._query_kdtree( self.data.tree, ( np.array(request[points_key].roi.get_begin()), np.array(request[points_key].roi.get_end()), ), ) # To account for boundary crossings we must retrieve neighbors of all points # in the graph. This is too slow for large queries and less important points_subgraph = self._subgraph_points( point_ids, with_neighbors=len(point_ids) < len(self._graph.nodes) // 2) nodes = [ Node(id=node, location=attrs["location"], attrs=attrs) for node, attrs in points_subgraph.nodes.items() ] edges = [Edge(u, v) for u, v in points_subgraph.edges] return_graph = Graph(nodes, edges, GraphSpec(roi=request[points_key].roi)) # Handle boundary cases return_graph = return_graph.trim(request[points_key].roi) batch = Batch() batch.points[points_key] = return_graph logger.debug( "Graph points source provided {} points for roi: {}".format( len(list(batch.points[points_key].nodes)), request[points_key].roi)) logger.debug( f"Providing {len(list(points_subgraph.nodes))} nodes to {points_key}" ) timing.stop() batch.profiling_stats.add(timing) return batch
def provide(self, request): timing_process = Timing(self) timing_process.start() batch = Batch() with h5py.File(self.filename, 'r') as hdf_file: # if pre and postsynaptic locations required, their id # SynapseLocation dictionaries should be created together s.t. ids # are unique and allow to find partner locations if PointsKeys.PRESYN in request.points_specs or PointsKeys.POSTSYN in request.points_specs: assert self.kind == 'synapse' # If only PRESYN or POSTSYN requested, assume PRESYN ROI = POSTSYN ROI. pre_key = PointsKeys.PRESYN if PointsKeys.PRESYN in request.points_specs else PointsKeys.POSTSYN post_key = PointsKeys.POSTSYN if PointsKeys.POSTSYN in request.points_specs else PointsKeys.PRESYN presyn_points, postsyn_points = self.__get_syn_points( pre_roi=request.points_specs[pre_key].roi, post_roi=request.points_specs[post_key].roi, syn_file=hdf_file) points = { PointsKeys.PRESYN: presyn_points, PointsKeys.POSTSYN: postsyn_points } else: assert self.kind == 'presyn' or self.kind == 'postsyn' synkey = list(self.datasets.items())[0][0] # only key of dic. presyn_points, postsyn_points = self.__get_syn_points( pre_roi=request.points_specs[synkey].roi, post_roi=request.points_specs[synkey].roi, syn_file=hdf_file) points = { synkey: presyn_points if self.kind == 'presyn' else postsyn_points } for (points_key, request_spec) in request.points_specs.items(): logger.debug("Reading %s in %s...", points_key, request_spec.roi) points_spec = self.spec[points_key].copy() points_spec.roi = request_spec.roi logger.debug("Number of points len()".format( len(points[points_key]))) batch.points[points_key] = Points(data=points[points_key], spec=points_spec) timing_process.stop() batch.profiling_stats.add(timing_process) return batch
def provide(self, request): empty_request = (len(request) == 0) if empty_request: scan_spec = self.spec else: scan_spec = request stride = self.__get_stride() shift_roi = self.__get_shift_roi(scan_spec) shifts = self.__enumerate_shifts(shift_roi, stride) num_chunks = len(shifts) logger.info("scanning over %d chunks", num_chunks) # the batch to return self.batch = Batch() if self.num_workers > 1: for shift in shifts: shifted_reference = self.__shift_request(self.reference, shift) self.request_queue.put(shifted_reference) for i in range(num_chunks): chunk = self.workers.get() if not empty_request: self.__add_to_batch(request, chunk) logger.info("processed chunk %d/%d", i, num_chunks) else: for i, shift in enumerate(shifts): shifted_reference = self.__shift_request(self.reference, shift) chunk = self.__get_chunk(shifted_reference) if not empty_request: self.__add_to_batch(request, chunk) logger.info("processed chunk %d/%d", i, num_chunks) batch = self.batch self.batch = None logger.debug("returning batch %s", batch) return batch
def __setup_batch(self, request, chunk_batch): batch = Batch() for (volume_type, roi) in request.volumes.items(): if volume_type == VolumeTypes.PRED_AFFINITIES or volume_type == VolumeTypes.GT_AFFINITIES: shape = (3, ) + roi.get_shape() else: shape = roi.get_shape() batch.volumes[volume_type] = Volume( data=np.zeros(shape), roi=roi, resolution=chunk_batch.volumes[VolumeTypes.RAW].resolution) return batch
def provide(self, request): timing = Timing(self) timing.start() batch = Batch() # if pre and postsynaptic locations requested, their id : SynapseLocation dictionaries should be created # together s.t. the ids are unique and allow to find partner locations if GraphKey.PRESYN in request.points or GraphKey.POSTSYN in request.points: try: # either both have the same roi, or only one of them is requested assert request.points[GraphKey.PRESYN] == request.points[ GraphKey.POSTSYN] except AssertionError: assert GraphKey.PRESYN not in request.points or GraphKey.POSTSYN not in request.points if GraphKey.PRESYN in request.points: presyn_points, postsyn_points = self.__read_syn_points( roi=request.points[GraphKey.PRESYN]) elif GraphKey.POSTSYN in request.points: presyn_points, postsyn_points = self.__read_syn_points( roi=request.points[GraphKey.POSTSYN]) for (points_key, roi) in request.points.items(): # check if requested points can be provided if points_key not in self.spec: raise RuntimeError( "Asked for %s which this source does not provide" % points_key) # check if request roi lies within provided roi if not self.spec[points_key].roi.contains(roi): raise RuntimeError( "%s's ROI %s outside of my ROI %s" % (points_key, roi, self.spec[points_key].roi)) logger.debug("Reading %s in %s..." % (points_key, roi)) id_to_point = { GraphKey.PRESYN: presyn_points, GraphKey.POSTSYN: postsyn_points }[points_key] batch.points[points_key] = Graph(data=id_to_point, spec=GraphSpec(roi=roi)) logger.debug("done") timing.stop() batch.profiling_stats.add(timing) return batch
def process(self, batch, request): output = Batch() for in_key, out_key in zip(self.arrays, self.output_arrays): array = batch[in_key] data = array.data d_min = data.min() d_max = data.max() assert ( d_min >= 0 and d_max <= 1 ), f"Clahe expects data in range (0,1), got ({d_min}, {d_max})" if np.isclose(d_max, d_min): output[out_key] = Array(data, array.spec) continue if self.normalize: data = (data - d_min) / (d_max - d_min) shape = data.shape data_dims = len(shape) kernel_dims = len(self.kernel_size) extra_dims = data_dims - kernel_dims voxel_size = array.spec.voxel_size for index in itertools.product(*[range(s) for s in shape[:extra_dims]]): data[index] = clahe( data[index], kernel_size=Coordinate(self.kernel_size / voxel_size), clip_limit=self.clip_limit, nbins=self.nbins, ) assert ( data.min() >= 0 and data.max() <= 1 ), f"Clahe should output data in range (0,1), got ({data.min()}, {data.max()})" output[out_key] = Array(data, array.spec).crop(request[out_key].roi) return output
def process(self, batch, request): output = Batch() for in_key, out_key in zip(self.arrays, self.output_arrays): array = batch[in_key] data = array.data shape = data.shape data_dims = len(shape) kernel_dims = len(self.kernel_size) extra_dims = data_dims - kernel_dims if self.slice_wise: for index in itertools.product( *[range(s) for s in shape[:extra_dims]]): data[index] = mclahe( data[index], kernel_size=self.kernel_size, clip_limit=self.clip_limit, n_bins=self.nbins, use_gpu=False, adaptive_hist_range=self.adaptive_hist_range, ) else: full_kernel = np.array( (1, ) * extra_dims + tuple(self.kernel_size), dtype=int) data = mclahe( data, kernel_size=full_kernel, clip_limit=self.clip_limit, n_bins=self.nbins, # use_gpu=False, ).astype(self.spec[out_key].dtype) output[out_key] = Array(data, array.spec).crop(request[out_key].roi) return output
def provide(self, request): timing = Timing(self) timing.start() min_bb = request[self.points].roi.get_begin() max_bb = request[self.points].roi.get_end() logger.debug( "CSV points source got request for %s", request[self.points].roi) point_filter = np.ones((self.data.shape[0],), dtype=np.bool) for d in range(self.ndims): point_filter = np.logical_and(point_filter, self.data[:,d] >= min_bb[d]) point_filter = np.logical_and(point_filter, self.data[:,d] < max_bb[d]) filtered = self.data[point_filter] ids = np.arange(len(self.data))[point_filter] points_data = { i: Point(p) for i, p in zip(ids, filtered) } points_spec = PointsSpec(roi=request[self.points].roi.copy()) batch = Batch() batch.points[self.points] = Points(points_data, points_spec) timing.stop() batch.profiling_stats.add(timing) return batch
def process(self, batch, request): outputs = Batch() if self.target not in request: return input_roi = batch.arrays[self.source].spec.roi request_roi = request[self.target].roi assert input_roi.contains(request_roi) # upsample logger.debug("upsampling %s with %s", self.source, self.factor) crop = batch.arrays[self.source].crop(request_roi) data = crop.data for d, f in enumerate(self.factor): data = np.repeat(data, f, axis=d) # create output array spec = self.spec[self.target].copy() spec.roi = request_roi outputs.arrays[self.target] = Array(data, spec) return outputs
def provide(self, request): timing = Timing(self) timing.start() batch = Batch() for key, spec in request.items(): logger.debug(f"fetching {key} in roi {spec.roi}") requested_graph = self.graph_provider.get_graph( spec.roi, edge_inclusion="either", node_inclusion="dangling", node_attrs=self.node_attrs, edge_attrs=self.edge_attrs, nodes_filter=self.nodes_filter, edges_filter=self.edges_filter, ) logger.debug( f"got {len(requested_graph.nodes)} nodes and {len(requested_graph.edges)} edges" ) for node, attrs in list(requested_graph.nodes.items()): if self.dist_attribute in attrs: if attrs[self.dist_attribute] < self.min_dist: requested_graph.remove_node(node) logger.debug( f"{len(requested_graph.nodes)} nodes remaining after filtering by distance" ) if len(requested_graph.nodes) > self.num_nodes: nodes = list(requested_graph.nodes) nodes_to_keep = set(random.sample(nodes, self.num_nodes)) for node in list(requested_graph.nodes()): if node not in nodes_to_keep: requested_graph.remove_node(node) for node, attrs in requested_graph.nodes.items(): attrs["location"] = np.array(attrs[self.position_attribute], dtype=np.float32) attrs["id"] = node if spec.directed: requested_graph = requested_graph.to_directed() else: requested_graph = requested_graph.to_undirected() logger.debug( f"providing {key} with {len(requested_graph.nodes)} nodes and {len(requested_graph.edges)} edges" ) points = Graph.from_nx_graph(requested_graph, spec) points.crop(spec.roi) batch[key] = points timing.stop() batch.profiling_stats.add(timing) return batch
def provide(self, request): timing = Timing(self) timing.start() min_bb = request[self.points].roi.get_begin() max_bb = request[self.points].roi.get_end() logger.debug("CSV points source got request for %s", request[self.points].roi) point_filter = np.ones((self.data.shape[0], ), dtype=np.bool) for d in range(self.ndims): point_filter = np.logical_and(point_filter, self.data[:, d] >= min_bb[d]) point_filter = np.logical_and(point_filter, self.data[:, d] < max_bb[d]) points_data = self._get_points(point_filter) points_spec = GraphSpec(roi=request[self.points].roi.copy()) batch = Batch() batch.graphs[self.points] = Graph(points_data, [], points_spec) timing.stop() batch.profiling_stats.add(timing) return batch
def provide(self, request): empty_request = (len(request) == 0) if not empty_request: raise RuntimeError( "requests made to DaisyRequestBlocks have to be empty") if self.num_workers > 1: self.workers = [ multiprocessing.Process(target=self.__get_chunks) for _ in range(self.num_workers) ] for worker in self.workers: worker.start() for worker in self.workers: worker.join() else: self.__get_chunks() return Batch()
def provide(self, request): timing = Timing(self) timing.start() batch = Batch() for key, spec in request.items(): logger.debug(f"fetching {key} in roi {spec.roi}") requested_graph = self.graph_provider.get_graph( spec.roi, edge_inclusion=self.edge_inclusion, node_inclusion=self.node_inclusion, node_attrs=self.node_attrs, edge_attrs=self.edge_attrs, nodes_filter=self.nodes_filter, edges_filter=self.edges_filter, ) logger.debug( f"got {len(requested_graph.nodes)} nodes and {len(requested_graph.edges)} edges" ) failed_nodes = [] for node, attrs in requested_graph.nodes.items(): try: attrs["location"] = np.array( attrs[self.position_attribute], dtype=np.float32) except KeyError: logger.warning( f"node: {node} was written (probably part of an edge), but never given coordinates!" ) failed_nodes.append(node) attrs["id"] = node for node in failed_nodes: if self.fail_on_inconsistent_node: raise ValueError( f"Mongodb contains node {node} without location! " f"It was probably written as part of an edge") requested_graph.remove_node(node) if spec.directed: requested_graph = requested_graph.to_directed() else: requested_graph = requested_graph.to_undirected() points = Graph.from_nx_graph(requested_graph, spec) points.relabel_connected_components() points.crop(spec.roi) batch[key] = points logger.debug(f"{key} with {len(list(points.nodes))} nodes") timing.stop() batch.profiling_stats.add(timing) return batch
def provide(self, request): timing = Timing(self) timing.start() batch = Batch() with h5py.File(self.filename, 'r') as hdf_file: # if pre and postsynaptic locations required, their id # SynapseLocation dictionaries should be created together s.t. ids # are unique and allow to find partner locations if PointsKeys.PRESYN in request.points_specs or PointsKeys.POSTSYN in request.points_specs: assert request.points_specs[ PointsKeys.PRESYN].roi == request.points_specs[ PointsKeys.POSTSYN].roi # Cremi specific, ROI offset corresponds to offset present in the # synapse location relative to the raw data. dataset_offset = self.spec[PointsKeys.PRESYN].roi.get_offset() presyn_points, postsyn_points = self.__get_syn_points( roi=request.points_specs[PointsKeys.PRESYN].roi, syn_file=hdf_file, dataset_offset=dataset_offset) for (points_key, request_spec) in request.points_specs.items(): logger.debug("Reading %s in %s...", points_key, request_spec.roi) id_to_point = { PointsKeys.PRESYN: presyn_points, PointsKeys.POSTSYN: postsyn_points }[points_key] points_spec = self.spec[points_key].copy() points_spec.roi = request_spec.roi batch.points[points_key] = Points(data=id_to_point, spec=points_spec) logger.debug("done") timing.stop() batch.profiling_stats.add(timing) return batch
def provide(self, request): timing = Timing(self) timing.start() spec = self.get_spec() batch = Batch() logger.debug("providing batch with resolution of {}".format( self.resolution)) for (volume_type, roi) in request.volumes.items(): if volume_type not in spec.volumes: raise RuntimeError( "Asked for %s which this source does not provide" % volume_type) if not spec.volumes[volume_type].contains(roi): raise RuntimeError( "%s's ROI %s outside of my ROI %s" % (volume_type, roi, spec.volumes[volume_type])) read, interpolate = { VolumeType.RAW: (self.__read_raw, True), VolumeType.GT_LABELS: (self.__read_gt, False), VolumeType.GT_MASK: (self.__read_gt_mask, False), }[volume_type] logger.debug("Reading %s in %s..." % (volume_type, roi)) batch.volumes[volume_type] = Volume( read(roi), roi=roi, # TODO: get resolution from repository resolution=self.resolution, interpolate=interpolate) logger.debug("done") timing.stop() batch.profiling_stats.add(timing) return batch
def provide(self, request): timing = Timing(self) timing.start() spec = self.get_spec() batch = Batch() with h5py.File(self.filename, 'r') as f: for (volume_type, roi) in request.volumes.items(): if volume_type not in spec.volumes: raise RuntimeError( "Asked for %s which this source does not provide" % volume_type) if not spec.volumes[volume_type].contains(roi): raise RuntimeError( "%s's ROI %s outside of my ROI %s" % (volume_type, roi, spec.volumes[volume_type])) interpolate = { VolumeType.RAW: True, VolumeType.GT_LABELS: False, VolumeType.GT_MASK: False, VolumeType.ALPHA_MASK: True, }[volume_type] logger.debug("Reading %s in %s..." % (volume_type, roi)) batch.volumes[volume_type] = Volume( self.__read(f, self.datasets[volume_type], roi), roi=roi, resolution=self.resolutions[volume_type], interpolate=interpolate) logger.debug("done") timing.stop() batch.profiling_stats.add(timing) return batch
def __setup_batch(self, batch_spec, reference): batch = Batch(batch_spec) for (volume_type, volume) in reference.volumes.items(): interpolate = False if volume_type == VolumeType.RAW: shape = batch_spec.input_roi.get_shape() interpolate = True elif volume_type == VolumeType.GT_AFFINITIES or volume_type == VolumeType.PRED_AFFINITIES: shape = (3, ) + batch_spec.output_roi.get_shape() else: shape = batch_spec.output_roi.get_shape() batch.volumes[volume_type] = Volume( np.zeros(shape, volume.data.dtype), interpolate) return batch
def provide(self, request): timing = Timing(self) timing.start() batch = Batch() for (array_key, request_spec) in request.array_specs.items(): logger.debug("Reading %s in %s...", array_key, request_spec.roi) voxel_size = self.spec[array_key].voxel_size # scale request roi to voxel units dataset_roi = request_spec.roi / voxel_size # shift request roi into dataset dataset_roi = dataset_roi - self.spec[array_key].roi.get_offset( ) / voxel_size # create array spec array_spec = self.spec[array_key].copy() array_spec.roi = request_spec.roi # read the data if array_key in self.datasets: data = self.__read_array(self.datasets[array_key], dataset_roi) elif array_key in self.masks: data = self.__read_mask(self.masks[array_key], dataset_roi) else: assert False, ( "Encountered a request for %s that is neither a volume " "nor a mask." % array_key) # add array to batch batch.arrays[array_key] = Array(data, array_spec) logger.debug("done") timing.stop() batch.profiling_stats.add(timing) return batch
def provide(self, request): timing = Timing(self) timing.start() batch = Batch() cv = CloudVolume(self.cloudvolume_url, use_https=True, mip=self.mip) request_spec = request.array_specs[self.array_key] array_key = self.array_key logger.debug("Reading %s in %s...", array_key, request_spec.roi) voxel_size = self.array_spec.voxel_size # scale request roi to voxel units dataset_roi = request_spec.roi / voxel_size # shift request roi into dataset dataset_roi = dataset_roi - self.spec[ array_key].roi.get_offset() / voxel_size # create array spec array_spec = self.array_spec.copy() array_spec.roi = request_spec.roi # array_spec.voxel_size = array_spec.voxel_size # add array to batch batch.arrays[array_key] = Array( self.__read(cv, dataset_roi), array_spec) logger.debug("done") timing.stop() batch.profiling_stats.add(timing) return batch
def provide(self, request): # create upstream requests upstream_requests = {} for key, spec in request.items(): provider = self.key_to_provider[key] if provider not in upstream_requests: upstream_requests[provider] = BatchRequest() upstream_requests[provider][key] = spec # execute requests, merge batches merged_batch = Batch() for provider, upstream_request in upstream_requests.items(): batch = provider.request_batch(upstream_request) for key, array in batch.arrays.items(): merged_batch.arrays[key] = array for key, points in batch.points.items(): merged_batch.points[key] = points return merged_batch
def provide(self, request): # create upstream requests upstream_requests = {} for key, spec in request.items(): provider = self.key_to_provider[key] if provider not in upstream_requests: upstream_requests[provider] = BatchRequest() upstream_requests[provider][key] = spec # execute requests, merge batches merged_batch = Batch() for provider, upstream_request in upstream_requests.items(): batch = provider.request_batch(upstream_request) for key, array in batch.arrays.items(): merged_batch.arrays[key] = array for key, graph in batch.graphs.items(): merged_batch.graphs[key] = graph merged_batch.profiling_stats.merge_with(batch.profiling_stats) return merged_batch
def process(self, batch, request): outputs = Batch() for array in self.arrays: if array in batch: if not batch[array].spec.nonspatial: spatial_dims = request[array].roi.dims() if self.axis > batch[array].data.ndim - spatial_dims: raise ValueError(( f"Unsqueeze.axis={self.axis} not permitted. " "Unsqueeze only supported for " "non-spatial dimensions of Array." )) outputs[array] = copy.deepcopy(batch[array]) outputs[array].data = np.expand_dims(batch[array].data, self.axis) return outputs
def provide(self, request: BatchRequest) -> Batch: random.seed(request.random_seed) np.random.seed(request.random_seed) timing = Timing(self, "provide") timing.start() batch = Batch() roi = request[self.points].roi region_shape = roi.get_shape() trees = [] for _ in range(self.n_obj): for _ in range(100): root = np.random.random(len(region_shape)) * region_shape tree = self._grow_tree( root, Roi((0,) * len(region_shape), region_shape) ) if self.num_nodes[0] <= len(tree.nodes) <= self.num_nodes[1]: break trees.append(tree) # logger.info("{} trees got, expected {}".format(len(trees), self.n_obj)) trees_graph = nx.disjoint_union_all(trees) points = { node_id: Node(np.floor(node_attrs["pos"]) + roi.get_begin()) for node_id, node_attrs in trees_graph.nodes.items() } batch[self.points] = Graph(points, request[self.points], list(trees_graph.edges)) timing.stop() batch.profiling_stats.add(timing) # self._plot_tree(tree) return batch
def process(self, batch, request): outputs = Batch() # downsample if isinstance(self.factor, tuple): slices = tuple(slice(None, None, k) for k in self.factor) else: slices = tuple( slice(None, None, self.factor) for i in range(batch[self.source].spec.roi.dims())) logger.debug("downsampling %s with %s", self.source, slices) data = batch.arrays[self.source].data[slices] # create output array spec = self.spec[self.target].copy() spec.roi = request[self.target].roi outputs.arrays[self.target] = Array(data, spec) return outputs
def provide(self, request: BatchRequest) -> Batch: timing = Timing(self) timing.start() logger.debug("Swc points source got request for %s", request[self.points].roi) # Retrieve all points in the requested region using a kdtree for speed points = self._query_kdtree( self.data.tree, ( np.array(request[self.points].roi.get_begin()), np.array(request[self.points].roi.get_end()), ), ) # Obtain subgraph that contains these points. Keep track of edges that # are present in the main graph, but not the subgraph sub_graph, predecessors, successors = self._points_to_graph(points) # Handle boundary cases self._handle_boundary_crossings( sub_graph, predecessors, successors, request[self.points].roi ) # Convert graph into Points format points_data = self._graph_to_data(sub_graph) points_spec = PointsSpec(roi=request[self.points].roi.copy()) batch = Batch() batch.points[self.points] = Points(points_data, points_spec) timing.stop() batch.profiling_stats.add(timing) return batch