class PreCache(BatchFilter): def __init__(self, request, cache_size=50, num_workers=20): ''' request: A BatchRequest used to pre-cache batches. cache_size: int How many batches to pre-cache. num_workers: int How many processes to spawn to fill the cache. ''' self.request = copy.deepcopy(request) self.batches = multiprocessing.Queue(maxsize=cache_size) self.workers = ProducerPool( [lambda i=i: self.__run_worker(i) for i in range(num_workers)], queue_size=cache_size) def setup(self): self.workers.start() def teardown(self): self.workers.stop() def provide(self, request): timing = Timing(self) timing.start() logger.debug("getting batch from queue...") batch = self.workers.get() timing.stop() batch.profiling_stats.add(timing) return batch def __run_worker(self, i): request = copy.deepcopy(self.request) return self.get_upstream_provider().request_batch(request)
class GenericTrain(BatchFilter): '''Generic train node to perform one training iteration for each batch that passes through. This node alone does nothing and should be subclassed for concrete implementations. Args: inputs (dict): Dictionary from names of input layers in the network to :class:``ArrayKey`` or batch attribute name as string. outputs (dict): Dictionary from the names of output layers in the network to :class:``ArrayKey``. New arrays will be generated by this node for each entry (if requested downstream). gradients (dict): Dictionary from the names of output layers in the network to :class:``ArrayKey``. New arrays containing the gradient of an output with respect to the loss will be generated by this node for each entry (if requested downstream). array_specs (dict, optional): An optional dictionary of :class:`ArrayKey` to :class:`ArraySpec` to set the array specs generated arrays (``outputs`` and ``gradients``). This is useful to set the ``voxel_size``, for example, if they differ from the voxel size of the input arrays. Only fields that are not ``None`` in the given :class:`ArraySpec` will be used. spawn_subprocess (bool, optional): Whether to run the ``train_step`` in a separate process. Default is false. ''' def __init__(self, inputs, outputs, gradients, array_specs=None, spawn_subprocess=False): self.initialized = False self.inputs = inputs self.outputs = outputs self.gradients = gradients self.array_specs = {} if array_specs is None else array_specs self.spawn_subprocess = spawn_subprocess self.provided_arrays = list(self.outputs.values()) + list( self.gradients.values()) def setup(self): # get common voxel size of inputs, or None if they differ common_voxel_size = None for key in self.inputs.values(): if not isinstance(key, ArrayKey): continue voxel_size = self.spec[key].voxel_size if common_voxel_size is None: common_voxel_size = voxel_size elif common_voxel_size != voxel_size: common_voxel_size = None break # announce provided outputs for key in self.provided_arrays: if key in self.array_specs: spec = self.array_specs[key].copy() else: spec = ArraySpec() if spec.voxel_size is None and not spec.nonspatial: assert common_voxel_size is not None, ( "There is no common voxel size of the inputs, and no " "ArraySpec has been given for %s that defines " "voxel_size." % key) spec.voxel_size = common_voxel_size if spec.interpolatable is None: # default for predictions spec.interpolatable = False self.provides(key, spec) if self.spawn_subprocess: # start training as a producer pool, so that we can gracefully exit if # anything goes wrong self.worker = ProducerPool([self.__produce_train_batch], queue_size=1) self.batch_in = multiprocessing.Queue(maxsize=1) self.worker.start() else: self.start() self.initialized = True def prepare(self, request): deps = BatchRequest() for key in self.inputs.values(): deps[key] = request[key] return deps def teardown(self): if self.spawn_subprocess: # signal "stop" self.batch_in.put((None, None)) try: self.worker.get(timeout=2) except NoResult: pass self.worker.stop() else: self.stop() def process(self, batch, request): start = time.time() if self.spawn_subprocess: self.batch_in.put((batch, request)) try: out = self.worker.get() except WorkersDied: raise TrainProcessDied() for array_key in self.provided_arrays: if array_key in request: batch.arrays[array_key] = out.arrays[array_key] batch.loss = out.loss batch.iteration = out.iteration else: self.train_step(batch, request) time_of_iteration = time.time() - start logger.info("Train process: iteration=%d loss=%f time=%f", batch.iteration, batch.loss, time_of_iteration) def start(self): '''To be implemented in subclasses. This method will be called before the first call to :fun:`train_step`, from the same process that :fun:`train_step` will be called from. Use this to initialize you solver and training hardware. ''' pass def train_step(self, batch, request): '''To be implemented in subclasses. In this method, an implementation should perform one training iteration on the given batch. ``batch.loss`` and ``batch.iteration`` should be set. Output arrays should be created according to the given request and added to ``batch``.''' raise NotImplementedError("Class %s does not implement 'train_step'" % self.name()) def stop(self): '''To be implemented in subclasses. This method will be called after the last call to :fun:`train_step`, from the same process that :fun:`train_step` will be called from. Use this to tear down you solver and free training hardware. ''' pass def _checkpoint_name(self, basename, iteration): return basename + '_checkpoint_' + '%i' % iteration def _get_latest_checkpoint(self, basename): def atoi(text): return int(text) if text.isdigit() else text def natural_keys(text): return [atoi(c) for c in re.split(r'(\d+)', text)] checkpoints = glob.glob(basename + '_checkpoint_*') checkpoints.sort(key=natural_keys) if len(checkpoints) > 0: checkpoint = checkpoints[-1] iteration = int(checkpoint.split('_')[-1]) return checkpoint, iteration return None, 0 def __produce_train_batch(self): '''Process one train batch.''' if not self.initialized: self.start() self.initialized = True batch, request = self.batch_in.get() # stop signal if batch is None: self.stop() return None self.train_step(batch, request) return batch
class Train(BatchFilter): '''Performs one training iteration for each batch that passes through. Adds the predicted affinities to the batch. ''' def __init__(self, solver_parameters, use_gpu=None): # start training as a producer pool, so that we can gracefully exit if # anything goes wrong self.worker = ProducerPool([lambda gpu=use_gpu: self.__train(gpu)], queue_size=1) self.batch_in = multiprocessing.Queue(maxsize=1) self.solver_parameters = solver_parameters self.solver_initialized = False def setup(self): self.worker.start() def teardown(self): self.worker.stop() def prepare(self, request): # remove request parts that we provide for volume_type in [ VolumeType.LOSS_GRADIENT, VolumeType.PRED_AFFINITIES ]: if volume_type in request.volumes: del request.volumes[volume_type] def process(self, batch, request): self.batch_in.put((batch, request)) try: out = self.worker.get() except WorkersDied: raise TrainProcessDied() batch.volumes[VolumeType.PRED_AFFINITIES] = out.volumes[ VolumeType.PRED_AFFINITIES] if VolumeType.LOSS_GRADIENT in request.volumes: batch.volumes[VolumeType.LOSS_GRADIENT] = out.volumes[ VolumeType.LOSS_GRADIENT] batch.loss = out.loss batch.iteration = out.iteration def __train(self, use_gpu): start = time.time() if not self.solver_initialized: logger.info("Initializing solver...") if use_gpu is not None: logger.debug("Train process: using GPU %d" % use_gpu) caffe.enumerate_devices(False) caffe.set_devices((use_gpu, )) caffe.set_mode_gpu() caffe.select_device(use_gpu, False) self.solver = caffe.get_solver(self.solver_parameters) if self.solver_parameters.resume_from is not None: logger.debug("Train process: restoring solver state from " + self.solver_parameters.resume_from) self.solver.restore(self.solver_parameters.resume_from) self.net_io = NetIoWrapper(self.solver.net) self.solver_initialized = True batch, request = self.batch_in.get() data = { 'data': batch.volumes[VolumeType.RAW].data[np.newaxis, np.newaxis, :], 'aff_label': batch.volumes[VolumeType.GT_AFFINITIES].data[np.newaxis, :], } if self.solver_parameters.train_state.get_stage(0) == 'euclid': logger.debug( "Train process: preparing input data for Euclidean training") self.__prepare_euclidean(batch, data) else: logger.debug( "Train process: preparing input data for Malis training") self.__prepare_malis(batch, data) self.net_io.set_inputs(data) loss = self.solver.step(1) # self.__consistency_check() output = self.net_io.get_outputs() batch.volumes[VolumeType.PRED_AFFINITIES] = Volume( output['aff_pred'][0], batch.volumes[VolumeType.GT_AFFINITIES].roi, batch.volumes[VolumeType.GT_AFFINITIES].resolution, interpolate=True) batch.loss = loss batch.iteration = self.solver.iter if VolumeType.LOSS_GRADIENT in request.volumes: diffs = self.net_io.get_output_diffs() batch.volumes[VolumeType.LOSS_GRADIENT] = Volume( diffs['aff_pred'][0], batch.volumes[VolumeType.GT_AFFINITIES].roi, batch.volumes[VolumeType.GT_AFFINITIES].resolution, interpolate=True) time_of_iteration = time.time() - start logger.info("Train process: iteration=%d loss=%f time=%f" % (self.solver.iter, batch.loss, time_of_iteration)) return batch def __prepare_euclidean(self, batch, data): gt_affinities = batch.volumes[VolumeType.GT_AFFINITIES] # initialize error scale with 1s error_scale = np.ones(gt_affinities.data.shape, dtype=np.float) # set error_scale to 0 in masked-out areas if VolumeType.GT_MASK in batch.volumes: self.__mask_error_scale(error_scale, batch.volumes[VolumeType.GT_MASK].data) if VolumeType.GT_IGNORE in batch.volumes: self.__mask_error_scale(error_scale, batch.volumes[VolumeType.GT_IGNORE].data) # in the masked-in area, compute the fraction of positive samples masked_in = error_scale.sum() num_pos = (gt_affinities.data * error_scale).sum() frac_pos = float(num_pos) / masked_in if masked_in > 0 else 0 frac_pos = np.clip(frac_pos, 0.05, 0.95) frac_neg = 1.0 - frac_pos # compute the class weights for positive and negative samples w_pos = 1.0 / (2.0 * frac_pos) w_neg = 1.0 / (2.0 * frac_neg) # scale the masked-in error_scale with the class weights error_scale *= (data >= 0.5) * w_pos + (data < 0.5) * w_neg data['scale'] = error_scale[np.newaxis, :] def __mask_error_scale(self, error_scale, mask): for d in range(error_scale.shape[0]): error_scale[d] *= mask def __prepare_malis(self, batch, data): gt_labels = batch.volumes[VolumeType.GT_LABELS] next_id = gt_labels.data.max() + 1 gt_pos_pass = gt_labels.data if VolumeType.GT_IGNORE in batch.volumes: gt_neg_pass = np.array(gt_labels.data) gt_neg_pass[batch.volumes[VolumeType.GT_IGNORE].data == 0] = next_id else: gt_neg_pass = gt_pos_pass data['comp_label'] = np.array([[gt_neg_pass, gt_pos_pass]]) data['nhood'] = batch.affinity_neighborhood[np.newaxis, np.newaxis, :] # Why don't we update gt_affinities in the same way? # -> not needed # # GT affinities are all 0 in the masked area (because masked area is # assumed to be set to background in batch.gt). # # In the negative pass: # # We set all affinities inside GT regions to 1. Affinities in masked # area as predicted. Belongs to one forground region (introduced # above). But we only count loss on edges connecting different labels # -> loss in masked-out area only from outside regions. # # In the positive pass: # # We set all affinities outside GT regions to 0 -> no loss in masked # out area. def __consistency_check(self): diffs = self.net_io.get_outputs() for k in diffs: assert not np.isnan( diffs[k]).any(), "Detected NaN in output diff " + k
class TrainSyntist(BatchFilter): '''Performs one training iteration for each batch that passes through. Adds the predicted affinities to the batch. ''' def __init__(self, solver_parameters, use_gpu=None): # start training as a producer pool, so that we can gracefully exit if # anything goes wrong self.worker = ProducerPool([lambda gpu=use_gpu: self.__train(gpu)], queue_size=1) self.batch_in = multiprocessing.Queue(maxsize=1) self.solver_parameters = solver_parameters self.solver_initialized = False def setup(self): self.worker.start() def teardown(self): self.worker.stop() def prepare(self, request): # remove request parts that we provide for volume_type in [ VolumeTypes.LOSS_GRADIENT, VolumeTypes.PRED_AFFINITIES ]: if volume_type in request.volumes: del request.volumes[volume_type] def process(self, batch, request): self.batch_in.put((batch, request)) try: out = self.worker.get() except WorkersDied: raise TrainProcessDied() batch.volumes[VolumeTypes.PRED_BM_PRESYN] = out.volumes[ VolumeTypes.PRED_BM_PRESYN] if VolumeTypes.LOSS_GRADIENT in request.volumes: batch.volumes[VolumeTypes.LOSS_GRADIENT] = out.volumes[ VolumeTypes.LOSS_GRADIENT] batch.loss = out.loss batch.iteration = out.iteration def __train(self, use_gpu): start = time.time() if not self.solver_initialized: logger.info("Initializing solver...") if use_gpu is not None: logger.debug("Train process: using GPU %d" % use_gpu) caffe.enumerate_devices(False) caffe.set_devices((use_gpu, )) caffe.set_mode_gpu() caffe.select_device(use_gpu, False) self.solver = caffe.get_solver(self.solver_parameters) if self.solver_parameters.resume_from is not None: logger.debug("Train process: restoring solver state from " + self.solver_parameters.resume_from) self.solver.restore(self.solver_parameters.resume_from) self.net_io = NetIoWrapper(self.solver.net) self.solver_initialized = True batch, request = self.batch_in.get() data = { 'data': batch.volumes[VolumeTypes.RAW].data[np.newaxis, np.newaxis, :], 'bm_presyn_label': batch.volumes[VolumeTypes.GT_BM_PRESYN].data[np.newaxis, :], } if self.solver_parameters.train_state.get_stage(0) == 'euclid': logger.debug( "Train process: preparing input data for Euclidean training") self.__prepare_euclidean(batch, data) else: raise Exception("should not enter malis phase at all") logger.debug( "Train process: preparing input data for Malis training") self.__prepare_malis(batch, data) self.net_io.set_inputs(data) loss = self.solver.step(1) # self.__consistency_check() output = self.net_io.get_outputs() batch.volumes[VolumeTypes.PRED_BM_PRESYN] = Volume( output['bm_presyn_pred'], # [0] batch.volumes[VolumeTypes.GT_BM_PRESYN].roi, batch.volumes[VolumeTypes.GT_BM_PRESYN].resolution, # interpolate=True ) batch.loss = loss batch.iteration = self.solver.iter if VolumeTypes.LOSS_GRADIENT in request.volumes: diffs = self.net_io.get_output_diffs() batch.volumes[VolumeTypes.LOSS_GRADIENT] = Volume( diffs['bm_presyn_pred'][0], batch.volumes[VolumeTypes.GT_BM_PRESYN].roi, batch.volumes[VolumeTypes.GT_BM_PRESYN].resolution, # interpolate=True ) time_of_iteration = time.time() - start logger.info("Train process: iteration=%d loss=%f time=%f" % (self.solver.iter, batch.loss, time_of_iteration)) return batch def __prepare_euclidean(self, batch, data): gt_bm_presyn = batch.volumes[VolumeTypes.GT_BM_PRESYN] frac_pos = np.clip(gt_bm_presyn.data.mean(), 0.05, 0.95) w_pos = 1.0 / (2.0 * frac_pos) w_neg = 1.0 / (2.0 * (1.0 - frac_pos)) error_scale = self.__scale_errors(gt_bm_presyn.data, w_neg, w_pos) # if VolumeTypes.GT_MASK in batch.volumes: # error_scale[batch.volumes[VolumeTypes.GT_MASK_EXCLUSIVEZONE_PRESYN].data==0] = 0 # if VolumeTypes.GT_IGNORE in batch.volumes: # error_scale[batch.volumes[VolumeTypes.GT_MASK_EXCLUSIVEZONE_PRESYN].data==0] = 0 # if VolumeTypes.GT_MASK in batch.volumes: # self.__mask_errors(batch, error_scale, batch.volumes[VolumeTypes.GT_MASK].data) # if VolumeTypes.GT_IGNORE in batch.volumes: # self.__mask_errors(batch, error_scale, batch.volumes[VolumeTypes.GT_IGNORE].data) if VolumeTypes.GT_MASK_EXCLUSIVEZONE_PRESYN in batch.volumes: error_scale = np.multiply( error_scale, batch.volumes[VolumeTypes.GT_MASK_EXCLUSIVEZONE_PRESYN].data) if VolumeTypes.GT_MASK_EXCLUSIVEZONE_POSTSYN in batch.volumes: error_scale = np.multiply( error_scale, batch.volumes[VolumeTypes.GT_MASK_EXCLUSIVEZONE_POSTSYN].data) data['scale'] = error_scale[np.newaxis, :] def __scale_errors(self, data, factor_low, factor_high): scaled_data = np.add((data >= 0.5) * factor_high, (data < 0.5) * factor_low) return scaled_data def __mask_errors(self, batch, error_scale, mask): for d in range(len(batch.affinity_neighborhood)): error_scale[d] = np.multiply(error_scale[d], mask) # def __prepare_malis(self, batch, data): # # gt_labels = batch.volumes[VolumeTypes.GT_LABELS] # next_id = gt_labels.data.max() + 1 # # gt_pos_pass = gt_labels.data # # if VolumeTypes.GT_IGNORE in batch.volumes: # # gt_neg_pass = np.array(gt_labels.data) # gt_neg_pass[batch.volumes[VolumeTypes.GT_IGNORE].data==0] = next_id # # else: # # gt_neg_pass = gt_pos_pass # # data['comp_label'] = np.array([[gt_neg_pass, gt_pos_pass]]) # data['nhood'] = batch.affinity_neighborhood[np.newaxis,np.newaxis,:] # # # Why don't we update gt_affinities in the same way? # # -> not needed # # # # GT affinities are all 0 in the masked area (because masked area is # # assumed to be set to background in batch.gt). # # # # In the negative pass: # # # # We set all affinities inside GT regions to 1. Affinities in masked # # area as predicted. Belongs to one forground region (introduced # # above). But we only count loss on edges connecting different labels # # -> loss in masked-out area only from outside regions. # # # # In the positive pass: # # # # We set all affinities outside GT regions to 0 -> no loss in masked # # out area. def __consistency_check(self): diffs = self.net_io.get_outputs() for k in diffs: assert not np.isnan( diffs[k]).any(), "Detected NaN in output diff " + k
class Scan(BatchFilter): '''Iteratively requests batches of size ``reference`` from upstream providers in a scanning fashion, until all requested ROIs are covered. If the batch request to this node is empty, it will scan the complete upstream ROIs (and return nothing). Otherwise, it scans only the requested ROIs and returns a batch assembled of the smaller requests. In either case, the upstream requests will be contained in the downstream requested ROI or upstream ROIs. Args: reference(:class:`BatchRequest`): A reference :class:`BatchRequest`. This request will be shifted in a scanning fashion over the upstream ROIs of the requested arrays or points. num_workers (int, optional): If set to >1, upstream requests are made in parallel with that number of workers. cache_size (int, optional): If multiple workers are used, how many batches to hold at most. ''' def __init__(self, reference, num_workers=1, cache_size=50): self.reference = reference.copy() self.num_workers = num_workers self.cache_size = cache_size self.workers = None if num_workers > 1: self.request_queue = multiprocessing.Queue(maxsize=0) self.batch = None def setup(self): if self.num_workers > 1: self.workers = ProducerPool( [self.__worker_get_chunk for _ in range(self.num_workers)], queue_size=self.cache_size) self.workers.start() def teardown(self): if self.num_workers > 1: self.workers.stop() def provide(self, request): empty_request = (len(request) == 0) if empty_request: scan_spec = self.spec else: scan_spec = request stride = self.__get_stride() shift_roi = self.__get_shift_roi(scan_spec) shifts = self.__enumerate_shifts(shift_roi, stride) num_chunks = len(shifts) logger.info("scanning over %d chunks", num_chunks) # the batch to return self.batch = Batch() if self.num_workers > 1: for shift in shifts: shifted_reference = self.__shift_request(self.reference, shift) self.request_queue.put(shifted_reference) for i in range(num_chunks): chunk = self.workers.get() if not empty_request: self.__add_to_batch(request, chunk) logger.info("processed chunk %d/%d", i, num_chunks) else: for i, shift in enumerate(shifts): shifted_reference = self.__shift_request(self.reference, shift) chunk = self.__get_chunk(shifted_reference) if not empty_request: self.__add_to_batch(request, chunk) logger.info("processed chunk %d/%d", i, num_chunks) batch = self.batch self.batch = None logger.debug("returning batch %s", batch) return batch def __get_stride(self): '''Get the maximal amount by which ``reference`` can be moved, such that it tiles the space.''' stride = None # get the least common multiple of all voxel sizes, we have to stride # at least that far lcm_voxel_size = self.spec.get_lcm_voxel_size( self.reference.array_specs.keys()) # that's just the minimal size in each dimension for key, reference_spec in self.reference.items(): shape = reference_spec.roi.get_shape() for d in range(len(lcm_voxel_size)): assert shape[d] >= lcm_voxel_size[d], ( "Shape of reference " "ROI %s for %s is " "smaller than least " "common multiple of " "voxel size " "%s" % (reference_spec.roi, key, lcm_voxel_size)) if stride is None: stride = shape else: stride = Coordinate((min(a, b) for a, b in zip(stride, shape))) return stride def __get_shift_roi(self, spec): '''Get the minimal and maximal shift (as a ROI) to apply to ``self.reference``, such that it is still fully contained in ``spec``. ''' total_shift_roi = None # get individual shift ROIs and intersect them for key, reference_spec in self.reference.items(): if key not in spec: continue if spec[key].roi is None: continue # we have a reference ROI # # [--------) [9] # 3 12 # # and a spec ROI # # [---------------) [16] # 16 32 # # min and max shifts of reference are # # [--------) [9] # 16 25 # [--------) [9] # 23 32 # # therefore, all possible ways to shift the reference such that it # is contained in the spec are at least 16-3=13 and at most 23-3=20 # (inclusive) # # [-------) [8] # 13 21 # # 1. the starting point is beginning of spec - beginning of reference # 2. the length is length of spec - length of reference + 1 # 1. get the starting point of the shift ROI shift_begin = (spec[key].roi.get_begin() - reference_spec.roi.get_begin()) # 2. get the shape of the shift ROI shift_shape = (spec[key].roi.get_shape() - reference_spec.roi.get_shape() + (1, ) * reference_spec.roi.dims()) # create a ROI... shift_roi = Roi(shift_begin, shift_shape) # ...and intersect it with previous shift ROIs if total_shift_roi is None: total_shift_roi = shift_roi else: total_shift_roi = total_shift_roi.intersect(shift_roi) if total_shift_roi.empty(): raise RuntimeError("There is no location where the ROIs " "the reference %s are contained in the " "request/upstream ROIs " "%s." % (self.reference, spec)) if total_shift_roi is None: raise RuntimeError("None of the upstream ROIs are bounded (all " "ROIs are None). Scan needs at least one " "bounded upstream ROI.") return total_shift_roi def __enumerate_shifts(self, shift_roi, stride): '''Produces a sequence of shift coordinates starting at the beginning of ``shift_roi``, progressing with ``stride``. The maximum shift coordinate in any dimension will be the last point inside the shift roi in this dimension.''' min_shift = shift_roi.get_offset() max_shift = Coordinate(m - 1 for m in shift_roi.get_end()) shift = np.array(min_shift) shifts = [] dims = len(min_shift) logger.debug("enumerating possible shifts of %s in %s", stride, shift_roi) while True: logger.debug("adding %s", shift) shifts.append(Coordinate(shift)) if (shift == max_shift).all(): break # count up dimensions for d in range(dims): if shift[d] >= max_shift[d]: if d == dims - 1: break shift[d] = min_shift[d] else: shift[d] += stride[d] # snap to last possible shift, don't overshoot if shift[d] > max_shift[d]: shift[d] = max_shift[d] break return shifts def __shift_request(self, request, shift): shifted = request.copy() for _, spec in shifted.items(): spec.roi = spec.roi.shift(shift) return shifted def __worker_get_chunk(self): request = self.request_queue.get() return self.__get_chunk(request) def __get_chunk(self, request): return self.get_upstream_provider().request_batch(request) def __add_to_batch(self, spec, chunk): if self.batch.get_total_roi() is None: self.batch = self.__setup_batch(spec, chunk) for (array_key, array) in chunk.arrays.items(): if array_key not in spec: continue self.__fill(self.batch.arrays[array_key].data, array.data, spec.array_specs[array_key].roi, array.spec.roi, self.spec[array_key].voxel_size) for (points_key, points) in chunk.points.items(): if points_key not in spec: continue self.__fill_points(self.batch.points[points_key].data, points.data, spec.points_specs[points_key].roi, points.roi) def __setup_batch(self, batch_spec, chunk): '''Allocate a batch matching the sizes of ``batch_spec``, using ``chunk`` as template.''' batch = Batch() for (array_key, spec) in batch_spec.array_specs.items(): roi = spec.roi voxel_size = self.spec[array_key].voxel_size # get the 'non-spatial' shape of the chunk-batch # and append the shape of the request to it array = chunk.arrays[array_key] shape = array.data.shape[:-roi.dims()] shape += (roi.get_shape() // voxel_size) spec = self.spec[array_key].copy() spec.roi = roi logger.info("allocating array of shape %s for %s", shape, array_key) batch.arrays[array_key] = Array(data=np.zeros(shape), spec=spec) for (points_key, spec) in batch_spec.points_specs.items(): roi = spec.roi spec = self.spec[points_key].copy() spec.roi = roi batch.points[points_key] = Points(data={}, spec=spec) logger.debug("setup batch to fill %s", batch) return batch def __fill(self, a, b, roi_a, roi_b, voxel_size): logger.debug("filling " + str(roi_b) + " into " + str(roi_a)) roi_a = roi_a // voxel_size roi_b = roi_b // voxel_size common_roi = roi_a.intersect(roi_b) if common_roi.empty(): return common_in_a_roi = common_roi - roi_a.get_offset() common_in_b_roi = common_roi - roi_b.get_offset() slices_a = common_in_a_roi.get_bounding_box() slices_b = common_in_b_roi.get_bounding_box() if len(a.shape) > len(slices_a): slices_a = (slice(None), ) * (len(a.shape) - len(slices_a)) + slices_a slices_b = (slice(None), ) * (len(b.shape) - len(slices_b)) + slices_b a[slices_a] = b[slices_b] def __fill_points(self, a, b, roi_a, roi_b): logger.debug("filling points of " + str(roi_b) + " into points of" + str(roi_a)) common_roi = roi_a.intersect(roi_b) if common_roi is None: return # find max point_id in a so far max_point_id = 0 for point_id, point in a.items(): if point_id > max_point_id: max_point_id = point_id for point_id, point in b.items(): if roi_a.contains(Coordinate(point.location)): a[point_id + max_point_id] = point
class Predict(BatchFilter): '''Augments the batch with the predicted affinities. ''' def __init__(self, prototxt, weights, use_gpu=None): for f in [prototxt, weights]: if not os.path.isfile(f): raise RuntimeError("%s does not exist" % f) # start prediction as a producer pool, so that we can gracefully exit if # anything goes wrong self.worker = ProducerPool([lambda gpu=use_gpu: self.__predict(gpu)], queue_size=1) self.batch_in = multiprocessing.Queue(maxsize=1) self.prototxt = prototxt self.weights = weights self.net_initialized = False def setup(self): self.worker.start() def teardown(self): self.worker.stop() def prepare(self, request): # remove request parts that we provide if VolumeTypes.PRED_AFFINITIES in request.volumes: del request.volumes[VolumeTypes.PRED_AFFINITIES] def process(self, batch, request): self.batch_in.put(batch) try: out = self.worker.get() except WorkersDied: raise PredictProcessDied() affs = out.volumes[VolumeTypes.PRED_AFFINITIES] affs.roi = request.volumes[VolumeTypes.PRED_AFFINITIES] affs.resolution = batch.volumes[VolumeTypes.RAW].resolution batch.volumes[VolumeTypes.PRED_AFFINITIES] = affs def __predict(self, use_gpu): if not self.net_initialized: logger.info("Initializing solver...") if use_gpu is not None: logger.debug("Predict process: using GPU %d" % use_gpu) caffe.enumerate_devices(False) caffe.set_devices((use_gpu, )) caffe.set_mode_gpu() caffe.select_device(use_gpu, False) self.net = caffe.Net(self.prototxt, self.weights, caffe.TEST) self.net_io = NetIoWrapper(self.net) self.net_initialized = True start = time.time() batch = self.batch_in.get() fetch_time = time.time() - start self.net_io.set_inputs({ 'data': batch.volumes[VolumeTypes.RAW].data[np.newaxis, np.newaxis, :], }) self.net.forward() output = self.net_io.get_outputs() predict_time = time.time() - start logger.info( "Predict process: time=%f (including %f waiting for batch)" % (predict_time, fetch_time)) assert len( output['aff_pred'].shape ) == 5, "Got affinity prediction with unexpected number of dimensions, should be 1 (direction) + 3 (spatial) + 1 (batch, not used), but is %d" % len( output['aff_pred'].shape) batch.volumes[VolumeTypes.PRED_AFFINITIES] = Volume( output['aff_pred'][0], Roi(), (1, 1, 1)) return batch
class PreCache(BatchFilter): '''Pre-cache repeated equal batch requests. For the first of a series of equal batch request, a set of workers is spawned to pre-cache the batches in parallel processes. This way, subsequent requests can be served quickly. This node only makes sense if: 1. Incoming batch requests are repeatedly the same. 2. There is a source of randomness in upstream nodes. Args: cache_size (int): How many batches to hold at most in the cache. num_workers (int): How many processes to spawn to fill the cache. ''' def __init__(self, cache_size=50, num_workers=20): self.current_request = None self.workers = None self.cache_size = cache_size self.num_workers = num_workers def teardown(self): if self.workers is not None: self.workers.stop() def provide(self, request): timing = Timing(self) timing.start() if request != self.current_request: if self.workers is not None: logger.info( "new request received, stopping current workers...") self.workers.stop() self.current_request = copy.deepcopy(request) logger.info("starting new set of workers...") self.workers = ProducerPool([ lambda i=i: self.__run_worker(i) for i in range(self.num_workers) ], queue_size=self.cache_size) self.workers.start() logger.debug("getting batch from queue...") batch = self.workers.get() timing.stop() batch.profiling_stats.add(timing) return batch def __run_worker(self, i): return self.get_upstream_provider().request_batch(self.current_request)
class GenericPredict(BatchFilter): '''Generic predict node to add predictions of a trained network to each each batch that passes through. This node alone does nothing and should be subclassed for concrete implementations. Args: inputs (dict): Dictionary from names of input layers in the network to :class:``ArrayKey`` or batch attribute name as string. outputs (dict): Dictionary from the names of output layers in the network to :class:``ArrayKey``. New arrays will be generated by this node for each entry (if requested downstream). array_specs (dict, optional): An optional dictionary of :class:`ArrayKey` to :class:`ArraySpec` to set the array specs generated arrays (``outputs`` and ``gradients``). This is useful to set the ``voxel_size``, for example, if they differ from the voxel size of the input arrays. Only fields that are not ``None`` in the given :class:`ArraySpec` will be used. spawn_subprocess (bool, optional): Whether to run ``predict`` in a separate process. Default is false. ''' def __init__(self, inputs, outputs, array_specs=None, spawn_subprocess=False): self.initialized = False self.inputs = inputs self.outputs = outputs self.array_specs = {} if array_specs is None else array_specs self.spawn_subprocess = spawn_subprocess self.timer_start = None def setup(self): # get common voxel size of inputs, or None if they differ common_voxel_size = None for key in self.inputs.values(): if not isinstance(key, ArrayKey): continue voxel_size = self.spec[key].voxel_size if common_voxel_size is None: common_voxel_size = voxel_size elif common_voxel_size != voxel_size: common_voxel_size = None break # announce provided outputs for key in self.outputs.values(): if key in self.array_specs: spec = self.array_specs[key].copy() else: spec = ArraySpec() if spec.voxel_size is None and not spec.nonspatial: assert common_voxel_size is not None, ( "There is no common voxel size of the inputs, and no " "ArraySpec has been given for %s that defines " "voxel_size." % key) spec.voxel_size = common_voxel_size if spec.interpolatable is None: # default for predictions spec.interpolatable = False self.provides(key, spec) if self.spawn_subprocess: # start prediction as a producer pool, so that we can gracefully # exit if anything goes wrong self.worker = ProducerPool([self.__produce_predict_batch], queue_size=1) self.batch_in = multiprocessing.Queue(maxsize=1) self.batch_in_lock = multiprocessing.Lock() self.batch_out_lock = multiprocessing.Lock() self.worker.start() def teardown(self): if self.spawn_subprocess: # signal "stop" self.batch_in.put((None, None)) try: self.worker.get(timeout=2) except NoResult: pass self.worker.stop() else: self.stop() def prepare(self, request): if not self.initialized and not self.spawn_subprocess: self.start() self.initialized = True deps = BatchRequest() for key in self.inputs.values(): deps[key] = request[key] return deps def process(self, batch, request): if self.spawn_subprocess: start = time.time() self.batch_in_lock.acquire() logger.debug("waited for batch in lock for %.3fs", time.time() - start) start = time.time() self.batch_in.put((batch, request)) logger.debug("queued batch for %.3fs", time.time() - start) start = time.time() with self.batch_out_lock: logger.debug("waited for batch out lock for %.3fs", time.time() - start) start = time.time() self.batch_in_lock.release() logger.debug("released batch in lock for %.3fs", time.time() - start) try: start = time.time() out = self.worker.get() logger.debug("retreived batch for %.3fs", time.time() - start) except WorkersDied: raise PredictProcessDied() for array_key in self.outputs.values(): if array_key in request: batch.arrays[array_key] = out.arrays[array_key] else: self.predict(batch, request) def start(self): '''To be implemented in subclasses. This method will be called before the first call to :fun:`predict`, from the same process that :fun:`predict` will be called from. Use this to initialize your model and hardware. ''' pass def predict(self, batch, request): '''To be implemented in subclasses. In this method, an implementation should predict arrays on the given batch. Output arrays should be created according to the given request and added to ``batch``.''' raise NotImplementedError("Class %s does not implement 'predict'" % self.name()) def stop(self): '''To be implemented in subclasses. This method will be called after the last call to :fun:`predict`, from the same process that :fun:`predict` will be called from. Use this to tear down your model and free training hardware. ''' pass def __produce_predict_batch(self): '''Process one batch.''' if not self.initialized: self.start() self.initialized = True if self.timer_start is not None: self.time_out = time.time() - self.timer_start logger.info("batch in: %.3fs, predict: %.3fs, batch out: %.3fs", self.time_in, self.time_predict, self.time_out) self.timer_start = time.time() batch, request = self.batch_in.get() self.time_in = time.time() - self.timer_start # stop signal if batch is None: self.stop() return None self.timer_start = time.time() self.predict(batch, request) self.time_predict = time.time() - self.timer_start self.timer_start = time.time() return batch
class PreCache(BatchFilter): '''Pre-cache repeated equal batch requests. For the first of a series of equal batch request, a set of workers is spawned to pre-cache the batches in parallel processes. This way, subsequent requests can be served quickly. A note on changing the requests sent to `PreCache`. Given requests A and B, if requests are sent in the sequence: A, ..., A, B, A, ..., A, B, A, ... Precache will build a Queue of batches that satisfy A, and handle requests B on demand. This prevents `PreCache` from discarding the queue on every SnapshotRequest. However if B request replace A as the most common request, i.e.: A, A, A, ..., A, B, B, B, ..., `PreCache` will discard the A queue and build a B queue after it has seen more B requests than A requests out of the last 5 requests. This node only makes sense if: 1. Incoming batch requests are repeatedly the same. 2. There is a source of randomness in upstream nodes. Args: cache_size (``int``): How many batches to hold at most in the cache. num_workers (``int``): How many processes to spawn to fill the cache. ''' def __init__(self, cache_size=50, num_workers=20): self.current_request = None self.workers = None self.cache_size = cache_size self.num_workers = num_workers # keep track of recent requests self.last_5 = deque([ None, ] * 5, maxlen=5) def teardown(self): if self.workers is not None: self.workers.stop() def provide(self, request): timing = Timing(self) timing.start() # update recent requests self.last_5.popleft() self.last_5.append(request) if request != self.current_request: current_count = sum([ recent_request == self.current_request for recent_request in self.last_5 ]) new_count = sum( [recent_request == request for recent_request in self.last_5]) if new_count > current_count or self.current_request is None: if self.workers is not None: logger.info( "new request received, stopping current workers...") self.workers.stop() self.current_request = copy.deepcopy(request) logger.info("starting new set of workers...") self.workers = ProducerPool( [ lambda i=i: self.__run_worker(i) for i in range(self.num_workers) ], queue_size=self.cache_size, ) self.workers.start() logger.debug("getting batch from queue...") batch = self.workers.get() timing.stop() batch.profiling_stats.add(timing) else: logger.debug("Resolving new request sequentially") batch = self.get_upstream_provider().request_batch(request) timing.stop() batch.profiling_stats.add(timing) else: logger.debug("getting batch from queue...") batch = self.workers.get() timing.stop() batch.profiling_stats.add(timing) return batch def __run_worker(self, i): return self.get_upstream_provider().request_batch(self.current_request)