def overlap():
    run_datetime = datetime.datetime.now(
        pytz.timezone('US/Eastern')).strftime('%Y%m%dT%H%M%S.%f%z')
    temp_dir = os.path.join(config.temp_path, run_datetime)
    os.makedirs(temp_dir)

    fragments = daisy.open_ds(config.fragments_zarr, config.fragments_ds)
    groundtruth = daisy.open_ds(config.groundtruth_zarr, config.groundtruth_ds)
    total_roi = daisy.Roi(offset=config.roi_offset, shape=config.roi_shape)

    start = time.time()
    daisy.run_blockwise(
        total_roi=total_roi,
        read_roi=daisy.Roi(offset=(0, 0, 0), shape=config.block_size),
        write_roi=daisy.Roi(offset=(0, 0, 0), shape=config.block_size),
        process_function=lambda block: overlap_in_block(
            block=block,
            fragments=fragments,
            groundtruth=groundtruth,
            tmp_path=temp_dir),
        fit='shrink',
        num_workers=config.num_workers,
        read_write_conflict=False,
        max_retries=1)

    logger.info(
        f"Blockwise overlapping of fragments and ground truth in {time.time() - start:.3f}s")
    logger.debug(
        f"num blocks: {np.prod(np.ceil(np.array(config.roi_shape) / np.array(config.block_size)))}")

    frag_to_gt = overlap_reduce(tmp_path=temp_dir)

    pickle.dump(frag_to_gt, open(os.path.join(
        temp_dir, 'frag_to_gt.pickle'), 'wb'))
    return frag_to_gt
Exemple #2
0
def open_ds(f, ds):

    try:
        data = daisy.open_ds(f, ds)
    except:
        data = daisy.open_ds(f, ds + '/s0')

    return data
    def __init__(self, 
                 credentials, 
                 db_name, 
                 collection_name,
                 predict_id,
                 dataset,
                 dx,dy,dz,
                 n_gpus,
                 gpu_id,
                 n_cpus,
                 cpu_id,
                 transform=None):

        self.db = BrainDb(credentials, db_name, collection_name, predict_id)

        start = time.time()
        log.info("Partition DB to workers...")
        self.cursor = self.get_cursor(n_gpus, n_cpus, gpu_id, cpu_id) 
        log.info(f"...took {time.time() - start} seconds")

        self.dataset = dataset
        self.container = dataset.container
        self.dset = dataset.dataset
        self.voxel_size = dataset.voxel_size
        self.data = daisy.open_ds(self.container,
                                  self.dset)
        self.transform = transform

        if dx % 8 != 0 or dy % 8 != 0 or dz % 80 != 0:
            raise ValueError("Roi size must be divisible by two")

        self.dx = dx
        self.dy = dy
        self.dz = dz
Exemple #4
0
    def build_source(self):
        data = daisy.open_ds(filename, key)

        if self.time_window is None:
            source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        else:
            offs = list(data.roi.get_offset())
            offs[1] += self.time_window[0]
            sh = list(data.roi.get_shape())
            offs[1] = self.time_window[1] - self.time_window[0]
            source_roi = gp.Roi(tuple(offs), tuple(sh))

        voxel_size = gp.Coordinate(data.voxel_size)

        return gp.ZarrSource(filename,
                             {
                                 self.raw_0: key,
                                 self.raw_1: key
                             },
                             array_specs={
                                 self.raw_0: gp.ArraySpec(
                                     roi=source_roi,
                                     voxel_size=voxel_size,
                                     interpolatable=True),
                                 self.raw_1: gp.ArraySpec(
                                     roi=source_roi,
                                     voxel_size=voxel_size,
                                     interpolatable=True)
                             })
def open_ds_wrapper(path, ds_name):
    """Returns None if ds_name does not exists """
    try:
        return daisy.open_ds(path, ds_name)
    except KeyError:
        print('dataset %s could not be loaded' % ds_name)
        return None
