Exemple #1
0
    def testCleanup(self):
        try:
            CacheMemoryManager().disable()

            op = OpBlockedArrayCache(graph=self.opProvider.graph)
            op.Input.connect(self.opProvider.Output)
            s = self.opProvider.Output.meta.shape
            op.innerBlockShape.setValue(s)
            op.outerBlockShape.setValue(s)
            op.fixAtCurrent.setValue(False)
            x = op.Output[...].wait()
            op.Input.disconnect()
            op.cleanUp()

            r = weakref.ref(op)
            del op
            gc.collect()
            ref = r()
            if ref is not None:
                for i, o in enumerate(gc.get_referrers(ref)):
                    print "Object", i, ":", type(o), ":", o

            assert r(
            ) is None, "OpBlockedArrayCache was not cleaned up correctly"
        finally:
            CacheMemoryManager().enable()
Exemple #2
0
    def __init__(self, parent=None, update=True):
        QDialog.__init__(self, parent=parent)
        layout = QVBoxLayout()
        self.tree = QTreeWidget()
        layout.addWidget(self.tree)
        self.setLayout(layout)

        self._mgr = CacheMemoryManager()

        self._tracked_caches = {}

        # tree setup code
        self.tree.setHeaderLabels(
            ["cache", "memory", "roi", "dtype", "type", "info", "id"])
        self._idIndex = self.tree.columnCount() - 1
        self.tree.setColumnHidden(self._idIndex, True)
        self.tree.setSortingEnabled(True)
        self.tree.clear()

        self._root = TreeNode()

        # refresh every x seconds (see showEvent())
        self.timer = QTimer(self)
        if update:
            self.timer.timeout.connect(self._updateReport)
    def testCleanup(self):
        try:
            CacheMemoryManager().disable()

            sampleData = np.random.randint(0, 256, size=(50, 30, 10))
            sampleData = sampleData.astype(np.uint8)
            sampleData = vigra.taggedView(sampleData, axistags="xyz")

            graph = Graph()
            opData = OpArrayPiper(graph=graph)
            opData.Input.setValue(sampleData)

            op = OpLabelVolume(graph=graph)
            op.Input.connect(opData.Output)
            x = op.Output[...].wait()
            op.Input.disconnect()
            op.cleanUp()

            r = weakref.ref(op)
            del op
            gc.collect()
            ref = r()
            if ref is not None:
                for i, o in enumerate(gc.get_referrers(ref)):
                    print("Object", i, ":", type(o), ":", o)

            assert r() is None, "OpBlockedArrayCache was not cleaned up correctly"
        finally:
            CacheMemoryManager().enable()
 def teardown_method(self, method):
     # reset cleanup frequency to sane value
     # reset max memory
     Memory.setAvailableRamCaches(-1)
     mgr = CacheMemoryManager()
     mgr.setRefreshInterval(default_refresh_interval)
     mgr.enable()
     Request.reset_thread_pool()
    def testBlockedCacheHandling(self):
        n, k = 10, 5
        vol = np.zeros((n, ) * 5, dtype=np.uint8)
        vol = vigra.taggedView(vol, axistags='txyzc')

        g = Graph()
        pipe = OpArrayPiperWithAccessCount(graph=g)
        cache = OpBlockedArrayCache(graph=g)

        mgr = CacheMemoryManager()

        # restrict cache memory to 0 Byte
        Memory.setAvailableRamCaches(0)

        # set to frequent cleanup
        mgr.setRefreshInterval(.01)
        mgr.enable()

        cache.outerBlockShape.setValue((k, ) * 5)
        cache.Input.connect(pipe.Output)
        pipe.Input.setValue(vol)

        a = pipe.accessCount
        cache.Output[...].wait()
        b = pipe.accessCount
        assert b > a, "did not cache"

        # let the manager clean up
        mgr.enable()
        time.sleep(.5)
        gc.collect()

        cache.Output[...].wait()
        c = pipe.accessCount
        assert c > b, "did not clean up"
    def __init__(self, parent=None, update=True):
        QDialog.__init__(self, parent=parent)
        layout = QVBoxLayout()
        self.tree = QTreeWidget()
        layout.addWidget(self.tree)
        self.setLayout(layout)

        self._mgr = CacheMemoryManager()

        self._tracked_caches = {}

        # tree setup code
        self.tree.setHeaderLabels(
            ["cache", "memory", "roi", "dtype", "type", "info", "id"])
        self._idIndex = self.tree.columnCount() - 1
        self.tree.setColumnHidden(self._idIndex, True)
        self.tree.setSortingEnabled(True)
        self.tree.clear()

        self._root = TreeNode()

        # refresh every x seconds (see showEvent())
        self.timer = QTimer(self)
        if update:
            self.timer.timeout.connect(self._updateReport)
