Ejemplo n.º 1
0
def reorder_axes(input_arr: numpy.ndarray, *, from_axes_tags: str, to_axes_tags: str):
    if isinstance(from_axes_tags, AxisTags):
        from_axes_tags = "".join(from_axes_tags.keys())

    if isinstance(to_axes_tags, AxisTags):
        to_axes_tags = "".join(to_axes_tags.keys())

    op = OpReorderAxes(graph=Graph())

    tagged_arr = vigra.VigraArray(input_arr, axistags=vigra.defaultAxistags(from_axes_tags))
    op.Input.setValue(tagged_arr)
    op.AxisOrder.setValue(to_axes_tags)

    return op.Output([]).wait()
Ejemplo n.º 2
0
class ObjectExtractionTimeComparison(object):
    def __init__(self):
        # Set memory and number of threads here
        #lazyflow.request.Request.reset_thread_pool(2)
        #Memory.setAvailableRam(500*1024**2)

        binary_img = binaryImage()
        raw_img = rawImage()

        g = Graph()

        # Reorder axis operators
        self.op5Raw = OpReorderAxes(graph=g)
        self.op5Raw.AxisOrder.setValue("txyzc")
        #self.op5Raw.Input.connect(self.opReaderRaw.OutputImage)#self.opReaderRaw.OutputImage)
        self.op5Raw.Input.setValue(raw_img)

        self.op5Binary = OpReorderAxes(graph=g)
        self.op5Binary.AxisOrder.setValue("txyzc")
        #self.op5Binary.Input.connect(self.opReaderBinary.OutputImage)
        self.op5Binary.Input.setValue(binary_img)

        # Cache operators
        self.opCacheRaw = OpBlockedArrayCache(graph=g)
        self.opCacheRaw.Input.connect(self.op5Raw.Output)
        self.opCacheRaw.BlockShape.setValue((1, ) +
                                            self.op5Raw.Output.meta.shape[1:])

        self.opCacheBinary = OpBlockedArrayCache(graph=g)
        self.opCacheBinary.Input.connect(self.op5Binary.Output)
        self.opCacheBinary.BlockShape.setValue(
            (1, ) + self.op5Binary.Output.meta.shape[1:])

        # Label volume operator
        self.opLabel = OpLabelVolume(graph=g)
        self.opLabel.Input.connect(self.op5Binary.Output)
        #self.opLabel.Input.connect(self.opCacheBinary.Output)

        # Object extraction
        self.opObjectExtraction = OpObjectExtraction(graph=g)
        self.opObjectExtraction.RawImage.connect(self.op5Raw.Output)
        self.opObjectExtraction.BinaryImage.connect(self.op5Binary.Output)
        self.opObjectExtraction.Features.setValue(FEATURES)

        # Simplified object features operator (No overhead)
        self.opObjectFeaturesSimp = OpObjectFeaturesSimplified(graph=g)
        self.opObjectFeaturesSimp.RawVol.connect(self.opCacheRaw.Output)
        self.opObjectFeaturesSimp.BinaryVol.connect(self.opCacheBinary.Output)

    def run(self):

        #         # Load caches beforehand (To remove overhead of reading frames)
        #         with Timer() as timerCaches:
        #             rawVol = self.opCacheRaw.Output([]).wait()
        #             binaryVol = self.opCacheBinary.Output([]).wait()
        #
        #         print "Caches took {} secs".format(timerCaches.seconds())
        #
        #         del rawVol
        #         del binaryVol

        # Profile object extraction simplified
        print(
            "\nStarting object extraction simplified (single-thread, without cache)"
        )

        with Timer() as timerObjectFeaturesSimp:
            featsObjectFeaturesSimp = self.opObjectFeaturesSimp.Features(
                []).wait()

        print("Simplified object extraction took: {} seconds".format(
            timerObjectFeaturesSimp.seconds()))

        # Profile object extraction optimized
        print("\nStarting object extraction (multi-thread, without cache)")

        with Timer() as timerObjectExtraction:
            featsObjectExtraction = self.opObjectExtraction.RegionFeatures(
                []).wait()

        print("Object extraction took: {} seconds".format(
            timerObjectExtraction.seconds()))

        # Profile for basic multi-threaded feature computation
        # just a multi-threaded loop that labels volumes and extract object features directly (No operators, no plugin system, no overhead, just a loop)
        featsBasicFeatureComp = dict.fromkeys(
            list(range(self.op5Raw.Output.meta.shape[0])), None)

        print("\nStarting basic multi-threaded feature computation")
        pool = RequestPool()
        for t in range(0, self.op5Raw.Output.meta.shape[0], 1):
            pool.add(
                Request(
                    partial(self._computeObjectFeatures, t,
                            featsBasicFeatureComp)))

        with Timer() as timerBasicFeatureComp:
            pool.wait()

        print(
            "Basic multi-threaded feature extraction took: {} seconds".format(
                timerBasicFeatureComp.seconds()))

    # Compute object features for single frame
    def _computeObjectFeatures(self, t, result):
        roi = [slice(None) for i in range(len(self.op5Raw.Output.meta.shape))]
        roi[0] = slice(t, t + 1)
        roi = tuple(roi)

        #         rawVol = self.opCacheRaw.Output(roi).wait()
        #         binaryVol = self.opCacheBinary.Output(roi).wait()

        rawVol = self.op5Raw.Output(roi).wait()
        binaryVol = self.op5Binary.Output(roi).wait()

        features = [
            'Count', 'Coord<Minimum>', 'RegionCenter',
            'Coord<Principal<Kurtosis>>', 'Coord<Maximum>'
        ]

        for i in range(t, t + 1):
            labelVol = vigra.analysis.labelImageWithBackground(
                binaryVol[i - t].squeeze(), background_value=int(0))
            res = vigra.analysis.extractRegionFeatures(
                rawVol[i - t].squeeze().astype(np.float32),
                labelVol.squeeze().astype(np.uint32),
                features,
                ignoreLabel=0)

            # Cleanup results (as done in vigra_objfeats)
            local_features = [x for x in features if "Global<" not in x]
            nobj = res[local_features[0]].shape[0]
            result[i] = cleanup(res, nobj, features)