Exemple #6
0
def main(config_path):
    # S2: 3280-21435, 2840 22160, 70-1170
    # S4: 1000-14000, 1000 14000, 35-640
    # S5: 500-7000, 500 7000, 15-300
    with open(config_path, 'r') as f:
        configs = json.load(f)
    section = configs['section']
    raw_file = configs["zarr_file"]
    raw_ds = configs['raw_ds']
    now = datetime.now().strftime("%m%d.%H.%M.%S")
    output = configs[
        'output_folder']  # f'/n/groups/htem/Segmentation/xg76/mipmap/mipmap/{now}_{section}_{z_start}_{z_end}'

    coord_begin = configs['coord_begin']
    coord_end = configs['coord_end']
    z_range = range(coord_begin[2], coord_end[2], configs['interval'])
    scale_list = [configs['scale']]
    os.makedirs(output, exist_ok=True)

    print(f'mipmapping: {configs["section"]}')
    cutout_ds = daisy.open_ds(raw_file, raw_ds)

    for z in z_range:
        coord_begin[2] = z
        coord_end[2] = z + 1
        print(f'coord begin: {coord_begin}')
        raw_img_list = get_ndarray_img_from_zarr(coord_begin=coord_begin,
                                                 coord_end=coord_end,
                                                 cutout_ds=cutout_ds)
        img = raw_img_list[0]
        # write_to_tiff(img, os.path.join(output, f'{z}_origin.tif'))
        mipmap = down_sampling_img(img, scale_list)[0]
        write_to_tiff(
            mipmap,
            os.path.join(output, f'{section}_{z}_{scale_list[0]}_mipmap.tiff'))
Exemple #7
0
def overlay_segmentation(db_host,
                         db_name,
                         roi_offset,
                         roi_size,
                         selected_attr,
                         solved_attr,
                         edge_collection,
                         segmentation_container,
                         segmentation_dataset,
                         segmentation_number,
                         voxel_size=(40, 4, 4)):

    graph_provider = MongoDbGraphProvider(db_name,
                                          db_host,
                                          directed=False,
                                          position_attribute=['z', 'y', 'x'],
                                          edges_collection=edge_collection)
    graph_roi = daisy.Roi(roi_offset, roi_size)

    segmentation = daisy.open_ds(segmentation_container, segmentation_dataset)
    intersection_roi = segmentation.roi.intersect(graph_roi).snap_to_grid(
        voxel_size)

    nx_graph = graph_provider.get_graph(intersection_roi,
                                        nodes_filter={selected_attr: True},
                                        edges_filter={selected_attr: True})

    for node_id, data in nx_graph.nodes(data=True):
        node_position = daisy.Coordinate((data["z"], data["y"]), data["x"])
        nx_graph.nodes[node_id]["segmentation_{}".format(
            segmentation_number)] = segmentation[node_position]

    graph_provider.write_nodes(intersection_roi)
Exemple #8
0
def get_ndarray_img_from_zarr(raw_file=None,
                              raw_ds=None,
                              coord_begin=None,
                              coord_end=None,
                              cutout_ds=None):
    """
    Retrieve image from zarr file.
    Return list of images
    """
    if raw_file is None and raw_ds is None and cutout_ds is None:
        raise ValueError('No raw file is found')
    elif raw_file is not None and raw_ds is not None and cutout_ds is None:
        cutout_ds = daisy.open_ds(raw_file, raw_ds)
    else:
        print('Using passed in cutout_ds')
    print(f'Voxel size: {cutout_ds.voxel_size}')
    roi = None
    if coord_begin is not None and coord_end is not None:
        voxel_size = cutout_ds.voxel_size
        coord_begin = Coordinate(np.flip(np.array(coord_begin))) * voxel_size
        coord_end = Coordinate(np.flip(np.array(coord_end))) * voxel_size

        roi_offset = coord_begin
        roi_shape = coord_end - coord_begin
        roi = Roi(roi_offset, roi_shape)
    print(f"Getting data from zarr file... ROI: {roi}")
    ndarray = cutout_ds.to_ndarray(roi=roi)
    return ndarray
Exemple #9
0
def add_snapshot(
    context,
    snapshot_file,
    name_prefix="",
    volume_paths=["volumes"],
    graph_paths=["points"],
    graph_node_attrs=None,
    graph_edge_attrs=None,
    # mst=["embedding", "fg_maxima"],
    mst=None,
    roi=None,
):
    f = zarr.open(str(snapshot_file.absolute()), "r")
    with f as dataset:
        volumes = []
        for volume in volume_paths:
            volumes += get_volumes(dataset, volume)

        v = None
        for volume in volumes:
            v = daisy.open_ds(str(snapshot_file.absolute()), f"{volume}")
            if roi is not None:
                v = v.intersect(roi)
            if v.dtype == np.int64:
                v.materialize()
                v.data = v.data.astype(np.uint64)
            add_layer(context, v, f"{name_prefix}_{volume}", visible=False)