Exemple #7
0
        def _configure_lazyflow_settings():
            import lazyflow
            import lazyflow.request
            from lazyflow.utility import Memory
            from lazyflow.operators.cacheMemoryManager import CacheMemoryManager

            if status_interval_secs:
                memory_logger = logging.getLogger(
                    "lazyflow.operators.cacheMemoryManager")
                memory_logger.setLevel(logging.DEBUG)
                CacheMemoryManager().setRefreshInterval(status_interval_secs)

            if n_threads is not None:
                logger.info(f"Resetting lazyflow thread pool with {n_threads} "
                            "threads.")
                lazyflow.request.Request.reset_thread_pool(n_threads)
            if total_ram_mb > 0:
                if total_ram_mb < 500:
                    raise Exception("In your current configuration, RAM is "
                                    f"limited to {total_ram_mb} MB. Remember "
                                    "to specify RAM in MB, not GB.")
                ram = total_ram_mb * 1024**2
                fmt = Memory.format(ram)
                logger.info("Configuring lazyflow RAM limit to {}".format(fmt))
                Memory.setAvailableRam(ram)
    def testBlockedCacheHandling(self):
        n, k = 10, 5
        vol = np.zeros((n,) * 5, dtype=np.uint8)
        vol = vigra.taggedView(vol, axistags="txyzc")

        g = Graph()
        pipe = OpArrayPiperWithAccessCount(graph=g)
        cache = OpBlockedArrayCache(graph=g)

        mgr = CacheMemoryManager()

        # restrict cache memory to 0 Byte
        Memory.setAvailableRamCaches(0)

        # set to frequent cleanup
        mgr.setRefreshInterval(0.01)
        mgr.enable()

        cache.BlockShape.setValue((k,) * 5)
        cache.Input.connect(pipe.Output)
        pipe.Input.setValue(vol)

        a = pipe.accessCount
        cache.Output[...].wait()
        b = pipe.accessCount
        assert b > a, "did not cache"

        # let the manager clean up
        mgr.enable()
        time.sleep(0.5)
        gc.collect()

        cache.Output[...].wait()
        c = pipe.accessCount
        assert c > b, "did not clean up"
    def testCleanup(self):
        try:
            CacheMemoryManager().disable()
            
            op = OpSlicedBlockedArrayCache(graph=self.opProvider.graph)
            op.Input.connect(self.opProvider.Output)
            op.BlockShape.setValue(self.opCache.BlockShape.value)
            op.fixAtCurrent.setValue(False)
            x = op.Output[...].wait()
            op.Input.disconnect()
            op.cleanUp()

            r = weakref.ref(op)
            del op
            gc.collect()
            assert r() is None, "OpBlockedArrayCache was not cleaned up correctly"
        finally:
            CacheMemoryManager().enable()
 def tearDown(self):
     # reset cleanup frequency to sane value
     # reset max memory
     Memory.setAvailableRamCaches(-1)
     mgr = CacheMemoryManager()
     mgr.setRefreshInterval(default_refresh_interval)
     mgr.enable()
     Request.reset_thread_pool()
    def testCleanup(self):
        try:
            CacheMemoryManager().disable()
            sampleData = numpy.indices((100, 200, 150),
                                       dtype=numpy.float32).sum(0)
            sampleData = vigra.taggedView(sampleData, axistags='xyz')

            graph = Graph()
            opData = OpArrayPiper(graph=graph)
            opData.Input.setValue(sampleData)

            op = OpCompressedCache(graph=graph)
            #logger.debug("Setting block shape...")
            op.BlockShape.setValue([100, 75, 50])
            op.Input.connect(opData.Output)
            x = op.Output[...].wait()
            op.Input.disconnect()
            r = weakref.ref(op)
            del op
            gc.collect()
            assert r(
            ) is None, "OpBlockedArrayCache was not cleaned up correctly"
        finally:
            CacheMemoryManager().enable()
