def setup(self): upstream = self.get_upstream_provider() self.upstream_spec = upstream.spec if self.mask and self.min_masked > 0: assert self.mask in self.upstream_spec, ( "Upstream provider does not have %s" % self.mask) self.mask_spec = self.upstream_spec.array_specs[self.mask] logger.info("requesting complete mask...") mask_request = BatchRequest({self.mask: self.mask_spec}) mask_batch = upstream.request_batch(mask_request) logger.info("allocating mask integral array...") mask_data = mask_batch.arrays[self.mask].data mask_integral_dtype = np.uint64 logger.debug("mask size is %s", mask_data.size) if mask_data.size < 2**32: mask_integral_dtype = np.uint32 if mask_data.size < 2**16: mask_integral_dtype = np.uint16 logger.debug("chose %s as integral array dtype", mask_integral_dtype) self.mask_integral = np.array(mask_data > 0, dtype=mask_integral_dtype) self.mask_integral = integral_image(self.mask_integral) if self.ensure_nonempty: assert self.ensure_nonempty in self.upstream_spec, ( "Upstream provider does not have %s" % self.ensure_nonempty) points_spec = self.upstream_spec.points_specs[self.ensure_nonempty] logger.info("requesting all %s points...", self.ensure_nonempty) points_request = BatchRequest({self.ensure_nonempty: points_spec}) points_batch = upstream.request_batch(points_request) self.points = KDTree([ p.location for p in points_batch[self.ensure_nonempty].data.values() ]) logger.info("retrieved %d points", len(self.points.data)) # clear bounding boxes of all provided arrays and points -- # RandomLocation does not have limits (offsets are ignored) for key, spec in self.spec.items(): spec.roi.set_shape(None) self.updates(key, spec)
def prepare(self, request): if self.labels_mask: assert ( request[self.labels].roi == request[self.labels_mask].roi ), ("requested GT label roi %s and GT label mask roi %s are not " "the same." % (request[self.labels].roi, request[self.labels_mask].roi)) if self.unlabelled: assert ( request[self.labels].roi == request[self.unlabelled].roi ), ("requested GT label roi %s and GT unlabelled mask roi %s are not " "the same." % (request[self.labels].roi, request[self.unlabelled].roi)) deps = BatchRequest() # grow labels ROI to accomodate padding labels_roi = request[self.affinities].roi.grow(-self.padding_neg, self.padding_pos) deps[self.labels] = request[self.affinities].copy() deps[self.labels].dtype = None deps[self.labels].roi = labels_roi if self.labels_mask: deps[self.labels_mask] = deps[self.labels].copy() if self.unlabelled: deps[self.unlabelled] = deps[self.labels].copy() return deps
def setup(self): self.roi = self.get_spec().get_total_roi() if self.roi is None: raise RuntimeError("Can not draw random samples from a provider that does not have a bounding box.") if self.min_masked > 0: assert self.mask_volume_type in self.get_spec().volumes, "Upstream provider does not have %s"%self.mask_volume_type self.mask_roi = self.get_spec().volumes[self.mask_volume_type] logger.info("requesting complete mask...") mask_request = BatchRequest({self.mask_volume_type: self.mask_roi}) mask_batch = self.get_upstream_provider().request_batch(mask_request) logger.info("allocating mask integral volume...") mask_data = mask_batch.volumes[self.mask_volume_type].data mask_integral_dtype = np.uint64 logger.debug("mask size is " + str(mask_data.size)) if mask_data.size < 2**32: mask_integral_dtype = np.uint32 if mask_data.size < 2**16: mask_integral_dtype = np.uint16 logger.debug("chose %s as integral volume dtype"%mask_integral_dtype) self.mask_integral = np.array(mask_data>0, dtype=mask_integral_dtype) self.mask_integral = integral_image(self.mask_integral)
def __init__(self, output_dir='snapshots', output_filename='{id}.hdf', every=1, additional_request=None): ''' output_dir: string The directory to save the snapshots. Will be created, if it does not exist. output_filename: string Template for output filenames. '{id}' in the string will be replaced with the ID of the batch. '{iteration}' with the training iteration (if training was performed on this batch). every: How often to save a batch. 'every=1' indicates that every batch will be stored, 'every=2' every second and so on. By default, every batch will be stored. additional_request: An additional batch request to merge with the passing request, if a snapshot is to be made. If not given, only the volumes that are in the batch anyway are recorded. ''' self.output_dir = output_dir self.output_filename = output_filename self.every = max(1, every) self.additional_request = BatchRequest( ) if additional_request is None else additional_request self.n = 0
def prepare(self, request): # TODO: move all randomness into the prepare method # TODO: write a test for this node np.random.seed(request.random_seed) deps = BatchRequest() deps[self.array] = request[self.array].copy() return deps
def prepare(self, request): deps = BatchRequest() deps[self.labels] = request[self.scales] for mask in self.masks: deps[mask] = request[self.scales] return deps
def provide(self, request): # create upstream requests upstream_requests = {} for key, spec in request.items(): provider = self.key_to_provider[key] if provider not in upstream_requests: # use new random seeds per upstream request. # seeds picked by random should be deterministic since # the provided request already has a random seed. seed = random.randint(0, 2**32) upstream_requests[provider] = BatchRequest(random_seed=seed) upstream_requests[provider][key] = spec # execute requests, merge batches merged_batch = Batch() for provider, upstream_request in upstream_requests.items(): batch = provider.request_batch(upstream_request) for key, array in batch.arrays.items(): merged_batch.arrays[key] = array for key, graph in batch.graphs.items(): merged_batch.graphs[key] = graph merged_batch.profiling_stats.merge_with(batch.profiling_stats) return merged_batch
def __init__( self, dataset_names, output_dir="snapshots", output_filename="{id}.zarr", every=1, additional_request=None, compression_type=None, dataset_dtypes=None, store_value_range=False, ): self.dataset_names = dataset_names self.output_dir = output_dir self.output_filename = output_filename self.every = max(1, every) self.additional_request = ( BatchRequest() if additional_request is None else additional_request ) self.n = 0 self.compression_type = compression_type self.store_value_range = store_value_range if dataset_dtypes is None: self.dataset_dtypes = {} else: self.dataset_dtypes = dataset_dtypes self.mode = "w"
def prepare(self, request): if self.settings.mode == 'ball': context = np.ceil(self.settings.radius).astype(np.int) elif self.settings.mode == 'peak': context = np.ceil(2*self.settings.radius).astype(np.int) else: raise RuntimeError('unknown raster mode %s'%self.settings.mode) dims = self.array_spec.roi.dims() if len(context) == 1: context = context.repeat(dims) # request graph in a larger area to get rasterization from outside # graph graph_roi = request[self.array].roi.grow( Coordinate(context), Coordinate(context)) # however, restrict the request to the graph actually provided graph_roi = graph_roi.intersect(self.spec[self.graph].roi) deps = BatchRequest() deps[self.graph] = GraphSpec(roi=graph_roi) if self.settings.mask is not None: mask_voxel_size = self.spec[self.settings.mask].voxel_size assert self.spec[self.array].voxel_size == mask_voxel_size, ( "Voxel size of mask and rasterized volume need to be equal") new_mask_roi = graph_roi.snap_to_grid(mask_voxel_size) deps[self.settings.mask] = ArraySpec(roi=new_mask_roi) return deps
def prepare(self, request): upstream_spec = self.get_upstream_provider().spec logger.debug("request: %s" % request) logger.debug("upstream spec: %s" % upstream_spec) # TODO: remove this? if self.key not in request: return roi = request[self.key].roi.copy() # change request to fit into upstream spec request[self.key].roi = roi.intersect(upstream_spec[self.key].roi) if request[self.key].roi.empty(): logger.warning( "Requested %s ROI %s lies entirely outside of upstream " "ROI %s.", self.key, roi, upstream_spec[self.key].roi) # ensure a valid request by asking for empty ROI request[self.key].roi = Roi( upstream_spec[self.key].roi.get_offset(), (0, ) * upstream_spec[self.key].roi.dims()) logger.debug("new request: %s" % request) deps = BatchRequest() deps[self.key] = request[self.key] return deps
def prepare(self, request): deps = BatchRequest() for in_key, out_key in zip(self.arrays, self.output_arrays): spec = request[out_key].copy() if self.context is not None: spec.roi = spec.roi.grow(self.context, self.context) deps[in_key] = spec return deps
def prepare(self, request): if not self.initialized and not self.spawn_subprocess: self.start() self.initialized = True deps = BatchRequest() for key in self.inputs.values(): deps[key] = request[key] return deps
def __init__(self, every=1, additional_request=None, ignore_key=lambda key: False): self.every = max(1, every) self.additional_request = BatchRequest( ) if additional_request is None else additional_request self.n = 0 self.snapshots = {} self.ignore_key = ignore_key
def prepare(self, request): """ TODO: There is no prepare method for the train nodes. This is a pain because it means that whatever is in the pipeline when it passes this node will be used as the inputs/targets etc. If you request ground truth labels of size "input_size", your loss function will probably throw an error due to it comparing the output of your network with size "output_size" to your labels which have size "input_size". """ deps = BatchRequest() # Get the roi for the outputs output_requests = BatchRequest() for array_key in self.outputs.values(): if array_key in request: output_requests[array_key] = request[array_key].copy() output_total_roi = output_requests.get_total_roi() diff = self.output_size - output_total_roi.get_shape() assert Coordinate([x % 2 for x in diff ]) == Coordinate([0] * len(self.output_size)) output_roi = output_total_roi.grow(diff // 2, diff // 2) assert output_roi.get_shape() == self.output_size # Grow the output roi to fit the appropriate input roi diff = self.input_size - output_roi.get_shape() assert Coordinate([x % 2 for x in diff ]) == Coordinate([0] * len(self.output_size)) input_roi = output_roi.grow(diff // 2, diff // 2) # Request inputs: for array_key in self.inputs.values(): deps[array_key] = ArraySpec(roi=input_roi) # Request targets: for array_key in self.targets.values(): deps[array_key] = ArraySpec(roi=output_roi) deps[self.weights] = ArraySpec(roi=output_roi) return deps
def prepare(self, request): deps = BatchRequest() if self.target not in request: return logger.debug("preparing upsampling of " + str(self.source)) request_roi = request[self.target].roi logger.debug("request ROI is %s" % request_roi) # add or merge to batch request deps[self.source] = ArraySpec(roi=request_roi) return deps
def process(self, batch, request): assert batch.get_total_roi().dims()==3, "DefectAugment works on 3D batches only" prob_missing_threshold = self.prob_missing prob_low_contrast_threshold = prob_missing_threshold + self.prob_low_contrast prob_artifact_threshold = prob_low_contrast_threshold + self.prob_artifact raw = batch.volumes[VolumeTypes.RAW] for c in range(batch.get_total_roi().get_shape()[self.axis]): r = random.random() section_selector = tuple( slice(None if d != self.axis else c, None if d != self.axis else c+1) for d in range(batch.get_total_roi().dims()) ) if r < prob_missing_threshold: logger.debug("Zero-out " + str(section_selector)) raw.data[section_selector] = 0 elif r < prob_low_contrast_threshold: logger.debug("Lower contrast " + str(section_selector)) section = raw.data[section_selector] mean = section.mean() section -= mean section *= self.contrast_scale section += mean raw.data[section_selector] = section elif r < prob_artifact_threshold: logger.debug("Add artifact " + str(section_selector)) section = raw.data[section_selector] artifact_request = BatchRequest() artifact_request.add_volume_request(VolumeTypes.RAW, section.shape) artifact_request.add_volume_request(VolumeTypes.ALPHA_MASK, section.shape) logger.debug("Requesting artifact batch " + str(artifact_request)) artifact_batch = self.artifact_source.request_batch(artifact_request) artifact_alpha = artifact_batch.volumes[VolumeTypes.ALPHA_MASK].data artifact_raw = artifact_batch.volumes[VolumeTypes.RAW].data assert artifact_raw.dtype == section.dtype assert artifact_alpha.dtype == np.float32 assert artifact_alpha.min() >= 0.0 assert artifact_alpha.max() <= 1.0 raw.data[section_selector] = section*(1.0 - artifact_alpha) + artifact_raw*artifact_alpha
def prepare(self, request): deps = BatchRequest() for (array_key, (src_points_key, trg_points_key)) in self.array_to_src_trg_points.items(): if array_key in request: # increase or set request for points to be array roi + padding for partners outside roi for target points deps[src_points_key] = PointsSpec(request[array_key].roi) padded_roi = request[array_key].roi.grow((self.pad_for_partners), (self.pad_for_partners)) deps[trg_points_key] = PointsSpec(padded_roi) for (array_key, stayinside_array_key) in self.array_keys_to_stayinside_array_keys.items(): if array_key in request: deps[stayinside_array_key] = copy.deepcopy(request[array_key]) return deps
def prepare(self, request): deps = BatchRequest() for key, spec in request.items(): if key in self.dataset_names: deps[key] = spec self.record_snapshot = self.n % self.every == 0 if self.record_snapshot: # append additional array requests, don't overwrite existing ones for array_key, spec in self.additional_request.array_specs.items(): if array_key not in deps: deps[array_key] = spec for graph_key, spec in self.additional_request.graph_specs.items(): if graph_key not in deps: deps[graph_key] = spec return deps
def __init__(self, dataset_names, output_dir='snapshots', output_filename='{id}.hdf', every=1, additional_request=None, compression_type=None, dataset_dtypes=None): self.dataset_names = dataset_names self.output_dir = output_dir self.output_filename = output_filename self.every = max(1, every) self.additional_request = BatchRequest( ) if additional_request is None else additional_request self.n = 0 self.compression_type = compression_type if dataset_dtypes is None: self.dataset_dtypes = {} else: self.dataset_dtypes = dataset_dtypes
def provide(self, request): # create upstream requests upstream_requests = {} for key, spec in request.items(): provider = self.key_to_provider[key] if provider not in upstream_requests: upstream_requests[provider] = BatchRequest() upstream_requests[provider][key] = spec # execute requests, merge batches merged_batch = Batch() for provider, upstream_request in upstream_requests.items(): batch = provider.request_batch(upstream_request) for key, array in batch.arrays.items(): merged_batch.arrays[key] = array for key, points in batch.points.items(): merged_batch.points[key] = points return merged_batch
def provide(self, request): # create upstream requests upstream_requests = {} for key, spec in request.items(): provider = self.key_to_provider[key] if provider not in upstream_requests: upstream_requests[provider] = BatchRequest() upstream_requests[provider][key] = spec # execute requests, merge batches merged_batch = Batch() for provider, upstream_request in upstream_requests.items(): batch = provider.request_batch(upstream_request) for key, array in batch.arrays.items(): merged_batch.arrays[key] = array for key, graph in batch.graphs.items(): merged_batch.graphs[key] = graph merged_batch.profiling_stats.merge_with(batch.profiling_stats) return merged_batch
def __select_random_location_with_points( self, request, lcm_shift_roi, lcm_voxel_size): request_points_roi = request[self.ensure_nonempty].roi while True: # How to pick shifts that ensure that a randomly chosen point is # contained in the request ROI: # # # request point # [---------) . # 0 +10 17 # # least shifted to contain point # [---------) # 8 +10 # == # point-request.begin-request.shape+1 # # most shifted to contain point: # [---------) # 17 +10 # == # point-request.begin # # all possible shifts # [---------) # 8 +10 # == # point-request.begin-request.shape+1 # == # request.shape # # In the most shifted scenario, it could happen that the point lies # exactly at the lower boundary (17 in the example). This will cause # trouble if later we mirror the batch. The point would end up lying # on the other boundary, which is exclusive and thus not part of the # ROI. Therefore, we have to ensure that the point is well inside # the shifted ROI, not just on the boundary: # # all possible shifts # [--------) # 8 +9 # == # request.shape-1 # pick a random point point_id = choice(self.points.data.keys()) point = self.points.data[point_id] logger.debug( "select random point %d at %s", point_id, point.location) # get the lcm voxel that contains this point lcm_location = Coordinate(point.location/lcm_voxel_size) logger.debug( "belongs to lcm voxel %s", lcm_location) # mark all dimensions in which the point lies on the lower boundary # of the lcm voxel on_lower_boundary = lcm_location*lcm_voxel_size == point.location logger.debug( "lies on the lower boundary of the lcm voxel in dimensions %s", on_lower_boundary) # for each of these dimensions, we have to change the shape of the # shift ROI using the following correction lower_boundary_correction = Coordinate(( -1 if o else 0 for o in on_lower_boundary )) logger.debug( "lower bound correction for shape of shift ROI %s", lower_boundary_correction) # get the request ROI's shape in lcm lcm_roi_begin = request_points_roi.get_begin()/lcm_voxel_size lcm_roi_shape = request_points_roi.get_shape()/lcm_voxel_size logger.debug("Point request ROI: %s", request_points_roi) logger.debug("Point request lcm ROI shape: %s", lcm_roi_shape) # get all possible starting points of lcm_roi_shape that contain # lcm_location lcm_shift_roi_begin = ( lcm_location - lcm_roi_begin - lcm_roi_shape + Coordinate((1,)*len(lcm_location)) ) lcm_shift_roi_shape = ( lcm_roi_shape + lower_boundary_correction ) lcm_point_shift_roi = Roi(lcm_shift_roi_begin, lcm_shift_roi_shape) logger.debug("lcm point shift roi: %s", lcm_point_shift_roi) # intersect with total shift ROI if not lcm_point_shift_roi.intersects(lcm_shift_roi): logger.debug( "reject random shift, random point %s shift ROI %s does " "not intersect total shift ROI %s", point.location, lcm_point_shift_roi, lcm_shift_roi) continue lcm_point_shift_roi = lcm_point_shift_roi.intersect(lcm_shift_roi) # select a random shift from all possible shifts random_shift = self.__select_random_location( lcm_point_shift_roi, lcm_voxel_size) logger.debug("random shift: %s", random_shift) # count all points inside the shifted ROI points_request = BatchRequest() points_request[self.ensure_nonempty] = PointsSpec( roi=request_points_roi.shift(random_shift)) logger.debug("points request: %s", points_request) points_batch = self.get_upstream_provider().request_batch(points_request) point_ids = points_batch.points[self.ensure_nonempty].data.keys() assert point_id in point_ids, ( "Requested batch to contain point %s, but got points " "%s"%(point_id, point_ids)) num_points = len(point_ids) # accept this shift with p=1/num_points # # This is to compensate the bias introduced by close-by points. accept = random() <= 1.0/num_points if accept: return random_shift
def process(self, batch, request): assert batch.get_total_roi().dims() == 3, "defectaugment works on 3d batches only" raw = batch.arrays[self.intensities] raw_voxel_size = self.spec[self.intensities].voxel_size for c, augmentation_type in self.slice_to_augmentation.items(): section_selector = tuple( slice(None if d != self.axis else c, None if d != self.axis else c+1) for d in range(raw.spec.roi.dims()) ) if augmentation_type == 'zero_out': raw.data[section_selector] = 0 elif augmentation_type == 'low_contrast': section = raw.data[section_selector] mean = section.mean() section -= mean section *= self.contrast_scale section += mean raw.data[section_selector] = section elif augmentation_type == 'artifact': section = raw.data[section_selector] alpha_voxel_size = self.artifact_source.spec[self.artifacts_mask].voxel_size assert raw_voxel_size == alpha_voxel_size, ("Can only alpha blend RAW with " "ALPHA_MASK if both have the same " "voxel size") artifact_request = BatchRequest() artifact_request.add(self.artifacts, Coordinate(section.shape) * raw_voxel_size, voxel_size=raw_voxel_size) artifact_request.add(self.artifacts_mask, Coordinate(section.shape) * alpha_voxel_size, voxel_size=raw_voxel_size) logger.debug("Requesting artifact batch %s", artifact_request) artifact_batch = self.artifact_source.request_batch(artifact_request) artifact_alpha = artifact_batch.arrays[self.artifacts_mask].data artifact_raw = artifact_batch.arrays[self.artifacts].data assert artifact_alpha.dtype == np.float32 assert artifact_alpha.min() >= 0.0 assert artifact_alpha.max() <= 1.0 raw.data[section_selector] = section*(1.0 - artifact_alpha) + artifact_raw*artifact_alpha elif augmentation_type == 'deformed_slice': section = raw.data[section_selector].squeeze() # set interpolation to cubic, spec interploatable is true, else to 0 interpolation = 3 if self.spec[self.intensities].interpolatable else 0 # load the deformation fields that were prepared for this slice flow_x, flow_y, line_mask = self.deform_slice_transformations[c] # apply the deformation fields shape = section.shape section = map_coordinates( section, (flow_y, flow_x), mode='constant', order=interpolation ).reshape(shape) # things can get smaller than 0 at the boundary, so we clip section = np.clip(section, 0., 1.) # zero-out data below the line mask section[line_mask] = 0. raw.data[section_selector] = section # in case we needed to change the ROI due to a deformation augment, # restore original ROI and crop the array data if 'deformed_slice' in self.slice_to_augmentation.values(): old_roi = request[self.intensities].roi logger.debug("resetting roi to %s" % old_roi) crop = tuple( slice(None) if d == self.axis else slice(self.deformation_strength, -self.deformation_strength) for d in range(raw.spec.roi.dims()) ) raw.data = raw.data[crop] raw.spec.roi = old_roi
def prepare(self, request): deps = BatchRequest() # we prepare the augmentations, by determining which slices # will be augmented by which method # If one of the slices is augmented with 'deform', # we prepare these trafos already # and request a bigger roi from upstream prob_missing_threshold = self.prob_missing prob_low_contrast_threshold = prob_missing_threshold + self.prob_low_contrast prob_artifact_threshold = prob_low_contrast_threshold + self.prob_artifact prob_deform_slice = prob_artifact_threshold + self.prob_deform spec = request[self.intensities].copy() roi = spec.roi logger.debug("downstream request ROI is %s" % roi) raw_voxel_size = self.spec[self.intensities].voxel_size # store the mapping slice to augmentation type in a dict self.slice_to_augmentation = {} # store the transformations for deform slice self.deform_slice_transformations = {} for c in range((roi / raw_voxel_size).get_shape()[self.axis]): r = random.random() if r < prob_missing_threshold: logger.debug("Zero-out " + str(c)) self.slice_to_augmentation[c] = 'zero_out' elif r < prob_low_contrast_threshold: logger.debug("Lower contrast " + str(c)) self.slice_to_augmentation[c] = 'lower_contrast' elif r < prob_artifact_threshold: logger.debug("Add artifact " + str(c)) self.slice_to_augmentation[c] = 'artifact' elif r < prob_deform_slice: logger.debug("Add deformed slice " + str(c)) self.slice_to_augmentation[c] = 'deformed_slice' # get the shape of a single slice slice_shape = (roi / raw_voxel_size).get_shape() slice_shape = slice_shape[:self.axis] + slice_shape[self.axis+1:] self.deform_slice_transformations[c] = self.__prepare_deform_slice(slice_shape) # prepare transformation and # request bigger upstream roi for deformed slice if 'deformed_slice' in self.slice_to_augmentation.values(): # create roi sufficiently large to feed deformation logger.debug("before growth: %s" % spec.roi) growth = Coordinate( tuple(0 if d == self.axis else raw_voxel_size[d] * self.deformation_strength for d in range(spec.roi.dims())) ) logger.debug("growing request by %s" % str(growth)) source_roi = roi.grow(growth, growth) # update request ROI to get all voxels necessary to perfrom # transformation spec.roi = source_roi logger.debug("upstream request roi is %s" % spec.roi) deps[self.intensities] = spec
def prepare(self, request): deps = BatchRequest() for key in self.dataset_names.keys(): deps[key] = request[key] return deps
def prepare(self, request): deps = BatchRequest() for key in self.inputs.values(): deps[key] = request[key] return deps
def prepare(self, request): deps = BatchRequest() deps[self.array] = request[self.array] return deps
def setup(self): upstream = self.get_upstream_provider() self.upstream_spec = upstream.spec if self.mask and self.min_masked > 0: assert self.mask in self.upstream_spec, ( "Upstream provider does not have %s"%self.mask) self.mask_spec = self.upstream_spec.array_specs[self.mask] logger.info("requesting complete mask...") mask_request = BatchRequest({self.mask: self.mask_spec}) mask_batch = upstream.request_batch(mask_request) logger.info("allocating mask integral array...") mask_data = mask_batch.arrays[self.mask].data mask_integral_dtype = np.uint64 logger.debug("mask size is %s", mask_data.size) if mask_data.size < 2**32: mask_integral_dtype = np.uint32 if mask_data.size < 2**16: mask_integral_dtype = np.uint16 logger.debug("chose %s as integral array dtype", mask_integral_dtype) self.mask_integral = np.array(mask_data > 0, dtype=mask_integral_dtype) self.mask_integral = integral_image(self.mask_integral).astype(mask_integral_dtype) if self.ensure_nonempty: assert self.ensure_nonempty in self.upstream_spec, ( "Upstream provider does not have %s"%self.ensure_nonempty) graph_spec = self.upstream_spec.graph_specs[self.ensure_nonempty] logger.info("requesting all %s points...", self.ensure_nonempty) nonempty_request = BatchRequest({self.ensure_nonempty: graph_spec}) nonempty_batch = upstream.request_batch(nonempty_request) self.points = cKDTree( [p.location for p in nonempty_batch[self.ensure_nonempty].nodes] ) point_counts = self.points.query_ball_point( [p.location for p in nonempty_batch[self.ensure_nonempty].nodes], r=self.point_balance_radius, ) weights = [1 / len(point_count) for point_count in point_counts] self.cumulative_weights = list(itertools.accumulate(weights)) logger.debug("retrieved %d points", len(self.points.data)) # clear bounding boxes of all provided arrays and points -- # RandomLocation does not have limits (offsets are ignored) for key, spec in self.spec.items(): if spec.roi is not None: spec.roi.set_shape(None) self.updates(key, spec)
def prepare(self, request): deps = BatchRequest() deps[self.label_array_key] = request[self.gradient_array_key] return deps
def prepare(self, request): deps = BatchRequest() deps[self.array] = request[self.array] deps[self.array].dtype = None return deps