Exemple #10
0
    def get_local_segmentation(self, roi: daisy.Roi, threshold: float):
        # open fragments
        fragments = daisy.open_ds(self.fragments_file, self.fragments_dataset)

        # open RAG DB
        rag_provider = MongoDbRagProvider(
            self.fragments_db,
            host=self.fragments_host,
            mode="r",
            edges_collection=self.edges_collection,
        )

        segmentation = fragments[roi]
        segmentation.materialize()
        ids = [int(id) for id in list(np.unique(segmentation.data))]
        rag = rag_provider.read_rag(ids)

        if len(rag.nodes()) == 0:
            raise Exception('RAG is empty')

        components = rag.get_connected_components(threshold)

        values_map = np.array(
            [[fragment, i] for i in range(1,
                                          len(components) + 1)
             for fragment in components[i - 1]],
            dtype=np.uint64,
        )
        old_values = values_map[:, 0]
        new_values = values_map[:, 1]
        replace_values(segmentation.data, old_values, new_values, inplace=True)

        return segmentation
Exemple #11
0
def get_source_roi(data_dir, sample):

    sample_path = os.path.join(data_dir, sample)

    # get absolute paths
    if os.path.isfile(sample_path) or sample.endswith((".zarr", ".n5")):
        sample_dir = os.path.abspath(
            os.path.join(data_dir, os.path.dirname(sample)))
    else:
        sample_dir = os.path.abspath(os.path.join(data_dir, sample))

    if os.path.isfile(os.path.join(sample_dir, 'attributes.json')):

        with open(os.path.join(sample_dir, 'attributes.json'), 'r') as f:
            attributes = json.load(f)
        voxel_size = daisy.Coordinate(attributes['resolution'])
        shape = daisy.Coordinate(attributes['shape'])
        offset = daisy.Coordinate(attributes['offset'])
        source_roi = daisy.Roi(offset, shape * voxel_size)

        return voxel_size, source_roi

    elif os.path.isdir(os.path.join(sample_dir, 'timelapse.zarr')):

        a = daisy.open_ds(os.path.join(sample_dir, 'timelapse.zarr'),
                          'volumes/raw')

        return a.voxel_size, a.roi

    else:

        raise RuntimeError(
            "Can't find attributes.json or timelapse.zarr in %s" % sample_dir)
Exemple #12
0
 def check_block(block):
     _logger.debug("Checking if block %s is complete...", block.write_roi)
     ds = daisy.open_ds(out_container, affs_dataset)
     center_values = ds[block.write_roi.get_center()]
     s = np.sum(center_values)
     _logger.debug("Sum of center values in %s is %f", block.write_roi, s)
     return s != 0
Exemple #13
0
def get_array(data_container, data_set, begin, end, context=(0, 0, 0)):

    context = np.array(context)
    roi = daisy.Roi(begin - context / 2, end - begin + context)
    dataset = daisy.open_ds(data_container, data_set)
    data_array = dataset[roi].to_ndarray()
    return data_array
Exemple #14
0
    def query_local_segmentation(self, roi, threshold):
        # open fragments
        fragments = daisy.open_ds(self.fragments_file, self.fragments_dataset)

        # open RAG DB
        rag_provider = lsd.persistence.MongoDbRagProvider(
            self.frag_db_name,
            host=self.frag_db_host,
            mode="r",
            edges_collection=self.edges_collection,
        )

        segmentation = fragments[roi]
        segmentation.materialize()
        rag = rag_provider[roi]

        if len(rag.nodes()) == 0:
            return segmentation

        components = rag.get_connected_components(threshold)

        values_map = np.array(
            [[fragment, i] for i in range(len(components))
             for fragment in components[i]],
            dtype=np.uint64,
        )
        old_values = values_map[:, 0]
        new_values = values_map[:, 1]
        funlib.segment.arrays.replace_values(segmentation.data,
                                             old_values,
                                             new_values,
                                             inplace=True)

        return segmentation