Ejemplo n.º 3
0
class TikTorchLazyflowClassifier(LazyflowPixelwiseClassifierABC):
    HDF5_GROUP_FILENAME = "pytorch_network_path"

    def __init__(self,
                 tiktorch_net,
                 filename=None,
                 HALO_SIZE=32,
                 BATCH_SIZE=3):
        """
        Args:
            tiktorch_net (tiktorch): tiktorch object to be loaded into this
              classifier object
            filename (None, optional): Save file name for future reference
        """
        self._filename = filename
        if self._filename is None:
            self._filename = ""

        self.HALO_SIZE = HALO_SIZE
        self.BATCH_SIZE = BATCH_SIZE

        if tiktorch_net is None:
            print(self._filename)
            tiktorch_net = TikTorch.unserialize(self._filename)

        # print (self._filename)

        # assert tiktorch_net.return_hypercolumns == False
        # print('blah')

        self._tiktorch_net = tiktorch_net

        self._opReorderAxes = OpReorderAxes(graph=Graph())
        self._opReorderAxes.AxisOrder.setValue("zcyx")

    def predict_probabilities_pixelwise(self,
                                        feature_image,
                                        roi,
                                        axistags=None):
        """
        Implicitly assumes that feature_image is includes the surrounding HALO!
        roi must be chosen accordingly
        """
        logger.info(
            f"predicting using pytorch network for image of shape {feature_image.shape} and roi {roi}"
        )
        logger.info(
            f"Stats of input: min={feature_image.min()}, max={feature_image.max()}, mean={feature_image.mean()}"
        )
        logger.info(
            f"expected pytorch input shape is {self._tiktorch_net.expected_input_shape}"
        )
        logger.info(
            f"expected pytorch output shape is {self._tiktorch_net.expected_output_shape}"
        )

        # print(self._tiktorch_net.expected_input_shape)
        # print(self._tiktorch_net.expected_output_shape)

        num_channels = len(self.known_classes)
        expected_shape = [stop - start for start, stop in zip(roi[0], roi[1])
                          ] + [num_channels]

        self._opReorderAxes.Input.setValue(
            vigra.VigraArray(feature_image, axistags=axistags))
        self._opReorderAxes.AxisOrder.setValue("zcyx")
        reordered_feature_image = self._opReorderAxes.Output([]).wait()

        # normalizing patch
        # reordered_feature_image = (reordered_feature_image - reordered_feature_image.mean()) / (reordered_feature_image.std() + 0.000001)

        if len(self._tiktorch_net.get("window_size")) == 2:
            exp_input_shape = numpy.array(
                self._tiktorch_net.expected_input_shape)
            exp_input_shape = tuple(numpy.append(1, exp_input_shape))
            print(exp_input_shape)
        else:
            exp_input_shape = self._tiktorch_net.expected_input_shape

        logger.info(
            f"input axistags are {axistags}, "
            f"Shape after reordering input is {reordered_feature_image.shape}, "
            f"axistags are {self._opReorderAxes.Output.meta.axistags}")

        slice_shape = list(reordered_feature_image.shape[1::])  # ignore z axis
        # assuming [z, y, x]
        result_roi = numpy.array(roi)
        if slice_shape != list(exp_input_shape[1::]):
            logger.info(f"Expected input shape is {exp_input_shape[1::]}, "
                        f"but got {slice_shape}, reshaping...")

            # adding a zero border to images that have the specific shape

            exp_shape = list(self._tiktorch_net.expected_input_shape[1::])
            zero_img = numpy.zeros(exp_shape)

            # diff shape: cyx
            diff_shape = numpy.array(
                exp_input_shape[1::]) - numpy.array(slice_shape)
            # diff_shape = numpy.array(self._tiktorch_net.expected_input_shape) - numpy.array(slice_shape)
            # offset shape z, y, x, c for easy indexing, with c = 0, z = 0
            offset = numpy.array([0, 0, 0, 0])
            logger.info(f"Diff_shape {diff_shape}")

            # at least one of y, x (diff_shape[1], diff_shape[2]) should be off
            # let's determine how to adjust the offset -> offset[2] and offset[3]
            # caveat: this code assumes that image requests were tiled in a regular
            # pattern starting from left upper corner.
            # We use a blocked array-cache to achieve that
            # y-offset:
            if diff_shape[1] > 0:
                # was the halo added to the upper side of the feature image?
                # HACK: this only works because we assume the data to be in zyx!!!
                if roi[0][1] == 0:
                    # no, doesn't seem like it
                    offset[1] = self.HALO_SIZE

            # x-offsets:
            if diff_shape[2] > 0:
                # was the halo added to the upper side of the feature image?
                # HACK: this only works because we assume the data to be in zyx!!!
                if roi[0][2] == 0:
                    # no, doesn't seem like it
                    offset[2] = self.HALO_SIZE

            # HACK: still assuming zyxc
            result_roi[0] += offset[0:3]
            result_roi[1] += offset[0:3]
            reorder_feature_image_extents = numpy.array(
                reordered_feature_image.shape)
            # add the offset:
            reorder_feature_image_extents[2:4] += offset[1:3]
            # zero_img[:, :, offset[1]:reorder_feature_image_extents[2], offset[2]:reorder_feature_image_extents[3]] = \
            #     reordered_feature_image

            # reordered_feature_image = zero_img

            pad_img = numpy.pad(
                reordered_feature_image,
                [
                    (0, 0),
                    (0, 0),
                    (offset[1],
                     exp_input_shape[2] - reorder_feature_image_extents[2]),
                    (offset[2],
                     exp_input_shape[3] - reorder_feature_image_extents[3]),
                ],
                "reflect",
            )

            reordered_feature_image = pad_img

            logger.info(f"New Image shape {reordered_feature_image.shape}")

        result = numpy.zeros([reordered_feature_image.shape[0], num_channels] +
                             list(reordered_feature_image.shape[2:]))

        logger.info(f"forward")

        # we always predict in 2D, per z-slice, so we loop over z
        for z in range(0, reordered_feature_image.shape[0], self.BATCH_SIZE):
            # logger.warning("Dumping to {}".format('"/Users/chaubold/Desktop/dump.h5"'))
            # vigra.impex.writeHDF5(reordered_feature_image[z,...], "data", "/Users/chaubold/Desktop/dump.h5")

            # create batch of desired num slices. Multiple slices can be processed on multiple GPUs!
            batch = [
                reordered_feature_image[zi:zi + 1, ...].reshape(
                    self._tiktorch_net.expected_input_shape)
                for zi in range(
                    z,
                    min(z + self.BATCH_SIZE, reordered_feature_image.shape[0]))
            ]
            logger.info(f"batch info: {[x.shape for x in batch]}")

            print("batch info:", [x.shape for x in batch])

            # if len(self._tiktorch_net.get('window_size')) == 2:
            #     print("BATTCHHHHH", batch.shape)

            result_batch = self._tiktorch_net.forward(batch)
            logger.info(
                f"Resulting slices from {z} to {z + len(batch)} have shape {result_batch[0].shape}"
            )

            print("Resulting slices from ", z, " to ", z + len(batch),
                  " have shape ", result_batch[0].shape)

            for i, zi in enumerate(range(z, (z + len(batch)))):
                result[zi:(zi + 1), ...] = result_batch[i]

        logger.info(f"Obtained a predicted block of shape {result.shape}")

        print("Obtained a predicted block of shape ", result.shape)

        self._opReorderAxes.Input.setValue(
            vigra.VigraArray(result, axistags=vigra.makeAxistags("zcyx")))
        # axistags is vigra.AxisTags, but opReorderAxes expects a string
        self._opReorderAxes.AxisOrder.setValue("".join(axistags.keys()))
        result = self._opReorderAxes.Output([]).wait()
        logger.info(f"Reordered result to shape {result.shape}")

        # FIXME: not needed for real neural net results:
        logger.info(
            f"Stats of result: min={result.min()}, max={result.max()}, mean={result.mean()}"
        )

        # cut out the required roi
        logger.info(f"Roi shape {result_roi}")

        # crop away halo and reorder axes to match "axistags"
        # crop in X and Y:
        cropped_result = result[roiToSlice(*result_roi)]

        logger.info(
            f"cropped the predicted block to shape {cropped_result.shape}")

        return cropped_result

    @property
    def known_classes(self):
        return list(range(self._tiktorch_net.expected_output_shape[0]))

    @property
    def feature_count(self):
        return self._tiktorch_net.expected_input_shape[0]

    def get_halo_shape(self, data_axes="zyxc"):
        if len(data_axes) == 4:
            return (0, self.HALO_SIZE, self.HALO_SIZE, 0)
        # FIXME: assuming 'yxc' !
        elif len(data_axes) == 3:
            return (self.HALO_SIZE, self.HALO_SIZE, 0)

    def serialize_hdf5(self, h5py_group):
        logger.debug("Serializing")
        h5py_group[self.HDF5_GROUP_FILENAME] = self._filename
        h5py_group["pickled_type"] = pickle.dumps(type(self), 0)

        # HACK: can this be done more elegantly?
        with tempfile.TemporaryFile() as f:
            self._tiktorch_net.serialize(f)
            f.seek(0)
            h5py_group["classifier"] = numpy.void(f.read())

    @classmethod
    def deserialize_hdf5(cls, h5py_group):
        # TODO: load from HDF5 instead of hard coded path!
        logger.debug("Deserializing")
        # HACK:
        # filename = PYTORCH_MODEL_FILE_PATH
        filename = h5py_group[cls.HDF5_GROUP_FILENAME]
        logger.debug("Deserializing from {}".format(filename))

        with tempfile.TemporaryFile() as f:
            f.write(h5py_group["classifier"].value)
            f.seek(0)
            loaded_pytorch_net = TikTorch.unserialize(f)

        return TikTorchLazyflowClassifier(loaded_pytorch_net, filename)