Exemple #12
0
class MemUsageDialog(QDialog):
    def __init__(self, parent=None, update=True):
        QDialog.__init__(self, parent=parent)
        layout = QVBoxLayout()
        self.tree = QTreeWidget()
        layout.addWidget(self.tree)
        self.setLayout(layout)

        self._mgr = CacheMemoryManager()

        self._tracked_caches = {}

        # tree setup code
        self.tree.setHeaderLabels(
            ["cache", "memory", "roi", "dtype", "type", "info", "id"])
        self._idIndex = self.tree.columnCount() - 1
        self.tree.setColumnHidden(self._idIndex, True)
        self.tree.setSortingEnabled(True)
        self.tree.clear()

        self._root = TreeNode()

        # refresh every x seconds (see showEvent())
        self.timer = QTimer(self)
        if update:
            self.timer.timeout.connect(self._updateReport)

    def _updateReport(self):
        # we keep track of dirty reports so we just have to update the tree
        # instead of reconstructing it
        reports = []
        for c in self._mgr.getFirstClassCaches():
            r = MemInfoNode()
            c.generateReport(r)
            reports.append(r)
        self._root.handleChildrenReports(
            reports, root=self.tree.invisibleRootItem())

    def hideEvent(self, event):
        self.timer.stop()

    def showEvent(self, show):
        # update once so we don't have to wait for initial report
        self._updateReport()
        # update every 5 sec.
        self.timer.start(5*1000)
Exemple #13
0
class MemUsageDialog(QDialog):
    def __init__(self, parent=None, update=True):
        QDialog.__init__(self, parent=parent)
        layout = QVBoxLayout()
        self.tree = QTreeWidget()
        layout.addWidget(self.tree)
        self.setLayout(layout)

        self._mgr = CacheMemoryManager()

        self._tracked_caches = {}

        # tree setup code
        self.tree.setHeaderLabels(
            ["cache", "memory", "roi", "dtype", "type", "info", "id"])
        self._idIndex = self.tree.columnCount() - 1
        self.tree.setColumnHidden(self._idIndex, True)
        self.tree.setSortingEnabled(True)
        self.tree.clear()

        self._root = TreeNode()

        # refresh every x seconds (see showEvent())
        self.timer = QTimer(self)
        if update:
            self.timer.timeout.connect(self._updateReport)

    def _updateReport(self):
        # we keep track of dirty reports so we just have to update the tree
        # instead of reconstructing it
        reports = []
        for c in self._mgr.getFirstClassCaches():
            r = MemInfoNode()
            c.generateReport(r)
            reports.append(r)
        self._root.handleChildrenReports(reports,
                                         root=self.tree.invisibleRootItem())

    def hideEvent(self, event):
        self.timer.stop()

    def showEvent(self, show):
        # update once so we don't have to wait for initial report
        self._updateReport()
        # update every 5 sec.
        self.timer.start(5 * 1000)