Exemple #15
0
def open_dataset(f, ds):
    original_ds = ds
    ds, slices = parse_ds_name(ds)
    slices_str = original_ds[len(ds):]

    try:
        dataset_as = []
        if all(key.startswith("s") for key in zarr.open(f)[ds].keys()):
            raise AttributeError("This group is a multiscale array!")
        for key in zarr.open(f)[ds].keys():
            dataset_as.extend(open_dataset(f, f"{ds}/{key}{slices_str}"))
        return dataset_as
    except AttributeError as e:
        # dataset is an array, not a group
        pass

    print("ds    :", ds)
    print("slices:", slices)
    try:
        zarr.open(f)[ds].keys()
        is_multiscale = True
    except:
        is_multiscale = False

    if not is_multiscale:
        a = daisy.open_ds(f, ds)

        if slices is not None:
            a = slice_dataset(a, slices)

        if a.roi.dims == 2:
            print("ROI is 2D, recruiting next channel to z dimension")
            a.roi = daisy.Roi((0,) + a.roi.get_begin(), (a.shape[-3],) + a.roi.get_shape())
            a.voxel_size = daisy.Coordinate((1,) + a.voxel_size)

        if a.roi.dims == 4:
            print("ROI is 4D, stripping first dimension and treat as channels")
            a.roi = daisy.Roi(a.roi.get_begin()[1:], a.roi.get_shape()[1:])
            a.voxel_size = daisy.Coordinate(a.voxel_size[1:])

        if a.data.dtype == np.int64 or a.data.dtype == np.int16:
            print("Converting dtype in memory...")
            a.data = a.data[:].astype(np.uint64)

        return [(a, ds)]
    else:
        return [([daisy.open_ds(f, f"{ds}/{key}") for key in zarr.open(f)[ds].keys()], ds)]
Exemple #16
0
 def test_load_klb(self):
     self.write_test_klb()
     data = daisy.open_ds(self.klb_file, None)
     self.assertEqual(data.roi.get_offset(), daisy.Coordinate((0, 0, 0, 0)))
     self.assertEqual(data.roi.get_shape(), daisy.Coordinate(
         (1, 20, 10, 10)))
     self.assertEqual(data.voxel_size, daisy.Coordinate((1, 2, 1, 1)))
     self.remove_test_klb()
def add_container(
    context,
    snapshot_file,
    name_prefix="",
    volume_paths=[None],
    graph_paths=[None],
    graph_node_attrs=None,
    graph_edge_attrs=None,
    # mst=["embedding", "fg_maxima"],
    mst=None,
    roi=None,
    modify=None,
    dims=3,
    array_offset=None,
    voxel_size=None,
):
    if snapshot_file.name.endswith(".zarr") or snapshot_file.name.endswith(
            ".n5"):
        f = zarr.open(str(snapshot_file.absolute()), "r")
    elif snapshot_file.name.endswith(".h5") or snapshot_file.name.endswith(
            ".hdf"):
        f = h5py.File(str(snapshot_file.absolute()), "r")
    with f as dataset:
        volumes = []
        for volume in volume_paths:
            volumes += get_volumes(dataset, volume)

        v = None
        for volume in volumes:
            v = daisy.open_ds(str(snapshot_file.absolute()), f"{volume}")
            if roi is not None:
                v = v.intersect(roi)
            if v.dtype == np.int64:
                v.materialize()
                v.data = v.data.astype(np.uint64)
            if v.dtype == np.dtype(bool):
                v.materialize()
                v.data = v.data.astype(np.float32)

            v.materialize()
            if modify is not None:
                data = modify(v.data, volume)
            else:
                data = v.data

            if voxel_size is None:
                voxel_size = v.voxel_size[-dims:]
            if array_offset is None:
                array_offset = v.roi.get_offset()[-dims:]
            add_layer(
                context,
                data,
                f"{name_prefix}_{volume}",
                visible=False,
                voxel_size=voxel_size,
                array_offset=array_offset,
            )
