Example #1
0
    def _do_predict(self, roi: DataRoi) -> Predictions:
        feature_data = self.feature_extractor(roi)
        linear_feature_data = feature_data.raw("tzyxc").reshape(
            (feature_data.shape.t * feature_data.shape.volume,
             feature_data.shape.c))

        predictions = Array5D.allocate(
            axiskeys="tzyxc",
            interval=self.get_expected_roi(roi),
            dtype=np.dtype('float32'),
            value=0,
        )

        assert predictions.interval == self.get_expected_roi(roi)
        raw_linear_predictions: "ndarray[Any, dtype[float32]]" = predictions.raw(
            "tzyxc").reshape((predictions.shape.t * predictions.shape.volume,
                              predictions.shape.c))

        executor = get_executor(hint="predicting")
        f = partial(_compute_partial_predictions, linear_feature_data)
        futures = [executor.submit(f, forest) for forest in self.forests]
        for partial_predictions_future in futures:
            raw_linear_predictions += partial_predictions_future.result()

        raw_linear_predictions /= self.num_trees
        predictions.setflags(write=False)

        return Predictions(
            arr=predictions.raw(predictions.axiskeys),
            axiskeys=predictions.axiskeys,
            location=predictions.location,
        )
Example #2
0
class SimpleSegmenter(Operator[DataRoi, List[Array5D]]):
    def __init__(
        self,
        *,
        preprocessor: Operator[DataRoi, Array5D] = OpRetriever(),
    ) -> None:
        super().__init__()
        self.preprocessor = preprocessor

    def __call__(self, /,  roi: DataRoi) -> List[Array5D]:
        data = self.preprocessor(roi)
        winning_channel_indices = Array5D(
            arr=np.argmax(data.raw(data.axiskeys), axis=data.axiskeys.index("c")),
            axiskeys=data.axiskeys.replace("c", ""),
            location=roi.start,
        )

        segmentations: List[Array5D] = []

        for class_index in range(data.shape.c):
            class_seg = Array5D.allocate(data.interval.updated(c=(0,3)), dtype=np.dtype("uint8"), value=0)
            red_channel = class_seg.cut(c=0)
            raw_segmentation = (winning_channel_indices.raw("tzyx") == class_index).astype(np.dtype("uint8")) * 255
            red_channel.raw("tzyx")[...] = raw_segmentation
            segmentations.append(class_seg)

        return segmentations
Example #3
0
    def to_z_slice_pngs(self,
                        class_colors: Sequence[Color]) -> Iterator[io.BytesIO]:
        for z_slice in self.split(self.shape.updated(z=1)):
            print(f"\nz_slice: {z_slice}")
            rendered_rgb = Array5D.allocate(z_slice.shape.updated(c=3),
                                            dtype=np.dtype("float32"),
                                            value=0)
            rendered_rgb_yxc = rendered_rgb.raw("yxc")

            for prediction_channel, color in zip(
                    z_slice.split(z_slice.shape.updated(c=1)), class_colors):
                print(f"\nprediction_channel: {prediction_channel}")

                class_rgb = Array5D(np.ones(
                    prediction_channel.shape.updated(c=3).to_tuple("yxc")),
                                    axiskeys="yxc")
                class_rgb.raw("yxc")[...] *= np.asarray(
                    [color.r, color.g, color.b])
                class_rgb.raw("cyx")[...] *= prediction_channel.raw("yx")

                rendered_rgb_yxc += class_rgb.raw("yxc")

            out_image = PIL.Image.fromarray(
                rendered_rgb.raw("yxc").astype(np.uint8))  # type: ignore
            out_file = io.BytesIO()
            out_image.save(out_file, "png")
            _ = out_file.seek(0)
            yield out_file
Example #4
0
    def interpolate_from_points(cls, voxels: Sequence[Point5D],
                                raw_data: DataSource):
        start = Point5D.min_coords(voxels)
        stop = Point5D.max_coords(
            voxels
        ) + 1  # +1 because slice.stop is exclusive, but max_point isinclusive
        scribbling_roi = Interval5D.create_from_start_stop(start=start,
                                                           stop=stop)
        if scribbling_roi.shape.c != 1:
            raise ValueError(
                f"Annotations must not span multiple channels: {voxels}")
        scribblings = Array5D.allocate(scribbling_roi,
                                       dtype=np.dtype(bool),
                                       value=False)

        anchor = voxels[0]
        for voxel in voxels:
            for interp_voxel in anchor.interpolate_until(voxel):
                scribblings.paint_point(point=interp_voxel, value=True)
            anchor = voxel

        return cls(scribblings._data,
                   axiskeys=scribblings.axiskeys,
                   raw_data=raw_data,
                   location=start)