class TestOpReorderAxes(unittest.TestCase):
    def setUp(self):
        self.array = None
        self.axis = list('tzyxc')
        self.tests = 20
        graph = Graph()
        self.operator = OpReorderAxes(graph=graph)

    def prepareVolnOp(self, possible_axes='tzyxc', num=5):
        tags = random.sample(possible_axes, random.randint(2, num))
        tagStr = ''
        for s in tags:
            tagStr += s
        axisTags = vigra.defaultAxistags(tagStr)

        self.shape = []
        for tag in axisTags:
            self.shape.append(random.randint(20, 30))

        self.array = (numpy.random.rand(*tuple(self.shape)) * 255)
        self.array = (float(250) / 255 * self.array + 5).astype(int)
        self.inArray = vigra.VigraArray(self.array, axistags=axisTags)

        opProvider = OpArrayProvider(graph=self.operator.graph)
        opProvider.Input.setValue(self.inArray)
        self.operator.Input.connect(opProvider.Output)

    def test_Full(self):
        for i in range(self.tests):
            self.prepareVolnOp()
            result = self.operator.Output().wait()
            logger.debug(
                '------------------------------------------------------')
            logger.debug("self.array.shape = " + str(self.array.shape))
            logger.debug("type(input) == " +
                         str(type(self.operator.Input.value)))
            logger.debug("input.shape == " +
                         str(self.operator.Input.meta.shape))
            logger.debug("Input Tags:")
            logger.debug(str(self.operator.Input.meta.axistags))
            logger.debug("Output Tags:")
            logger.debug(str(self.operator.Output.meta.axistags))
            logger.debug("type(result) == " + str(type(result)))
            logger.debug("result.shape == " + str(result.shape))
            logger.debug(
                '------------------------------------------------------')

            # Check the shape
            assert len(result.shape) == 5

            assert not isinstance(result, vigra.VigraArray), \
                "For compatibility with generic code, output should be provided as a plain numpy array."

            # Ensure the result came out in default order
            assert self.operator.Output.meta.axistags == vigra.defaultAxistags(
                'tzyxc')

            # Check the data
            vresult = result.view(vigra.VigraArray)
            vresult.axistags = self.operator.Output.meta.axistags
            reorderedInput = self.inArray.withAxes(
                *[tag.key for tag in vresult.axistags])
            assert numpy.all(vresult == reorderedInput)

    def test_Roi_default_order(self):
        for i in range(self.tests):
            self.prepareVolnOp()
            shape = self.operator.Output.meta.shape
            roi = [None, None]
            roi[1] = [
                numpy.random.randint(2, s) if s != 1 else 1 for s in shape
            ]
            roi[0] = [
                numpy.random.randint(0, roi[1][i]) if s != 1 else 0
                for i, s in enumerate(shape)
            ]
            roi[0] = TinyVector(roi[0])
            roi[1] = TinyVector(roi[1])
            result = self.operator.Output(roi[0], roi[1]).wait()
            logger.debug(
                '------------------------------------------------------')
            logger.debug("self.array.shape = " + str(self.array.shape))
            logger.debug("type(input) == " +
                         str(type(self.operator.Input.value)))
            logger.debug("input.shape == " +
                         str(self.operator.Input.meta.shape))
            logger.debug("Input Tags:")
            logger.debug(str(self.operator.Input.meta.axistags))
            logger.debug("Output Tags:")
            logger.debug(str(self.operator.Output.meta.axistags))
            logger.debug("roi= " + str(roi))
            logger.debug("type(result) == " + str(type(result)))
            logger.debug("result.shape == " + str(result.shape))
            logger.debug(
                '------------------------------------------------------')

            # Check the shape
            assert len(result.shape) == 5
            assert not isinstance(result, vigra.VigraArray), \
                "For compatibility with generic code, output should be provided as a plain numpy array."

            # Ensure the result came out in volumina order
            assert self.operator.Output.meta.axistags == vigra.defaultAxistags(
                'tzyxc')

            # Check the data
            vresult = result.view(vigra.VigraArray)
            vresult.axistags = self.operator.Output.meta.axistags
            reorderedInput = self.inArray.withAxes(
                *[tag.key for tag in self.operator.Output.meta.axistags])
            assert numpy.all(
                vresult == reorderedInput[roiToSlice(roi[0], roi[1])])

    def test_Roi_custom_order(self):
        self._impl_roi_custom_order('cztxy')
        self._impl_roi_custom_order('xyz')

    def _impl_roi_custom_order(self, axisorder):
        for i in range(self.tests):
            self.prepareVolnOp(axisorder, len(axisorder) - 1)

            # Specify a strange order for the output axis tags
            self.operator.AxisOrder.setValue(axisorder)
            shape = self.operator.Output.meta.shape

            roi = [None, None]
            roi[1] = [
                numpy.random.randint(2, s) if s != 1 else 1 for s in shape
            ]
            roi[0] = [
                numpy.random.randint(0, roi[1][i]) if s != 1 else 0
                for i, s in enumerate(shape)
            ]
            roi[0] = TinyVector(roi[0])
            roi[1] = TinyVector(roi[1])
            result = self.operator.Output(roi[0], roi[1]).wait()
            logger.debug(
                '------------------------------------------------------')
            logger.debug("self.array.shape = " + str(self.array.shape))
            logger.debug("type(input) == " +
                         str(type(self.operator.Input.value)))
            logger.debug("input.shape == " +
                         str(self.operator.Input.meta.shape))
            logger.debug("Input Tags:")
            logger.debug(str(self.operator.Input.meta.axistags))
            logger.debug("Output Tags:")
            logger.debug(str(self.operator.Output.meta.axistags))
            logger.debug("roi= " + str(roi))
            logger.debug("type(result) == " + str(type(result)))
            logger.debug("result.shape == " + str(result.shape))
            logger.debug(
                '------------------------------------------------------')

            # Check the shape
            assert len(result.shape) == len(axisorder)

            assert not isinstance(result, vigra.VigraArray), \
                "For compatibility with generic code, output should be provided as a plain numpy array."

            # Ensure the result came out in the same strange order we asked for.
            assert self.operator.Output.meta.axistags == vigra.defaultAxistags(
                axisorder)

            # Check the data
            vresult = result.view(vigra.VigraArray)
            vresult.axistags = self.operator.Output.meta.axistags
            reorderedInput = self.inArray.withAxes(
                *[tag.key for tag in self.operator.Output.meta.axistags])
            assert numpy.all(
                vresult == reorderedInput[roiToSlice(roi[0], roi[1])])

    def test_insert_singleton_axis(self):
        for i in range(self.tests):
            self.prepareVolnOp('xyzc', 4)

            # Specify a strange order for the output axis tags
            self.operator.AxisOrder.setValue('yxtzc')
            shape = self.operator.Output.meta.shape

            roi = [None, None]
            roi[1] = [
                numpy.random.randint(2, s) if s != 1 else 1 for s in shape
            ]
            roi[0] = [
                numpy.random.randint(0, roi[1][i]) if s != 1 else 0
                for i, s in enumerate(shape)
            ]
            roi[0] = TinyVector(roi[0])
            roi[1] = TinyVector(roi[1])
            result = self.operator.Output(roi[0], roi[1]).wait()
            logger.debug(
                '------------------------------------------------------')
            logger.debug("self.array.shape = " + str(self.array.shape))
            logger.debug("type(input) == " +
                         str(type(self.operator.Input.value)))
            logger.debug("input.shape == " +
                         str(self.operator.Input.meta.shape))
            logger.debug("Input Tags:")
            logger.debug(str(self.operator.Input.meta.axistags))
            logger.debug("Output Tags:")
            logger.debug(str(self.operator.Output.meta.axistags))
            logger.debug("roi= " + str(roi))
            logger.debug("type(result) == " + str(type(result)))
            logger.debug("result.shape == " + str(result.shape))
            logger.debug(
                '------------------------------------------------------')

            # Check the shape
            assert len(result.shape) == 5

            assert not isinstance(result, vigra.VigraArray), \
                "For compatibility with generic code, output should be provided as a plain numpy array."

            # Ensure the result came out in the same strange order we asked for.
            assert self.operator.Output.meta.axistags == vigra.defaultAxistags(
                'yxtzc')

            # Check the data
            vresult = result.view(vigra.VigraArray)
            vresult.axistags = self.operator.Output.meta.axistags
            reorderedInput = self.inArray.withAxes(
                *[tag.key for tag in self.operator.Output.meta.axistags])
            assert numpy.all(
                vresult == reorderedInput[roiToSlice(roi[0], roi[1])])

    def test_attempt_drop_nonsingleton_axis(self):
        """
        Attempt to configure the operator with invalid settings by trying to drop a non-singleton axis.
        The execute method should assert in that case.
        """
        data = numpy.zeros((100, 100, 100), dtype=numpy.uint8)
        data = vigra.taggedView(data, vigra.defaultAxistags('xyz'))

        op = OpReorderAxes(graph=Graph())
        op.Input.setValue(data)

        # Attempt to drop some axes that can't be dropped.
        op.AxisOrder.setValue('txc')

        # Make sure this results in an error.
        req = op.Output[:]
        req.notify_failed(
            lambda *args: None
        )  # We expect an exception here, so disable the default fail handler to hide the traceback
        self.assertRaises(AssertionError, req.wait)