Exemple #18
0
    def __init__(self, filename,
                 key,
                 density=None,
                 channels=0,
                 shape=(16, 256, 256),
                 time_window=None,
                 add_sparse_mosaic_channel=True,
                 random_rot=False):

        self.filename = filename
        self.key = key
        self.shape = shape
        self.density = density
        self.raw = gp.ArrayKey('RAW_0')
        self.add_sparse_mosaic_channel = add_sparse_mosaic_channel
        self.random_rot = random_rot
        self.channels = channels

        data = daisy.open_ds(filename, key)

        if time_window is None:
            source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        else:
            offs = list(data.roi.get_offset())
            offs[1] += time_window[0]
            sh = list(data.roi.get_shape())
            offs[1] = time_window[1] - time_window[0]
            source_roi = gp.Roi(tuple(offs), tuple(sh))

        voxel_size = gp.Coordinate(data.voxel_size)

        self.pipeline = gp.ZarrSource(
            filename,
            {
                self.raw: key
            },
            array_specs={
                self.raw: gp.ArraySpec(
                    roi=source_roi,
                    voxel_size=voxel_size,
                    interpolatable=True)
            }) + gp.RandomLocation() + IntensityDiffFilter(self.raw, 0, min_distance=0.1, channels=Slice(None))

        # add  augmentations
        self.pipeline = self.pipeline + gp.ElasticAugment([40, 40],
                                                          [2, 2],
                                                          [0, math.pi / 2.0],
                                                          prob_slip=-1,
                                                          spatial_dims=2)



        self.pipeline.setup()
        np.random.seed(os.getpid() + int(time.time()))
Exemple #19
0
 def test_overwrite_attrs(self):
     self.write_test_klb()
     data = daisy.open_ds(self.klb_file,
                          None,
                          attr_filename=self.attrs_file)
     self.assertEqual(data.roi.get_offset(),
                      daisy.Coordinate((0, 100, 10, 0)))
     self.assertEqual(data.roi.get_shape(), daisy.Coordinate(
         (1, 10, 20, 20)))
     self.assertEqual(data.voxel_size, daisy.Coordinate((1, 1, 2, 2)))
     self.remove_test_klb()
Exemple #20
0
def get_raw(locs, size, voxel_size, data_container, data_set):
    """
    Get raw crops from the specified
    dataset.

    locs(``list of tuple of ints``):

        list of centers of location of interest

    size(``tuple of ints``):
        
        size of cropout in voxel

    voxel_size(``tuple of ints``):

        size of a voxel

    data_container(``string``):

        path to data container (e.g. zarr file)

    data_set(``string``):

        corresponding data_set name, (e.g. raw)

    """

    raw = []
    size = daisy.Coordinate(size)
    voxel_size = daisy.Coordinate(voxel_size)
    size_nm = (size * voxel_size)
    dataset = daisy.open_ds(data_container, data_set)

    for loc in locs:
        offset_nm = loc - (size / 2 * voxel_size)
        roi = daisy.Roi(offset_nm, size_nm).snap_to_grid(voxel_size,
                                                         mode='closest')

        if roi.get_shape()[0] != size[0]:
            roi.set_shape(size_nm)

        if not dataset.roi.contains(roi):
            logger.WARNING("Location %s is not fully contained in dataset" %
                           loc)
            return

        raw.append(dataset[roi].to_ndarray())

    raw = np.stack(raw)
    raw = raw.astype(np.float32)
    raw_normalized = raw / 255.0
    raw_normalized = raw_normalized * 2.0 - 1.0
    return raw, raw_normalized
    def runTest(self):
        repo = HemiNeuprint()
        test_position = (80, 100, 120)
        [interface_transform] = repo.transform_positions([test_position])

        z = test_position[0]
        y = test_position[1]
        x = test_position[2]
        n5 = "/nrs/flyem/data/tmp/Z0115-22.export.n5"
        ds = daisy.open_ds(n5, "22-34/s0")
        X, Z, Y = ds.shape
        manual_transform = (X - x - 1, z, y) * repo.dataset.voxel_size
        self.assertEqual(interface_transform, tuple(manual_transform))
Exemple #22
0
 def __init__(self):
     dataset = Hemi()
     service = Neuprint(
         dataset=dataset,
         server="neuprint.janelia.org",
         neuprint_dataset="hemibrain:v1.2.1",
     )
     self.x_shape = daisy.open_ds(dataset.container,
                                  dataset.dataset).shape[0]
     super().__init__(
         dataset=dataset,
         service=service,
     )