Example #5
0
class FeatureExtractorCollection(FeatureExtractor):
    def __init__(self, extractors: Iterable[FeatureExtractor]):
        self.extractors = tuple(extractors)
        assert len(self.extractors) > 0
        super().__init__()

    def is_applicable_to(self, datasource: DataSource) -> bool:
        return all(fx.is_applicable_to(datasource) for fx in self.extractors)

    def __call__(self, /, roi: DataRoi) -> FeatureData:
        assert roi.interval.c[0] == 0
        feature_promises: Dict[int, Future[FeatureData]] = {}

        executor = get_executor(hint="feature_extraction",
                                max_workers=len(self.extractors))
        from webilastik.features.ilp_filter import IlpGaussianSmoothing

        feature_promises = {
            fx_index: executor.submit(fx, roi)
            for fx_index, fx in enumerate(self.extractors)
            if isinstance(fx, IlpGaussianSmoothing)
        }
        feature_promises.update({
            fx_index: executor.submit(fx, roi)
            for fx_index, fx in enumerate(self.extractors)
            if not isinstance(fx, IlpGaussianSmoothing)
        })
        assert len(feature_promises) == len(self.extractors)
        features = [
            feature_promises[i].result() for i in range(len(self.extractors))
        ]

        out = Array5D.allocate(
            dtype=np.dtype("float32"),
            interval=roi.shape.updated(c=sum(feat.shape.c
                                             for feat in features)),
            axiskeys="tzyxc",
        ).translated(roi.start)

        channel_offset: int = 0
        for feature in features:
            out.set(feature.translated(Point5D.zero(c=channel_offset)))
            channel_offset += feature.shape.c

        return FeatureData(arr=out.raw(out.axiskeys),
                           axiskeys=out.axiskeys,
                           location=out.location)