Exemple #14
0
    def testFixAtCurrent(self):
        try:
            CacheMemoryManager().disable()
            opCache = self.opCache
            opProvider = self.opProvider

            # Track dirty notifications
            gotDirtyRois = []

            def handleDirty(slot, roi):
                gotDirtyRois.append((roi.start, roi.stop))

            opCache.Output.notifyDirty(handleDirty)

            opCache.fixAtCurrent.setValue(True)

            oldAccessCount = 0
            assert opProvider.accessCount == oldAccessCount, "Access count={}, expected={}".format(
                opProvider.accessCount, oldAccessCount)

            # Request (no access to provider because fixAtCurrent)
            slicing = make_key[:, 0:50, 15:45, 0:1, :]
            data = opCache.Output(slicing).wait()
            assert opProvider.accessCount == oldAccessCount, "Access count={}, expected={}".format(
                opProvider.accessCount, oldAccessCount)

            # We haven't accessed this data yet,
            # but fixAtCurrent is True so the cache gives us zeros
            assert (data == 0).all()

            opCache.fixAtCurrent.setValue(False)

            def boundingBox(roiA, roiB):
                return (numpy.minimum(roiA[0], roiB[0]),
                        numpy.maximum(roiA[1], roiB[1]))

            # Since we got zeros while the cache was fixed, the requested
            #  tiles are signaled as dirty when the cache becomes unfixed.
            # Our only requirement here is that any dirty rois we got add up to encompass all the tiles we requested.
            dirty_bb = reduce(boundingBox, gotDirtyRois)
            requested_roi = sliceToRoi(slicing, opCache.Output.meta.shape)
            assert (dirty_bb[0] <= requested_roi[0]).all() and (
                dirty_bb[1] >= requested_roi[1]).all()

            # Request again.  Data should match this time.
            oldAccessCount = opProvider.accessCount
            data = opCache.Output(slicing).wait()
            data = data.view(vigra.VigraArray)
            data.axistags = opCache.Output.meta.axistags
            assert (data == self.data[slicing]).all()

            # Our slice intersects 3*3=9 outerBlocks, and a total of 20 innerBlocks
            # Inner caches are allowed to split up the accesses, so there could be as many as 20
            minAccess = oldAccessCount + 9
            maxAccess = oldAccessCount + 20
            assert opProvider.accessCount >= minAccess
            assert opProvider.accessCount <= maxAccess
            oldAccessCount = opProvider.accessCount

            # Request again.  Data should match WITHOUT requesting from the source.
            data = opCache.Output(slicing).wait()
            data = data.view(vigra.VigraArray)
            data.axistags = opCache.Output.meta.axistags
            assert (data == self.data[slicing]).all()
            assert opProvider.accessCount == oldAccessCount, "Access count={}, expected={}".format(
                opProvider.accessCount, oldAccessCount)

            # Freeze it again
            opCache.fixAtCurrent.setValue(True)

            # Clear previous
            gotDirtyRois = []

            # Change some of the input data that ISN'T cached yet and mark it dirty
            dirtykey = make_key[0:1, 90:100, 90:100, 0:1, 0:1]
            self.data[dirtykey] = 0.12345
            opProvider.Input.setDirty(dirtykey)

            # Dirtiness not propagated due to fixAtCurrent
            assert len(gotDirtyRois) == 0

            # Same request.  Data should still match the previous data (not yet refreshed)
            data2 = opCache.Output(slicing).wait()
            data2 = data2.view(vigra.VigraArray)
            data2.axistags = opCache.Output.meta.axistags
            assert opProvider.accessCount == oldAccessCount, "Access count={}, expected={}".format(
                opProvider.accessCount, oldAccessCount)
            assert (data2 == data).all()

            # Unfreeze.
            opCache.fixAtCurrent.setValue(False)

            # Dirty blocks are propagated after the cache is unfixed.
            assert len(gotDirtyRois) > 0

            # Same request.  Data should be updated now that we're unfrozen.
            data = opCache.Output(slicing).wait()
            data = data.view(vigra.VigraArray)
            data.axistags = opCache.Output.meta.axistags
            assert (data == self.data[slicing]).all()

            # Dirty data did not intersect with this request.
            # Data should still be cached (no extra accesses)
            assert opProvider.accessCount == oldAccessCount, "Access count={}, expected={}".format(
                opProvider.accessCount, oldAccessCount)

            ###########################3
            # Freeze it again
            opCache.fixAtCurrent.setValue(True)

            # Reset tracked notifications
            gotDirtyRois = []

            # Change some of the input data that IS cached and mark it dirty
            dirtykey = make_key[:, 0:25, 20:40, 0:1, :]
            self.data[dirtykey] = 0.54321
            opProvider.Input.setDirty(dirtykey)

            # Dirtiness not propagated due to fixAtCurrent
            assert len(gotDirtyRois) == 0

            # Same request.  Data should still match the previous data (not yet refreshed)
            data2 = opCache.Output(slicing).wait()
            data2 = data2.view(vigra.VigraArray)
            data2.axistags = opCache.Output.meta.axistags
            assert opProvider.accessCount == oldAccessCount, "Access count={}, expected={}".format(
                opProvider.accessCount, oldAccessCount)
            assert (data2 == data).all()

            # Unfreeze. Previous dirty notifications should now be seen.
            opCache.fixAtCurrent.setValue(False)
            assert len(gotDirtyRois) > 0

            # Same request.  Data should be updated now that we're unfrozen.
            data = opCache.Output(slicing).wait()
            data = data.view(vigra.VigraArray)
            data.axistags = opCache.Output.meta.axistags
            assert (data == self.data[slicing]).all()

            # The dirty data intersected 2 outerBlocks, and a total of 6 innerblocks
            # Inner caches are allowed to split up the accesses, so there could be as many as 6
            minAccess = oldAccessCount + 2
            maxAccess = oldAccessCount + 6
            assert opProvider.accessCount >= minAccess
            assert opProvider.accessCount <= maxAccess

            #####################

            #### Repeat plain dirty test to ensure fixAtCurrent didn't mess up the block states.

            opProvider.accessCount = 0  # Reset
            gotDirtyRois = []

            # Change some of the input data and mark it dirty
            dirtykey = make_key[0:1, 10:11, 20:21, 0:3, 0:1]
            self.data[dirtykey] = 0.54321
            opProvider.Input.setDirty(dirtykey)

            assert len(gotDirtyRois) > 0

            # Should need access again.
            slicing = make_key[:, 0:50, 15:45, 0:10, :]
            data = opCache.Output(slicing).wait()
            data = data.view(vigra.VigraArray)
            data.axistags = opCache.Output.meta.axistags
            assert (data == self.data[slicing]).all()

            # The dirty data intersected 1 outerBlocks and a total of 1 innerblock
            minAccess = 1
            maxAccess = 1
            assert opProvider.accessCount >= minAccess
            assert opProvider.accessCount <= maxAccess, "Too many accesses: {}".format(
                opProvider.accessCount)
            oldAccessCount = opProvider.accessCount

        finally:
            CacheMemoryManager().enable()
    def testAPIConformity(self):
        c = NonRegisteredCache("c")
        mgr = CacheMemoryManager()

        # dont clean up while we are testing
        mgr.disable()

        import weakref
        d = NonRegisteredCache("testwr")
        s = weakref.WeakSet()
        s.add(d)
        del d
        gc.collect()
        l = list(s)
        assert len(l) == 0, l[0].name

        c1 = NonRegisteredCache("c1")
        c1a = c1
        c2 = NonRegisteredCache("c2")

        mgr.addFirstClassCache(c)
        mgr.addCache(c1)
        mgr.addCache(c1a)
        mgr.addCache(c2)

        fcc = mgr.getFirstClassCaches()
        assert len(fcc) == 1, "too many first class caches"
        assert c in fcc, "did not register fcc correctly"
        del fcc

        cs = mgr.getCaches()
        assert len(cs) == 3, "wrong number of caches"
        refcs = [c, c1, c2]
        for cache in refcs:
            assert cache in cs, "{} not stored".format(cache.name)
        del cs
        del refcs
        del cache

        del c1a
        gc.collect()
        cs = mgr.getCaches()
        assert c1 in cs
        assert len(cs) == 3, str(map(lambda x: x.name, cs))
        del cs

        del c2
        gc.collect()
        cs = mgr.getCaches()
        assert len(cs) == 2, str(map(lambda x: x.name, cs))
    def testBadMemoryConditions(self):
        """
        TestCacheMemoryManager.testBadMemoryConditions

        This test is a proof of the proposition in
            https://github.com/ilastik/lazyflow/issue/185
        which states that, given certain memory constraints, the cache
        cleanup strategy in use is inefficient. An advanced strategy
        should pass the test.
        """

        mgr = CacheMemoryManager()
        mgr.setRefreshInterval(0.01)
        mgr.enable()

        d = 2
        tags = "xy"

        shape = (999,) * d
        blockshape = (333,) * d

        # restrict memory for computation to one block (including fudge
        # factor 2 of bigRequestStreamer)
        cacheMem = np.prod(shape)
        Memory.setAvailableRam(np.prod(blockshape) * 2 + cacheMem)

        # restrict cache memory to the whole volume
        Memory.setAvailableRamCaches(cacheMem)

        # to ease observation, do everything single threaded
        Request.reset_thread_pool(num_workers=1)

        x = np.zeros(shape, dtype=np.uint8)
        x = vigra.taggedView(x, axistags=tags)

        g = Graph()
        pipe = OpArrayPiperWithAccessCount(graph=g)
        pipe.Input.setValue(x)
        pipe.Output.meta.ideal_blockshape = blockshape

        # simulate BlockedArrayCache behaviour without caching
        # cache = OpSplitRequestsBlockwise(True, graph=g)
        # cache.BlockShape.setValue(blockshape)
        # cache.Input.connect(pipe.Output)

        cache = OpBlockedArrayCache(graph=g)
        cache.Input.connect(pipe.Output)
        cache.BlockShape.setValue(blockshape)

        op = OpEnlarge(graph=g)
        op.Input.connect(cache.Output)

        split = OpSplitRequestsBlockwise(True, graph=g)
        split.BlockShape.setValue(blockshape)
        split.Input.connect(op.Output)

        streamer = BigRequestStreamer(split.Output, [(0,) * len(shape), shape])
        streamer.execute()

        # in the worst case, we have 4*4 + 4*6 + 9 = 49 requests to pipe
        # in the best case, we have 9
        np.testing.assert_equal(pipe.accessCount, 9)
    def testBadMemoryConditions(self):
        """
        TestCacheMemoryManager.testBadMemoryConditions

        This test is a proof of the proposition in 
            https://github.com/ilastik/lazyflow/issue/185
        which states that, given certain memory constraints, the cache
        cleanup strategy in use is inefficient. An advanced strategy
        should pass the test.
        """

        mgr = CacheMemoryManager()
        mgr.setRefreshInterval(.01)
        mgr.enable()

        d = 2
        tags = 'xy'

        shape = (999, ) * d
        blockshape = (333, ) * d

        # restrict memory for computation to one block (including fudge
        # factor 2 of bigRequestStreamer)
        cacheMem = np.prod(shape)
        Memory.setAvailableRam(np.prod(blockshape) * 2 + cacheMem)

        # restrict cache memory to the whole volume
        Memory.setAvailableRamCaches(cacheMem)

        # to ease observation, do everything single threaded
        Request.reset_thread_pool(num_workers=1)

        x = np.zeros(shape, dtype=np.uint8)
        x = vigra.taggedView(x, axistags=tags)

        g = Graph()
        pipe = OpArrayPiperWithAccessCount(graph=g)
        pipe.Input.setValue(x)
        pipe.Output.meta.ideal_blockshape = blockshape

        # simulate BlockedArrayCache behaviour without caching
        # cache = OpSplitRequestsBlockwise(True, graph=g)
        # cache.BlockShape.setValue(blockshape)
        # cache.Input.connect(pipe.Output)

        cache = OpBlockedArrayCache(graph=g)
        cache.Input.connect(pipe.Output)
        cache.outerBlockShape.setValue(blockshape)

        op = OpEnlarge(graph=g)
        op.Input.connect(cache.Output)

        split = OpSplitRequestsBlockwise(True, graph=g)
        split.BlockShape.setValue(blockshape)
        split.Input.connect(op.Output)

        streamer = BigRequestStreamer(split.Output,
                                      [(0, ) * len(shape), shape])
        streamer.execute()

        # in the worst case, we have 4*4 + 4*6 + 9 = 49 requests to pipe
        # in the best case, we have 9
        np.testing.assert_equal(pipe.accessCount, 9)
    def testAPIConformity(self):
        c = NonRegisteredCache("c")
        mgr = CacheMemoryManager()

        # dont clean up while we are testing
        mgr.disable()

        import weakref

        d = NonRegisteredCache("testwr")
        s = weakref.WeakSet()
        s.add(d)
        del d
        gc.collect()
        l = list(s)
        assert len(l) == 0, l[0].name

        c1 = NonRegisteredCache("c1")
        c1a = c1
        c2 = NonRegisteredCache("c2")

        mgr.addFirstClassCache(c)
        mgr.addCache(c1)
        mgr.addCache(c1a)
        mgr.addCache(c2)

        fcc = mgr.getFirstClassCaches()
        assert len(fcc) == 1, "too many first class caches"
        assert c in fcc, "did not register fcc correctly"
        del fcc

        cs = mgr.getCaches()
        assert len(cs) == 3, "wrong number of caches"
        refcs = [c, c1, c2]
        for cache in refcs:
            assert cache in cs, "{} not stored".format(cache.name)
        del cs
        del refcs
        del cache

        del c1a
        gc.collect()
        cs = mgr.getCaches()
        assert c1 in cs
        assert len(cs) == 3, str([x.name for x in cs])
        del cs

        del c2
        gc.collect()
        cs = mgr.getCaches()
        assert len(cs) == 2, str([x.name for x in cs])