Exemple #23
0
def get_raw_parallel(locs, size, voxel_size, data_container, data_set):
    """
    Get raw crops from the specified
    dataset.

    locs(``list of tuple of ints``):

        list of centers of location of interest

    size(``tuple of ints``):
        
        size of cropout in voxel

    voxel_size(``tuple of ints``):

        size of a voxel

    data_container(``string``):

        path to data container (e.g. zarr file)

    data_set(``string``):

        corresponding data_set name, (e.g. raw)

    """

    pool = Pool(processes=len(locs))
    raw = []
    size = daisy.Coordinate(size)
    voxel_size = daisy.Coordinate(voxel_size)
    size_nm = (size * voxel_size)
    dataset = daisy.open_ds(data_container, data_set)

    raw_workers = [
        pool.apply_async(fetch_from_ds,
                         (dataset, loc, voxel_size, size, size_nm))
        for loc in locs
    ]
    raw = [w.get(timeout=60) for w in raw_workers]
    pool.close()
    pool.join()

    raw = np.stack(raw)
    raw = raw.astype(np.float32)
    raw_normalized = raw / 255.0
    raw_normalized = raw_normalized * 2.0 - 1.0
    return raw, raw_normalized
Exemple #24
0
    def get_predictions(self):

        pipeline, request, predictions = self.make_pipeline()

        with gp.build(pipeline):
            try:
                shutil.rmtree(
                    os.path.join(self.curr_log_dir, 'predictions.zarr'))
            except OSError as e:
                pass

            pipeline.request_batch(gp.BatchRequest())
            f = daisy.open_ds(
                os.path.join(self.curr_log_dir, 'predictions.zarr'),
                "predictions")
            return f
Exemple #25
0
    def get_crops(self, center_positions, size):
        """
        Args:

            center_positions (list of tuple of ints): [(z,y,x)] in nm

            size (tuple of ints): size of the crop, in voxels
        """

        crops = []
        size = daisy.Coordinate(size)
        voxel_size = daisy.Coordinate(self.voxel_size)
        size_nm = (size * voxel_size)
        dataset = daisy.open_ds(self.container, self.dataset)

        dataset_resolution = None
        try:
            dataset_resolution = dataset.voxel_size
        except AttributeError:
            pass

        if dataset_resolution is not None:
            if not np.all(dataset.voxel_size == self.voxel_size):
                raise ValueError(
                    f"Dataset {dataset} resolution missmatch {dataset.resolution} vs {self.voxel_size}"
                )

        for position in center_positions:
            position = daisy.Coordinate(tuple(position))
            offset_nm = position - ((size / 2) * voxel_size)
            roi = daisy.Roi(offset_nm, size_nm).snap_to_grid(voxel_size,
                                                             mode='closest')

            if roi.get_shape()[0] != size_nm[0]:
                roi.set_shape(size_nm)

            if not dataset.roi.contains(roi):
                raise Warning(
                    f"Location {position} is not fully contained in dataset")
                return

            crops.append(dataset[roi].to_ndarray())

        crops_batched = np.stack(crops)
        crops_batched = crops_batched.astype(np.float32)

        return crops_batched
def add_dacapo_snapshot(
    context,
    snapshot_file,
    name_prefix="snapshot",
    volume_paths=[None],
    graph_paths=[None],
    graph_node_attrs=None,
    graph_edge_attrs=None,
    # mst=["embedding", "fg_maxima"],
    mst=None,
    roi=None,
):
    raw = daisy.open_ds(str(snapshot_file.absolute()), f"raw")
    raw_shape = raw.shape[-3:]
    voxel_size = raw.voxel_size[-3:]
    array_offset = raw.roi.get_offset()[-3:]

    def modify(v, name):
        if name == "prediction":
            v = reshape_batch_channel(v, 1, 3, raw_shape)
        elif name == "raw":
            v = reshape_batch_channel(v, 0, 2, raw_shape)
        elif name == "target":
            v = reshape_batch_channel(v, 1, 3, raw_shape)
        elif name == "weights":
            v = reshape_batch_channel(v, 1, 3, raw_shape)
        else:
            print(f"Modifying unknown array {name} with shape {v.shape}")
            v = reshape_batch_channel(v, 1, 3, raw_shape)
        return v

    add_container(
        context,
        snapshot_file,
        name_prefix,
        volume_paths,
        graph_paths,
        graph_node_attrs,
        graph_edge_attrs,
        mst,
        roi,
        modify,
        voxel_size=voxel_size,
        array_offset=array_offset,
    )