Example #6
0
    def __call__(self, roi: DataRoi) -> FeatureData:
        haloed_roi = roi.enlarged(self.halo)
        source_data = self.preprocessor(haloed_roi)

        step_shape: Shape5D = Shape5D(
            c=1,
            t=1,
            x=1 if self.axis_2d == "x" else source_data.shape.x,
            y=1 if self.axis_2d == "y" else source_data.shape.y,
            z=1 if self.axis_2d == "z" else source_data.shape.z,
        )

        out = Array5D.allocate(
            interval=roi.updated(c=(roi.c[0] * self.channel_multiplier,
                                    roi.c[1] * self.channel_multiplier)),
            dtype=numpy.dtype("float32"),
            axiskeys=source_data.axiskeys.replace("c", "") +
            "c"  # fastfilters puts channel last
        )

        for data_slice in source_data.split(step_shape):
            source_axes = "zyx"
            if self.axis_2d:
                source_axes = source_axes.replace(self.axis_2d, "")

            raw_data: "ndarray[Any, dtype[float32]]" = data_slice.raw(
                source_axes).astype(numpy.float32)
            raw_feature_data: "ndarray[Any, dtype[float32]]" = self.filter_fn(
                raw_data)

            feature_data = FeatureData(
                raw_feature_data,
                axiskeys=source_axes +
                "c" if len(raw_feature_data.shape) > len(source_axes) else
                source_axes,
                location=data_slice.location.updated(c=data_slice.location.c *
                                                     self.channel_multiplier))
            out.set(feature_data, autocrop=True)
        out.setflags(write=False)
        return FeatureData(
            out.raw(out.axiskeys),
            axiskeys=out.axiskeys,
            location=out.location,
        )
    def populate_group(self, group: h5py.Group):
        LabelColors: "ndarray[Any, dtype[int64]]"  = np.asarray([label.color.rgba for label in self.labels], dtype=int64)

        # expected group keys to look like this:
        # ['Bookmarks', 'ClassifierFactory', 'LabelColors', 'LabelNames', 'PmapColors', 'StorageVersion', 'LabelSets', 'ClassifierForests']>
        bookmark = group.create_group("Bookmarks").create_dataset("0000", data=np.void(pickle.dumps([], 0))) # empty value is [], serialized with SerialPickleableSlot
        bookmark.attrs["version"] = 1
        group["ClassifierFactory"] = VIGRA_ILP_CLASSIFIER_FACTORY
        group["LabelColors"] = LabelColors
        group["LabelColors"].attrs["isEmpty"] = False
        group["LabelNames"] = [label.name.encode("utf8") for label in self.labels]
        group["LabelNames"].attrs["isEmpty"] = False
        group["PmapColors"] = LabelColors
        group["PmapColors"].attrs["isEmpty"] = False
        group["StorageVersion"] = "0.1".encode("utf8")

        merged_annotation_tiles: Dict[DataSource, Dict[Interval5D, Array5D]] = {}
        for label_class, label in enumerate(self.labels, start=1):
            for annotation in label.annotations:
                datasource = annotation.raw_data
                merged_tiles = merged_annotation_tiles.setdefault(datasource, {})

                for interval in annotation.interval.get_tiles(
                    tile_shape=datasource.tile_shape.updated(c=1), tiles_origin=datasource.interval.start.updated(c=0)
                ):
                    annotation_tile = annotation.cut(interval.clamped(annotation.interval))
                    tile = merged_tiles.setdefault(interval, Array5D.allocate(interval=interval, value=0, dtype=np.dtype("uint8")))
                    tile.set(annotation_tile.colored(np.uint8(label_class)), mask_value=0)

        LabelSets = group.create_group("LabelSets")
        for lane_index, (lane_datasource, blocks) in enumerate(merged_annotation_tiles.items()):
            assert isinstance(lane_datasource, FsDataSource) #FIXME? how do autocontext annotations work? They wouldn't be on FsDataSource
            axiskeys = lane_datasource.c_axiskeys_on_disk
            label_set = LabelSets.create_group(f"labels{lane_index:03}")
            for block_index, block in enumerate(blocks.values()):
                labels_dataset = label_set.create_dataset(f"block{block_index:04d}", data=block.raw(axiskeys))
                labels_dataset.attrs["blockSlice"] = "[" + ",".join(f"{slc.start}:{slc.stop}" for slc in block.interval.updated(c=0).to_slices(axiskeys)) + "]"
                labels_dataset.attrs["axistags"] = vigra.defaultAxistags(axiskeys).toJSON()
        if len(LabelSets.keys()) == 0:
            _ = LabelSets.create_group("labels000")  # empty labels still produce this in classic ilastik

        if self.classifier:
            # ['Forest0000', ..., 'Forest000N', 'feature_names', 'known_labels', 'pickled_type']
            ClassifierForests = group.create_group("ClassifierForests")

            feature_names: List[bytes] = []
            get_feature_extractor_order: Callable[[IlpFilter], int] = lambda ex: self.feature_classes.index(ex.__class__)
            for fe in sorted(self.classifier.feature_extractors, key=get_feature_extractor_order):
                for c in range(self.classifier.num_input_channels * fe.channel_multiplier):
                    feature_names.append(self.make_feature_ilp_name(fe, channel_index=c).encode("utf8"))

            for forest_index, forest_bytes in enumerate(self.classifier.forest_h5_bytes):
                forests_h5_path = dump_to_temp_file(forest_bytes)
                with h5py.File(forests_h5_path, "r") as f:
                    forest_group = f["/"]
                    assert isinstance(forest_group, h5py.Group)
                    ClassifierForests.copy(forest_group, f"Forest{forest_index:04}") # 'Forest0000', ..., 'Forest000N'

            ClassifierForests["feature_names"] = feature_names
            ClassifierForests["known_labels"] = np.asarray(self.classifier.classes).astype(np.uint32)
            ClassifierForests["pickled_type"] = b"clazyflow.classifiers.parallelVigraRfLazyflowClassifier\nParallelVigraRfLazyflowClassifier\np0\n."
Example #8
0
 def _allocate(self, interval: Union[Shape5D, Interval5D], fill_value: int, axiskeys_hint: str = "tzyxc") -> Array5D:
     return Array5D.allocate(interval, dtype=self.dtype, value=fill_value, axiskeys=axiskeys_hint)
Example #9
0
 def enlarged(self, radius: Point5D, limits: Interval5D) -> "ConnectedComponents":
     """Enlarges the array by 'radius', and fills this halo with zero"""
     haloed_roi = self.interval.enlarged(radius).clamped(limits)
     haloed_data = Array5D.allocate(haloed_roi, value=0, dtype=self.dtype)
     haloed_data.set(self)
     return ConnectedComponents.from_array5d(haloed_data, labels=self.labels)