예제 #1
0
 def get_tile_path(self, tile: Interval5D) -> Path:
     "Gets the relative path into the n5 dataset where 'tile' should be stored"
     if not tile.is_tile(tile_shape=self.blockSize,
                         full_interval=self.interval,
                         clamped=True):
         raise ValueError(
             f"{tile} is not a tile of {json.dumps(self.to_json_data())}")
     slice_address_components = (tile.translated(-self.location).start //
                                 self.blockSize).to_tuple(
                                     self.axiskeys[::-1])
     return Path("/".join(
         str(component) for component in slice_address_components))
예제 #2
0
 def __init__(
     self,
     *,
     stack_axis: str,
     datasources: Iterable[DataSource],
 ):
     self.stack_axis = stack_axis
     self.datasources = sorted(datasources,
                               key=lambda ds: ds.location[stack_axis])
     tile_shapes = {ds.tile_shape for ds in self.datasources}
     if len(tile_shapes) != 1:
         raise ValueError(
             f"All datasources must have the same tile shape. Tile shapes: {tile_shapes}"
         )
     tile_shape = tile_shapes.pop()
     if any(ds.shape[stack_axis] % tile_shape[stack_axis] != 0
            for ds in self.datasources):
         raise ValueError(
             f"Stacking over axis that are not multiple of the tile size is not supported"
         )
     self.stack_levels = [
         ds.location[stack_axis] for ds in self.datasources
     ]
     self.axiskeys = stack_axis + Point5D.LABELS.replace(stack_axis, "")
     interval = Interval5D.enclosing(ds.interval for ds in self.datasources)
     super().__init__(dtype=self.datasources[0].dtype,
                      interval=interval,
                      tile_shape=tile_shape)
예제 #3
0
    def interpolate_from_points(cls, voxels: Sequence[Point5D],
                                raw_data: DataSource):
        start = Point5D.min_coords(voxels)
        stop = Point5D.max_coords(
            voxels
        ) + 1  # +1 because slice.stop is exclusive, but max_point isinclusive
        scribbling_roi = Interval5D.create_from_start_stop(start=start,
                                                           stop=stop)
        if scribbling_roi.shape.c != 1:
            raise ValueError(
                f"Annotations must not span multiple channels: {voxels}")
        scribblings = Array5D.allocate(scribbling_roi,
                                       dtype=np.dtype(bool),
                                       value=False)

        anchor = voxels[0]
        for voxel in voxels:
            for interp_voxel in anchor.interpolate_until(voxel):
                scribblings.paint_point(point=interp_voxel, value=True)
            anchor = voxel

        return cls(scribblings._data,
                   axiskeys=scribblings.axiskeys,
                   raw_data=raw_data,
                   location=start)
예제 #4
0
 def __setstate__(self, data: JsonValue):
     data_obj = ensureJsonObject(data)
     self.__init__(
         path=PurePosixPath(ensureJsonString(data_obj.get("path"))),
         location=Interval5D.from_json_value(data_obj.get("interval")).start,
         filesystem=JsonableFilesystem.from_json_value(data_obj.get("filesystem"))
     )
예제 #5
0
 def label_at(self, point: Point5D) -> int:
     point_roi = Interval5D.enclosing([point])
     if not self.interval.contains(point_roi):
         raise ValueError(f"Point {point} is not inside the labels at {self.interval}")
     label = self.cut(point_roi).raw("x")[0]
     if label == 0:
         raise ValueError(f"Point {point} is not on top of an object")
     return label
예제 #6
0
 def __init__(
         self,
         *,
         position: Point5D,  # an anchor point within datasource
         klass: int,
         datasource: DataSource,  # an object label is tied to the datasource
         components_extractor:
     ConnectedComponentsExtractor,  # and also to the method used to extract the objects
 ):
     position_roi = DataRoi(datasource,
                            **Interval5D.enclosing([position]).to_dict())
     self.data_tile = position_roi.enlarge_to_tiles(
         tile_shape=datasource.tile_shape,
         tiles_origin=datasource.interval.start).clamped(
             datasource.interval)
     self.position = position
     self.klass = klass
     self.datasource = datasource
     self.components_extractor = components_extractor
     # compute connected components in constructor to prevent creation of bad annotation
     self.label = components_extractor(self.data_tile).label_at(position)
     super().__init__()