Exemple #27
0
    def _task_init(self):

        # open dataset
        dataset = daisy.open_ds(self.in_file, self.in_ds_name)

        # define total region of interest (roi)
        total_roi = dataset.roi
        ndims = len(total_roi.get_offset())

        # define block read and write rois
        assert len(self.block_read_size) == ndims,\
            "Read size must have same dimensions as in_file"
        assert len(self.block_write_size) == ndims,\
            "Write size must have same dimensions as in_file"
        block_read_size = daisy.Coordinate(self.block_read_size)
        block_write_size = daisy.Coordinate(self.block_write_size)
        block_read_size *= dataset.voxel_size
        block_write_size *= dataset.voxel_size
        context = (block_read_size - block_write_size) / 2
        block_read_roi = daisy.Roi((0,)*ndims, block_read_size)
        block_write_roi = daisy.Roi(context, block_write_size)

        # prepare output dataset
        output_roi = total_roi.grow(-context, -context)
        if self.out_file is None:
            self.out_file = self.in_file
        if self.out_ds_name is None:
            self.out_ds_name = self.in_ds_name + '_smoothed'

        logger.info(f'Processing data to {self.out_file}/{self.out_ds_name}')

        output_dataset = daisy.prepare_ds(
                self.out_file,
                self.out_ds_name,
                total_roi=output_roi,
                voxel_size=dataset.voxel_size,
                dtype=dataset.dtype,
                write_size=block_write_roi.get_shape())

        # save variables for other functions
        self.total_roi = total_roi
        self.block_read_roi = block_read_roi
        self.block_write_roi = block_write_roi
        self.dataset = dataset
        self.output_dataset = output_dataset
Exemple #28
0
    def open_daisy(self):
        """
        Open this dataset as a daisy array.
        """
        data = daisy.open_ds(self.container, self.dataset)

        # Correct for datasets where the container does not have the voxel size
        if data.voxel_size != tuple(self.voxel_size):
            log.warn(
                "Container has different voxel size than dataset: "\
                f"{data.voxel_size} != {self.voxel_size}")
            orig_shape = data.roi.get_shape()
            data = daisy.Array(data.data,
                               daisy.Roi(
                                   data.roi.get_offset(), self.voxel_size *
                                   data.data.shape[-len(self.voxel_size):]),
                               self.voxel_size,
                               chunk_shape=data.chunk_shape)
            log.warn(
                "Reloaded container data with dataset voxel size, changing shape: "\
                f"{orig_shape} => {data.roi.get_shape()}")

        return data
def extract_segmentation(fragments_file,
                         fragments_dataset,
                         edges_collection,
                         threshold,
                         out_file,
                         out_dataset,
                         num_workers,
                         lut_fragment_segment,
                         roi_offset=None,
                         roi_shape=None,
                         run_type=None,
                         **kwargs):

    # open fragments
    fragments = daisy.open_ds(fragments_file, fragments_dataset)

    total_roi = fragments.roi
    if roi_offset is not None:
        assert roi_shape is not None, "If roi_offset is set, roi_shape " \
                                      "also needs to be provided"
        total_roi = daisy.Roi(offset=roi_offset, shape=roi_shape)

    read_roi = daisy.Roi((0, 0, 0), (5000, 5000, 5000))
    write_roi = daisy.Roi((0, 0, 0), (5000, 5000, 5000))

    logging.info("Preparing segmentation dataset...")
    segmentation = daisy.prepare_ds(out_file,
                                    out_dataset,
                                    total_roi,
                                    voxel_size=fragments.voxel_size,
                                    dtype=np.uint64,
                                    write_roi=write_roi)

    lut_filename = 'seg_%s_%d' % (edges_collection, int(threshold * 100))

    lut_dir = os.path.join(fragments_file, lut_fragment_segment)

    if run_type:
        lut_dir = os.path.join(lut_dir, run_type)
        logging.info("Run type set, using luts from %s data" % run_type)

    lut = os.path.join(lut_dir, lut_filename + '.npz')

    assert os.path.exists(lut), "%s does not exist" % lut

    start = time.time()
    logging.info("Reading fragment-segment LUT...")
    lut = np.load(lut)['fragment_segment_lut']
    logging.info("%.3fs" % (time.time() - start))

    logging.info("Found %d fragments in LUT" % len(lut[0]))

    daisy.run_blockwise(total_roi,
                        read_roi,
                        write_roi,
                        lambda b: segment_in_block(
                            b, fragments_file, segmentation, fragments, lut),
                        fit='shrink',
                        num_workers=num_workers,
                        processes=True,
                        read_write_conflict=False)