Exemple #19
0
 def registerWithMemoryManager(self):
     manager = CacheMemoryManager()
     if self.parent is None or not isinstance(self.parent, Cache):
         manager.addFirstClassCache(self)
     else:
         manager.addCache(self)
Exemple #20
0
 def registerWithMemoryManager(self):
     manager = CacheMemoryManager()
     if self.parent is None or not isinstance(self.parent, Cache):
         manager.addFirstClassCache(self)
     else:
         manager.addCache(self)
def ilastik_predict_with_array(gray_vol,
                               mask,
                               ilp_path,
                               selected_channels=None,
                               normalize=True,
                               LAZYFLOW_THREADS=1,
                               LAZYFLOW_TOTAL_RAM_MB=None,
                               logfile="/dev/null",
                               extra_cmdline_args=[]):
    """
    Using ilastik's python API, open the given project 
    file and run a prediction on the given raw data array.
    
    Other than the project file, nothing is read or written 
    using the hard disk.
    
    gray_vol: A 3D numpy array with axes zyx

    mask: A binary image where 0 means "no prediction necessary".
         'None' can be given, which means "predict everything".

    ilp_path: Path to the project file.  ilastik also accepts a url to a DVID key-value, which will be downloaded and opened as an ilp
    
    selected_channels: A list of channel indexes to select and return from the prediction results.
                       'None' can also be given, which means "return all prediction channels".
                       You may also return a *nested* list, in which case groups of channels can be
                       combined (summed) into their respective output channels.
                       For example: selected_channels=[0,3,[2,4],7] means the output will have 4 channels:
                                    0,3,2+4,7 (channels 5 and 6 are simply dropped).
    
    normalize: Renormalize all outputs so the channels sum to 1 everywhere.
               That is, (predictions.sum(axis=-1) == 1.0).all()
               Note: Pixels with 0.0 in all channels will be simply given a value of 1/N in all channels.
    
    LAZYFLOW_THREADS, LAZYFLOW_TOTAL_RAM_MB: Passed to ilastik via environment variables.
    """
    print "ilastik_predict_with_array(): Starting with raw data: dtype={}, shape={}".format(
        str(gray_vol.dtype), gray_vol.shape)

    import os
    from collections import OrderedDict

    import uuid
    import multiprocessing
    import platform
    import psutil
    import vigra

    import ilastik_main
    from ilastik.applets.dataSelection import DatasetInfo
    from lazyflow.operators.cacheMemoryManager import CacheMemoryManager

    import logging
    logging.getLogger(__name__).info('status=ilastik prediction')
    print "ilastik_predict_with_array(): Done with imports"

    if LAZYFLOW_TOTAL_RAM_MB is None:
        # By default, assume our alotted RAM is proportional
        # to the CPUs we've been told to use
        machine_ram = psutil.virtual_memory().total
        machine_ram -= 1024**3  # Leave 1 GB RAM for the OS.

        LAZYFLOW_TOTAL_RAM_MB = LAZYFLOW_THREADS * machine_ram / multiprocessing.cpu_count(
        )

    # Before we start ilastik, prepare the environment variable settings.
    os.environ["LAZYFLOW_THREADS"] = str(LAZYFLOW_THREADS)
    os.environ["LAZYFLOW_TOTAL_RAM_MB"] = str(LAZYFLOW_TOTAL_RAM_MB)
    os.environ["LAZYFLOW_STATUS_MONITOR_SECONDS"] = "10"

    # Prepare ilastik's "command-line" arguments, as if they were already parsed.
    args, extra_workflow_cmdline_args = ilastik_main.parser.parse_known_args(
        extra_cmdline_args)
    args.headless = True
    args.debug = True  # ilastik's 'debug' flag enables special power features, including experimental workflows.
    args.project = str(ilp_path)
    args.readonly = True

    # The process_name argument is prefixed to all log messages.
    # For now, just use the machine name and a uuid
    # FIXME: It would be nice to provide something more descriptive, like the ROI of the current spark job...
    args.process_name = platform.node() + "-" + str(uuid.uuid1())

    # To avoid conflicts between processes, give each process it's own logfile to write to.
    if logfile != "/dev/null":
        base, ext = os.path.splitext(logfile)
        logfile = base + '.' + args.process_name + ext

    # By default, all ilastik processes duplicate their console output to ~/.ilastik_log.txt
    # Obviously, having all spark nodes write to a common file is a bad idea.
    # The "/dev/null" setting here is recognized by ilastik and means "Don't write a log file"
    args.logfile = logfile

    print "ilastik_predict_with_array(): Creating shell..."

    # Instantiate the 'shell', (in this case, an instance of ilastik.shell.HeadlessShell)
    # This also loads the project file into shell.projectManager
    shell = ilastik_main.main(args, extra_workflow_cmdline_args)

    ## Need to find a better way to verify the workflow type
    #from ilastik.workflows.pixelClassification import PixelClassificationWorkflow
    #assert isinstance(shell.workflow, PixelClassificationWorkflow)

    # Construct an OrderedDict of role-names -> DatasetInfos
    # (See PixelClassificationWorkflow.ROLE_NAMES)
    raw_data_array = vigra.taggedView(gray_vol, 'zyx')
    role_data_dict = OrderedDict([
        ("Raw Data", [DatasetInfo(preloaded_array=raw_data_array)])
    ])

    if mask is not None:
        # If there's a mask, we might be able to save some computation time.
        mask = vigra.taggedView(mask, 'zyx')
        role_data_dict["Prediction Mask"] = [DatasetInfo(preloaded_array=mask)]

    print "ilastik_predict_with_array(): Starting export..."

    # Sanity checks
    opInteractiveExport = shell.workflow.batchProcessingApplet.dataExportApplet.topLevelOperator.getLane(
        0)
    selected_result = opInteractiveExport.InputSelection.value
    num_channels = opInteractiveExport.Inputs[selected_result].meta.shape[-1]

    # For convenience, verify the selected channels before we run the export.
    if selected_channels:
        assert isinstance(selected_channels, list)
        for selection in selected_channels:
            if isinstance(selection, list):
                assert all(c < num_channels for c in selection), \
                    "Selected channels ({}) exceed number of prediction classes ({})"\
                    .format( selected_channels, num_channels )
            else:
                assert selection < num_channels, \
                    "Selected channels ({}) exceed number of prediction classes ({})"\
                    .format( selected_channels, num_channels )

    # Run the export via the BatchProcessingApplet
    prediction_list = shell.workflow.batchProcessingApplet.run_export(
        role_data_dict, export_to_array=True)
    assert len(prediction_list) == 1
    predictions = prediction_list[0]

    assert predictions.shape[-1] == num_channels
    selected_predictions = select_channels(predictions, selected_channels)

    if normalize:
        normalize_channels_in_place(selected_predictions)

    # Cleanup: kill cache monitor thread
    CacheMemoryManager().stop()
    CacheMemoryManager.instance = None

    # Cleanup environment
    del os.environ["LAZYFLOW_THREADS"]
    del os.environ["LAZYFLOW_TOTAL_RAM_MB"]
    del os.environ["LAZYFLOW_STATUS_MONITOR_SECONDS"]

    logging.getLogger(__name__).info('status=ilastik prediction finished')
    return selected_predictions