예제 #7
0
def test_skimage_datasource_tiles():
    bs = DataRoi(SkimageDataSource(path=png_image, filesystem=OsFs("/")))
    num_checked_tiles = 0
    for tile in bs.split(Shape5D(x=2, y=2)):
        if tile == Interval5D.zero(x=(0, 2), y=(0, 2)):
            expected_raw = raw_0_2x0_2y
        elif tile == Interval5D.zero(x=(0, 2), y=(2, 4)):
            expected_raw = raw_0_2x2_4y
        elif tile == Interval5D.zero(x=(2, 4), y=(0, 2)):
            expected_raw = raw_2_4x0_2y
        elif tile == Interval5D.zero(x=(2, 4), y=(2, 4)):
            expected_raw = raw_2_4x2_4y
        elif tile == Interval5D.zero(x=(4, 5), y=(0, 2)):
            expected_raw = raw_4_5x0_2y
        elif tile == Interval5D.zero(x=(4, 5), y=(2, 4)):
            expected_raw = raw_4_5x2_4y
        else:
            raise Exception(f"Unexpected tile {tile}")
        assert (tile.retrieve().raw("yx") == expected_raw).all()
        num_checked_tiles += 1
    assert num_checked_tiles == 6
예제 #8
0
 def get_expected_roi(self, data_slice: Interval5D) -> Interval5D:
     c_start = data_slice.c[0]
     c_stop = c_start + self.num_classes
     return data_slice.updated(c=(c_start, c_stop))
    def parse(cls, group: h5py.Group, raw_data_sources: Mapping[int, "FsDataSource | None"]) -> "IlpPixelClassificationGroup":
        LabelColors = ensure_color_list(group, "LabelColors")
        LabelNames = ensure_encoded_string_list(group, "LabelNames")
        class_to_color: Mapping[np.uint8, Color] = {np.uint8(i): color for i, color in enumerate(LabelColors, start=1)}

        label_classes: Dict[Color, Label] = {color: Label(name=name, color=color, annotations=[]) for name, color in zip(LabelNames, LabelColors)}
        LabelSets = ensure_group(group, "LabelSets")
        for lane_key in LabelSets.keys():
            if not lane_key.startswith("labels"):
                continue
            lane_index = int(lane_key.replace("labels", ""))
            lane_label_blocks = ensure_group(LabelSets, lane_key)
            if len(lane_label_blocks.keys()) == 0:
                continue
            raw_data = raw_data_sources.get(lane_index)
            if raw_data is None:
                raise IlpParsingError(f"No datasource for lane {lane_index:03d}")
            for block_name in lane_label_blocks.keys():
                if not block_name.startswith("block"):
                    continue
                block = ensure_dataset(lane_label_blocks, block_name)
                block_data = block[()]
                if not isinstance(block_data, np.ndarray):
                    raise IlpParsingError("Expected annotation block to contain a ndarray")

                raw_axistags = block.attrs.get("axistags")
                if not isinstance(raw_axistags, str):
                    raise IlpParsingError(f"Expected axistags to be a str, found {raw_axistags}")
                axistags = AxisTags.fromJSON(raw_axistags)
                axiskeys = "".join(axistags.keys())

                if "blockSlice" not in block.attrs:
                    raise IlpParsingError(f"Expected 'blockSlice' in attrs from {block.name}")
                blockSlice = block.attrs["blockSlice"]
                if not isinstance(blockSlice, str):
                    raise IlpParsingError(f"Expected 'blockSlice'' to be a str, found {blockSlice}")
                # import pydevd; pydevd.settrace()
                blockSpans: Sequence[List[str]] = [span_str.split(":") for span_str in blockSlice[1:-1].split(",")]
                blockInterval = Interval5D.zero(**{
                    key: (int(span[0]), int(span[1]))
                    for key, span in zip(axiskeys, blockSpans)
                })

                block_5d = Array5D(block_data, axiskeys=axiskeys)
                for color_5d in block_5d.unique_colors().split(shape=Shape5D(x=1, c=block_5d.shape.c)):
                    color_index = np.uint8(color_5d.raw("c")[0])
                    if color_index == np.uint8(0): # background
                        continue
                    color = class_to_color.get(color_index)
                    if color is None:
                        raise IlpParsingError(f"Could not find a label color for index {color_index}")
                    annotation_data: "np.ndarray[Any, np.dtype[np.uint8]]" = block_5d.color_filtered(color=color_5d).raw(axiskeys)
                    annotation = Annotation(
                        annotation_data.astype(np.dtype(bool)),
                        location=blockInterval.start,
                        axiskeys=axiskeys, # FIXME: what if the user changed the axiskeys in the data source?
                        raw_data=raw_data,
                    )

                    label_classes[color].annotations.append(annotation)



        ClassifierFactory = ensure_bytes(group, "ClassifierFactory")
        if ClassifierFactory != VIGRA_ILP_CLASSIFIER_FACTORY:
            raise IlpParsingError(f"Expecting ClassifierFactory to be pickled ParallelVigraRfLazyflowClassifierFactory, found {ClassifierFactory}")
        if "ClassifierForests" in group:
            ClassifierForests = ensure_group(group, "ClassifierForests")
            forests: List[VigraRandomForest] = []
            for forest_key in sorted(ClassifierForests.keys()):
                if not forest_key.startswith("Forest"):
                    continue
                forest = VigraRandomForest(group.file.filename, f"{ClassifierForests.name}/{forest_key}")
                # forest_bytes = ensure_bytes(ClassifierForests, forest_key)
                # forest = h5_bytes_to_vigra_forest(h5_bytes=VigraForestH5Bytes(forest_bytes))
                forests.append(forest)

            feature_names = ensure_encoded_string_list(ClassifierForests, "feature_names")
            feature_extractors, expected_num_channels = cls.ilp_filters_and_expected_num_channels_from_names(feature_names)

            classifier = VigraPixelClassifier(
                feature_extractors=feature_extractors,
                forest_h5_bytes=[vigra_forest_to_h5_bytes(forest) for forest in forests],
                num_classes=len([label for label in label_classes.values() if not label.is_empty()]),
                num_input_channels=expected_num_channels,
            )
        else:
            classifier = None

        return IlpPixelClassificationGroup(
            classifier=classifier,
            labels=list(label_classes.values()),
        )
예제 #10
0
 def get_tile_path(self, tile: Interval5D) -> Path:
     assert any(tile.is_tile(tile_shape=cs, full_interval=self.interval, clamped=True) for cs in self.chunk_sizes_5d), f"Bad tile: {tile}"
     return self.key / f"{tile.x[0]}-{tile.x[1]}_{tile.y[0]}-{tile.y[1]}_{tile.z[0]}-{tile.z[1]}"
예제 #11
0
 def _get_tile(self, tile: Interval5D) -> Array5D:
     slices = tile.translated(-self.location).to_slices(self.axiskeys)
     raw = self._dataset[slices]
     if not isinstance(raw, ndarray):
         raise IOError("Expected ndarray at {slices}, found {raw}")
     return Array5D(raw, axiskeys=self.axiskeys, location=tile.start)
예제 #12
0
 def is_tile(self, tile: Interval5D) -> bool:
     return tile.is_tile(tile_shape=self.tile_shape, full_interval=self.interval, clamped=True)
예제 #13
0
 def interval(self) -> Interval5D:
     return Interval5D(t=self.t, c=self.c, x=self.x, y=self.y, z=self.z)