Exemple #30
0
def predict_blockwise(base_dir, experiment, train_number, predict_number,
                      iteration, in_container_spec, in_container, in_dataset,
                      in_offset, in_size, out_container, db_name, db_host,
                      singularity_container, num_cpus, num_cache_workers,
                      num_block_workers, queue, mount_dirs, **kwargs):
    '''Run prediction in parallel blocks. Within blocks, predict in chunks.

    Args:

        experiment (``string``):

            Name of the experiment (cremi, fib19, fib25, ...).

        setup (``string``):

            Name of the setup to predict.

        iteration (``int``):

            Training iteration to predict from.

        raw_file (``string``):
        raw_dataset (``string``):
        auto_file (``string``):
        auto_dataset (``string``):

            Paths to the input autocontext datasets (affs or lsds). Can be None if not needed.

        out_file (``string``):

            Path to directory where zarr should be stored

        **Note:

            out_dataset no longer needed as input, build out_dataset from config
            outputs dictionary generated in mknet.py

        file_name (``string``):

            Name of output file

        block_size_in_chunks (``tuple`` of ``int``):

            The size of one block in chunks (not voxels!). A chunk corresponds
            to the output size of the network.

        num_workers (``int``):

            How many blocks to run in parallel.

        queue (``string``):

            Name of queue to run inference on (i.e slowpoke, gpu_rtx, gpu_any,
            gpu_tesla, gpu_tesla_large)
    '''

    predict_setup_dir = os.path.join(
        os.path.join(base_dir, experiment),
        "02_predict/setup_t{}_p{}".format(train_number, predict_number))
    train_setup_dir = os.path.join(os.path.join(base_dir, experiment),
                                   "01_train/setup_t{}".format(train_number))

    # from here on, all values are in world units (unless explicitly mentioned)
    # get ROI of source
    source = daisy.open_ds(in_container_spec, in_dataset)
    logger.info('Source dataset has shape %s, ROI %s, voxel size %s' %
                (source.shape, source.roi, source.voxel_size))

    # Read network config
    predict_net_config = os.path.join(predict_setup_dir, 'predict_net.json')
    with open(predict_net_config) as f:
        logger.info('Reading setup config from {}'.format(predict_net_config))
        net_config = json.load(f)
    outputs = net_config['outputs']

    # get chunk size and context
    net_input_size = daisy.Coordinate(
        net_config['input_shape']) * source.voxel_size
    net_output_size = daisy.Coordinate(
        net_config['output_shape']) * source.voxel_size
    context = (net_input_size - net_output_size) / 2
    logger.info('Network context: {}'.format(context))

    # get total input and output ROIs
    input_roi = source.roi.grow(context, context)
    output_roi = source.roi

    # create read and write ROI
    block_read_roi = daisy.Roi((0, 0, 0), net_input_size) - context
    block_write_roi = daisy.Roi((0, 0, 0), net_output_size)

    logger.info('Preparing output dataset...')

    for output_name, val in outputs.items():
        out_dims = val['out_dims']
        out_dtype = val['out_dtype']
        out_dataset = 'volumes/%s' % output_name

        ds = daisy.prepare_ds(out_container,
                              out_dataset,
                              output_roi,
                              source.voxel_size,
                              out_dtype,
                              write_roi=block_write_roi,
                              num_channels=out_dims,
                              compressor={
                                  'id': 'gzip',
                                  'level': 5
                              })

    logger.info('Starting block-wise processing...')

    client = pymongo.MongoClient(db_host)
    db = client[db_name]
    if 'blocks_predicted' not in db.list_collection_names():
        blocks_predicted = db['blocks_predicted']
        blocks_predicted.create_index([('block_id', pymongo.ASCENDING)],
                                      name='block_id')
    else:
        blocks_predicted = db['blocks_predicted']

    # process block-wise
    succeeded = daisy.run_blockwise(
        input_roi,
        block_read_roi,
        block_write_roi,
        process_function=lambda: predict_worker(
            train_setup_dir, predict_setup_dir, predict_number, train_number,
            experiment, iteration, in_container, in_dataset, out_container,
            db_host, db_name, queue, singularity_container, num_cpus,
            num_cache_workers, mount_dirs),
        check_function=lambda b: check_block(blocks_predicted, b),
        num_workers=num_block_workers,
        read_write_conflict=False,
        fit='overhang')

    if not succeeded:
        raise RuntimeError("Prediction failed for (at least) one block")