def _scale_block(block_id, blocking, ds_in, ds_bd, ds_out, offset, erode_by, erode_3d, channel): fu.log("start processing block %i" % block_id) # load the block with halo set to 'erode_by' halo = compute_halo(erode_by, erode_3d) block = blocking.getBlockWithHalo(block_id, halo) in_bb = vu.block_to_bb(block.outerBlock) out_bb = vu.block_to_bb(block.innerBlock) local_bb = vu.block_to_bb(block.innerBlockLocal) obj = ds_in[in_bb] # don't scale if block is empty if np.sum(obj != 0) == 0: fu.log_block_success(block_id) return # load boundary map and fit obj to it if ds_bd.ndim == 4: in_bb = (slice(channel, channel + 1),) + in_bb hmap = ds_bd[in_bb].squeeze() obj, _ = vu.fit_to_hmap(obj, hmap, erode_by, erode_3d) obj = obj[local_bb] fg_mask = obj != 0 obj[fg_mask] += offset # load previous output volume, insert obj into it and save again out = ds_out[out_bb] out[fg_mask] += obj[fg_mask] ds_out[out_bb] = out # log block success fu.log_block_success(block_id)
def _threshold_block(block_id, blocking, ds_in, ds_out, threshold, threshold_mode, channel, sigma): fu.log("start processing block %i" % block_id) block = blocking.getBlock(block_id) bb = vu.block_to_bb(block) bb = vu.block_to_bb(block) if channel is None: input_ = ds_in[bb] else: channel_ = [channel] if isinstance(channel, int) else channel in_shape = (len(channel_), ) + tuple(b.stop - b.start for b in bb) input_ = np.zeros(in_shape, dtype=ds_in.dtype) for chan_id, chan in enumerate(channel_): bb_inp = (slice(chan, chan + 1), ) + bb input_[chan_id] = ds_in[bb_inp].squeeze() input_ = np.mean(input_, axis=0) input_ = vu.normalize(input_) if sigma > 0: input_ = vu.apply_filter(input_, 'gaussianSmoothing', sigma) input_ = vu.normalize(input_) if threshold_mode == 'greater': input_ = input_ > threshold elif threshold_mode == 'less': input_ = input_ < threshold elif threshold_mode == 'equal': input_ = input_ == threshold else: raise RuntimeError("Thresholding Mode %s not supported" % threshold_mode) ds_out[bb] = input_.astype('uint8') fu.log_block_success(block_id)
def _embedding_distances_block(block_id, blocking, input_datasets, ds, offsets, norm): fu.log("start processing block %i" % block_id) halo = np.max(np.abs(offsets), axis=0) block = blocking.getBlockWithHalo(block_id, halo.tolist()) outer_bb = vu.block_to_bb(block.outerBlock) inner_bb = (slice(None),) + vu.block_to_bb(block.innerBlock) local_bb = (slice(None),) + vu.block_to_bb(block.innerBlockLocal) bshape = tuple(ob.stop - ob.start for ob in outer_bb) # TODO support multi-channel input data n_inchannels = len(input_datasets) in_shape = (n_inchannels,) + bshape in_data = np.zeros(in_shape, dtype='float32') for chan, inds in enumerate(input_datasets): in_data[chan] = inds[outer_bb] # TODO support thresholding the embedding before distance caclulation distance = compute_embedding_distances(in_data, offsets, norm) ds[inner_bb] = distance[local_bb] fu.log_block_success(block_id)
def _upsample_block(block_id, blocking, halo, ds_in, ds_out, ds_skel, scale_factor, pixel_pitch): fu.log("start processing block %i" % block_id) if halo is None: block = blocking.getBlock(block_id) inner_bb = outer_bb = vu.block_to_bb(block) local_bb = np.s_[:] else: block = blocking.getBlockWithHalo(block_id, halo) inner_bb = vu.block_to_bb(block.innerBlock) outer_bb = vu.block_to_bb(block.outerBlock) local_bb = vu.block_to_bb(block.innerBlockLocal) # load the segmentation seg = ds_in[outer_bb] skels_out = np.zeros_like(seg, dtype='uint64') # find the bounding box for downsampled skeletons skel_bb = tuple(slice(b.start // scale, int(ceil(b.stop / scale))) for b, scale in zip(outer_bb, scale_factor)) skels = ds_skel[skel_bb] # get ids of skeletons in this block (excluding zeros) ids = np.unique(skels)[1:] for skel_id in ids: upsampled_skel = _upsample_skeleton(skel_id, seg, skels, scale_factor) skels_out += upsampled_skel ds_skel[inner_bb] = skels_out[local_bb] # log block success fu.log_block_success(block_id)
def _mws_block_pass1(block_id, blocking, ds_in, ds_out, mask, offsets, strides, randomize_strides, halo, noise_level, max_block_id, tmp_folder): fu.log("(Pass1) start processing block %i" % block_id) block = blocking.getBlockWithHalo(block_id, halo) in_bb = vu.block_to_bb(block.outerBlock) if mask is None: bb_mask = None else: bb_mask = mask[in_bb].astype('bool') if np.sum(bb_mask) == 0: fu.log_block_success(block_id) return aff_bb = (slice(None), ) + in_bb affs = vu.normalize(ds_in[aff_bb]) seg = mutex_watershed(affs, offsets, strides=strides, mask=bb_mask, randomize_strides=randomize_strides, noise_level=noise_level) out_bb = vu.block_to_bb(block.innerBlock) local_bb = vu.block_to_bb(block.innerBlockLocal) seg = seg[local_bb] # FIXME once vigra supports uint64 or we implement our own ... # seg = vigra.analysis.labelVolumeWithBackground(seg) # offset with lowest block coordinate offset_id = block_id * np.prod(blocking.blockShape) vigra.analysis.relabelConsecutive(seg, start_label=offset_id, keep_zeros=True, out=seg) ds_out[out_bb] = seg # get the state of the segmentation of this block grid_graph = compute_grid_graph(seg.shape, mask=bb_mask) affs = affs[(slice(None), ) + local_bb] # FIXME this function yields incorrect uv-ids ! state_uvs, state_weights, state_attractive = grid_graph.compute_state_for_segmentation( affs, seg, offsets, n_attractive_channels=3, ignore_label=True) # serialize the states save_path = os.path.join(tmp_folder, 'seg_state_block%i.h5' % block_id) with vu.file_reader(save_path) as f: f.create_dataset('edges', data=state_uvs) f.create_dataset('weights', data=state_weights) f.create_dataset('attractive_edge_mask', data=state_attractive) # write max-id for the last block if block_id == max_block_id: _write_nlabels(ds_out, seg) # log block success fu.log_block_success(block_id)
def _minfilter_block(block_id, blocking, halo, ds_in, ds_out, filter_shape): fu.log("start processing block %i" % block_id) block = blocking.getBlockWithHalo(block_id, halo) outer_roi = vu.block_to_bb(block.outerBlock) inner_roi = vu.block_to_bb(block.innerBlock) local_roi = vu.block_to_bb(block.innerBlockLocal) mask = ds_in[outer_roi] min_filter_mask = minimum_filter(mask, size=filter_shape) ds_out[inner_roi] = min_filter_mask[local_roi] fu.log_block_success(block_id)
def _insert_affinities_block(block_id, blocking, ds_in, ds_out, objects, offsets, erode_by, erode_3d, zero_objects_list, dilate_by): fu.log("start processing block %i" % block_id) halo = np.max(np.abs(offsets), axis=0).tolist() if erode_3d: halo = [max(ha, erode_by) for axis, ha in enumerate(halo)] else: halo = [ha if axis == 0 else max(ha, erode_by) for axis, ha in enumerate(halo)] block = blocking.getBlockWithHalo(block_id, halo) outer_bb = vu.block_to_bb(block.outerBlock) inner_bb = (slice(None),) + vu.block_to_bb(block.innerBlock) local_bb = (slice(None),) + vu.block_to_bb(block.innerBlockLocal) # load objects and check if we have any in this block # catch run-time error for singleton dimension try: objs = objects[outer_bb] obj_sum = objs.sum() except RuntimeError: obj_sum = 0 # if we don't have objs, just copy the affinities if obj_sum == 0: ds_out[inner_bb] = ds_in[inner_bb] fu.log_block_success(block_id) return outer_bb = (slice(None),) + outer_bb affs = ds_in[outer_bb] # fit object to hmap derived from affinities via shrinking and watershed if erode_by > 0: objs, obj_ids = vu.fit_to_hmap(objs, affs[0].copy(), erode_by, erode_3d) else: obj_ids = np.unique(objs) if 0 in obj_ids: obj_ids = obj_ids[1:] # insert affinities to objs into the original affinities affs = _insert_affinities(affs, objs.astype('uint64'), offsets, dilate_by) # zero out some affs if necessary if zero_objects_list is not None: zero_ids = obj_ids[np.in1d(obj_ids, zero_objects_list)] if zero_ids.size: for zero_id in zero_ids: # erode the mask to avoid ugly boundary artifacts zero_mask = binary_erosion(objs == zero_id, iterations=4) affs[:, zero_mask] = 0 ds_out[inner_bb] = affs[local_bb] fu.log_block_success(block_id)
def _get_bbs(blocking, block_id, halo): if halo is None: block = blocking.getBlock(block_id) in_bb = out_bb = vu.block_to_bb(block) local_bb = np.s_[:] else: block = blocking.getBlockWithHalo(block_id, halo) in_bb = vu.block_to_bb(block.outerBlock) out_bb = vu.block_to_bb(block.innerBlock) local_bb = vu.block_to_bb(block.innerBlockLocal) return in_bb, out_bb, local_bb
def _apply_filter(blocking, block_id, ds_in, ds_out, halo, filter_name, sigma, apply_in_2d): fu.log("start processing block %i" % block_id) block = blocking.getBlockWithHalo(block_id, halo) bb_in = vu.block_to_bb(block.outerBlock) input_ = vu.normalize(ds_in[bb_in]) response = vu.apply_filter(input_, filter_name, sigma, apply_in_2d) bb_out = vu.block_to_bb(block.innerBlock) inner_bb = vu.block_to_bb(block.innerBlockLocal) ds_out[bb_out] = response[inner_bb] fu.log_block_success(block_id)
def _upsample_block(blocking, block_id, ds_in, ds_out, scale_factor, sampler): fu.log("start processing block %i" % block_id) # load the block (output dataset / upscaled) coordinates block = blocking.getBlock(block_id) local_bb = np.s_[:] in_bb = vu.block_to_bb(block) out_bb = vu.block_to_bb(block) out_shape = block.shape # upsample the input bounding box if isinstance(scale_factor, int): in_bb = tuple( slice(int(ib.start // scale_factor), min(int(ceil(ib.stop / scale_factor)), sh)) for ib, sh in zip(in_bb, ds_in.shape)) else: in_bb = tuple( slice(int(ib.start // sf), min(int(ceil(ib.stop // sf)), sh)) for ib, sf, sh in zip(in_bb, scale_factor, ds_in.shape)) x = ds_in[in_bb] # don't sample empty blocks if np.sum(x != 0) == 0: fu.log_block_success(block_id) return dtype = x.dtype if np.dtype(dtype) != np.dtype('float32'): x = x.astype('float32') if isinstance(scale_factor, int): out = sampler(x, shape=out_shape) else: out = np.zeros(out_shape, dtype='float32') for z in range(out_shape[0]): out[z] = sampler(x[z], shape=out_shape[1:]) if np.dtype(dtype) in (np.dtype('uint8'), np.dtype('uint16')): max_val = np.iinfo(np.dtype(dtype)).max np.clip(out, 0, max_val, out=out) np.round(out, out=out) try: ds_out[out_bb] = out[local_bb].astype(dtype) except IndexError as e: raise (IndexError("%s, %s, %s" % (str(out_bb), str(local_bb), str(out.shape)))) # log block success fu.log_block_success(block_id)
def _get_bbs(blocking, block_id, config): # read the input config halo = list(config.get('halo', [0, 0, 0])) if sum(halo) > 0: block = blocking.getBlockWithHalo(block_id, halo) input_bb = vu.block_to_bb(block.outerBlock) output_bb = vu.block_to_bb(block.innerBlock) inner_bb = vu.block_to_bb(block.innerBlockLocal) else: block = blocking.getBlock(block_id) input_bb = output_bb = vu.block_to_bb(block) inner_bb = np.s_[:] return input_bb, inner_bb, output_bb
def _cc_block_with_mask(block_id, blocking, ds_in, ds_out, threshold, threshold_mode, mask, channel, sigma): fu.log("start processing block %i" % block_id) block = blocking.getBlock(block_id) bb = vu.block_to_bb(block) # get the mask and check if we have any pixels in_mask = mask[bb].astype('bool') if np.sum(in_mask) == 0: fu.log_block_success(block_id) return 0 bb = vu.block_to_bb(block) if channel is None: input_ = ds_in[bb] else: channel_ = [channel] if isinstance(channel, int) else channel in_shape = (len(channel_),) + tuple(b.stop - b.start for b in bb) input_ = np.zeros(in_shape, dtype=ds_in.dtype) for chan_id, chan in enumerate(channel_): bb_inp = (slice(chan, chan + 1),) + bb input_[chan_id] = ds_in[bb_inp].squeeze() input_ = np.mean(input_, axis=0) input_ = vu.normalize(input_) if sigma > 0: input_ = vu.apply_filter(input_, 'gaussianSmoothing', sigma) input_ = vu.normalize(input_) if threshold_mode == 'greater': input_ = input_ > threshold elif threshold_mode == 'less': input_ = input_ < threshold elif threshold_mode == 'equal': input_ = input_ == threshold else: raise RuntimeError("Thresholding Mode %s not supported" % threshold_mode) input_[np.logical_not(in_mask)] = 0 if np.sum(input_) == 0: fu.log_block_success(block_id) return 0 components = label(input_) ds_out[bb] = components fu.log_block_success(block_id) return int(components.max()) + 1
def _transform_block(ds_in, ds_out, transformation, blocking, block_id, mask=None): fu.log("start processing block %i" % block_id) block = blocking.getBlock(block_id) bb = vu.block_to_bb(block) if mask is not None: bb_mask = mask[bb].astype('bool') if bb_mask.sum() == 0: fu.log_block_success(block_id) return else: bb_mask = None data = ds_in[bb] if len(transformation) == 2: data = _transform_data(data, transformation['a'], transformation['b'], bb_mask) else: z_offset = block.begin[0] for z in range(data.shape[0]): trafo = transformation[z + z_offset] data[z] = _transform_data(data[z], trafo['a'], trafo['b'], bb_mask[z]) ds_out[bb] = data fu.log_block_success(block_id)
def _morphology_for_block(block_id, blocking, ds_in, output_path, output_key): fu.log("start processing block %i" % block_id) # read labels and input in this block block = blocking.getBlock(block_id) bb = vu.block_to_bb(block) seg = ds_in[bb] # check if segmentation block is empty if seg.sum() == 0: fu.log("block %i is empty" % block_id) fu.log_block_success(block_id) return chunk_id = tuple(beg // ch for beg, ch in zip(block.begin, blocking.blockShape)) # extract and save simple morphology: # - size of segments 1 # - center of mass of segments 2:5 # - minimum coordinates of segments 5:8 # - maximum coordinates of segments 8:11 # [:,0] is the label id ndist.computeAndSerializeMorphology(seg, block.begin, output_path, output_key, chunk_id) fu.log_block_success(block_id)
def uniques_in_block(block_id, blocking, ds, return_counts): fu.log("start processing block %i" % block_id) block = blocking.getBlock(block_id) bb = vu.block_to_bb(block) shape = tuple(b.stop - b.start for b in bb) labels = ds[bb] empty_labels = labels.sum() == 0 if empty_labels: fu.log_block_success(block_id) if return_counts: return np.array([0], dtype='uint64'), np.array([int(np.prod(shape))], dtype='int64') return np.array([0], dtype='uint64') if return_counts: uniques, counts = np.unique(labels, return_counts=True) fu.log_block_success(block_id) return uniques, counts else: uniques = np.unique(labels) fu.log_block_success(block_id) return uniques
def _cc_block(block_id, blocking, ds_in, ds_out, threshold, threshold_mode, channel): fu.log("start processing block %i" % block_id) block = blocking.getBlock(block_id) bb = vu.block_to_bb(block) if channel is None: input_ = input_[bb] else: block_shape = tuple(b.stop - b.start for b in bb) input_ = np.zeros(block_shape, dtype=ds_in.dtype) channel_ = [channel] if isinstance(channel, int) else channel for chan in channel_: bb_inp = (slice(chan, chan + 1), ) + bb input_ += ds_in[bb_inp].squeeze() if threshold_mode == 'greater': input_ = input_ > threshold elif threshold_mode == 'less': input_ = input_ < threshold elif threshold_mode == 'equal': input_ = input_ == threshold else: raise RuntimeError("Thresholding Mode %s not supported" % threshold_mode) if np.sum(input_) == 0: fu.log_block_success(block_id) return 0 components = label(input_) ds_out[bb] = components fu.log_block_success(block_id) return int(components.max()) + 1
def _create_multiset_block(blocking, block_id, ds_in, ds_out): fu.log("start processing block %i" % block_id) block = blocking.getBlock(block_id) bb = vu.block_to_bb(block) labels = ds_in[bb] # we can't encode the paintra ignore label paintera_ignore_label = 18446744073709551615 pignore_mask = labels == paintera_ignore_label if pignore_mask.sum() > 0: labels[pignore_mask] = 0 if labels.sum() == 0: fu.log("block %i is empty" % block_id) fu.log_block_success(block_id) return # compute multiset from input labels multiset = create_multiset_from_labels(labels) ser = serialize_multiset(multiset) chunk_id = tuple(bs // ch for bs, ch in zip(block.begin, ds_out.chunks)) ds_out.write_chunk(chunk_id, ser, True) fu.log_block_success(block_id)
def _labels_for_block(block_id, blocking, ds_ws, out_path, labels, ignore_label): fu.log("start processing block %i" % block_id) # read labels and input in this block block = blocking.getBlock(block_id) bb = vu.block_to_bb(block) ws = ds_ws[bb] # check if watershed block is empty if ws.sum() == 0: fu.log("block %i is empty" % block_id) fu.log_block_success(block_id) return # serialize the overlaps labs = labels[bb].astype('uint64') # check if label block is empty: if ignore_label is not None: if np.sum(labs == ignore_label) == labs.size: fu.log("labels of block %i is empty" % block_id) fu.log_block_success(block_id) return chunk_id = tuple(beg // ch for beg, ch in zip(block.begin, blocking.blockShape)) ndist.computeAndSerializeLabelOverlaps( ws, labs, out_path, chunk_id, withIgnoreLabel=False if ignore_label is None else True, ignoreLabel=0 if ignore_label is None else ignore_label) fu.log_block_success(block_id)
def write_output(output): out_shape = output.shape if len(out_shape) == 4: assert out_shape[1:] == block_shape assert out_shape[0] >= n_channels else: assert out_shape == block_shape bb = vu.block_to_bb(blocking.getBlock(block_id)) # adjust bounding box to multi-channel output if output.ndim == 4: output = output[:n_channels] bb = (slice(0, n_channels), ) + bb # check if we need to crop the output actual_shape = tuple(b.stop - b.start for b in bb) if actual_shape != block_shape: block_bb = tuple( slice(0, bsh - ash) for bsh, ash in zip(block_shape, actual_shape)) if output.ndim == 4: block_bb = (slice(None), ) + block_bb output = output[block_bb] # cast to uint8 if necessary if dtype == 'uint8': output = _to_uint8(output) ds_out[bb] = output return block_id
def load_input(block_id): if halo is None: block = blocking.getBlock(block_id) else: block = blocking.getBlockWithHalo(block_id, halo).outerBlock bb = vu.block_to_bb(block) return _load_input(ds_in, bb, block_shape)
def check_block(block_id, blocking, ds, ds_nodes): block = blocking.getBlock(block_id) bb = vu.block_to_bb(block) seg = ds[bb] nodes_seg = np.unique(seg) chunks = ds_nodes.chunks chunk_id = (b.start // ch for b, ch in zip(bb, chunks)) nodes = ds_nodes.read_chunk(chunk_id) if nodes is None: un_nodes = np.unique(nodes_seg) if len(un_nodes) != 1 or un_nodes[0] != 0: return block_id same_len = len(nodes_seg) == len(nodes) if not same_len: return block_id same_nodes = np.allclose(nodes, nodes_seg) if not same_nodes: return block_id return None
def _filter_block_inplace(blocking, block_id, ds, filter_ids): fu.log("start processing block %i" % block_id) # read labels and input in this block block = blocking.getBlock(block_id) bb = vu.block_to_bb(block) seg = ds[bb] # check if segmentation block is empty if seg.sum() == 0: fu.log_block_success(block_id) return # check for filter_ids filter_mask = np.in1d(seg, filter_ids).reshape(seg.shape) # check if we filter any ids if filter_mask.sum() == 0: fu.log_block_success(block_id) return seg[filter_mask] = 0 ds[bb] = seg fu.log_block_success(block_id)
def debug_vol(): path = '../data.n5' key = 'volumes/cilia/segmentation' f = open_file(path) ds = f[key] shape = ds.shape block_shape = ds.chunks roi_begin = [7216, 12288, 7488] roi_end = [8640, 19040, 11392] blocks, blocking = blocks_in_volume(shape, block_shape, roi_begin, roi_end, return_blocking=True) print("Have", len(blocks), "blocks in roi") # check reading all blocks for block_id in blocks: print("Check block", block_id) block = blocking.getBlock(block_id) bb = block_to_bb(block) d = ds[bb] print("Have block", block_id) print("All checks passsed")
def _copy_blocks(ds_in, ds_out, blocking, block_list): for block_id in block_list: fu.log("start processing block %i" % block_id) block = blocking.getBlock(block_id) bb = vu.block_to_bb(block) data = ds_in[bb] ds_out[bb] = data
def _failing_block(block_id, blocking, ds, n_retries): # fail for odd block ids if we are in the first try if n_retries == 0 and block_id % 2 == 1: raise RuntimeError("Fail") bb = vu.block_to_bb(blocking.getBlock(block_id)) ds[bb] = 1 fu.log_block_success(block_id)
def _write_block_res(ds_in, ds_out, block_id, blocking, block_res): fu.log("start processing block %i" % block_id) block = blocking.getBlock(block_id) bb = vu.block_to_bb(block) ws = ds_in[bb] seg = nt.takeDict(block_res, ws) ds_out[bb] = seg fu.log_block_success(block_id)
def stack_block(block_id, blocking, ds_raw, ds_pred, ds_out, dtype): fu.log("start processing block %i" % block_id) bb = vu.block_to_bb(blocking.getBlock(block_id)) raw = cast(ds_raw[bb], dtype) bb = (slice(None), ) + bb pred = cast(ds_pred[bb], dtype) out = np.concatenate([raw[None], pred], axis=0) ds_out[bb] = out fu.log_block_success(block_id)
def _predict_and_serialize_block(block_id, blocking, input_path, input_key, output_prefix, halo, ilastik_folder, ilastik_project, ds_out): fu.log("Start processing block %i" % block_id) block = blocking.getBlockWithHalo(block_id, halo) # # check if the input block is empty (only need to check channel 0) # with vu.file_reader(input_path) as f: # bb = vu.block_to_bb(block.innerBlock) # ds_in = f[input_key] # if ds_in.ndim == 4: # (slice(0, 1),) + bb # inp_ = ds_in[bb] # if np.sum(inp_) == 0: # fu.log_block_success(block_id) # return _predict_block_impl(block_id, block.outerBlock, input_path, input_key, output_prefix, ilastik_folder, ilastik_project) bb = vu.block_to_bb(block.innerBlock) inner_bb = vu.block_to_bb(block.innerBlockLocal) path = '%s_block%i.h5' % (output_prefix, block_id) n_channels = ds_out.shape[0] with vu.file_reader(path, 'r') as f: pred = f['exported_data'][:].squeeze() assert pred.ndim in (3, 4), '%i' % pred.ndim if pred.ndim == 4: bb = (slice(None), ) + bb inner_bb = (slice(None), ) + inner_bb # check if we need to transpose if pred.shape[-1] == n_channels: pred = pred.transpose((3, 0, 1, 2)) else: assert pred.shape[0] == n_channels,\ "Expected first axis to be channel axis with %i channels, but got shape %s" % (n_channels, str(pred.shape)) pred = pred[inner_bb] pred = _to_dtype(pred, ds_out.dtype) ds_out[bb] = pred # os.remove(path) fu.log_block_success(block_id)
def _insert_affinities_block(block_id, blocking, ds, objects, offsets): fu.log("start processing block %i" % block_id) halo = np.max(np.abs(offsets), axis=0) block = blocking.getBlockWithHalo(block_id, halo.tolist()) outer_bb = vu.block_to_bb(block.outerBlock) inner_bb = (slice(None), ) + vu.block_to_bb(block.innerBlock) local_bb = (slice(None), ) + vu.block_to_bb(block.innerBlockLocal) # load objects and check if we have any in this block objs = objects[outer_bb] if objs.sum() == 0: fu.log_block_success(block_id) return affs, _ = compute_affinities(objs, offsets) affs = cast(1. - affs, ds.dtype) ds[inner_bb] += affs[local_bb] fu.log_block_success(block_id)
def write_output(inputs): block_id, output = inputs if output is None: return block_id if isinstance(output, (list, tuple)): output = output[0] out_shape = output.shape if len(out_shape) == 3: assert len(ds_out) == 1 bb = vu.block_to_bb(blocking.getBlock(block_id)) # check if we need to crop the output # NOTE this is not cropping the halo, which is done beforehand in the # predictor already, but to crop overhanging chunks at the end of th dataset actual_shape = [b.stop - b.start for b in bb] if actual_shape != block_shape: block_bb = tuple(slice(0, ash) for ash in actual_shape) if output.ndim == 4: block_bb = (slice(None),) + block_bb output = output[block_bb] # write the output to our output dataset(s) for dso, chann_mapping in zip(ds_out, channel_mapping): chan_start, chan_stop = chann_mapping if dso.ndim == 3: if channel_accumulation is None: assert chan_stop - chan_start == 1 out_bb = bb else: assert output.ndim == 4 assert chan_stop - chan_start == dso.shape[0] out_bb = (slice(None),) + bb if output.ndim == 4: channel_output = output[chan_start:chan_stop] if channel_output.shape[0] == 1: channel_output = channel_output[0] else: channel_output = output # apply channel accumulation if specified if channel_accumulation is not None and channel_output.ndim == 4: channel_output = channel_accumulation(channel_output, axis=0) # cast to uint8 if necessary if dtype == 'uint8': channel_output = _to_uint8(channel_output) dso[out_bb] = channel_output return block_id