def overlap(): run_datetime = datetime.datetime.now( pytz.timezone('US/Eastern')).strftime('%Y%m%dT%H%M%S.%f%z') temp_dir = os.path.join(config.temp_path, run_datetime) os.makedirs(temp_dir) fragments = daisy.open_ds(config.fragments_zarr, config.fragments_ds) groundtruth = daisy.open_ds(config.groundtruth_zarr, config.groundtruth_ds) total_roi = daisy.Roi(offset=config.roi_offset, shape=config.roi_shape) start = time.time() daisy.run_blockwise( total_roi=total_roi, read_roi=daisy.Roi(offset=(0, 0, 0), shape=config.block_size), write_roi=daisy.Roi(offset=(0, 0, 0), shape=config.block_size), process_function=lambda block: overlap_in_block( block=block, fragments=fragments, groundtruth=groundtruth, tmp_path=temp_dir), fit='shrink', num_workers=config.num_workers, read_write_conflict=False, max_retries=1) logger.info( f"Blockwise overlapping of fragments and ground truth in {time.time() - start:.3f}s") logger.debug( f"num blocks: {np.prod(np.ceil(np.array(config.roi_shape) / np.array(config.block_size)))}") frag_to_gt = overlap_reduce(tmp_path=temp_dir) pickle.dump(frag_to_gt, open(os.path.join( temp_dir, 'frag_to_gt.pickle'), 'wb')) return frag_to_gt
def open_ds(f, ds): try: data = daisy.open_ds(f, ds) except: data = daisy.open_ds(f, ds + '/s0') return data
def __init__(self, credentials, db_name, collection_name, predict_id, dataset, dx,dy,dz, n_gpus, gpu_id, n_cpus, cpu_id, transform=None): self.db = BrainDb(credentials, db_name, collection_name, predict_id) start = time.time() log.info("Partition DB to workers...") self.cursor = self.get_cursor(n_gpus, n_cpus, gpu_id, cpu_id) log.info(f"...took {time.time() - start} seconds") self.dataset = dataset self.container = dataset.container self.dset = dataset.dataset self.voxel_size = dataset.voxel_size self.data = daisy.open_ds(self.container, self.dset) self.transform = transform if dx % 8 != 0 or dy % 8 != 0 or dz % 80 != 0: raise ValueError("Roi size must be divisible by two") self.dx = dx self.dy = dy self.dz = dz
def build_source(self): data = daisy.open_ds(filename, key) if self.time_window is None: source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) else: offs = list(data.roi.get_offset()) offs[1] += self.time_window[0] sh = list(data.roi.get_shape()) offs[1] = self.time_window[1] - self.time_window[0] source_roi = gp.Roi(tuple(offs), tuple(sh)) voxel_size = gp.Coordinate(data.voxel_size) return gp.ZarrSource(filename, { self.raw_0: key, self.raw_1: key }, array_specs={ self.raw_0: gp.ArraySpec( roi=source_roi, voxel_size=voxel_size, interpolatable=True), self.raw_1: gp.ArraySpec( roi=source_roi, voxel_size=voxel_size, interpolatable=True) })
def open_ds_wrapper(path, ds_name): """Returns None if ds_name does not exists """ try: return daisy.open_ds(path, ds_name) except KeyError: print('dataset %s could not be loaded' % ds_name) return None
def main(config_path): # S2: 3280-21435, 2840 22160, 70-1170 # S4: 1000-14000, 1000 14000, 35-640 # S5: 500-7000, 500 7000, 15-300 with open(config_path, 'r') as f: configs = json.load(f) section = configs['section'] raw_file = configs["zarr_file"] raw_ds = configs['raw_ds'] now = datetime.now().strftime("%m%d.%H.%M.%S") output = configs[ 'output_folder'] # f'/n/groups/htem/Segmentation/xg76/mipmap/mipmap/{now}_{section}_{z_start}_{z_end}' coord_begin = configs['coord_begin'] coord_end = configs['coord_end'] z_range = range(coord_begin[2], coord_end[2], configs['interval']) scale_list = [configs['scale']] os.makedirs(output, exist_ok=True) print(f'mipmapping: {configs["section"]}') cutout_ds = daisy.open_ds(raw_file, raw_ds) for z in z_range: coord_begin[2] = z coord_end[2] = z + 1 print(f'coord begin: {coord_begin}') raw_img_list = get_ndarray_img_from_zarr(coord_begin=coord_begin, coord_end=coord_end, cutout_ds=cutout_ds) img = raw_img_list[0] # write_to_tiff(img, os.path.join(output, f'{z}_origin.tif')) mipmap = down_sampling_img(img, scale_list)[0] write_to_tiff( mipmap, os.path.join(output, f'{section}_{z}_{scale_list[0]}_mipmap.tiff'))
def overlay_segmentation(db_host, db_name, roi_offset, roi_size, selected_attr, solved_attr, edge_collection, segmentation_container, segmentation_dataset, segmentation_number, voxel_size=(40, 4, 4)): graph_provider = MongoDbGraphProvider(db_name, db_host, directed=False, position_attribute=['z', 'y', 'x'], edges_collection=edge_collection) graph_roi = daisy.Roi(roi_offset, roi_size) segmentation = daisy.open_ds(segmentation_container, segmentation_dataset) intersection_roi = segmentation.roi.intersect(graph_roi).snap_to_grid( voxel_size) nx_graph = graph_provider.get_graph(intersection_roi, nodes_filter={selected_attr: True}, edges_filter={selected_attr: True}) for node_id, data in nx_graph.nodes(data=True): node_position = daisy.Coordinate((data["z"], data["y"]), data["x"]) nx_graph.nodes[node_id]["segmentation_{}".format( segmentation_number)] = segmentation[node_position] graph_provider.write_nodes(intersection_roi)
def get_ndarray_img_from_zarr(raw_file=None, raw_ds=None, coord_begin=None, coord_end=None, cutout_ds=None): """ Retrieve image from zarr file. Return list of images """ if raw_file is None and raw_ds is None and cutout_ds is None: raise ValueError('No raw file is found') elif raw_file is not None and raw_ds is not None and cutout_ds is None: cutout_ds = daisy.open_ds(raw_file, raw_ds) else: print('Using passed in cutout_ds') print(f'Voxel size: {cutout_ds.voxel_size}') roi = None if coord_begin is not None and coord_end is not None: voxel_size = cutout_ds.voxel_size coord_begin = Coordinate(np.flip(np.array(coord_begin))) * voxel_size coord_end = Coordinate(np.flip(np.array(coord_end))) * voxel_size roi_offset = coord_begin roi_shape = coord_end - coord_begin roi = Roi(roi_offset, roi_shape) print(f"Getting data from zarr file... ROI: {roi}") ndarray = cutout_ds.to_ndarray(roi=roi) return ndarray
def add_snapshot( context, snapshot_file, name_prefix="", volume_paths=["volumes"], graph_paths=["points"], graph_node_attrs=None, graph_edge_attrs=None, # mst=["embedding", "fg_maxima"], mst=None, roi=None, ): f = zarr.open(str(snapshot_file.absolute()), "r") with f as dataset: volumes = [] for volume in volume_paths: volumes += get_volumes(dataset, volume) v = None for volume in volumes: v = daisy.open_ds(str(snapshot_file.absolute()), f"{volume}") if roi is not None: v = v.intersect(roi) if v.dtype == np.int64: v.materialize() v.data = v.data.astype(np.uint64) add_layer(context, v, f"{name_prefix}_{volume}", visible=False)
def get_local_segmentation(self, roi: daisy.Roi, threshold: float): # open fragments fragments = daisy.open_ds(self.fragments_file, self.fragments_dataset) # open RAG DB rag_provider = MongoDbRagProvider( self.fragments_db, host=self.fragments_host, mode="r", edges_collection=self.edges_collection, ) segmentation = fragments[roi] segmentation.materialize() ids = [int(id) for id in list(np.unique(segmentation.data))] rag = rag_provider.read_rag(ids) if len(rag.nodes()) == 0: raise Exception('RAG is empty') components = rag.get_connected_components(threshold) values_map = np.array( [[fragment, i] for i in range(1, len(components) + 1) for fragment in components[i - 1]], dtype=np.uint64, ) old_values = values_map[:, 0] new_values = values_map[:, 1] replace_values(segmentation.data, old_values, new_values, inplace=True) return segmentation
def get_source_roi(data_dir, sample): sample_path = os.path.join(data_dir, sample) # get absolute paths if os.path.isfile(sample_path) or sample.endswith((".zarr", ".n5")): sample_dir = os.path.abspath( os.path.join(data_dir, os.path.dirname(sample))) else: sample_dir = os.path.abspath(os.path.join(data_dir, sample)) if os.path.isfile(os.path.join(sample_dir, 'attributes.json')): with open(os.path.join(sample_dir, 'attributes.json'), 'r') as f: attributes = json.load(f) voxel_size = daisy.Coordinate(attributes['resolution']) shape = daisy.Coordinate(attributes['shape']) offset = daisy.Coordinate(attributes['offset']) source_roi = daisy.Roi(offset, shape * voxel_size) return voxel_size, source_roi elif os.path.isdir(os.path.join(sample_dir, 'timelapse.zarr')): a = daisy.open_ds(os.path.join(sample_dir, 'timelapse.zarr'), 'volumes/raw') return a.voxel_size, a.roi else: raise RuntimeError( "Can't find attributes.json or timelapse.zarr in %s" % sample_dir)
def check_block(block): _logger.debug("Checking if block %s is complete...", block.write_roi) ds = daisy.open_ds(out_container, affs_dataset) center_values = ds[block.write_roi.get_center()] s = np.sum(center_values) _logger.debug("Sum of center values in %s is %f", block.write_roi, s) return s != 0
def get_array(data_container, data_set, begin, end, context=(0, 0, 0)): context = np.array(context) roi = daisy.Roi(begin - context / 2, end - begin + context) dataset = daisy.open_ds(data_container, data_set) data_array = dataset[roi].to_ndarray() return data_array
def query_local_segmentation(self, roi, threshold): # open fragments fragments = daisy.open_ds(self.fragments_file, self.fragments_dataset) # open RAG DB rag_provider = lsd.persistence.MongoDbRagProvider( self.frag_db_name, host=self.frag_db_host, mode="r", edges_collection=self.edges_collection, ) segmentation = fragments[roi] segmentation.materialize() rag = rag_provider[roi] if len(rag.nodes()) == 0: return segmentation components = rag.get_connected_components(threshold) values_map = np.array( [[fragment, i] for i in range(len(components)) for fragment in components[i]], dtype=np.uint64, ) old_values = values_map[:, 0] new_values = values_map[:, 1] funlib.segment.arrays.replace_values(segmentation.data, old_values, new_values, inplace=True) return segmentation
def open_dataset(f, ds): original_ds = ds ds, slices = parse_ds_name(ds) slices_str = original_ds[len(ds):] try: dataset_as = [] if all(key.startswith("s") for key in zarr.open(f)[ds].keys()): raise AttributeError("This group is a multiscale array!") for key in zarr.open(f)[ds].keys(): dataset_as.extend(open_dataset(f, f"{ds}/{key}{slices_str}")) return dataset_as except AttributeError as e: # dataset is an array, not a group pass print("ds :", ds) print("slices:", slices) try: zarr.open(f)[ds].keys() is_multiscale = True except: is_multiscale = False if not is_multiscale: a = daisy.open_ds(f, ds) if slices is not None: a = slice_dataset(a, slices) if a.roi.dims == 2: print("ROI is 2D, recruiting next channel to z dimension") a.roi = daisy.Roi((0,) + a.roi.get_begin(), (a.shape[-3],) + a.roi.get_shape()) a.voxel_size = daisy.Coordinate((1,) + a.voxel_size) if a.roi.dims == 4: print("ROI is 4D, stripping first dimension and treat as channels") a.roi = daisy.Roi(a.roi.get_begin()[1:], a.roi.get_shape()[1:]) a.voxel_size = daisy.Coordinate(a.voxel_size[1:]) if a.data.dtype == np.int64 or a.data.dtype == np.int16: print("Converting dtype in memory...") a.data = a.data[:].astype(np.uint64) return [(a, ds)] else: return [([daisy.open_ds(f, f"{ds}/{key}") for key in zarr.open(f)[ds].keys()], ds)]
def test_load_klb(self): self.write_test_klb() data = daisy.open_ds(self.klb_file, None) self.assertEqual(data.roi.get_offset(), daisy.Coordinate((0, 0, 0, 0))) self.assertEqual(data.roi.get_shape(), daisy.Coordinate( (1, 20, 10, 10))) self.assertEqual(data.voxel_size, daisy.Coordinate((1, 2, 1, 1))) self.remove_test_klb()
def add_container( context, snapshot_file, name_prefix="", volume_paths=[None], graph_paths=[None], graph_node_attrs=None, graph_edge_attrs=None, # mst=["embedding", "fg_maxima"], mst=None, roi=None, modify=None, dims=3, array_offset=None, voxel_size=None, ): if snapshot_file.name.endswith(".zarr") or snapshot_file.name.endswith( ".n5"): f = zarr.open(str(snapshot_file.absolute()), "r") elif snapshot_file.name.endswith(".h5") or snapshot_file.name.endswith( ".hdf"): f = h5py.File(str(snapshot_file.absolute()), "r") with f as dataset: volumes = [] for volume in volume_paths: volumes += get_volumes(dataset, volume) v = None for volume in volumes: v = daisy.open_ds(str(snapshot_file.absolute()), f"{volume}") if roi is not None: v = v.intersect(roi) if v.dtype == np.int64: v.materialize() v.data = v.data.astype(np.uint64) if v.dtype == np.dtype(bool): v.materialize() v.data = v.data.astype(np.float32) v.materialize() if modify is not None: data = modify(v.data, volume) else: data = v.data if voxel_size is None: voxel_size = v.voxel_size[-dims:] if array_offset is None: array_offset = v.roi.get_offset()[-dims:] add_layer( context, data, f"{name_prefix}_{volume}", visible=False, voxel_size=voxel_size, array_offset=array_offset, )
def __init__(self, filename, key, density=None, channels=0, shape=(16, 256, 256), time_window=None, add_sparse_mosaic_channel=True, random_rot=False): self.filename = filename self.key = key self.shape = shape self.density = density self.raw = gp.ArrayKey('RAW_0') self.add_sparse_mosaic_channel = add_sparse_mosaic_channel self.random_rot = random_rot self.channels = channels data = daisy.open_ds(filename, key) if time_window is None: source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape()) else: offs = list(data.roi.get_offset()) offs[1] += time_window[0] sh = list(data.roi.get_shape()) offs[1] = time_window[1] - time_window[0] source_roi = gp.Roi(tuple(offs), tuple(sh)) voxel_size = gp.Coordinate(data.voxel_size) self.pipeline = gp.ZarrSource( filename, { self.raw: key }, array_specs={ self.raw: gp.ArraySpec( roi=source_roi, voxel_size=voxel_size, interpolatable=True) }) + gp.RandomLocation() + IntensityDiffFilter(self.raw, 0, min_distance=0.1, channels=Slice(None)) # add augmentations self.pipeline = self.pipeline + gp.ElasticAugment([40, 40], [2, 2], [0, math.pi / 2.0], prob_slip=-1, spatial_dims=2) self.pipeline.setup() np.random.seed(os.getpid() + int(time.time()))
def test_overwrite_attrs(self): self.write_test_klb() data = daisy.open_ds(self.klb_file, None, attr_filename=self.attrs_file) self.assertEqual(data.roi.get_offset(), daisy.Coordinate((0, 100, 10, 0))) self.assertEqual(data.roi.get_shape(), daisy.Coordinate( (1, 10, 20, 20))) self.assertEqual(data.voxel_size, daisy.Coordinate((1, 1, 2, 2))) self.remove_test_klb()
def get_raw(locs, size, voxel_size, data_container, data_set): """ Get raw crops from the specified dataset. locs(``list of tuple of ints``): list of centers of location of interest size(``tuple of ints``): size of cropout in voxel voxel_size(``tuple of ints``): size of a voxel data_container(``string``): path to data container (e.g. zarr file) data_set(``string``): corresponding data_set name, (e.g. raw) """ raw = [] size = daisy.Coordinate(size) voxel_size = daisy.Coordinate(voxel_size) size_nm = (size * voxel_size) dataset = daisy.open_ds(data_container, data_set) for loc in locs: offset_nm = loc - (size / 2 * voxel_size) roi = daisy.Roi(offset_nm, size_nm).snap_to_grid(voxel_size, mode='closest') if roi.get_shape()[0] != size[0]: roi.set_shape(size_nm) if not dataset.roi.contains(roi): logger.WARNING("Location %s is not fully contained in dataset" % loc) return raw.append(dataset[roi].to_ndarray()) raw = np.stack(raw) raw = raw.astype(np.float32) raw_normalized = raw / 255.0 raw_normalized = raw_normalized * 2.0 - 1.0 return raw, raw_normalized
def runTest(self): repo = HemiNeuprint() test_position = (80, 100, 120) [interface_transform] = repo.transform_positions([test_position]) z = test_position[0] y = test_position[1] x = test_position[2] n5 = "/nrs/flyem/data/tmp/Z0115-22.export.n5" ds = daisy.open_ds(n5, "22-34/s0") X, Z, Y = ds.shape manual_transform = (X - x - 1, z, y) * repo.dataset.voxel_size self.assertEqual(interface_transform, tuple(manual_transform))
def __init__(self): dataset = Hemi() service = Neuprint( dataset=dataset, server="neuprint.janelia.org", neuprint_dataset="hemibrain:v1.2.1", ) self.x_shape = daisy.open_ds(dataset.container, dataset.dataset).shape[0] super().__init__( dataset=dataset, service=service, )
def get_raw_parallel(locs, size, voxel_size, data_container, data_set): """ Get raw crops from the specified dataset. locs(``list of tuple of ints``): list of centers of location of interest size(``tuple of ints``): size of cropout in voxel voxel_size(``tuple of ints``): size of a voxel data_container(``string``): path to data container (e.g. zarr file) data_set(``string``): corresponding data_set name, (e.g. raw) """ pool = Pool(processes=len(locs)) raw = [] size = daisy.Coordinate(size) voxel_size = daisy.Coordinate(voxel_size) size_nm = (size * voxel_size) dataset = daisy.open_ds(data_container, data_set) raw_workers = [ pool.apply_async(fetch_from_ds, (dataset, loc, voxel_size, size, size_nm)) for loc in locs ] raw = [w.get(timeout=60) for w in raw_workers] pool.close() pool.join() raw = np.stack(raw) raw = raw.astype(np.float32) raw_normalized = raw / 255.0 raw_normalized = raw_normalized * 2.0 - 1.0 return raw, raw_normalized
def get_predictions(self): pipeline, request, predictions = self.make_pipeline() with gp.build(pipeline): try: shutil.rmtree( os.path.join(self.curr_log_dir, 'predictions.zarr')) except OSError as e: pass pipeline.request_batch(gp.BatchRequest()) f = daisy.open_ds( os.path.join(self.curr_log_dir, 'predictions.zarr'), "predictions") return f
def get_crops(self, center_positions, size): """ Args: center_positions (list of tuple of ints): [(z,y,x)] in nm size (tuple of ints): size of the crop, in voxels """ crops = [] size = daisy.Coordinate(size) voxel_size = daisy.Coordinate(self.voxel_size) size_nm = (size * voxel_size) dataset = daisy.open_ds(self.container, self.dataset) dataset_resolution = None try: dataset_resolution = dataset.voxel_size except AttributeError: pass if dataset_resolution is not None: if not np.all(dataset.voxel_size == self.voxel_size): raise ValueError( f"Dataset {dataset} resolution missmatch {dataset.resolution} vs {self.voxel_size}" ) for position in center_positions: position = daisy.Coordinate(tuple(position)) offset_nm = position - ((size / 2) * voxel_size) roi = daisy.Roi(offset_nm, size_nm).snap_to_grid(voxel_size, mode='closest') if roi.get_shape()[0] != size_nm[0]: roi.set_shape(size_nm) if not dataset.roi.contains(roi): raise Warning( f"Location {position} is not fully contained in dataset") return crops.append(dataset[roi].to_ndarray()) crops_batched = np.stack(crops) crops_batched = crops_batched.astype(np.float32) return crops_batched
def add_dacapo_snapshot( context, snapshot_file, name_prefix="snapshot", volume_paths=[None], graph_paths=[None], graph_node_attrs=None, graph_edge_attrs=None, # mst=["embedding", "fg_maxima"], mst=None, roi=None, ): raw = daisy.open_ds(str(snapshot_file.absolute()), f"raw") raw_shape = raw.shape[-3:] voxel_size = raw.voxel_size[-3:] array_offset = raw.roi.get_offset()[-3:] def modify(v, name): if name == "prediction": v = reshape_batch_channel(v, 1, 3, raw_shape) elif name == "raw": v = reshape_batch_channel(v, 0, 2, raw_shape) elif name == "target": v = reshape_batch_channel(v, 1, 3, raw_shape) elif name == "weights": v = reshape_batch_channel(v, 1, 3, raw_shape) else: print(f"Modifying unknown array {name} with shape {v.shape}") v = reshape_batch_channel(v, 1, 3, raw_shape) return v add_container( context, snapshot_file, name_prefix, volume_paths, graph_paths, graph_node_attrs, graph_edge_attrs, mst, roi, modify, voxel_size=voxel_size, array_offset=array_offset, )
def _task_init(self): # open dataset dataset = daisy.open_ds(self.in_file, self.in_ds_name) # define total region of interest (roi) total_roi = dataset.roi ndims = len(total_roi.get_offset()) # define block read and write rois assert len(self.block_read_size) == ndims,\ "Read size must have same dimensions as in_file" assert len(self.block_write_size) == ndims,\ "Write size must have same dimensions as in_file" block_read_size = daisy.Coordinate(self.block_read_size) block_write_size = daisy.Coordinate(self.block_write_size) block_read_size *= dataset.voxel_size block_write_size *= dataset.voxel_size context = (block_read_size - block_write_size) / 2 block_read_roi = daisy.Roi((0,)*ndims, block_read_size) block_write_roi = daisy.Roi(context, block_write_size) # prepare output dataset output_roi = total_roi.grow(-context, -context) if self.out_file is None: self.out_file = self.in_file if self.out_ds_name is None: self.out_ds_name = self.in_ds_name + '_smoothed' logger.info(f'Processing data to {self.out_file}/{self.out_ds_name}') output_dataset = daisy.prepare_ds( self.out_file, self.out_ds_name, total_roi=output_roi, voxel_size=dataset.voxel_size, dtype=dataset.dtype, write_size=block_write_roi.get_shape()) # save variables for other functions self.total_roi = total_roi self.block_read_roi = block_read_roi self.block_write_roi = block_write_roi self.dataset = dataset self.output_dataset = output_dataset
def open_daisy(self): """ Open this dataset as a daisy array. """ data = daisy.open_ds(self.container, self.dataset) # Correct for datasets where the container does not have the voxel size if data.voxel_size != tuple(self.voxel_size): log.warn( "Container has different voxel size than dataset: "\ f"{data.voxel_size} != {self.voxel_size}") orig_shape = data.roi.get_shape() data = daisy.Array(data.data, daisy.Roi( data.roi.get_offset(), self.voxel_size * data.data.shape[-len(self.voxel_size):]), self.voxel_size, chunk_shape=data.chunk_shape) log.warn( "Reloaded container data with dataset voxel size, changing shape: "\ f"{orig_shape} => {data.roi.get_shape()}") return data
def extract_segmentation(fragments_file, fragments_dataset, edges_collection, threshold, out_file, out_dataset, num_workers, lut_fragment_segment, roi_offset=None, roi_shape=None, run_type=None, **kwargs): # open fragments fragments = daisy.open_ds(fragments_file, fragments_dataset) total_roi = fragments.roi if roi_offset is not None: assert roi_shape is not None, "If roi_offset is set, roi_shape " \ "also needs to be provided" total_roi = daisy.Roi(offset=roi_offset, shape=roi_shape) read_roi = daisy.Roi((0, 0, 0), (5000, 5000, 5000)) write_roi = daisy.Roi((0, 0, 0), (5000, 5000, 5000)) logging.info("Preparing segmentation dataset...") segmentation = daisy.prepare_ds(out_file, out_dataset, total_roi, voxel_size=fragments.voxel_size, dtype=np.uint64, write_roi=write_roi) lut_filename = 'seg_%s_%d' % (edges_collection, int(threshold * 100)) lut_dir = os.path.join(fragments_file, lut_fragment_segment) if run_type: lut_dir = os.path.join(lut_dir, run_type) logging.info("Run type set, using luts from %s data" % run_type) lut = os.path.join(lut_dir, lut_filename + '.npz') assert os.path.exists(lut), "%s does not exist" % lut start = time.time() logging.info("Reading fragment-segment LUT...") lut = np.load(lut)['fragment_segment_lut'] logging.info("%.3fs" % (time.time() - start)) logging.info("Found %d fragments in LUT" % len(lut[0])) daisy.run_blockwise(total_roi, read_roi, write_roi, lambda b: segment_in_block( b, fragments_file, segmentation, fragments, lut), fit='shrink', num_workers=num_workers, processes=True, read_write_conflict=False)
def predict_blockwise(base_dir, experiment, train_number, predict_number, iteration, in_container_spec, in_container, in_dataset, in_offset, in_size, out_container, db_name, db_host, singularity_container, num_cpus, num_cache_workers, num_block_workers, queue, mount_dirs, **kwargs): '''Run prediction in parallel blocks. Within blocks, predict in chunks. Args: experiment (``string``): Name of the experiment (cremi, fib19, fib25, ...). setup (``string``): Name of the setup to predict. iteration (``int``): Training iteration to predict from. raw_file (``string``): raw_dataset (``string``): auto_file (``string``): auto_dataset (``string``): Paths to the input autocontext datasets (affs or lsds). Can be None if not needed. out_file (``string``): Path to directory where zarr should be stored **Note: out_dataset no longer needed as input, build out_dataset from config outputs dictionary generated in mknet.py file_name (``string``): Name of output file block_size_in_chunks (``tuple`` of ``int``): The size of one block in chunks (not voxels!). A chunk corresponds to the output size of the network. num_workers (``int``): How many blocks to run in parallel. queue (``string``): Name of queue to run inference on (i.e slowpoke, gpu_rtx, gpu_any, gpu_tesla, gpu_tesla_large) ''' predict_setup_dir = os.path.join( os.path.join(base_dir, experiment), "02_predict/setup_t{}_p{}".format(train_number, predict_number)) train_setup_dir = os.path.join(os.path.join(base_dir, experiment), "01_train/setup_t{}".format(train_number)) # from here on, all values are in world units (unless explicitly mentioned) # get ROI of source source = daisy.open_ds(in_container_spec, in_dataset) logger.info('Source dataset has shape %s, ROI %s, voxel size %s' % (source.shape, source.roi, source.voxel_size)) # Read network config predict_net_config = os.path.join(predict_setup_dir, 'predict_net.json') with open(predict_net_config) as f: logger.info('Reading setup config from {}'.format(predict_net_config)) net_config = json.load(f) outputs = net_config['outputs'] # get chunk size and context net_input_size = daisy.Coordinate( net_config['input_shape']) * source.voxel_size net_output_size = daisy.Coordinate( net_config['output_shape']) * source.voxel_size context = (net_input_size - net_output_size) / 2 logger.info('Network context: {}'.format(context)) # get total input and output ROIs input_roi = source.roi.grow(context, context) output_roi = source.roi # create read and write ROI block_read_roi = daisy.Roi((0, 0, 0), net_input_size) - context block_write_roi = daisy.Roi((0, 0, 0), net_output_size) logger.info('Preparing output dataset...') for output_name, val in outputs.items(): out_dims = val['out_dims'] out_dtype = val['out_dtype'] out_dataset = 'volumes/%s' % output_name ds = daisy.prepare_ds(out_container, out_dataset, output_roi, source.voxel_size, out_dtype, write_roi=block_write_roi, num_channels=out_dims, compressor={ 'id': 'gzip', 'level': 5 }) logger.info('Starting block-wise processing...') client = pymongo.MongoClient(db_host) db = client[db_name] if 'blocks_predicted' not in db.list_collection_names(): blocks_predicted = db['blocks_predicted'] blocks_predicted.create_index([('block_id', pymongo.ASCENDING)], name='block_id') else: blocks_predicted = db['blocks_predicted'] # process block-wise succeeded = daisy.run_blockwise( input_roi, block_read_roi, block_write_roi, process_function=lambda: predict_worker( train_setup_dir, predict_setup_dir, predict_number, train_number, experiment, iteration, in_container, in_dataset, out_container, db_host, db_name, queue, singularity_container, num_cpus, num_cache_workers, mount_dirs), check_function=lambda b: check_block(blocks_predicted, b), num_workers=num_block_workers, read_write_conflict=False, fit='overhang') if not succeeded: raise RuntimeError("Prediction failed for (at least) one block")