Ejemplo n.º 5
0
class OpExportMultipageTiff(Operator):
    Input = InputSlot(
    )  # The last two non-singleton axes (except 'c') are the axes of the 'pages'.
    # Re-order the axes yourself if you want an alternative slicing direction
    Filepath = InputSlot()

    DEFAULT_BATCH_SIZE = 4

    def __init__(self, *args, **kwargs):
        super(OpExportMultipageTiff, self).__init__(*args, **kwargs)
        self.progressSignal = OrderedSignal()
        self._opReorderAxes = OpReorderAxes(parent=self)
        self._opReorderAxes.Input.connect(self.Input)

    def setupOutputs(self):
        # Always export in tzcyx order (but omit missing axes)
        input_axes = self.Input.meta.getAxisKeys()
        export_axes = "".join(filter(lambda k: k in input_axes, 'tzcyx'))
        if not set("yx").issubset(set(export_axes)):
            # This could potentially be fixed...
            raise Exception(
                "I don't know how to export data without both an X and Y axis")

        self._opReorderAxes.AxisOrder.setValue(export_axes)
        self._export_axes = export_axes

    def run_export(self):
        """
        Request the volume in slices (running in parallel), and write each slice to the correct page.
        Note: We can't use BigRequestStreamer here, because the data for each slice wouldn't be 
              guaranteed to arrive in the correct order.
        """
        # Delete existing image if present
        image_path = self.Filepath.value
        if os.path.exists(image_path):
            os.remove(image_path)

        tagged_shape = self.Input.meta.getTaggedShape()
        export_shape = self._opReorderAxes.Output.meta.shape
        shape_yx = export_shape[-2:]
        stacked_axes_shape = export_shape[:-2]
        num_pages = numpy.prod(stacked_axes_shape)

        def create_slice_req():
            for stacked_axes_ndindex in numpy.ndindex(*stacked_axes_shape):
                roi = numpy.zeros((2, ) + (len(export_shape), ), dtype=int)
                roi[:, :-2] = stacked_axes_ndindex
                roi[1, :-2] += 1
                roi[1, -2:] = shape_yx
                yield self._opReorderAxes.Output(*roi)

        iter_slice_requests = create_slice_req()

        parallel_requests = self.DEFAULT_BATCH_SIZE

        # If ram usage info is available, make a better guess about how many requests we can launch in parallel
        ram_usage_per_requested_pixel = self.Input.meta.ram_usage_per_requested_pixel
        if ram_usage_per_requested_pixel is not None:
            pixels_per_slice = numpy.prod(shape_yx)
            if 'c' in tagged_shape:
                pixels_per_slice /= tagged_shape['c']

            ram_usage_per_slice = pixels_per_slice * ram_usage_per_requested_pixel

            # Fudge factor: Reduce RAM usage by a bit
            available_ram = psutil.virtual_memory().available
            available_ram *= 0.5

            parallel_requests = int(available_ram / ram_usage_per_slice)

        # Start with a batch of images
        reqs = collections.deque()
        for next_request_index in range(min(parallel_requests, num_pages)):
            reqs.append(iter_slice_requests.next())

        self.progressSignal(0)
        pages_written = 0
        while reqs:
            self.progressSignal(100 * next_request_index / num_pages)
            req = reqs.popleft()
            slice_data = req.wait()
            slice_data = vigra.taggedView(slice_data, self._export_axes)
            next_request_index += 1

            # Add a new request to the batch
            if next_request_index < num_pages:
                reqs.append(iter_slice_requests.next())

            if pages_written == 0:
                xml_description = OpExportMultipageTiff.generate_ome_xml_description(
                    self._opReorderAxes.Output.meta.getAxisKeys(),
                    self._opReorderAxes.Output.meta.shape,
                    self._opReorderAxes.Output.meta.dtype,
                    os.path.split(image_path)[1])
                # Write the first slice with tifffile, which allows us to write the tags.
                with tifffile.TiffWriter(image_path,
                                         software='ilastik',
                                         byteorder='<') as writer:
                    writer.save(slice_data.withAxes('yx'),
                                description=xml_description,
                                planarconfig='planar')
            else:
                # Append a slice to the multipage tiff file
                vigra.impex.writeImage(slice_data.withAxes('yx'),
                                       image_path,
                                       dtype='',
                                       compression='NONE',
                                       mode='a')
            pages_written += 1

        self.progressSignal(100)

    # No output slots...
    def execute(self, slot, subindex, roi, result):
        pass

    def propagateDirty(self, slot, subindex, roi):
        pass

    @classmethod
    def generate_ome_xml_description(cls, axes, shape, dtype, filename=''):
        """
        Generate an OME XML description of the data we're exporting,
        suitable for the image_description tag of the first page.

        axes and shape should be provided in C-order (will be reversed in the XML)
        """
        import uuid
        import xml.etree.ElementTree as ET

        # Normalize the inputs
        axes = "".join(axes)
        shape = tuple(shape)
        if not isinstance(dtype, type):
            dtype = dtype().type

        ome = ET.Element('OME')
        uuid_str = "urn:uuid:" + str(uuid.uuid1())
        ome.set('UUID', uuid_str)
        ome.set('xmlns:xsi', "http://www.w3.org/2001/XMLSchema-instance")
        ome.set(
            'xsi:schemaLocation',
            "http://www.openmicroscopy.org/Schemas/OME/2015-01 "
            "http://www.openmicroscopy.org/Schemas/OME/2015-01/ome.xsd")

        image = ET.SubElement(ome, 'Image')
        image.set('ID', 'Image:0')
        image.set('Name', 'exported-data')

        pixels = ET.SubElement(image, 'Pixels')
        pixels.set('BigEndian', 'true')
        pixels.set('ID', 'Pixels:0')

        fortran_axes = "".join(reversed(axes)).upper()
        pixels.set('DimensionOrder', fortran_axes)

        for axis, dim in zip(axes.upper(), shape):
            pixels.set('Size' + axis, str(dim))

        types = {
            numpy.uint8: 'uint8',
            numpy.uint16: 'uint16',
            numpy.uint32: 'uint32',
            numpy.int8: 'int8',
            numpy.int16: 'int16',
            numpy.int32: 'int32',
            numpy.float32: 'float',
            numpy.float64: 'double',
            numpy.complex64: 'complex',
            numpy.complex128: 'double-complex'
        }

        pixels.set('Type', types[dtype])

        # Omit channel information (is that okay?)
        # channel0 = ET.SubElement(pixels, "Channel")
        # channel0.set("ID", "Channel0:0")
        # channel0.set("SamplesPerPixel", "1")

        assert axes[-2:] == "yx"
        for page_index, page_ndindex in enumerate(numpy.ndindex(*shape[:-2])):
            tiffdata = ET.SubElement(pixels, "TiffData")
            for axis, index in zip(axes[:-2].upper(), page_ndindex):
                tiffdata.set("First" + axis, str(index))
            tiffdata.set("PlaneCount", "1")
            tiffdata.set("IFD", str(page_index))
            uuid_tag = ET.SubElement(tiffdata, "UUID")
            uuid_tag.text = uuid_str
            uuid_tag.set('FileName', filename)

        from textwrap import dedent
        from StringIO import StringIO
        xml_stream = StringIO()
        comment = ET.Comment(
            dedent(
                '\
            <!-- Warning: this comment is an OME-XML metadata block, which contains crucial '
                'dimensional parameters and other important metadata. Please edit cautiously '
                '(if at all), and back up the original data before doing so. For more information, '
                'see the OME-TIFF web site: http://ome-xml.org/wiki/OmeTiff. -->'
            ))

        tree = ET.ElementTree(ome)
        tree.write(xml_stream, encoding='utf-8', xml_declaration=True)

        if logger.isEnabledFor(logging.DEBUG):
            import xml.dom.minidom
            reparsed = xml.dom.minidom.parseString(xml_stream.getvalue())
            logger.debug("Generated OME-TIFF metadata:\n" +
                         reparsed.toprettyxml())

        return xml_stream.getvalue()