Esempio n. 1
0
class Counting3dGui(LabelingGui):

    ###########################################
    ### AppletGuiInterface Concrete Methods ###
    ###########################################
    def centralWidget( self ):
        return self

    def reset(self):
        # Base class first
        super(Counting3dGui, self).reset()

        # Ensure that we are NOT in interactive mode
        self.labelingDrawerUi.liveUpdateButton.setChecked(False)
        self._viewerControlUi.checkShowPredictions.setChecked(False)
        self._viewerControlUi.checkShowSegmentation.setChecked(False)
        self.toggleInteractive(False)

    def viewerControlWidget(self):
        return self._viewerControlUi

    ###########################################
    ###########################################

    @traceLogged(traceLogger)
    def __init__(self, topLevelOperatorView, shellRequestSignal, guiControlSignal, predictionSerializer ):

        # Tell our base class which slots to monitor
        labelSlots = LabelingGui.LabelingSlots()
        labelSlots.labelInput = topLevelOperatorView.LabelInputs
        labelSlots.labelOutput = topLevelOperatorView.LabelImages
        labelSlots.labelEraserValue = topLevelOperatorView.opLabelPipeline.opLabelArray.eraser
        labelSlots.labelDelete = topLevelOperatorView.opLabelPipeline.opLabelArray.deleteLabel
        labelSlots.maxLabelValue = topLevelOperatorView.MaxLabelValue
        labelSlots.labelsAllowed = topLevelOperatorView.LabelsAllowedFlags

        # We provide our own UI file (which adds an extra control for interactive mode)
        labelingDrawerUiPath = os.path.split(__file__)[0] + '/labelingDrawer.ui'

        # Base class init
        super(Counting3dGui, self).__init__( labelSlots, topLevelOperatorView, labelingDrawerUiPath )
        
        self.op = topLevelOperatorView
        #self.clickReporter.rightClickReceived.connect( self._handleEditorRightClick )

        self.topLevelOperatorView = topLevelOperatorView
        self.shellRequestSignal = shellRequestSignal
        self.guiControlSignal = guiControlSignal
        self.predictionSerializer = predictionSerializer

        self.interactiveModeActive = False
        self._currentlySavingPredictions = False

        self.labelingDrawerUi.savePredictionsButton.clicked.connect(self.onSavePredictionsButtonClicked)
        self.labelingDrawerUi.savePredictionsButton.setIcon( QIcon(ilastikIcons.Save) )
        
        self.labelingDrawerUi.liveUpdateButton.setEnabled(False)
        self.labelingDrawerUi.liveUpdateButton.setIcon( QIcon(ilastikIcons.Play) )
        self.labelingDrawerUi.liveUpdateButton.setToolButtonStyle(Qt.ToolButtonTextBesideIcon)
        self.labelingDrawerUi.liveUpdateButton.toggled.connect( self.toggleInteractive )

        self.topLevelOperatorView.MaxLabelValue.notifyDirty( bind(self.handleLabelSelectionChange) )
        
        self._initShortcuts()

        try:
            self.render = True
            self._renderedLayers = {} # (layer name, label number)
            self._renderMgr = RenderingManager(
                renderer=self.editor.view3d.qvtk.renderer,
                qvtk=self.editor.view3d.qvtk)
        except:
            self.render = False


        self.initCounting()


    def initCounting(self):
        self._addNewLabel()
        self._addNewLabel()
        self.labelingDrawerUi.SigmaLine.setText("1")
        self.labelingDrawerUi.UnderBox.setRange(0,1000000)
        self.labelingDrawerUi.UnderBox.setValue(1)
        self.labelingDrawerUi.OverBox.setRange(0,1000000)
        self.labelingDrawerUi.OverBox.setValue(1)
        self.labelingDrawerUi.UnderBox.setKeyboardTracking(False)
        self.labelingDrawerUi.OverBox.setKeyboardTracking(False)
        self.labelingDrawerUi.EpsilonBox.setKeyboardTracking(False)
        self.labelingDrawerUi.EpsilonBox.setDecimals(6)
        for option in self.op.options:
            print "option", option
            self.labelingDrawerUi.SVROptions.addItem('+'.join(option.values()), (option,))
        self.labelingDrawerUi.DebugButton.pressed.connect(self._debug)
        #self.labelingDrawerUi.TrainButton.pressed.connect(self._train)
        #self.labelingDrawerUi.PredictionButton.pressed.connect(self.updateDensitySum)
        self.labelingDrawerUi.SVROptions.currentIndexChanged.connect(self._updateSVROptions)
        self._updateSVROptions()
        self.labelingDrawerUi.OverBox.valueChanged.connect(self._updateOverMult)
        self.labelingDrawerUi.UnderBox.valueChanged.connect(self._updateUnderMult)
        self.labelingDrawerUi.SigmaLine.editingFinished.connect(self._updateSigma)
        self.labelingDrawerUi.SigmaLine.textChanged.connect(self._changedSigma)
        self.labelingDrawerUi.EpsilonBox.valueChanged.connect(self._updateEpsilon)
        self.changedSigma = False
        
        def updateSum(*args, **kw):
            print "updatingSum"
            density = self.op.OutputSum.value / 255
            strdensity = "{0:.2f}".format(density)
            self._labelControlUi.CountText.setText(strdensity)

        self.op.Density.notifyDirty(updateSum)
        

        self.boxes = dict()
        self._labelControlUi.labelListModel[0].name = "Foreground"
        self._labelControlUi.labelListModel[1].name = "Background"

    def _updateOverMult(self):
        self.op.opTrain.OverMult.setValue(self.labelingDrawerUi.OverBox.value())
    def _updateUnderMult(self):
        self.op.opTrain.UnderMult.setValue(self.labelingDrawerUi.UnderBox.value())
    def _updateSigma(self):
        if self.changedSigma:
            sigma = [float(n) for n in
                           self._labelControlUi.SigmaLine.text().split(" ")]
            self.op.opTrain.Sigma.setValue(sigma)
            self.changedSigma = False

    def _changedSigma(self, text):
        self.changedSigma = True

    def _updateEpsilon(self):
        self.op.opTrain.Epsilon.setValue(self.labelingDrawerUi.EpsilonBox.value())

    def _updateSVROptions(self):
        index = self.labelingDrawerUi.SVROptions.currentIndex()
        option = self.labelingDrawerUi.SVROptions.itemData(index).toPyObject()[0]
        self.op.opTrain.SelectedOption.setValue(option)


    def _debug(self):
        import sitecustomize
        sitecustomize.debug_trace()


    @traceLogged(traceLogger)
    def initViewerControlUi(self):
        localDir = os.path.split(__file__)[0]
        self._viewerControlUi = uic.loadUi( os.path.join( localDir, "viewerControls.ui" ) )

        # Connect checkboxes
        def nextCheckState(checkbox):
            checkbox.setChecked( not checkbox.isChecked() )

        self._viewerControlUi.checkShowPredictions.clicked.connect( self.handleShowPredictionsClicked )
        self._viewerControlUi.checkShowSegmentation.clicked.connect( self.handleShowSegmentationClicked )

        # The editor's layerstack is in charge of which layer movement buttons are enabled
        model = self.editor.layerStack
        self._viewerControlUi.viewerControls.setupConnections(model)
       
    def _initShortcuts(self):
        mgr = ShortcutManager()
        shortcutGroupName = "Predictions"

        togglePredictions = QShortcut( QKeySequence("p"), self, member=self._viewerControlUi.checkShowPredictions.click )
        mgr.register( shortcutGroupName,
                      "Toggle Prediction Layer Visibility",
                      togglePredictions,
                      self._viewerControlUi.checkShowPredictions )

        toggleSegmentation = QShortcut( QKeySequence("s"), self, member=self._viewerControlUi.checkShowSegmentation.click )
        mgr.register( shortcutGroupName,
                      "Toggle Segmentaton Layer Visibility",
                      toggleSegmentation,
                      self._viewerControlUi.checkShowSegmentation )

        toggleLivePredict = QShortcut( QKeySequence("l"), self, member=self.labelingDrawerUi.liveUpdateButton.toggle )
        mgr.register( shortcutGroupName,
                      "Toggle Live Prediction Mode",
                      toggleLivePredict,
                      self.labelingDrawerUi.liveUpdateButton )

    def _setup_contexts(self, layer):
        def callback(pos, clayer=layer):
            name = clayer.name
            if name in self._renderedLayers:
                label = self._renderedLayers.pop(name)
                self._renderMgr.removeObject(label)
                self._update_rendering()
            else:
                label = self._renderMgr.addObject()
                self._renderedLayers[clayer.name] = label
                self._update_rendering()

        if self.render:
            layer.contexts.append(('Toggle 3D rendering', callback))

    @traceLogged(traceLogger)
    def setupLayers(self):
        """
        Called by our base class when one of our data slots has changed.
        This function creates a layer for each slot we want displayed in the volume editor.
        """
        # Base class provides the label layer.
        layers = super(Counting3dGui, self).setupLayers()

        # Add each of the predictions
        labels = self.labelListData
     


        slots = {'density' : self.op.Density}

        for name, slot in slots.items():
            if slot.ready():
                from volumina import colortables
                layer = ColortableLayer(LazyflowSource(slot), colorTable = colortables.jet(), normalize = 'auto')
                layer.name = name
                layers.append(layer)


        boxlabelsrc = LazyflowSinkSource(self.op.BoxLabelImages,self.op.BoxLabelInputs )
        boxlabellayer = ColortableLayer(boxlabelsrc, colorTable = self._colorTable16, direct = False)
        boxlabellayer.name = "boxLabels"
        boxlabellayer.opacity = 0.3
        layers.append(boxlabellayer)
        self.boxlabelsrc = boxlabelsrc


        inputDataSlot = self.topLevelOperatorView.InputImages
        if inputDataSlot.ready():
            inputLayer = self.createStandardLayerFromSlot( inputDataSlot )
            inputLayer.name = "Input Data"
            inputLayer.visible = True
            inputLayer.opacity = 1.0

            def toggleTopToBottom():
                index = self.layerstack.layerIndex( inputLayer )
                self.layerstack.selectRow( index )
                if index == 0:
                    self.layerstack.moveSelectedToBottom()
                else:
                    self.layerstack.moveSelectedToTop()

            inputLayer.shortcutRegistration = (
                "Prediction Layers",
                "Bring Input To Top/Bottom",
                QShortcut( QKeySequence("i"), self.viewerControlWidget(), toggleTopToBottom),
                inputLayer )
            layers.append(inputLayer)
        
        self.handleLabelSelectionChange()
        return layers

    @traceLogged(traceLogger)
    def toggleInteractive(self, checked):
        """
        If enable
        """
        logger.debug("toggling interactive mode to '%r'" % checked)

        if checked==True:
            if not self.topLevelOperatorView.FeatureImages.ready() \
            or self.topLevelOperatorView.FeatureImages.meta.shape==None:
                self.labelingDrawerUi.liveUpdateButton.setChecked(False)
                mexBox=QMessageBox()
                mexBox.setText("There are no features selected ")
                mexBox.exec_()
                return

        self.labelingDrawerUi.savePredictionsButton.setEnabled(not checked)
        self.topLevelOperatorView.FreezePredictions.setValue( not checked )

        # Auto-set the "show predictions" state according to what the user just clicked.
        if checked:
            self._viewerControlUi.checkShowPredictions.setChecked( True )
            self.handleShowPredictionsClicked()

        # If we're changing modes, enable/disable our controls and other applets accordingly
        if self.interactiveModeActive != checked:
            if checked:
                self.labelingDrawerUi.labelListView.allowDelete = False
                #self.labelingDrawerUi.AddLabelButton.setEnabled( False )
            else:
                self.labelingDrawerUi.labelListView.allowDelete = True
                #self.labelingDrawerUi.AddLabelButton.setEnabled( True )
        self.interactiveModeActive = checked

    @pyqtSlot()
    @traceLogged(traceLogger)
    def handleShowPredictionsClicked(self):
        checked = self._viewerControlUi.checkShowPredictions.isChecked()
        for layer in self.layerstack:
            if "Prediction" in layer.name:
                layer.visible = checked

    @pyqtSlot()
    @traceLogged(traceLogger)
    def handleShowSegmentationClicked(self):
        checked = self._viewerControlUi.checkShowSegmentation.isChecked()
        for layer in self.layerstack:
            if "Segmentation" in layer.name:
                layer.visible = checked

    @pyqtSlot()
    @traceLogged(traceLogger)
    def updateShowPredictionCheckbox(self):
        predictLayerCount = 0
        visibleCount = 0
        for layer in self.layerstack:
            if "Prediction" in layer.name:
                predictLayerCount += 1
                if layer.visible:
                    visibleCount += 1

        if visibleCount == 0:
            self._viewerControlUi.checkShowPredictions.setCheckState(Qt.Unchecked)
        elif predictLayerCount == visibleCount:
            self._viewerControlUi.checkShowPredictions.setCheckState(Qt.Checked)
        else:
            self._viewerControlUi.checkShowPredictions.setCheckState(Qt.PartiallyChecked)

    @pyqtSlot()
    @traceLogged(traceLogger)
    def updateShowSegmentationCheckbox(self):
        segLayerCount = 0
        visibleCount = 0
        for layer in self.layerstack:
            if "Segmentation" in layer.name:
                segLayerCount += 1
                if layer.visible:
                    visibleCount += 1

        if visibleCount == 0:
            self._viewerControlUi.checkShowSegmentation.setCheckState(Qt.Unchecked)
        elif segLayerCount == visibleCount:
            self._viewerControlUi.checkShowSegmentation.setCheckState(Qt.Checked)
        else:
            self._viewerControlUi.checkShowSegmentation.setCheckState(Qt.PartiallyChecked)

    @pyqtSlot()
    @threadRouted
    @traceLogged(traceLogger)
    def handleLabelSelectionChange(self):
        enabled = False
        if self.topLevelOperatorView.MaxLabelValue.ready():
            enabled = True
            enabled &= self.topLevelOperatorView.MaxLabelValue.value >= 2
            enabled &= numpy.all(numpy.asarray(self.topLevelOperatorView.CachedFeatureImages.meta.shape) > 0)
            # FIXME: also check that each label has scribbles?
        
        self.labelingDrawerUi.savePredictionsButton.setEnabled(enabled)
        self.labelingDrawerUi.liveUpdateButton.setEnabled(enabled)
        self._viewerControlUi.checkShowPredictions.setEnabled(enabled)
        self._viewerControlUi.checkShowSegmentation.setEnabled(enabled)

    @pyqtSlot()
    @traceLogged(traceLogger)
    def onSavePredictionsButtonClicked(self):
        """
        The user clicked "Train and Predict".
        Handle this event by asking the topLevelOperatorView for a prediction over the entire output region.
        """
        # The button does double-duty as a cancel button while predictions are being stored
        if self._currentlySavingPredictions:
            self.predictionSerializer.cancel()
        else:
            # Compute new predictions as needed
            predictionsFrozen = self.topLevelOperatorView.FreezePredictions.value
            self.topLevelOperatorView.FreezePredictions.setValue(False)
            self._currentlySavingPredictions = True

            originalButtonText = "Full Volume Predict and Save"
            self.labelingDrawerUi.savePredictionsButton.setText("Cancel Full Predict")

            @traceLogged(traceLogger)
            def saveThreadFunc():
                logger.info("Starting full volume save...")
                # Disable all other applets
                self.guiControlSignal.emit( ControlCommand.DisableUpstream )
                self.guiControlSignal.emit( ControlCommand.DisableDownstream )

                def disableAllInWidgetButName(widget, exceptName):
                    for child in widget.children():
                        if child.findChild( QPushButton, exceptName) is None:
                            child.setEnabled(False)
                        else:
                            disableAllInWidgetButName(child, exceptName)

                # Disable everything in our drawer *except* the cancel button
                disableAllInWidgetButName(self.labelingDrawerUi, "savePredictionsButton")

                # But allow the user to cancel the save
                self.labelingDrawerUi.savePredictionsButton.setEnabled(True)

                # First, do a regular save.
                # During a regular save, predictions are not saved to the project file.
                # (It takes too much time if the user only needs the classifier.)
                self.shellRequestSignal.emit( ShellRequest.RequestSave )

                # Enable prediction storage and ask the shell to save the project again.
                # (This way the second save will occupy the whole progress bar.)
                self.predictionSerializer.predictionStorageEnabled = True
                self.shellRequestSignal.emit( ShellRequest.RequestSave )
                self.predictionSerializer.predictionStorageEnabled = False

                # Restore original states (must use events for UI calls)
                self.thunkEventHandler.post(self.labelingDrawerUi.savePredictionsButton.setText, originalButtonText)
                self.topLevelOperatorView.FreezePredictions.setValue(predictionsFrozen)
                self._currentlySavingPredictions = False

                # Re-enable our controls
                def enableAll(widget):
                    for child in widget.children():
                        if isinstance( child, QWidget ):
                            child.setEnabled(True)
                            enableAll(child)
                enableAll(self.labelingDrawerUi)

                # Re-enable all other applets
                self.guiControlSignal.emit( ControlCommand.Pop )
                self.guiControlSignal.emit( ControlCommand.Pop )
                logger.info("Finished full volume save.")

            saveThread = threading.Thread(target=saveThreadFunc)
            saveThread.start()

    def _getNext(self, slot, parentFun, transform=None):
        numLabels = self.labelListData.rowCount()
        value = slot.value
        if numLabels < len(value):
            result = value[numLabels]
            if transform is not None:
                result = transform(result)
            return result
        else:
            return parentFun()

    def _onLabelChanged(self, parentFun, mapf, slot):
        parentFun()
        new = map(mapf, self.labelListData)
        old = slot.value
        slot.setValue(_listReplace(old, new))

    def _onLabelRemoved(self, parent, start, end):
        super(Counting3dGui, self)._onLabelRemoved(parent, start, end)
        op = self.topLevelOperatorView
        for slot in (op.LabelNames, op.LabelColors, op.PmapColors):
            value = slot.value
            value.pop(start)
            slot.setValue(value)

    def getNextLabelName(self):
        return self._getNext(self.topLevelOperatorView.LabelNames,
                             super(Counting3dGui, self).getNextLabelName)

    def getNextLabelColor(self):
        return self._getNext(
            self.topLevelOperatorView.LabelColors,
            super(Counting3dGui, self).getNextLabelColor,
            lambda x: QColor(*x)
        )

    def getNextPmapColor(self):
        return self._getNext(
            self.topLevelOperatorView.PmapColors,
            super(Counting3dGui, self).getNextPmapColor,
            lambda x: QColor(*x)
        )

    def onLabelNameChanged(self):
        self._onLabelChanged(super(Counting3dGui, self).onLabelNameChanged,
                             lambda l: l.name,
                             self.topLevelOperatorView.LabelNames)

    def onLabelColorChanged(self):
        self._onLabelChanged(super(Counting3dGui, self).onLabelColorChanged,
                             lambda l: (l.brushColor().red(),
                                        l.brushColor().green(),
                                        l.brushColor().blue()),
                             self.topLevelOperatorView.LabelColors)


    def onPmapColorChanged(self):
        self._onLabelChanged(super(Counting3dGui, self).onPmapColorChanged,
                             lambda l: (l.pmapColor().red(),
                                        l.pmapColor().green(),
                                        l.pmapColor().blue()),
                             self.topLevelOperatorView.PmapColors)

    def _update_rendering(self):
        if not self.render:
            return
        shape = self.topLevelOperatorView.InputImages.meta.shape[1:4]
        time = self.editor.posModel.slicingPos5D[0]
        if not self._renderMgr.ready:
            self._renderMgr.setup(shape)

        layernames = set(layer.name for layer in self.layerstack)
        self._renderedLayers = dict((k, v) for k, v in self._renderedLayers.iteritems()
                                if k in layernames)

        newvolume = numpy.zeros(shape, dtype=numpy.uint8)
        for layer in self.layerstack:
            try:
                label = self._renderedLayers[layer.name]
            except KeyError:
                continue
            for ds in layer.datasources:
                vol = ds.dataSlot.value[time, ..., 0]
                indices = numpy.where(vol != 0)
                newvolume[indices] = label

        self._renderMgr.volume = newvolume
        self._update_colors()
        self._renderMgr.update()

    def _update_colors(self):
        for layer in self.layerstack:
            try:
                label = self._renderedLayers[layer.name]
            except KeyError:
                continue
            color = layer.tintColor
            color = (color.red() / 255.0, color.green() / 255.0, color.blue() / 255.0)
            self._renderMgr.setColor(label, color)



    def _gui_setNavigation(self):
        self._labelControlUi.brushSizeComboBox.setEnabled(False)
        self._labelControlUi.brushSizeCaption.setEnabled(False)
        self._labelControlUi.arrowToolButton.setChecked(True)
        self._labelControlUi.arrowToolButton.setChecked(True)
        print "setNavigation"
        if not hasattr(self, "rubberbandClickReporter"):
            self.rubberbandClickReporter = ClickReportingInterpreter(
                self.editor.navInterpret, self.editor.posModel, self.centralWidget() )
            self.rubberbandClickReporter.leftClickReleased.connect( self.handleBoxQuery )
        self.editor.setNavigationInterpreter(self.rubberbandClickReporter)
    
    def _gui_setBox(self):
        print "setBox"
        self._labelControlUi.brushSizeComboBox.setEnabled(False)
        self._labelControlUi.brushSizeCaption.setEnabled(False)
        self._labelControlUi.boxToolButton.setChecked(True)
        

    
    def _changeInteractionMode( self, toolId ):
        """
        Implement the GUI's response to the user selecting a new tool.
        """
        # Uncheck all the other buttons
        for tool, button in self.toolButtons.items():
            if tool != toolId:
                button.setChecked(False)

        # If we have no editor, we can't do anything yet
        if self.editor is None:
            return

        # The volume editor expects one of two specific names
        modeNames = { Tool.Navigation   : "navigation",
                      Tool.Paint        : "brushing",
                      Tool.Erase        : "brushing",
                      Tool.Box          : "navigation"
                    }

        # If the user can't label this image, disable the button and say why its disabled
        labelsAllowed = False

        labelsAllowedSlot = self._labelingSlots.labelsAllowed
        if labelsAllowedSlot.ready():
            labelsAllowed = labelsAllowedSlot.value

            if hasattr(self._labelControlUi, "AddLabelButton"):
                self._labelControlUi.AddLabelButton.setEnabled(labelsAllowed and self.maxLabelNumber > self._labelControlUi.labelListModel.rowCount())
                if labelsAllowed:
                    self._labelControlUi.AddLabelButton.setText("Add Label")
                else:
                    self._labelControlUi.AddLabelButton.setText("(Labeling Not Allowed)")

        e = labelsAllowed & (self._labelControlUi.labelListModel.rowCount() > 0)
        self._gui_enableLabeling(e)
        
        if labelsAllowed:
            # Update the applet bar caption
            if toolId == Tool.Navigation:
                # update GUI 
                self._gui_setNavigation()
                
            elif toolId == Tool.Paint:
                # If necessary, tell the brushing model to stop erasing
                if self.editor.brushingModel.erasing:
                    self.editor.brushingModel.disableErasing()
                # Set the brushing size
                brushSize = self.brushSizes[self.paintBrushSizeIndex]
                self.editor.brushingModel.setBrushSize(brushSize)
                # update GUI 
                self._gui_setBrushing()

            elif toolId == Tool.Erase:
                # If necessary, tell the brushing model to start erasing
                if not self.editor.brushingModel.erasing:
                    self.editor.brushingModel.setErasing()
                # Set the brushing size
                eraserSize = self.brushSizes[self.eraserSizeIndex]
                self.editor.brushingModel.setBrushSize(eraserSize)
                # update GUI 
                self._gui_setErasing()
            elif toolId == Tool.Box:
                self._gui_setBox()

        self.editor.setInteractionMode( modeNames[toolId] )
        self._toolId = toolId



    def _initLabelUic(self, drawerUiPath):
        super(Counting3dGui, self)._initLabelUic(drawerUiPath)
        self._labelControlUi.boxToolButton.setCheckable(True)
        self._labelControlUi.boxToolButton.clicked.connect( lambda checked: self._handleToolButtonClicked(checked,
                                                                                                          Tool.Box) )
        self.toolButtons[Tool.Box] = self._labelControlUi.boxToolButton
        if hasattr(self._labelControlUi, "AddBoxButton"):

            self._labelControlUi.AddBoxButton.setIcon( QIcon(ilastikIcons.AddSel) )
            self._labelControlUi.AddBoxButton.clicked.connect( bind(self._addNewBox) )

    def _addNewBox(self):

        label = Label( "Box: ", self.getNextLabelColor(),
                       pmapColor=self.getNextPmapColor(),
                   )
        label.nameChanged.connect(self._updateLabelShortcuts)
        label.nameChanged.connect(self.onLabelNameChanged)
        label.colorChanged.connect(self.onLabelColorChanged)

        newRow = self._labelControlUi.labelListModel.rowCount()
        self._labelControlUi.labelListModel.insertRow( newRow, label )
        newColorIndex = self._labelControlUi.labelListModel.index(newRow, 0)
        self.onLabelListDataChanged(newColorIndex, newColorIndex) # Make sure label layer colortable is in sync with the new color

        # Call the 'changed' callbacks immediately to initialize any listeners
        self.onLabelNameChanged()
        self.onLabelColorChanged()
        self.onPmapColorChanged()

        # Make the new label selected
        nlabels = self._labelControlUi.labelListModel.rowCount()
        selectedRow = nlabels-1
        self._labelControlUi.labelListModel.select(selectedRow)

        self._updateLabelShortcuts()
       
        e = self._labelControlUi.labelListModel.rowCount() > 0
        self._gui_enableLabeling(e)

    def _onLabelSelected(self, row):
        logger.debug("switching to label=%r" % (self._labelControlUi.labelListModel[row]))

        # If the user is selecting a label, he probably wants to be in paint mode
        self._changeInteractionMode(Tool.Paint)

        #+1 because first is transparent
        #FIXME: shouldn't be just row+1 here
        if row >= 2:
            self.toolButtons[Tool.Paint].setEnabled(False)
            self.toolButtons[Tool.Box].setEnabled(True)
            self.toolButtons[Tool.Box].click()
            self.activeBox = row - 2
        else:
            self.toolButtons[Tool.Paint].setEnabled(True)
            self.toolButtons[Tool.Box].setEnabled(False)
            self.toolButtons[Tool.Paint].click()

        self.editor.brushingModel.setDrawnNumber(row+1)
        brushColor = self._labelControlUi.labelListModel[row].brushColor()
        self.editor.brushingModel.setBrushColor( brushColor )


    def handleBoxQuery(self, position5d_start, position5d_stop):
        if self._labelControlUi.arrowToolButton.isChecked():
            self.test(position5d_start, position5d_stop)
        elif self._labelControlUi.boxToolButton.isChecked():
            self.test2(position5d_start, position5d_stop)


    def test2(self, position5d_start, position5d_stop):
        print "test2"

        roi = SubRegion(self.op.LabelInputs, position5d_start,
                                       position5d_stop)
        key = roi.toSlice()
        #key = tuple(k for k in key if k != slice(0,0, None))
        newKey = []
        for k in key:
            if k.stop < k.start:
                k = slice(k.stop, k.start)
            newKey.append(k)
        newKey = tuple(newKey)
        self.boxes[self.activeBox] = newKey
        #self.op.BoxLabelImages[newKey] = self.activeBox + 2
        #self.op.BoxLabelImages
        labelShape = tuple([position5d_stop[i] + 1 - position5d_start[i] for i in range(5)])
        labels = numpy.ones((labelShape), dtype = numpy.uint8) * (self.activeBox + 3)
        self.boxlabelsrc.put(newKey, labels)


    def test(self, position5d_start, position5d_stop):
        print "test"
        roi = SubRegion(self.op.Density, position5d_start,
                                       position5d_stop)
        key = roi.toSlice()
        key = tuple(k for k in key if k != slice(0,0, None))
        newKey = []
        for k in key:
            if k != slice(0,0,None):
                if k.stop < k.start:
                    k = slice(k.stop, k.start)
            newKey.append(k)
        newKey = tuple(newKey)
        try:
            density = numpy.sum(self.op.Density[newKey].wait()) / 255
            strdensity = "{0:.2f}".format(density)
            self._labelControlUi.CountText.setText(strdensity)
        except:
            pass
Esempio n. 2
0
class CarvingGui(LabelingGui):
    def __init__(self, parentApplet, topLevelOperatorView, drawerUiPath=None):
        self.topLevelOperatorView = topLevelOperatorView
        self.isInitialized = (
            False
        )  # Need this flag in carvingApplet where initialization is terminated with label selection

        # members
        self._doneSegmentationLayer = None
        self._showSegmentationIn3D = False
        # self._showUncertaintyLayer = False
        # end: members

        labelingSlots = LabelingGui.LabelingSlots()
        labelingSlots.labelInput = topLevelOperatorView.WriteSeeds
        labelingSlots.labelOutput = topLevelOperatorView.opLabelArray.Output
        labelingSlots.labelEraserValue = topLevelOperatorView.opLabelArray.EraserLabelValue
        labelingSlots.labelNames = topLevelOperatorView.LabelNames
        labelingSlots.labelDelete = topLevelOperatorView.opLabelArray.DeleteLabel
        labelingSlots.maxLabelValue = topLevelOperatorView.opLabelArray.MaxLabelValue

        # We provide our own UI file (which adds an extra control for interactive mode)
        directory = os.path.split(__file__)[0]
        if drawerUiPath is None:
            drawerUiPath = os.path.join(directory, "carvingDrawer.ui")
        self.dialogdirCOM = os.path.join(directory,
                                         "carvingObjectManagement.ui")
        self.dialogdirSAD = os.path.join(directory, "saveAsDialog.ui")

        # Add 3DWidget only if the data is 3D
        is_3d = self._is_3d()

        super(CarvingGui, self).__init__(parentApplet,
                                         labelingSlots,
                                         topLevelOperatorView,
                                         drawerUiPath,
                                         is_3d_widget_visible=is_3d)

        self.parentApplet = parentApplet
        self.labelingDrawerUi.currentObjectLabel.setText("<not saved yet>")

        # Init special base class members
        self.minLabelNumber = 2
        self.maxLabelNumber = 2

        mgr = ShortcutManager()
        ActionInfo = ShortcutManager.ActionInfo

        # set up keyboard shortcuts
        mgr.register(
            "3",
            ActionInfo(
                "Carving",
                "Run interactive segmentation",
                "Run interactive segmentation",
                self.labelingDrawerUi.segment.click,
                self.labelingDrawerUi.segment,
                self.labelingDrawerUi.segment,
            ),
        )

        # Disable 3D view by default
        self.render = False
        if is_3d:
            try:
                self._renderMgr = RenderingManager(self.editor.view3d)
                self._shownObjects3D = {}
                self.render = True
            except:
                self.render = False

        # Segmentation is toggled on by default in _after_init, below.
        # (We can't enable it until the layers are all present.)
        self._showSegmentationIn3D = False
        self._segmentation_3d_label = None

        self.labelingDrawerUi.segment.clicked.connect(self.onSegmentButton)
        self.labelingDrawerUi.segment.setEnabled(True)

        self.topLevelOperatorView.Segmentation.notifyDirty(
            bind(self._segmentation_dirty))
        self.topLevelOperatorView.HasSegmentation.notifyValueChanged(
            bind(self._updateGui))

        ## uncertainty

        # self.labelingDrawerUi.pushButtonUncertaintyFG.setEnabled(False)
        # self.labelingDrawerUi.pushButtonUncertaintyBG.setEnabled(False)

        # def onUncertaintyFGButton():
        #    logger.debug( "uncertFG button clicked" )
        #    pos = self.topLevelOperatorView.getMaxUncertaintyPos(label=2)
        #    self.editor.posModel.slicingPos = (pos[0], pos[1], pos[2])
        # self.labelingDrawerUi.pushButtonUncertaintyFG.clicked.connect(onUncertaintyFGButton)

        # def onUncertaintyBGButton():
        #    logger.debug( "uncertBG button clicked" )
        #    pos = self.topLevelOperatorView.getMaxUncertaintyPos(label=1)
        #    self.editor.posModel.slicingPos = (pos[0], pos[1], pos[2])
        # self.labelingDrawerUi.pushButtonUncertaintyBG.clicked.connect(onUncertaintyBGButton)

        # def onUncertaintyCombo(value):
        #    if value == 0:
        #        value = "none"
        #        self.labelingDrawerUi.pushButtonUncertaintyFG.setEnabled(False)
        #        self.labelingDrawerUi.pushButtonUncertaintyBG.setEnabled(False)
        #        self._showUncertaintyLayer = False
        #    else:
        #        if value == 1:
        #            value = "localMargin"
        #        elif value == 2:
        #            value = "exchangeCount"
        #        elif value == 3:
        #            value = "gabow"
        #        else:
        #            raise RuntimeError("unhandled case '%r'" % value)
        #        self.labelingDrawerUi.pushButtonUncertaintyFG.setEnabled(True)
        #        self.labelingDrawerUi.pushButtonUncertaintyBG.setEnabled(True)
        #        self._showUncertaintyLayer = True
        #        logger.debug( "uncertainty changed to %r" % value )
        #    self.topLevelOperatorView.UncertaintyType.setValue(value)
        #    self.updateAllLayers() #make sure that an added/deleted uncertainty layer is recognized
        # self.labelingDrawerUi.uncertaintyCombo.currentIndexChanged.connect(onUncertaintyCombo)

        self.labelingDrawerUi.objPrefix.setText(self.objectPrefix)
        self.labelingDrawerUi.objPrefix.textChanged.connect(
            self.setObjectPrefix)

        ## save

        self.labelingDrawerUi.save.clicked.connect(self.onSaveButton)

        ## clear

        self.labelingDrawerUi.clear.clicked.connect(self._onClearAction)

        ## object names

        self.labelingDrawerUi.namesButton.clicked.connect(
            self.onShowObjectNames)
        if hasattr(self.labelingDrawerUi, "exportAllMeshesButton"):
            self.labelingDrawerUi.exportAllMeshesButton.clicked.connect(
                self._exportAllObjectMeshes)

        self.labelingDrawerUi.labelListView.allowDelete = False
        self._labelControlUi.labelListModel.allowRemove(False)

        def layerIndexForName(name):
            return self.layerstack.findMatchingIndex(lambda x: x.name == name)

        def addLayerToggleShortcut(layername, shortcut):
            def toggle():
                row = layerIndexForName(layername)
                self.layerstack.selectRow(row)
                layer = self.layerstack[row]
                layer.visible = not layer.visible
                self.viewerControlWidget().layerWidget.setFocus()

            mgr.register(
                shortcut,
                ActionInfo(
                    "Carving",
                    "Toggle layer %s" % layername,
                    "Toggle layer %s" % layername,
                    toggle,
                    self.viewerControlWidget(),
                    None,
                ),
            )

        # TODO
        addLayerToggleShortcut("Completed segments (unicolor)", "d")
        addLayerToggleShortcut("Segmentation", "s")
        addLayerToggleShortcut("Input Data", "r")

        def makeColortable():
            self._doneSegmentationColortable = [QColor(0, 0, 0, 0).rgba()]
            for i in range(254):
                r, g, b = numpy.random.randint(0, 255), numpy.random.randint(
                    0, 255), numpy.random.randint(0, 255)
                # ensure colors have sufficient distance to pure red and pure green
                while (255 - r) + g + b < 128 or r + (255 - g) + b < 128:
                    r, g, b = numpy.random.randint(
                        0, 255), numpy.random.randint(
                            0, 255), numpy.random.randint(0, 255)
                self._doneSegmentationColortable.append(QColor(r, g, b).rgba())
            self._doneSegmentationColortable.append(QColor(0, 255, 0).rgba())

        makeColortable()
        self._updateGui()

    @property
    def objectPrefix(self):
        return self.topLevelOperatorView.ObjectPrefix.value

    def setObjectPrefix(self, value):
        self.topLevelOperatorView.ObjectPrefix.setValue(value)

    def _is_3d(self):
        tagged_shape = defaultdict(lambda: 1)
        tagged_shape.update(
            self.topLevelOperatorView.InputData.meta.getTaggedShape())
        is_3d = tagged_shape["x"] > 1 and tagged_shape[
            "y"] > 1 and tagged_shape["z"] > 1
        return is_3d

    def _after_init(self):
        super(CarvingGui, self)._after_init()
        if self.render:
            self._toggleSegmentation3D()

    def _updateGui(self):
        self.labelingDrawerUi.save.setEnabled(
            self.topLevelOperatorView.dataIsStorable())

    def onSegmentButton(self):
        logger.debug("segment button clicked")
        bkPriorityValue = self.labelingDrawerUi.backgroundPrioritySpin.value()
        self.topLevelOperatorView.BackgroundPriority.setValue(bkPriorityValue)
        biasValue = self.labelingDrawerUi.noBiasBelowSpin.value()
        self.topLevelOperatorView.NoBiasBelow.setValue(biasValue)
        self.topLevelOperatorView.Trigger.setDirty(slice(None))

    def getObjectNames(self):
        return self.topLevelOperatorView.AllObjectNames[:].wait()

    def findNextPrefixNumber(self):
        names = self.getObjectNames()
        last = 0

        for n in names:
            match = re.match(f"^{self.objectPrefix}(?P<suffix>\d+)", n)
            if match:
                val = int(match.group("suffix"))
                if val > last:
                    last = val

        return last + 1

    def saveAsDialog(self, name=""):
        """special functionality: reject names given to other objects"""
        namesInUse = self.getObjectNames()

        def generateObjectName():
            return f"{self.objectPrefix}{self.findNextPrefixNumber()}"

        name = name or generateObjectName()

        dialog = uic.loadUi(self.dialogdirSAD)
        dialog.lineEdit.setText(name)
        dialog.lineEdit.selectAll()
        dialog.warning.setVisible(False)
        dialog.Ok.clicked.connect(dialog.accept)
        dialog.Cancel.clicked.connect(dialog.reject)
        dialog.isDisabled = False

        def validate():
            name = dialog.lineEdit.text()
            if name in namesInUse:
                dialog.Ok.setEnabled(False)
                dialog.warning.setVisible(True)
                dialog.isDisabled = True
            elif dialog.isDisabled:
                dialog.Ok.setEnabled(True)
                dialog.warning.setVisible(False)
                dialog.isDisabled = False

        dialog.lineEdit.textChanged.connect(validate)
        result = dialog.exec_()
        if result:
            return str(dialog.lineEdit.text())

    def onSaveButton(self):
        logger.info("save object as?")
        prevName = self.topLevelOperatorView.currentObjectName()
        if self.topLevelOperatorView.dataIsStorable():
            prevName = ""
            if self.topLevelOperatorView.hasCurrentObject():
                prevName = self.topLevelOperatorView.currentObjectName()
            if prevName == "<not saved yet>":
                prevName = ""
            name = self.saveAsDialog(name=prevName)
            if name is None:
                return
            namesInUse = self.getObjectNames()
            if name in namesInUse and name != prevName:
                QMessageBox.critical(
                    self,
                    "Save Object As",
                    "An object with name '%s' already exists.\nPlease choose a different name."
                    % name,
                )
                return
            self.topLevelOperatorView.saveObjectAs(name)
            logger.info("save object as %s" % name)
            if prevName != name and prevName != "":
                self.topLevelOperatorView.deleteObject(prevName)
            elif prevName == name:
                self._renderMgr.removeObject(prevName)
                self._renderMgr.invalidateObject(prevName)
                self._shownObjects3D.pop(prevName, None)
        else:
            msgBox = QMessageBox(self)
            msgBox.setText("The data does not seem fit to be stored.")
            msgBox.setWindowTitle("Problem with Data")
            msgBox.setIcon(2)
            msgBox.exec_()
            logger.error("object not saved due to faulty data.")

    def onShowObjectNames(self):
        """show object names and allow user to load/delete them"""
        dialog = uic.loadUi(self.dialogdirCOM)
        names = self.getObjectNames()
        dialog.objectNames.addItems(sorted(names, key=_humansort_key))

        def loadSelection():
            selected = [
                str(name.text())
                for name in dialog.objectNames.selectedItems()
            ]
            dialog.close()
            for objectname in selected:
                self.topLevelOperatorView.loadObject(objectname)

        def deleteSelection():
            items = dialog.objectNames.selectedItems()
            if self.confirmAndDelete([str(name.text()) for name in items]):
                for name in items:
                    name.setHidden(True)
            dialog.close()

        dialog.loadButton.clicked.connect(loadSelection)
        dialog.deleteButton.clicked.connect(deleteSelection)
        dialog.cancelButton.clicked.connect(dialog.close)
        dialog.exec_()

    def confirmAndDelete(self, namelist):
        logger.info("confirmAndDelete: {}".format(namelist))
        objectlist = "".join("\n  " + str(i) for i in namelist)
        confirmed = QMessageBox.question(
            self,
            "Delete Object",
            "Do you want to delete these objects?" + objectlist,
            QMessageBox.Yes | QMessageBox.Cancel,
            defaultButton=QMessageBox.Yes,
        )

        if confirmed == QMessageBox.Yes:
            for name in namelist:
                self.topLevelOperatorView.deleteObject(name)
            return True
        return False

    def labelingContextMenu(self, names, op, position5d):
        menu = QMenu(self)
        menu.setObjectName("carving_context_menu")
        posItem = menu.addAction("position %d %d %d" %
                                 (position5d[1], position5d[2], position5d[3]))
        posItem.setEnabled(False)
        menu.addSeparator()
        for name in names:
            submenu = QMenu(name, menu)

            # Load
            loadAction = submenu.addAction("Load %s" % name)
            loadAction.triggered.connect(partial(op.loadObject, name))

            # Delete
            def onDelAction(_name):
                self.confirmAndDelete([_name])
                if self.render and self._renderMgr.ready:
                    self._update_rendering()

            delAction = submenu.addAction("Delete %s" % name)
            delAction.triggered.connect(partial(onDelAction, name))

            if self.render:
                if name in self._shownObjects3D:
                    # Remove
                    def onRemove3D(_name):
                        label = self._shownObjects3D.pop(_name)
                        self._renderMgr.removeObject(label)
                        self._update_rendering()

                    removeAction = submenu.addAction("Remove %s from 3D view" %
                                                     name)
                    removeAction.triggered.connect(partial(onRemove3D, name))
                else:
                    # Show
                    def onShow3D(_name):
                        label = self._renderMgr.addObject()
                        self._shownObjects3D[_name] = label
                        self._update_rendering()

                    showAction = submenu.addAction("Show 3D %s" % name)
                    showAction.triggered.connect(partial(onShow3D, name))

            # Export mesh

            exportAction = submenu.addAction("Export mesh for %s" % name)
            exportAction.triggered.connect(
                partial(self._onContextMenuExportMesh, name))

            menu.addMenu(submenu)

        if names:
            menu.addSeparator()

        menu.addSeparator()
        if self.render:
            showSeg3DAction = menu.addAction("Show Editing Segmentation in 3D")
            showSeg3DAction.setCheckable(True)
            showSeg3DAction.setChecked(self._showSegmentationIn3D)
            showSeg3DAction.triggered.connect(self._toggleSegmentation3D)

        if op.dataIsStorable():
            menu.addAction("Save object").triggered.connect(self.onSaveButton)
        menu.addAction("Browse objects").triggered.connect(
            self.onShowObjectNames)
        menu.addAction("Segment").triggered.connect(self.onSegmentButton)
        menu.addAction("Clear").triggered.connect(self._onClearAction)
        return menu

    def _onClearAction(self):
        confirm = QMessageBox.warning(self, "Really Clear?",
                                      "Clear all brushtrokes?",
                                      QMessageBox.Ok | QMessageBox.Cancel)
        if confirm == QMessageBox.Ok:
            self.topLevelOperatorView.clearCurrentLabeling()

    def _clearLabelListGui(self):
        # Remove rows until we have the right number
        while self._labelControlUi.labelListModel.rowCount() > 2:
            self._removeLastLabel()

    def _onContextMenuExportMesh(self, _name):
        """
        Export a single object mesh to a user-specified filename.
        """
        recent_dir = preferences.get("carving", "recent export mesh directory")
        if recent_dir is None:
            defaultPath = os.path.join(os.path.expanduser("~"),
                                       "{}obj".format(_name))
        else:
            defaultPath = os.path.join(recent_dir, "{}.obj".format(_name))
        filepath, _filter = QFileDialog.getSaveFileName(
            self, "Save meshes for object '{}'".format(_name), defaultPath,
            "OBJ Files (*.obj)")
        if not filepath:
            return
        obj_filepath = str(filepath)
        preferences.set("carving", "recent export mesh directory",
                        os.path.split(obj_filepath)[0])

        self._exportMeshes([_name], [obj_filepath])

    def _exportAllObjectMeshes(self):
        """
        Export all objects in the project as separate .obj files, stored to a user-specified directory.
        """
        mst = self.topLevelOperatorView.MST.value
        if not list(mst.object_lut.keys()):
            QMessageBox.critical(
                self, "Can't Export",
                "You have no saved objets, so there are no meshes to export.")
            return

        recent_dir = preferences.get("carving", "recent export mesh directory")
        if recent_dir is None:
            defaultPath = os.path.join(os.path.expanduser("~"))
        else:
            defaultPath = os.path.join(recent_dir)
        export_dir = QFileDialog.getExistingDirectory(
            self, "Select export directory for mesh files", defaultPath)
        if not export_dir:
            return
        export_dir = str(export_dir)
        preferences.set("carving", "recent export mesh directory", export_dir)

        # Get the list of all object names
        object_names = []
        obj_filepaths = []
        for object_name in list(mst.object_lut.keys()):
            object_names.append(object_name)
            obj_filepaths.append(
                os.path.join(export_dir, "{}.obj".format(object_name)))

        if object_names:
            self._exportMeshes(object_names, obj_filepaths)

    def _exportMeshes(self, object_names: List[str],
                      obj_filepaths: List[str]) -> Request:
        """Save objects in the mst to .obj files

        Args:
            object_names: Names of the objects in the mst
            obj_filepaths: One path for each object in object_names

        Returns:
            Returns the request object, used in testing
        """
        def get_label_volume_from_mst(mst, object_name):
            object_supervoxels = mst.object_lut[object_name]
            object_lut = numpy.zeros(mst.nodeNum + 1, dtype=numpy.int32)
            object_lut[object_supervoxels] = 1
            supervoxel_volume = mst.supervoxelUint32
            object_volume = object_lut[supervoxel_volume]
            return object_volume

        mst = self.topLevelOperatorView.MST.value

        def exportMeshes(object_names, obj_filepaths):
            n_objects = len(object_names)
            progress_update = 100 / n_objects
            try:
                for obj, obj_path, obj_n in zip(object_names, obj_filepaths,
                                                range(n_objects)):
                    object_volume = get_label_volume_from_mst(mst, obj)
                    unique_ids = len(numpy.unique(object_volume))

                    if unique_ids <= 1:
                        logger.info(f"No voxels found for {obj}, skipping")
                        continue
                    elif unique_ids > 2:
                        logger.info(
                            f"Supervoxel segmentation not unique for {obj}, skipping, got {unique_ids}"
                        )
                        continue

                    logger.info(f"Generating mesh for {obj}")
                    _, mesh_data = list(labeling_to_mesh(object_volume,
                                                         [1]))[0]
                    self.parentApplet.progressSignal(
                        (obj_n + 0.5) * progress_update)
                    logger.info(f"Mesh generation for {obj} complete.")

                    logger.info(f"Saving mesh for {obj} to {obj_path}")
                    mesh_to_obj(mesh_data, obj_path, obj)
                    self.parentApplet.progressSignal(
                        (obj_n + 1) * progress_update)
            finally:
                self.parentApplet.busy = False
                self.parentApplet.progressSignal(100)
                self.parentApplet.appletStateUpdateRequested()

        self.parentApplet.busy = True
        self.parentApplet.progressSignal(-1)
        self.parentApplet.appletStateUpdateRequested()

        req = Request(partial(exportMeshes, object_names, obj_filepaths))
        req.submit()
        return req

    def handleEditorRightClick(self, position5d, globalWindowCoordinate):
        names = self.topLevelOperatorView.doneObjectNamesForPosition(
            position5d[1:4])
        op = self.topLevelOperatorView

        # (Subclasses may override menu)
        menu = self.labelingContextMenu(names, op, position5d)
        if menu is not None:
            menu.exec_(globalWindowCoordinate)

    def _toggleSegmentation3D(self):
        self._showSegmentationIn3D = not self._showSegmentationIn3D
        if self._showSegmentationIn3D:
            self._segmentation_3d_label = self._renderMgr.addObject()
        else:
            self._renderMgr.removeObject(self._segmentation_3d_label)
            self._segmentation_3d_label = None
        self._update_rendering()

    def _segmentation_dirty(self):
        if self.render:
            self._renderMgr.invalidateObject(CURRENT_SEGMENTATION_NAME)
            self._renderMgr.removeObject(CURRENT_SEGMENTATION_NAME)

        self._update_rendering()

    def _update_rendering(self):
        if not self.render:
            return

        op = self.topLevelOperatorView
        if not self._renderMgr.ready:
            shape = op.InputData.meta.shape[1:4]
            self._renderMgr.setup(op.InputData.meta.shape[1:4])

        # remove nonexistent objects
        self._shownObjects3D = dict(
            (k, v) for k, v in self._shownObjects3D.items()
            if k in list(op.MST.value.object_lut.keys()))

        lut = numpy.zeros(op.MST.value.nodeNum + 1, dtype=numpy.int32)
        label_name_map = {}
        for name, label in self._shownObjects3D.items():
            objectSupervoxels = op.MST.value.object_lut[name]
            lut[objectSupervoxels] = label
            label_name_map[label] = name
            label_name_map[name] = label

        if self._showSegmentationIn3D:
            # Add segmentation as label, which is green
            label_name_map[
                self._segmentation_3d_label] = CURRENT_SEGMENTATION_NAME
            label_name_map[
                CURRENT_SEGMENTATION_NAME] = self._segmentation_3d_label
            lut[:] = numpy.where(op.MST.value.getSuperVoxelSeg() == 2,
                                 self._segmentation_3d_label, lut)

        self._renderMgr.volume = lut[
            op.MST.value.
            supervoxelUint32], label_name_map  # (Advanced indexing)
        self._update_colors()
        self._renderMgr.update()

    def _update_colors(self):
        """Update colors of objects in 3D viewport"""
        op = self.topLevelOperatorView
        ctable = self._doneSegmentationLayer.colorTable

        for name, label in self._shownObjects3D.items():
            color = QColor(ctable[op.MST.value.object_names[name]])
            color = (color.red() / 255, color.green() / 255,
                     color.blue() / 255)
            self._renderMgr.setColor(label, color)

        if self._showSegmentationIn3D and self._segmentation_3d_label is not None:
            # color of the foreground label from label list data
            labels = self.labelListData
            assert len(labels) == 2
            fg_label = labels[1]
            color = fg_label.pmapColor()  # 2 is the foreground index
            self._renderMgr.setColor(
                self._segmentation_3d_label,
                (color.red() / 255, color.green() / 255, color.blue() / 255))

    def _getNext(self, slot, parentFun, transform=None):
        numLabels = self.labelListData.rowCount()
        value = slot.value
        if numLabels < len(value):
            result = value[numLabels]
            if transform is not None:
                result = transform(result)
            return result
        else:
            return parentFun()

    def getNextLabelName(self):
        return self._getNext(self.topLevelOperatorView.LabelNames,
                             super(CarvingGui, self).getNextLabelName)

    def appletDrawers(self):
        return [("Carving", self._labelControlUi)]

    def setupLayers(self):
        logger.debug("setupLayers")

        layers = []

        def onButtonsEnabled(slot, roi):
            currObj = self.topLevelOperatorView.CurrentObjectName.value
            hasSeg = self.topLevelOperatorView.HasSegmentation.value

            self.labelingDrawerUi.currentObjectLabel.setText(currObj)
            self.labelingDrawerUi.save.setEnabled(hasSeg)

        self.topLevelOperatorView.CurrentObjectName.notifyDirty(
            onButtonsEnabled)
        self.topLevelOperatorView.HasSegmentation.notifyDirty(onButtonsEnabled)
        self.topLevelOperatorView.opLabelArray.NonzeroBlocks.notifyDirty(
            onButtonsEnabled)

        # Labels
        labellayer, labelsrc = self.createLabelLayer(direct=True)
        if labellayer is not None:
            labellayer._allowToggleVisible = False
            layers.append(labellayer)
            # Tell the editor where to draw label data
            self.editor.setLabelSink(labelsrc)

        # uncertainty
        # if self._showUncertaintyLayer:
        #    uncert = self.topLevelOperatorView.Uncertainty
        #    if uncert.ready():
        #        colortable = []
        #        for i in range(256-len(colortable)):
        #            r,g,b,a = i,0,0,i
        #            colortable.append(QColor(r,g,b,a).rgba())
        #        layer = ColortableLayer(createDataSource(uncert), colortable, direct=True)
        #        layer.name = "Uncertainty"
        #        layer.visible = True
        #        layer.opacity = 0.3
        #        layers.append(layer)

        # segmentation
        seg = self.topLevelOperatorView.Segmentation

        # seg = self.topLevelOperatorView.MST.value.segmentation
        # temp = self._done_lut[self.MST.value.supervoxelUint32[sl[1:4]]]
        if seg.ready():
            # source = RelabelingArraySource(seg)
            # source.setRelabeling(numpy.arange(256, dtype=numpy.uint8))

            # assign to the object label color, 0 is transparent, 1 is background
            colortable = [
                QColor(0, 0, 0, 0).rgba(),
                QColor(0, 0, 0, 0).rgba(), labellayer._colorTable[2]
            ]
            for i in range(256 - len(colortable)):
                r, g, b = numpy.random.randint(0, 255), numpy.random.randint(
                    0, 255), numpy.random.randint(0, 255)
                colortable.append(QColor(r, g, b).rgba())

            layer = ColortableLayer(createDataSource(seg),
                                    colortable,
                                    direct=True)
            layer.name = "Segmentation"
            layer.setToolTip(
                "This layer displays the <i>current</i> segmentation. Simply add foreground and background "
                "labels, then press <i>Segment</i>.")
            layer.visible = True
            layer.opacity = 0.3
            layers.append(layer)

        # done
        doneSeg = self.topLevelOperatorView.DoneSegmentation
        if doneSeg.ready():
            # FIXME: if the user segments more than 255 objects, those with indices that divide by 255 will be shown as transparent
            # both here and in the _doneSegmentationColortable
            colortable = 254 * [QColor(230, 25, 75).rgba()]
            colortable.insert(0, QColor(0, 0, 0, 0).rgba())

            # have to use lazyflow because it provides dirty signals
            layer = ColortableLayer(createDataSource(doneSeg),
                                    colortable,
                                    direct=True)
            layer.name = "Completed segments (unicolor)"
            layer.setToolTip(
                "In order to keep track of which objects you have already completed, this layer "
                "shows <b>all completed object</b> in one color (<b>blue</b>). "
                "The reason for only one color is that for finding out which "
                "objects to label next, the identity of already completed objects is unimportant "
                "and destracting.")
            layer.visible = False
            layer.opacity = 0.5
            layers.append(layer)

            layer = ColortableLayer(createDataSource(doneSeg),
                                    self._doneSegmentationColortable,
                                    direct=True)
            layer.name = "Completed segments (one color per object)"
            layer.setToolTip(
                "<html>In order to keep track of which objects you have already completed, this layer "
                "shows <b>all completed object</b>, each with a random color.</html>"
            )
            layer.visible = False
            layer.opacity = 0.5
            layer.colortableIsRandom = True
            self._doneSegmentationLayer = layer
            layers.append(layer)

        # supervoxel
        sv = self.topLevelOperatorView.Supervoxels
        if sv.ready():
            colortable = []
            for i in range(256):
                r, g, b = numpy.random.randint(0, 255), numpy.random.randint(
                    0, 255), numpy.random.randint(0, 255)
                colortable.append(QColor(r, g, b).rgba())
            layer = ColortableLayer(createDataSource(sv),
                                    colortable,
                                    direct=True)
            layer.name = "Supervoxels"
            layer.setToolTip(
                "<html>This layer shows the partitioning of the input image into <b>supervoxels</b>. The carving "
                "algorithm uses these tiny puzzle-piceces to piece together the segmentation of an "
                "object. Sometimes, supervoxels are too large and straddle two distinct objects "
                "(undersegmentation). In this case, it will be impossible to achieve the desired "
                "segmentation. This layer helps you to understand these cases.</html>"
            )
            layer.visible = False
            layer.colortableIsRandom = True
            layer.opacity = 0.5
            layers.append(layer)

        # Visual overlay (just for easier labeling)
        overlaySlot = self.topLevelOperatorView.OverlayData
        if overlaySlot.ready():
            overlay5D = self.topLevelOperatorView.OverlayData.value
            layer = GrayscaleLayer(ArraySource(overlay5D), direct=True)
            layer.visible = True
            layer.name = "Overlay"
            layer.opacity = 1.0
            # if the flag window_leveling is set the contrast
            # of the layer is adjustable
            layer.window_leveling = True
            self.labelingDrawerUi.thresToolButton.show()
            layers.append(layer)
            del layer

        inputSlot = self.topLevelOperatorView.InputData
        if inputSlot.ready():
            layer = GrayscaleLayer(createDataSource(inputSlot), direct=True)
            layer.name = "Input Data"
            layer.setToolTip(
                "<html>The data originally loaded into ilastik (unprocessed).</html>"
            )
            # layer.visible = not rawSlot.ready()
            layer.visible = True
            layer.opacity = 1.0

            # Window leveling is already active on the Overlay,
            # but if no overlay was provided, then activate window_leveling on the raw data instead.
            if not overlaySlot.ready():
                # if the flag window_leveling is set the contrast
                # of the layer is adjustable
                layer.window_leveling = True
                self.labelingDrawerUi.thresToolButton.show()

            layers.append(layer)
            del layer

        filteredSlot = self.topLevelOperatorView.FilteredInputData
        if filteredSlot.ready():
            layer = GrayscaleLayer(createDataSource(filteredSlot))
            layer.name = "Filtered Input"
            layer.visible = False
            layer.opacity = 1.0
            layers.append(layer)

        return layers
class PixelClassificationGui(LabelingGui):

    ###########################################
    ### AppletGuiInterface Concrete Methods ###
    ###########################################
    def centralWidget( self ):
        return self

    def stopAndCleanUp(self):
        for fn in self.__cleanup_fns:
            fn()

        # Base class
        super(PixelClassificationGui, self).stopAndCleanUp()

    def viewerControlWidget(self):
        return self._viewerControlUi

    def menus( self ):
        menus = super( PixelClassificationGui, self ).menus()

        advanced_menu = QMenu("Advanced", parent=self)
                    
        def handleClassifierAction():
            dlg = ClassifierSelectionDlg(self.topLevelOperatorView, parent=self)
            dlg.exec_()
        
        classifier_action = advanced_menu.addAction("Classifier...")
        classifier_action.triggered.connect( handleClassifierAction )
        
        def showVarImpDlg():
            varImpDlg = VariableImportanceDialog(self.topLevelOperatorView.Classifier.value.named_importances, parent=self)
            varImpDlg.exec_()
            
        advanced_menu.addAction("Variable Importance Table").triggered.connect(showVarImpDlg)
        
        def handleImportLabelsAction():
            # Find the directory of the most recently opened image file
            mostRecentImageFile = PreferencesManager().get( 'DataSelection', 'recent image' )
            if mostRecentImageFile is not None:
                defaultDirectory = os.path.split(mostRecentImageFile)[0]
            else:
                defaultDirectory = os.path.expanduser('~')
            fileNames = DataSelectionGui.getImageFileNamesToOpen(self, defaultDirectory)
            fileNames = list(map(str, fileNames))
            
            # For now, we require a single hdf5 file
            if len(fileNames) > 1:
                QMessageBox.critical(self, "Too many files", 
                                     "Labels must be contained in a single hdf5 volume.")
                return
            if len(fileNames) == 0:
                # user cancelled
                return
            
            file_path = fileNames[0]
            internal_paths = DataSelectionGui.getPossibleInternalPaths(file_path)
            if len(internal_paths) == 0:
                QMessageBox.critical(self, "No volumes in file", 
                                     "Couldn't find a suitable dataset in your hdf5 file.")
                return
            if len(internal_paths) == 1:
                internal_path = internal_paths[0]
            else:
                dlg = H5VolumeSelectionDlg(internal_paths, self)
                if dlg.exec_() == QDialog.Rejected:
                    return
                selected_index = dlg.combo.currentIndex()
                internal_path = str(internal_paths[selected_index])

            path_components = PathComponents(file_path)
            path_components.internalPath = str(internal_path)
            
            try:
                top_op = self.topLevelOperatorView
                opReader = OpInputDataReader(parent=top_op.parent)
                opReader.FilePath.setValue( path_components.totalPath() )
                
                # Reorder the axes
                op5 = OpReorderAxes(parent=top_op.parent)
                op5.AxisOrder.setValue( top_op.LabelInputs.meta.getAxisKeys() )
                op5.Input.connect( opReader.Output )
            
                # Finally, import the labels
                top_op.importLabels( top_op.current_view_index(), op5.Output )
                    
            finally:
                op5.cleanUp()
                opReader.cleanUp()

        def print_label_blocks(sorted_axis):
            sorted_column = self.topLevelOperatorView.InputImages.meta.getAxisKeys().index(sorted_axis)
            
            input_shape = self.topLevelOperatorView.InputImages.meta.shape
            label_block_slicings = self.topLevelOperatorView.NonzeroLabelBlocks.value

            sorted_block_slicings = sorted(label_block_slicings, key=lambda s: s[sorted_column])

            for slicing in sorted_block_slicings:
                # Omit channel
                order = "".join( self.topLevelOperatorView.InputImages.meta.getAxisKeys() )
                line = order[:-1].upper() + ": "
                line += slicing_to_string( slicing[:-1], input_shape )
                print(line)

        labels_submenu = QMenu("Labels")
        self.labels_submenu = labels_submenu # Must retain this reference or else it gets auto-deleted.
        
        import_labels_action = labels_submenu.addAction("Import Labels...")
        import_labels_action.triggered.connect( handleImportLabelsAction )

        self.print_labels_submenu = QMenu("Print Label Blocks")
        labels_submenu.addMenu(self.print_labels_submenu)
        
        for axis in self.topLevelOperatorView.InputImages.meta.getAxisKeys()[:-1]:
            self.print_labels_submenu\
                .addAction("Sort by {}".format( axis.upper() ))\
                .triggered.connect( partial(print_label_blocks, axis) )

        advanced_menu.addMenu(labels_submenu)
        
        if ilastik_config.getboolean('ilastik', 'debug'):
            def showBookmarksWindow():
                self._bookmarks_window.show()
            advanced_menu.addAction("Bookmarks...").triggered.connect(showBookmarksWindow)

        menus += [advanced_menu]

        return menus

    ###########################################
    ###########################################

    def __init__(self, parentApplet, topLevelOperatorView, labelingDrawerUiPath=None ):
        self.parentApplet = parentApplet
        # Tell our base class which slots to monitor
        labelSlots = LabelingGui.LabelingSlots()
        labelSlots.labelInput = topLevelOperatorView.LabelInputs
        labelSlots.labelOutput = topLevelOperatorView.LabelImages
        labelSlots.labelEraserValue = topLevelOperatorView.opLabelPipeline.opLabelArray.eraser
        labelSlots.labelDelete = topLevelOperatorView.opLabelPipeline.DeleteLabel
        labelSlots.labelNames = topLevelOperatorView.LabelNames

        self.__cleanup_fns = []

        # We provide our own UI file (which adds an extra control for interactive mode)
        if labelingDrawerUiPath is None:
            labelingDrawerUiPath = os.path.split(__file__)[0] + '/labelingDrawer.ui'

        # Base class init
        super(PixelClassificationGui, self).__init__( parentApplet, labelSlots, topLevelOperatorView, labelingDrawerUiPath )
        
        self.topLevelOperatorView = topLevelOperatorView

        self.interactiveModeActive = False
        # Immediately update our interactive state
        self.toggleInteractive( not self.topLevelOperatorView.FreezePredictions.value )

        self._currentlySavingPredictions = False

        self.labelingDrawerUi.labelListView.support_merges = True

        self.labelingDrawerUi.liveUpdateButton.setEnabled(False)
        self.labelingDrawerUi.liveUpdateButton.setIcon( QIcon(ilastikIcons.Play) )
        self.labelingDrawerUi.liveUpdateButton.setToolButtonStyle(Qt.ToolButtonTextBesideIcon)
        self.labelingDrawerUi.liveUpdateButton.toggled.connect( self.toggleInteractive )

        self.initFeatSelDlg()
        self.labelingDrawerUi.suggestFeaturesButton.clicked.connect(self.show_feature_selection_dialog)
        self.featSelDlg.accepted.connect(self.update_features_from_dialog)
        self.labelingDrawerUi.suggestFeaturesButton.setEnabled(False)

        self.topLevelOperatorView.LabelNames.notifyDirty( bind(self.handleLabelSelectionChange) )
        self.__cleanup_fns.append( partial( self.topLevelOperatorView.LabelNames.unregisterDirty, bind(self.handleLabelSelectionChange) ) )
        
        self._initShortcuts()

        self._bookmarks_window = BookmarksWindow(self, self.topLevelOperatorView)


        # FIXME: We MUST NOT enable the render manager by default,
        #        since it will drastically slow down the app for large volumes.
        #        For now, we leave it off by default.
        #        To re-enable rendering, we need to allow the user to render a segmentation 
        #        and then initialize the render manager on-the-fly. 
        #        (We might want to warn the user if her volume is not small.)
        self.render = False
        self._renderMgr = None
        self._renderedLayers = {} # (layer name, label number)
        
        # Always off for now (see note above)
        if self.render:
            try:
                self._renderMgr = RenderingManager( self.editor.view3d )
            except:
                self.render = False

        # toggle interactive mode according to freezePredictions.value
        self.toggleInteractive(not self.topLevelOperatorView.FreezePredictions.value)
        def FreezePredDirty():
            self.toggleInteractive(not self.topLevelOperatorView.FreezePredictions.value)
        # listen to freezePrediction changes
        self.topLevelOperatorView.FreezePredictions.notifyDirty( bind(FreezePredDirty) )
        self.__cleanup_fns.append( partial( self.topLevelOperatorView.FreezePredictions.unregisterDirty, bind(FreezePredDirty) ) )

    def initFeatSelDlg(self):
        if self.topLevelOperatorView.name=="OpPixelClassification":
            thisOpFeatureSelection = self.topLevelOperatorView.parent.featureSelectionApplet.topLevelOperator.innerOperators[0]
        elif self.topLevelOperatorView.name=="OpPixelClassification0":
            thisOpFeatureSelection = self.topLevelOperatorView.parent.featureSelectionApplets[0].topLevelOperator.innerOperators[0]
        elif self.topLevelOperatorView.name=="OpPixelClassification1":
            thisOpFeatureSelection = self.topLevelOperatorView.parent.featureSelectionApplets[1].topLevelOperator.innerOperators[0]
        elif self.topLevelOperatorView.name=="OpPixelClassification2":
            thisOpFeatureSelection = self.topLevelOperatorView.parent.featureSelectionApplets[2].topLevelOperator.innerOperators[0]
        elif self.topLevelOperatorView.name=="OpPixelClassification3":
            thisOpFeatureSelection = self.topLevelOperatorView.parent.featureSelectionApplets[3].topLevelOperator.innerOperators[0]
        else:
            raise NotImplementedError

        self.featSelDlg = FeatureSelectionDialog(thisOpFeatureSelection, self.topLevelOperatorView)

    def show_feature_selection_dialog(self):
        self.featSelDlg.exec_()


    def update_features_from_dialog(self):
        if self.topLevelOperatorView.name=="OpPixelClassification":
            thisOpFeatureSelection = self.topLevelOperatorView.parent.featureSelectionApplet.topLevelOperator.innerOperators[0]
        elif self.topLevelOperatorView.name=="OpPixelClassification0":
            thisOpFeatureSelection = self.topLevelOperatorView.parent.featureSelectionApplets[0].topLevelOperator.innerOperators[0]
        elif self.topLevelOperatorView.name=="OpPixelClassification1":
            thisOpFeatureSelection = self.topLevelOperatorView.parent.featureSelectionApplets[1].topLevelOperator.innerOperators[0]
        elif self.topLevelOperatorView.name=="OpPixelClassification2":
            thisOpFeatureSelection = self.topLevelOperatorView.parent.featureSelectionApplets[2].topLevelOperator.innerOperators[0]
        elif self.topLevelOperatorView.name=="OpPixelClassification3":
            thisOpFeatureSelection = self.topLevelOperatorView.parent.featureSelectionApplets[3].topLevelOperator.innerOperators[0]
        else:
            raise NotImplementedError


        thisOpFeatureSelection.SelectionMatrix.setValue(self.featSelDlg.selected_features_matrix)
        thisOpFeatureSelection.SelectionMatrix.setDirty()
        thisOpFeatureSelection.setupOutputs()

    def initViewerControlUi(self):
        localDir = os.path.split(__file__)[0]
        self._viewerControlUi = uic.loadUi( os.path.join( localDir, "viewerControls.ui" ) )

        # Connect checkboxes
        def nextCheckState(checkbox):
            checkbox.setChecked( not checkbox.isChecked() )
        self._viewerControlUi.checkShowPredictions.nextCheckState = partial(nextCheckState, self._viewerControlUi.checkShowPredictions)
        self._viewerControlUi.checkShowSegmentation.nextCheckState = partial(nextCheckState, self._viewerControlUi.checkShowSegmentation)

        self._viewerControlUi.checkShowPredictions.clicked.connect( self.handleShowPredictionsClicked )
        self._viewerControlUi.checkShowSegmentation.clicked.connect( self.handleShowSegmentationClicked )

        # The editor's layerstack is in charge of which layer movement buttons are enabled
        model = self.editor.layerStack
        self._viewerControlUi.viewerControls.setupConnections(model)
       
    def _initShortcuts(self):
        mgr = ShortcutManager()
        ActionInfo = ShortcutManager.ActionInfo
        shortcutGroupName = "Predictions"

        mgr.register( "p", ActionInfo( shortcutGroupName,
                                       "Toggle Prediction",
                                       "Toggle Prediction Layer Visibility",
                                       self._viewerControlUi.checkShowPredictions.click,
                                       self._viewerControlUi.checkShowPredictions,
                                       self._viewerControlUi.checkShowPredictions ) )

        mgr.register( "s", ActionInfo( shortcutGroupName,
                                       "Toggle Segmentaton",
                                       "Toggle Segmentaton Layer Visibility",
                                       self._viewerControlUi.checkShowSegmentation.click,
                                       self._viewerControlUi.checkShowSegmentation,
                                       self._viewerControlUi.checkShowSegmentation ) )

        mgr.register( "l", ActionInfo( shortcutGroupName,
                                       "Live Prediction",
                                       "Toggle Live Prediction Mode",
                                       self.labelingDrawerUi.liveUpdateButton.toggle,
                                       self.labelingDrawerUi.liveUpdateButton,
                                       self.labelingDrawerUi.liveUpdateButton ) )

    def _setup_contexts(self, layer):
        def callback(pos, clayer=layer):
            name = clayer.name
            if name in self._renderedLayers:
                label = self._renderedLayers.pop(name)
                self._renderMgr.removeObject(label)
                self._update_rendering()
            else:
                label = self._renderMgr.addObject()
                self._renderedLayers[clayer.name] = label
                self._update_rendering()

        if self.render:
            layer.contexts.append( QAction('Toggle 3D rendering', None, triggered=callback) )

    def setupLayers(self):
        """
        Called by our base class when one of our data slots has changed.
        This function creates a layer for each slot we want displayed in the volume editor.
        """
        # Base class provides the label layer.
        layers = super(PixelClassificationGui, self).setupLayers()

        ActionInfo = ShortcutManager.ActionInfo

        if ilastik_config.getboolean('ilastik', 'debug'):

            # Add the label projection layer.
            labelProjectionSlot = self.topLevelOperatorView.opLabelPipeline.opLabelArray.Projection2D
            if labelProjectionSlot.ready():
                projectionSrc = LazyflowSource(labelProjectionSlot)
                try:
                    # This colortable requires matplotlib
                    from volumina.colortables import jet
                    projectionLayer = ColortableLayer( projectionSrc, 
                                                       colorTable=[QColor(0,0,0,128).rgba()]+jet(N=255), 
                                                       normalize=(0.0, 1.0) )
                except (ImportError, RuntimeError):
                    pass
                else:
                    projectionLayer.name = "Label Projection"
                    projectionLayer.visible = False
                    projectionLayer.opacity = 1.0
                    layers.append(projectionLayer)

        # Show the mask over everything except labels
        maskSlot = self.topLevelOperatorView.PredictionMasks
        if maskSlot.ready():
            maskLayer = self._create_binary_mask_layer_from_slot( maskSlot )
            maskLayer.name = "Mask"
            maskLayer.visible = True
            maskLayer.opacity = 1.0
            layers.append( maskLayer )

        # Add the uncertainty estimate layer
        uncertaintySlot = self.topLevelOperatorView.UncertaintyEstimate
        if uncertaintySlot.ready():
            uncertaintySrc = LazyflowSource(uncertaintySlot)
            uncertaintyLayer = AlphaModulatedLayer( uncertaintySrc,
                                                    tintColor=QColor( Qt.cyan ),
                                                    range=(0.0, 1.0),
                                                    normalize=(0.0, 1.0) )
            uncertaintyLayer.name = "Uncertainty"
            uncertaintyLayer.visible = False
            uncertaintyLayer.opacity = 1.0
            uncertaintyLayer.shortcutRegistration = ( "u", ActionInfo( "Prediction Layers",
                                                                       "Uncertainty",
                                                                       "Show/Hide Uncertainty",
                                                                       uncertaintyLayer.toggleVisible,
                                                                       self.viewerControlWidget(),
                                                                       uncertaintyLayer ) )
            layers.append(uncertaintyLayer)

        labels = self.labelListData

        # Add each of the segmentations
        for channel, segmentationSlot in enumerate(self.topLevelOperatorView.SegmentationChannels):
            if segmentationSlot.ready() and channel < len(labels):
                ref_label = labels[channel]
                segsrc = LazyflowSource(segmentationSlot)
                segLayer = AlphaModulatedLayer( segsrc,
                                                tintColor=ref_label.pmapColor(),
                                                range=(0.0, 1.0),
                                                normalize=(0.0, 1.0) )

                segLayer.opacity = 1
                segLayer.visible = False #self.labelingDrawerUi.liveUpdateButton.isChecked()
                segLayer.visibleChanged.connect(self.updateShowSegmentationCheckbox)

                def setLayerColor(c, segLayer_=segLayer, initializing=False):
                    if not initializing and segLayer_ not in self.layerstack:
                        # This layer has been removed from the layerstack already.
                        # Don't touch it.
                        return
                    segLayer_.tintColor = c
                    self._update_rendering()

                def setSegLayerName(n, segLayer_=segLayer, initializing=False):
                    if not initializing and segLayer_ not in self.layerstack:
                        # This layer has been removed from the layerstack already.
                        # Don't touch it.
                        return
                    oldname = segLayer_.name
                    newName = "Segmentation (%s)" % n
                    segLayer_.name = newName
                    if not self.render:
                        return
                    if oldname in self._renderedLayers:
                        label = self._renderedLayers.pop(oldname)
                        self._renderedLayers[newName] = label

                setSegLayerName(ref_label.name, initializing=True)

                ref_label.pmapColorChanged.connect(setLayerColor)
                ref_label.nameChanged.connect(setSegLayerName)
                #check if layer is 3d before adding the "Toggle 3D" option
                #this check is done this way to match the VolumeRenderer, in
                #case different 3d-axistags should be rendered like t-x-y
                #_axiskeys = segmentationSlot.meta.getAxisKeys()
                if len(segmentationSlot.meta.shape) == 4:
                    #the Renderer will cut out the last shape-dimension, so
                    #we're checking for 4 dimensions
                    self._setup_contexts(segLayer)
                layers.append(segLayer)
        
        # Add each of the predictions
        for channel, predictionSlot in enumerate(self.topLevelOperatorView.PredictionProbabilityChannels):
            if predictionSlot.ready() and channel < len(labels):
                ref_label = labels[channel]
                predictsrc = LazyflowSource(predictionSlot)
                predictLayer = AlphaModulatedLayer( predictsrc,
                                                    tintColor=ref_label.pmapColor(),
                                                    range=(0.0, 1.0),
                                                    normalize=(0.0, 1.0) )
                predictLayer.opacity = 0.25
                predictLayer.visible = self.labelingDrawerUi.liveUpdateButton.isChecked()
                predictLayer.visibleChanged.connect(self.updateShowPredictionCheckbox)

                def setLayerColor(c, predictLayer_=predictLayer, initializing=False):
                    if not initializing and predictLayer_ not in self.layerstack:
                        # This layer has been removed from the layerstack already.
                        # Don't touch it.
                        return
                    predictLayer_.tintColor = c

                def setPredLayerName(n, predictLayer_=predictLayer, initializing=False):
                    if not initializing and predictLayer_ not in self.layerstack:
                        # This layer has been removed from the layerstack already.
                        # Don't touch it.
                        return
                    newName = "Prediction for %s" % n
                    predictLayer_.name = newName

                setPredLayerName(ref_label.name, initializing=True)
                ref_label.pmapColorChanged.connect(setLayerColor)
                ref_label.nameChanged.connect(setPredLayerName)
                layers.append(predictLayer)

        # Add the raw data last (on the bottom)
        inputDataSlot = self.topLevelOperatorView.InputImages        
        if inputDataSlot.ready():                        
            inputLayer = self.createStandardLayerFromSlot( inputDataSlot )
            inputLayer.name = "Input Data"
            inputLayer.visible = True
            inputLayer.opacity = 1.0
            # the flag window_leveling is used to determine if the contrast 
            # of the layer is adjustable
            if isinstance( inputLayer, GrayscaleLayer ):
                inputLayer.window_leveling = True
            else:
                inputLayer.window_leveling = False

            def toggleTopToBottom():
                index = self.layerstack.layerIndex( inputLayer )
                self.layerstack.selectRow( index )
                if index == 0:
                    self.layerstack.moveSelectedToBottom()
                else:
                    self.layerstack.moveSelectedToTop()

            inputLayer.shortcutRegistration = ( "i", ActionInfo( "Prediction Layers",
                                                                 "Bring Input To Top/Bottom",
                                                                 "Bring Input To Top/Bottom",
                                                                 toggleTopToBottom,
                                                                 self.viewerControlWidget(),
                                                                 inputLayer ) )
            layers.append(inputLayer)
            
            # The thresholding button can only be used if the data is displayed as grayscale.
            if inputLayer.window_leveling:
                self.labelingDrawerUi.thresToolButton.show()
            else:
                self.labelingDrawerUi.thresToolButton.hide()
        
        self.handleLabelSelectionChange()
        return layers

    def toggleInteractive(self, checked):
        logger.debug("toggling interactive mode to '%r'" % checked)

        if checked==True:
            if not self.topLevelOperatorView.FeatureImages.ready() \
            or self.topLevelOperatorView.FeatureImages.meta.shape==None:
                self.labelingDrawerUi.liveUpdateButton.setChecked(False)
                self.labelingDrawerUi.suggestFeaturesButton.setEnabled(False)
                mexBox=QMessageBox()
                mexBox.setText("There are no features selected ")
                mexBox.exec_()
                return

        # If we're changing modes, enable/disable our controls and other applets accordingly
        if self.interactiveModeActive != checked:
            if checked:
                self.labelingDrawerUi.suggestFeaturesButton.setEnabled(False)
                self.labelingDrawerUi.labelListView.allowDelete = False
                self.labelingDrawerUi.AddLabelButton.setEnabled( False )
            else:
                num_label_classes = self._labelControlUi.labelListModel.rowCount()
                self.labelingDrawerUi.labelListView.allowDelete = ( num_label_classes > self.minLabelNumber )
                self.labelingDrawerUi.AddLabelButton.setEnabled( ( num_label_classes < self.maxLabelNumber ) )
                self.labelingDrawerUi.suggestFeaturesButton.setEnabled(True)

        self.interactiveModeActive = checked

        self.topLevelOperatorView.FreezePredictions.setValue( not checked )
        self.labelingDrawerUi.liveUpdateButton.setChecked(checked)
        #self.labelingDrawerUi.suggestFeaturesButton.setEnabled(checked)
        # Auto-set the "show predictions" state according to what the user just clicked.
        if checked:
            self._viewerControlUi.checkShowPredictions.setChecked( True )
            self.handleShowPredictionsClicked()

        # Notify the workflow that some applets may have changed state now.
        # (For example, the downstream pixel classification applet can 
        #  be used now that there are features selected)
        self.parentApplet.appletStateUpdateRequested()

    @pyqtSlot()
    def handleShowPredictionsClicked(self):
        checked = self._viewerControlUi.checkShowPredictions.isChecked()
        for layer in self.layerstack:
            if "Prediction" in layer.name:
                layer.visible = checked

    @pyqtSlot()
    def handleShowSegmentationClicked(self):
        checked = self._viewerControlUi.checkShowSegmentation.isChecked()
        for layer in self.layerstack:
            if "Segmentation" in layer.name:
                layer.visible = checked

    @pyqtSlot()
    def updateShowPredictionCheckbox(self):
        predictLayerCount = 0
        visibleCount = 0
        for layer in self.layerstack:
            if "Prediction" in layer.name:
                predictLayerCount += 1
                if layer.visible:
                    visibleCount += 1

        if visibleCount == 0:
            self._viewerControlUi.checkShowPredictions.setCheckState(Qt.Unchecked)
        elif predictLayerCount == visibleCount:
            self._viewerControlUi.checkShowPredictions.setCheckState(Qt.Checked)
        else:
            self._viewerControlUi.checkShowPredictions.setCheckState(Qt.PartiallyChecked)

    @pyqtSlot()
    def updateShowSegmentationCheckbox(self):
        segLayerCount = 0
        visibleCount = 0
        for layer in self.layerstack:
            if "Segmentation" in layer.name:
                segLayerCount += 1
                if layer.visible:
                    visibleCount += 1

        if visibleCount == 0:
            self._viewerControlUi.checkShowSegmentation.setCheckState(Qt.Unchecked)
        elif segLayerCount == visibleCount:
            self._viewerControlUi.checkShowSegmentation.setCheckState(Qt.Checked)
        else:
            self._viewerControlUi.checkShowSegmentation.setCheckState(Qt.PartiallyChecked)

    @pyqtSlot()
    @threadRouted
    def handleLabelSelectionChange(self):
        enabled = False
        if self.topLevelOperatorView.LabelNames.ready():
            enabled = True
            enabled &= len(self.topLevelOperatorView.LabelNames.value) >= 2
            enabled &= numpy.all(numpy.asarray(self.topLevelOperatorView.CachedFeatureImages.meta.shape) > 0)
            # FIXME: also check that each label has scribbles?
        
        if not enabled:
            self.labelingDrawerUi.liveUpdateButton.setChecked(False)
            self._viewerControlUi.checkShowPredictions.setChecked(False)
            self._viewerControlUi.checkShowSegmentation.setChecked(False)
            self.handleShowPredictionsClicked()
            self.handleShowSegmentationClicked()

        self.labelingDrawerUi.liveUpdateButton.setEnabled(enabled)
        self.labelingDrawerUi.suggestFeaturesButton.setEnabled(enabled)
        self._viewerControlUi.checkShowPredictions.setEnabled(enabled)
        self._viewerControlUi.checkShowSegmentation.setEnabled(enabled)

    def _getNext(self, slot, parentFun, transform=None):
        numLabels = self.labelListData.rowCount()
        value = slot.value
        if numLabels < len(value):
            result = value[numLabels]
            if transform is not None:
                result = transform(result)
            return result
        else:
            return parentFun()

    def _onLabelChanged(self, parentFun, mapf, slot):
        parentFun()
        new = list(map(mapf, self.labelListData))
        old = slot.value
        slot.setValue(_listReplace(old, new))

    def _onLabelRemoved(self, parent, start, end):
        # Call the base class to update the operator.
        super(PixelClassificationGui, self)._onLabelRemoved(parent, start, end)

        # Keep colors in sync with names
        # (If we deleted a name, delete its corresponding colors, too.)
        op = self.topLevelOperatorView
        if len(op.PmapColors.value) > len(op.LabelNames.value):
            for slot in (op.LabelColors, op.PmapColors):
                value = slot.value
                value.pop(start)
                # Force dirty propagation even though the list id is unchanged.
                slot.setValue(value, check_changed=False)

    def getNextLabelName(self):
        return self._getNext(self.topLevelOperatorView.LabelNames,
                             super(PixelClassificationGui, self).getNextLabelName)

    def getNextLabelColor(self):
        return self._getNext(
            self.topLevelOperatorView.LabelColors,
            super(PixelClassificationGui, self).getNextLabelColor,
            lambda x: QColor(*x)
        )

    def getNextPmapColor(self):
        return self._getNext(
            self.topLevelOperatorView.PmapColors,
            super(PixelClassificationGui, self).getNextPmapColor,
            lambda x: QColor(*x)
        )

    def onLabelNameChanged(self):
        self._onLabelChanged(super(PixelClassificationGui, self).onLabelNameChanged,
                             lambda l: l.name,
                             self.topLevelOperatorView.LabelNames)

    def onLabelColorChanged(self):
        self._onLabelChanged(super(PixelClassificationGui, self).onLabelColorChanged,
                             lambda l: (l.brushColor().red(),
                                        l.brushColor().green(),
                                        l.brushColor().blue()),
                             self.topLevelOperatorView.LabelColors)


    def onPmapColorChanged(self):
        self._onLabelChanged(super(PixelClassificationGui, self).onPmapColorChanged,
                             lambda l: (l.pmapColor().red(),
                                        l.pmapColor().green(),
                                        l.pmapColor().blue()),
                             self.topLevelOperatorView.PmapColors)

    def _update_rendering(self):
        if not self.render:
            return
        shape = self.topLevelOperatorView.InputImages.meta.shape[1:4]
        if len(shape) != 5:
            #this might be a 2D image, no need for updating any 3D stuff 
            return
        
        time = self.editor.posModel.slicingPos5D[0]
        if not self._renderMgr.ready:
            self._renderMgr.setup(shape)

        layernames = set(layer.name for layer in self.layerstack)
        self._renderedLayers = dict((k, v) for k, v in self._renderedLayers.items()
                                if k in layernames)

        newvolume = numpy.zeros(shape, dtype=numpy.uint8)
        for layer in self.layerstack:
            try:
                label = self._renderedLayers[layer.name]
            except KeyError:
                continue
            for ds in layer.datasources:
                vol = ds.dataSlot.value[time, ..., 0]
                indices = numpy.where(vol != 0)
                newvolume[indices] = label

        self._renderMgr.volume = newvolume
        self._update_colors()
        self._renderMgr.update()

    def _update_colors(self):
        for layer in self.layerstack:
            try:
                label = self._renderedLayers[layer.name]
            except KeyError:
                continue
            color = layer.tintColor
            color = (old_div(color.red(), 255.0), old_div(color.green(), 255.0), old_div(color.blue(), 255.0))
            self._renderMgr.setColor(label, color)
Esempio n. 4
0
class CarvingGui(LabelingGui):
    def __init__(self, parentApplet, topLevelOperatorView, drawerUiPath=None ):
        self.topLevelOperatorView = topLevelOperatorView

        #members
        self._doneSegmentationLayer = None
        self._showSegmentationIn3D = False
        #self._showUncertaintyLayer = False
        #end: members

        labelingSlots = LabelingGui.LabelingSlots()
        labelingSlots.labelInput       = topLevelOperatorView.WriteSeeds
        labelingSlots.labelOutput      = topLevelOperatorView.opLabelArray.Output
        labelingSlots.labelEraserValue = topLevelOperatorView.opLabelArray.EraserLabelValue
        labelingSlots.labelNames       = topLevelOperatorView.LabelNames
        labelingSlots.labelDelete      = topLevelOperatorView.opLabelArray.DeleteLabel
        labelingSlots.maxLabelValue    = topLevelOperatorView.opLabelArray.MaxLabelValue
        
        # We provide our own UI file (which adds an extra control for interactive mode)
        directory = os.path.split(__file__)[0]
        if drawerUiPath is None:
            drawerUiPath = os.path.join(directory, 'carvingDrawer.ui')
        self.dialogdirCOM = os.path.join(directory, 'carvingObjectManagement.ui')
        self.dialogdirSAD = os.path.join(directory, 'saveAsDialog.ui')

        super(CarvingGui, self).__init__(parentApplet, labelingSlots, topLevelOperatorView, drawerUiPath)
        
        self.labelingDrawerUi.currentObjectLabel.setText("<not saved yet>")

        # Init special base class members
        self.minLabelNumber = 2
        self.maxLabelNumber = 2
        
        mgr = ShortcutManager()
        ActionInfo = ShortcutManager.ActionInfo
        
        #set up keyboard shortcuts
        mgr.register( "3", ActionInfo( "Carving", 
                                       "Run interactive segmentation", 
                                       "Run interactive segmentation", 
                                       self.labelingDrawerUi.segment.click,
                                       self.labelingDrawerUi.segment,
                                       self.labelingDrawerUi.segment  ) )

        
        # Disable 3D view by default
        self.render = False
        tagged_shape = defaultdict(lambda: 1)
        tagged_shape.update( topLevelOperatorView.InputData.meta.getTaggedShape() )
        is_3d = (tagged_shape['x'] > 1 and tagged_shape['y'] > 1 and tagged_shape['z'] > 1)

        if is_3d:
            try:
                self._renderMgr = RenderingManager( self.editor.view3d )
                self._shownObjects3D = {}
                self.render = True
            except:
                self.render = False

        # Segmentation is toggled on by default in _after_init, below.
        # (We can't enable it until the layers are all present.)
        self._showSegmentationIn3D = False
        self._segmentation_3d_label = None
                
        self.labelingDrawerUi.segment.clicked.connect(self.onSegmentButton)
        self.labelingDrawerUi.segment.setEnabled(True)

        self.topLevelOperatorView.Segmentation.notifyDirty( bind( self._update_rendering ) )
        self.topLevelOperatorView.HasSegmentation.notifyValueChanged( bind( self._updateGui ) )

        ## uncertainty

        #self.labelingDrawerUi.pushButtonUncertaintyFG.setEnabled(False)
        #self.labelingDrawerUi.pushButtonUncertaintyBG.setEnabled(False)

        #def onUncertaintyFGButton():
        #    logger.debug( "uncertFG button clicked" )
        #    pos = self.topLevelOperatorView.getMaxUncertaintyPos(label=2)
        #    self.editor.posModel.slicingPos = (pos[0], pos[1], pos[2])
        #self.labelingDrawerUi.pushButtonUncertaintyFG.clicked.connect(onUncertaintyFGButton)

        #def onUncertaintyBGButton():
        #    logger.debug( "uncertBG button clicked" )
        #    pos = self.topLevelOperatorView.getMaxUncertaintyPos(label=1)
        #    self.editor.posModel.slicingPos = (pos[0], pos[1], pos[2])
        #self.labelingDrawerUi.pushButtonUncertaintyBG.clicked.connect(onUncertaintyBGButton)

        #def onUncertaintyCombo(value):
        #    if value == 0:
        #        value = "none"
        #        self.labelingDrawerUi.pushButtonUncertaintyFG.setEnabled(False)
        #        self.labelingDrawerUi.pushButtonUncertaintyBG.setEnabled(False)
        #        self._showUncertaintyLayer = False
        #    else:
        #        if value == 1:
        #            value = "localMargin"
        #        elif value == 2:
        #            value = "exchangeCount"
        #        elif value == 3:
        #            value = "gabow"
        #        else:
        #            raise RuntimeError("unhandled case '%r'" % value)
        #        self.labelingDrawerUi.pushButtonUncertaintyFG.setEnabled(True)
        #        self.labelingDrawerUi.pushButtonUncertaintyBG.setEnabled(True)
        #        self._showUncertaintyLayer = True
        #        logger.debug( "uncertainty changed to %r" % value )
        #    self.topLevelOperatorView.UncertaintyType.setValue(value)
        #    self.updateAllLayers() #make sure that an added/deleted uncertainty layer is recognized
        #self.labelingDrawerUi.uncertaintyCombo.currentIndexChanged.connect(onUncertaintyCombo)

        ## background priority
        
        def onBackgroundPrioritySpin(value):
            logger.debug( "background priority changed to %f" % value )
            self.topLevelOperatorView.BackgroundPriority.setValue(value)
        self.labelingDrawerUi.backgroundPrioritySpin.valueChanged.connect(onBackgroundPrioritySpin)

        def onBackgroundPriorityDirty(slot, roi):
            oldValue = self.labelingDrawerUi.backgroundPrioritySpin.value()
            newValue = self.topLevelOperatorView.BackgroundPriority.value
            if  newValue != oldValue:
                self.labelingDrawerUi.backgroundPrioritySpin.setValue(newValue)
        self.topLevelOperatorView.BackgroundPriority.notifyDirty(onBackgroundPriorityDirty)
        
        ## bias
        
        def onNoBiasBelowDirty(slot, roi):
            oldValue = self.labelingDrawerUi.noBiasBelowSpin.value()
            newValue = self.topLevelOperatorView.NoBiasBelow.value
            if  newValue != oldValue:
                self.labelingDrawerUi.noBiasBelowSpin.setValue(newValue)
        self.topLevelOperatorView.NoBiasBelow.notifyDirty(onNoBiasBelowDirty)
        
        def onNoBiasBelowSpin(value):
            logger.debug( "background priority changed to %f" % value )
            self.topLevelOperatorView.NoBiasBelow.setValue(value)
        self.labelingDrawerUi.noBiasBelowSpin.valueChanged.connect(onNoBiasBelowSpin)
        
        ## save

        self.labelingDrawerUi.save.clicked.connect(self.onSaveButton)

        ## clear

        self.labelingDrawerUi.clear.clicked.connect(self._onClearAction)
        
        ## object names
        
        self.labelingDrawerUi.namesButton.clicked.connect(self.onShowObjectNames)
        if hasattr( self.labelingDrawerUi, 'exportAllMeshesButton' ):
            self.labelingDrawerUi.exportAllMeshesButton.clicked.connect(self._exportAllObjectMeshes)

        self.labelingDrawerUi.labelListView.allowDelete = False
        self._labelControlUi.labelListModel.allowRemove(False)

        def layerIndexForName(name):
            return self.layerstack.findMatchingIndex(lambda x: x.name == name)
        
        def addLayerToggleShortcut(layername, shortcut):
            def toggle():
                row = layerIndexForName(layername)
                self.layerstack.selectRow(row)
                layer = self.layerstack[row]
                layer.visible = not layer.visible
                self.viewerControlWidget().layerWidget.setFocus()

            mgr.register(shortcut, ActionInfo( "Carving", 
                                               "Toggle layer %s" % layername, 
                                               "Toggle layer %s" % layername, 
                                               toggle,
                                               self.viewerControlWidget(),
                                               None ) )

        #TODO
        addLayerToggleShortcut("Completed segments (unicolor)", "d")
        addLayerToggleShortcut("Segmentation", "s")
        addLayerToggleShortcut("Input Data", "r")

        def makeColortable():
            self._doneSegmentationColortable = [QColor(0,0,0,0).rgba()]
            for i in range(254):
                r,g,b = numpy.random.randint(0,255), numpy.random.randint(0,255), numpy.random.randint(0,255)
                # ensure colors have sufficient distance to pure red and pure green
                while (255 - r)+g+b<128 or r+(255-g)+b<128:
                    r,g,b = numpy.random.randint(0,255), numpy.random.randint(0,255), numpy.random.randint(0,255)
                self._doneSegmentationColortable.append(QColor(r,g,b).rgba())
            self._doneSegmentationColortable.append(QColor(0,255,0).rgba())
        makeColortable()
        def onRandomizeColors():
            if self._doneSegmentationLayer is not None:
                logger.debug( "randomizing colors ..." )
                makeColortable()
                self._doneSegmentationLayer.colorTable = self._doneSegmentationColortable
                if self.render and self._renderMgr.ready:
                    self._update_rendering()
        #self.labelingDrawerUi.randomizeColors.clicked.connect(onRandomizeColors)
        self._updateGui()
    
    def _after_init(self):
        super(CarvingGui, self)._after_init()
        if self.render:self._toggleSegmentation3D()
        
        
    def _updateGui(self):
        self.labelingDrawerUi.save.setEnabled( self.topLevelOperatorView.dataIsStorable() )
        
    def onSegmentButton(self):
        logger.debug( "segment button clicked" )
        self.topLevelOperatorView.Trigger.setDirty(slice(None))
    
    def saveAsDialog(self, name=""):
        '''special functionality: reject names given to other objects'''
        dialog = uic.loadUi(self.dialogdirSAD)
        dialog.lineEdit.setText(name)
        dialog.warning.setVisible(False)
        dialog.Ok.clicked.connect(dialog.accept)
        dialog.Cancel.clicked.connect(dialog.reject)
        listOfItems = self.topLevelOperatorView.AllObjectNames[:].wait()
        dialog.isDisabled = False
        def validate():
            name = dialog.lineEdit.text()
            if name in listOfItems:
                dialog.Ok.setEnabled(False)
                dialog.warning.setVisible(True)
                dialog.isDisabled = True
            elif dialog.isDisabled:
                dialog.Ok.setEnabled(True)
                dialog.warning.setVisible(False)
                dialog.isDisabled = False
        dialog.lineEdit.textChanged.connect(validate)
        result = dialog.exec_()
        if result:
            return str(dialog.lineEdit.text())
    
    def onSaveButton(self):
        logger.info( "save object as?" )
        if self.topLevelOperatorView.dataIsStorable():
            prevName = ""
            if self.topLevelOperatorView.hasCurrentObject():
                prevName = self.topLevelOperatorView.currentObjectName()
            if prevName == "<not saved yet>":
                prevName = ""
            name = self.saveAsDialog(name=prevName)
            if name is None:
                return
            objects = self.topLevelOperatorView.AllObjectNames[:].wait()
            if name in objects and name != prevName:
                QMessageBox.critical(self, "Save Object As", "An object with name '%s' already exists.\nPlease choose a different name." % name)
                return
            self.topLevelOperatorView.saveObjectAs(name)
            logger.info( "save object as %s" % name )
            if prevName != name and prevName != "":
                self.topLevelOperatorView.deleteObject(prevName)
        else:
            msgBox = QMessageBox(self)
            msgBox.setText("The data does not seem fit to be stored.")
            msgBox.setWindowTitle("Problem with Data")
            msgBox.setIcon(2)
            msgBox.exec_()
            logger.error( "object not saved due to faulty data." )
    
    def onShowObjectNames(self):
        '''show object names and allow user to load/delete them'''
        dialog = uic.loadUi(self.dialogdirCOM)
        listOfItems = self.topLevelOperatorView.AllObjectNames[:].wait()
        dialog.objectNames.addItems(sorted(listOfItems))
        
        def loadSelection():
            selected = [str(name.text()) for name in dialog.objectNames.selectedItems()]
            dialog.close()
            for objectname in selected: 
                objectname = str(name.text())
                self.topLevelOperatorView.loadObject(objectname)
        
        def deleteSelection():
            items = dialog.objectNames.selectedItems()
            if self.confirmAndDelete([str(name.text()) for name in items]):
                for name in items:
                    name.setHidden(True)
            dialog.close()
        
        dialog.loadButton.clicked.connect(loadSelection)
        dialog.deleteButton.clicked.connect(deleteSelection)
        dialog.cancelButton.clicked.connect(dialog.close)
        dialog.exec_()
    
    def confirmAndDelete(self,namelist):
        logger.info( "confirmAndDelete: {}".format( namelist ) )
        objectlist = "".join("\n  "+str(i) for i in namelist)
        confirmed = QMessageBox.question(self, "Delete Object", \
                    "Do you want to delete these objects?"+objectlist, \
                    QMessageBox.Yes | QMessageBox.Cancel, \
                    defaultButton=QMessageBox.Yes)
            
        if confirmed == QMessageBox.Yes:
            for name in namelist:
                self.topLevelOperatorView.deleteObject(name)
            return True
        return False
    
    def labelingContextMenu(self,names,op,position5d):
        menu = QMenu(self)
        menu.setObjectName("carving_context_menu")
        posItem = menu.addAction("position %d %d %d" % (position5d[1], position5d[2], position5d[3]))
        posItem.setEnabled(False)
        menu.addSeparator()
        for name in names:
            submenu = QMenu(name,menu)
            
            # Load
            loadAction = submenu.addAction("Load %s" % name)
            loadAction.triggered.connect( partial(op.loadObject, name) )
            
            # Delete
            def onDelAction(_name):
                self.confirmAndDelete([_name])
                if self.render and self._renderMgr.ready:
                    self._update_rendering()
            delAction = submenu.addAction("Delete %s" % name)
            delAction.triggered.connect( partial(onDelAction, name) )

            if self.render:
                if name in self._shownObjects3D:
                    # Remove
                    def onRemove3D(_name):
                        label = self._shownObjects3D.pop(_name)
                        self._renderMgr.removeObject(label)
                        self._update_rendering()
                    removeAction = submenu.addAction("Remove %s from 3D view" % name)
                    removeAction.triggered.connect( partial(onRemove3D, name) )
                else:
                    # Show
                    def onShow3D(_name):
                        label = self._renderMgr.addObject()
                        self._shownObjects3D[_name] = label
                        self._update_rendering()
                    showAction = submenu.addAction("Show 3D %s" % name)
                    showAction.triggered.connect( partial(onShow3D, name ) )
            
            # Export mesh
            if _have_vtk:
                exportAction = submenu.addAction("Export mesh for %s" % name)
                exportAction.triggered.connect( partial(self._onContextMenuExportMesh, name) )
                        
            menu.addMenu(submenu)

        if names:
            menu.addSeparator()

        menu.addSeparator()
        if self.render:
            showSeg3DAction = menu.addAction( "Show Editing Segmentation in 3D" )
            showSeg3DAction.setCheckable(True)
            showSeg3DAction.setChecked( self._showSegmentationIn3D )
            showSeg3DAction.triggered.connect( self._toggleSegmentation3D )
        
        if op.dataIsStorable():
            menu.addAction("Save object").triggered.connect( self.onSaveButton )
        menu.addAction("Browse objects").triggered.connect( self.onShowObjectNames )
        menu.addAction("Segment").triggered.connect( self.onSegmentButton )
        menu.addAction("Clear").triggered.connect( self._onClearAction )
        return menu

    def _onClearAction(self):
        confirm = QMessageBox.warning(self, "Really Clear?", "Clear all brushtrokes?", QMessageBox.Ok | QMessageBox.Cancel)
        if confirm == QMessageBox.Ok:
            self.topLevelOperatorView.clearCurrentLabeling()

    def _onContextMenuExportMesh(self, _name):
        """
        Export a single object mesh to a user-specified filename.
        """
        recent_dir = PreferencesManager().get( 'carving', 'recent export mesh directory' )
        if recent_dir is None:
            defaultPath = os.path.join( os.path.expanduser('~'), '{}obj'.format(_name) )
        else:
            defaultPath = os.path.join( recent_dir, '{}.obj'.format(_name) )
        filepath = QFileDialog.getSaveFileName(self, 
                                               "Save meshes for object '{}'".format(_name),
                                               defaultPath,
                                               "OBJ Files (*.obj)")
        if filepath.isNull():
            return
        obj_filepath = str(filepath)
        PreferencesManager().set( 'carving', 'recent export mesh directory', os.path.split(obj_filepath)[0] )
        
        self._exportMeshes([_name], [obj_filepath])

    def _exportAllObjectMeshes(self):
        """
        Export all objects in the project as separate .obj files, stored to a user-specified directory.
        """
        mst = self.topLevelOperatorView.MST.value
        if not mst.object_lut.keys():
            QMessageBox.critical(self, "Can't Export", "You have no saved objets, so there are no meshes to export.")
            return
        
        recent_dir = PreferencesManager().get( 'carving', 'recent export mesh directory' )
        if recent_dir is None:
            defaultPath = os.path.join( os.path.expanduser('~') )
        else:
            defaultPath = os.path.join( recent_dir )
        export_dir = QFileDialog.getExistingDirectory( self, 
                                                       "Select export directory for mesh files",
                                                       defaultPath)
        if export_dir.isNull():
            return
        export_dir = str(export_dir)
        PreferencesManager().set( 'carving', 'recent export mesh directory', export_dir )

        # Get the list of all object names
        object_names = []
        obj_filepaths = []
        for object_name in mst.object_lut.keys():
            object_names.append( object_name )
            obj_filepaths.append( os.path.join( export_dir, "{}.obj".format( object_name ) ) )
        
        if object_names:
            self._exportMeshes( object_names, obj_filepaths )

    def _exportMeshes(self, object_names, obj_filepaths):
        """
        Export a mesh .obj file for each object in the object_names list to the corresponding file name from the obj_filepaths list.
        This function is pseudo-recursive. It works like this:
        1) Pop the first name/file from the args
        2) Kick off the export by launching the export mesh dlg
        3) return from this function to allow the eventloop to resume while the export is running
        4) When the export dlg is finished, create the mesh file (by writing a temporary .vtk file and converting it into a .obj file)
        5) If there are still more items in the object_names list to process, repeat this function.
        """
        # Pop the first object off the list
        object_name = object_names.pop(0)
        obj_filepath = obj_filepaths.pop(0)
        
        # Construct a volume with only this object.
        # We might be tempted to get the object directly from opCarving.DoneObjects, 
        #  but that won't be correct for overlapping objects.
        mst = self.topLevelOperatorView.MST.value
        object_supervoxels = mst.object_lut[object_name]
        object_lut = numpy.zeros(mst.nodeNum+1, dtype=numpy.int32)
        object_lut[object_supervoxels] = 1
        supervoxel_volume = mst.supervoxelUint32
        object_volume = object_lut[supervoxel_volume]

        # Run the mesh extractor
        window = MeshExtractorDialog(parent=self)
        
        def onMeshesComplete():
            """
            Called when mesh extraction is complete.
            Writes the extracted mesh to an .obj file
            """
            logger.info( "Mesh generation complete." )
            mesh_count = len( window.extractor.meshes )

            # Mesh count can sometimes be 0 for the '<not saved yet>' object...
            if mesh_count > 0:
                assert mesh_count == 1, \
                    "Found {} meshes processing object '{}',"\
                    "(only expected 1)".format( mesh_count, object_name )
                mesh = window.extractor.meshes.values()[0]
                logger.info( "Saving meshes to {}".format( obj_filepath ) )
    
                # Use VTK to write to a temporary .vtk file
                tmpdir = tempfile.mkdtemp()
                vtkpoly_path = os.path.join(tmpdir, 'meshes.vtk')
                w = vtkPolyDataWriter()
                w.SetFileTypeToASCII()
                w.SetInput(mesh)
                w.SetFileName(vtkpoly_path)
                w.Write()
                
                # Now convert the file to .obj format.
                convertVTPtoOBJ(vtkpoly_path, obj_filepath)
    
            # Cleanup: We don't need the window anymore.
            window.setParent(None)

            # If there are still objects left to process,
            #   start again with the remainder of the list.
            if object_names:
                self._exportMeshes(object_names, obj_filepaths)
            
        window.finished.connect( onMeshesComplete )

        # Kick off the save process and exit to the event loop
        window.show()
        QTimer.singleShot(0, partial(window.run, object_volume, [0]))

    
    def handleEditorRightClick(self, position5d, globalWindowCoordinate):
        names = self.topLevelOperatorView.doneObjectNamesForPosition(position5d[1:4])
        op = self.topLevelOperatorView

        # (Subclasses may override menu)
        menu = self.labelingContextMenu(names,op,position5d)
        if menu is not None:
            menu.exec_(globalWindowCoordinate)

    def _toggleSegmentation3D(self):
        self._showSegmentationIn3D = not self._showSegmentationIn3D
        if self._showSegmentationIn3D:
            self._segmentation_3d_label = self._renderMgr.addObject()
        else:
            self._renderMgr.removeObject(self._segmentation_3d_label)
            self._segmentation_3d_label = None
        self._update_rendering()
    
    def _update_rendering(self):
        if not self.render:
            return

        op = self.topLevelOperatorView
        if not self._renderMgr.ready:
            shape = op.InputData.meta.shape[1:4]
            self._renderMgr.setup(op.InputData.meta.shape[1:4])

        # remove nonexistent objects
        self._shownObjects3D = dict((k, v) for k, v in self._shownObjects3D.iteritems()
                                    if k in op.MST.value.object_lut.keys())

        lut = numpy.zeros(op.MST.value.nodeNum+1, dtype=numpy.int32)
        for name, label in self._shownObjects3D.iteritems():
            objectSupervoxels = op.MST.value.objects[name]
            lut[objectSupervoxels] = label

        if self._showSegmentationIn3D:
            # Add segmentation as label, which is green
            lut[:] = numpy.where( op.MST.value.getSuperVoxelSeg() == 2, self._segmentation_3d_label, lut )
        import vigra
        #with vigra.Timer("remapping"):          
        self._renderMgr.volume = lut[op.MST.value.supervoxelUint32] # (Advanced indexing)
        self._update_colors()
        self._renderMgr.update()

    def _update_colors(self):
        op = self.topLevelOperatorView
        ctable = self._doneSegmentationLayer.colorTable

        for name, label in self._shownObjects3D.iteritems():
            color = QColor(ctable[op.MST.value.object_names[name]])
            color = (color.red() / 255.0, color.green() / 255.0, color.blue() / 255.0)
            self._renderMgr.setColor(label, color)

        if self._showSegmentationIn3D and self._segmentation_3d_label is not None:
            self._renderMgr.setColor(self._segmentation_3d_label, (0.0, 1.0, 0.0)) # Green

    def _getNext(self, slot, parentFun, transform=None):
        numLabels = self.labelListData.rowCount()
        value = slot.value
        if numLabels < len(value):
            result = value[numLabels]
            if transform is not None:
                result = transform(result)
            return result
        else:
            return parentFun()

    def getNextLabelName(self):
        return self._getNext(self.topLevelOperatorView.LabelNames,
                             super(CarvingGui, self).getNextLabelName)

    def appletDrawers(self):
        return [ ("Carving", self._labelControlUi) ]

    def setupLayers( self ):
        logger.debug( "setupLayers" )
        
        layers = []

        def onButtonsEnabled(slot, roi):
            currObj = self.topLevelOperatorView.CurrentObjectName.value
            hasSeg  = self.topLevelOperatorView.HasSegmentation.value
            
            self.labelingDrawerUi.currentObjectLabel.setText(currObj)
            self.labelingDrawerUi.save.setEnabled(hasSeg)

        self.topLevelOperatorView.CurrentObjectName.notifyDirty(onButtonsEnabled)
        self.topLevelOperatorView.HasSegmentation.notifyDirty(onButtonsEnabled)
        self.topLevelOperatorView.opLabelArray.NonzeroBlocks.notifyDirty(onButtonsEnabled)
        
        # Labels
        labellayer, labelsrc = self.createLabelLayer(direct=True)
        if labellayer is not None:
            labellayer._allowToggleVisible = False
            layers.append(labellayer)
            # Tell the editor where to draw label data
            self.editor.setLabelSink(labelsrc)

        #uncertainty
        #if self._showUncertaintyLayer:
        #    uncert = self.topLevelOperatorView.Uncertainty
        #    if uncert.ready():
        #        colortable = []
        #        for i in range(256-len(colortable)):
        #            r,g,b,a = i,0,0,i
        #            colortable.append(QColor(r,g,b,a).rgba())
        #        layer = ColortableLayer(LazyflowSource(uncert), colortable, direct=True)
        #        layer.name = "Uncertainty"
        #        layer.visible = True
        #        layer.opacity = 0.3
        #        layers.append(layer)
       
        #segmentation 
        seg = self.topLevelOperatorView.Segmentation
        
        #seg = self.topLevelOperatorView.MST.value.segmentation
        #temp = self._done_lut[self.MST.value.supervoxelUint32[sl[1:4]]]
        if seg.ready():
            #source = RelabelingArraySource(seg)
            #source.setRelabeling(numpy.arange(256, dtype=numpy.uint8))
            colortable = [QColor(0,0,0,0).rgba(), QColor(0,0,0,0).rgba(), QColor(0,255,0).rgba()]
            for i in range(256-len(colortable)):
                r,g,b = numpy.random.randint(0,255), numpy.random.randint(0,255), numpy.random.randint(0,255)
                colortable.append(QColor(r,g,b).rgba())

            layer = ColortableLayer(LazyflowSource(seg), colortable, direct=True)
            layer.name = "Segmentation"
            layer.setToolTip("This layer displays the <i>current</i> segmentation. Simply add foreground and background " \
                             "labels, then press <i>Segment</i>.")
            layer.visible = True
            layer.opacity = 0.3
            layers.append(layer)
        
        #done 
        done = self.topLevelOperatorView.DoneObjects
        if done.ready(): 
            colortable = [QColor(0,0,0,0).rgba(), QColor(0,0,255).rgba()]
            #have to use lazyflow because it provides dirty signals
            layer = ColortableLayer(LazyflowSource(done), colortable, direct=True)
            layer.name = "Completed segments (unicolor)"
            layer.setToolTip("In order to keep track of which objects you have already completed, this layer " \
                             "shows <b>all completed object</b> in one color (<b>blue</b>). " \
                             "The reason for only one color is that for finding out which " \
                              "objects to label next, the identity of already completed objects is unimportant " \
                              "and destracting.")
            layer.visible = False
            layer.opacity = 0.5
            layers.append(layer)

        #done seg
        doneSeg = self.topLevelOperatorView.DoneSegmentation
        if doneSeg.ready():
            layer = ColortableLayer(LazyflowSource(doneSeg), self._doneSegmentationColortable, direct=True)
            layer.name = "Completed segments (one color per object)"
            layer.setToolTip("<html>In order to keep track of which objects you have already completed, this layer " \
                             "shows <b>all completed object</b>, each with a random color.</html>")
            layer.visible = False
            layer.opacity = 0.5
            self._doneSegmentationLayer = layer
            layers.append(layer)

        #supervoxel
        sv = self.topLevelOperatorView.Supervoxels
        if sv.ready():
            colortable = []
            for i in range(256):
                r,g,b = numpy.random.randint(0,255), numpy.random.randint(0,255), numpy.random.randint(0,255)
                colortable.append(QColor(r,g,b).rgba())
            layer = ColortableLayer(LazyflowSource(sv), colortable, direct=True)
            layer.name = "Supervoxels"
            layer.setToolTip("<html>This layer shows the partitioning of the input image into <b>supervoxels</b>. The carving " \
                             "algorithm uses these tiny puzzle-piceces to piece together the segmentation of an " \
                             "object. Sometimes, supervoxels are too large and straddle two distinct objects " \
                             "(undersegmentation). In this case, it will be impossible to achieve the desired " \
                             "segmentation. This layer helps you to understand these cases.</html>")
            layer.visible = False
            layer.opacity = 1.0
            layers.append(layer)

        # Visual overlay (just for easier labeling)
        overlaySlot = self.topLevelOperatorView.OverlayData
        if overlaySlot.ready():
            overlay5D = self.topLevelOperatorView.OverlayData.value
            layer = GrayscaleLayer(ArraySource(overlay5D), direct=True)
            layer.visible = True
            layer.name = 'Overlay'
            layer.opacity = 1.0
            # if the flag window_leveling is set the contrast 
            # of the layer is adjustable
            layer.window_leveling = True
            self.labelingDrawerUi.thresToolButton.show()
            layers.append(layer)
            del layer

        inputSlot = self.topLevelOperatorView.InputData
        if inputSlot.ready():
            layer = GrayscaleLayer( LazyflowSource(inputSlot), direct=True )
            layer.name = "Input Data"
            layer.setToolTip("<html>The data originally loaded into ilastik (unprocessed).</html>")
            #layer.visible = not rawSlot.ready()
            layer.visible = True
            layer.opacity = 1.0

            # Window leveling is already active on the Overlay,
            # but if no overlay was provided, then activate window_leveling on the raw data instead.
            if not overlaySlot.ready():
                # if the flag window_leveling is set the contrast 
                # of the layer is adjustable
                layer.window_leveling = True
                self.labelingDrawerUi.thresToolButton.show()

            layers.append(layer)
            del layer

        filteredSlot = self.topLevelOperatorView.FilteredInputData
        if filteredSlot.ready():
            layer = GrayscaleLayer( LazyflowSource(filteredSlot) )
            layer.name = "Filtered Input"
            layer.visible = False
            layer.opacity = 1.0
            layers.append(layer)

        return layers
Esempio n. 5
0
class CarvingGui(LabelingGui):
    def __init__(self, labelingSlots, topLevelOperatorView, drawerUiPath=None, rawInputSlot=None ):
        self.topLevelOperatorView = topLevelOperatorView

        # We provide our own UI file (which adds an extra control for interactive mode)
        directory = os.path.split(__file__)[0]
        carvingDrawerUiPath = os.path.join(directory, 'carvingDrawer.ui')

        super(CarvingGui, self).__init__(labelingSlots, topLevelOperatorView, carvingDrawerUiPath, rawInputSlot)
        
        mgr = ShortcutManager()
        
        #set up keyboard shortcuts
        segmentShortcut = QShortcut(QKeySequence("3"), self, member=self.labelingDrawerUi.segment.click,
                                    ambiguousMember=self.labelingDrawerUi.segment.click)
        mgr.register("Carving", "Run interactive segmentation", segmentShortcut, self.labelingDrawerUi.segment)
        

        self._doneSegmentationLayer = None

        #volume rendering
        try:
            self.render = True
            self._shownObjects3D = {}
            self._renderMgr = RenderingManager(
                renderer=self.editor.view3d.qvtk.renderer,
                qvtk=self.editor.view3d.qvtk)
        except:
            self.render = False

        def onSegmentButton():
            print "segment button clicked"
            self.topLevelOperatorView.opCarving.Trigger.setDirty(slice(None))
        self.labelingDrawerUi.segment.clicked.connect(onSegmentButton)
        self.labelingDrawerUi.segment.setEnabled(True)

        def onUncertaintyFGButton():
            print "uncertFG button clicked"
            pos = self.topLevelOperatorView.opCarving.getMaxUncertaintyPos(label=2)
            self.editor.posModel.slicingPos = (pos[0], pos[1], pos[2])
        self.labelingDrawerUi.pushButtonUncertaintyFG.clicked.connect(onUncertaintyFGButton)
        self.labelingDrawerUi.pushButtonUncertaintyFG.setEnabled(True)

        def onUncertaintyBGButton():
            print "uncertBG button clicked"
            pos = self.topLevelOperatorView.opCarving.getMaxUncertaintyPos(label=1)
            self.editor.posModel.slicingPos = (pos[0], pos[1], pos[2])
        self.labelingDrawerUi.pushButtonUncertaintyBG.clicked.connect(onUncertaintyBGButton)
        self.labelingDrawerUi.pushButtonUncertaintyBG.setEnabled(True)


        def onBackgroundPrioritySpin(value):
            print "background priority changed to %f" % value
            self.topLevelOperatorView.opCarving.BackgroundPriority.setValue(value)
        self.labelingDrawerUi.backgroundPrioritySpin.valueChanged.connect(onBackgroundPrioritySpin)

        def onuncertaintyCombo(value):
            if value == 0:
                value = "none"
            if value == 1:
                value = "localMargin"
            if value == 2:
                value = "exchangeCount"
            if value == 3:
                value = "gabow"
            print "uncertainty changed to %r" % value
            self.topLevelOperatorView.opCarving.UncertaintyType.setValue(value)
        self.labelingDrawerUi.uncertaintyCombo.currentIndexChanged.connect(onuncertaintyCombo)

        def onBackgroundPriorityDirty(slot, roi):
            oldValue = self.labelingDrawerUi.backgroundPrioritySpin.value()
            newValue = self.topLevelOperatorView.opCarving.BackgroundPriority.value
            if  newValue != oldValue:
                self.labelingDrawerUi.backgroundPrioritySpin.setValue(newValue)
        self.topLevelOperatorView.opCarving.BackgroundPriority.notifyDirty(onBackgroundPriorityDirty)
        
        def onNoBiasBelowDirty(slot, roi):
            oldValue = self.labelingDrawerUi.noBiasBelowSpin.value()
            newValue = self.topLevelOperatorView.opCarving.NoBiasBelow.value
            if  newValue != oldValue:
                self.labelingDrawerUi.noBiasBelowSpin.setValue(newValue)
        self.topLevelOperatorView.opCarving.NoBiasBelow.notifyDirty(onNoBiasBelowDirty)
        
        def onNoBiasBelowSpin(value):
            print "background priority changed to %f" % value
            self.topLevelOperatorView.opCarving.NoBiasBelow.setValue(value)
        self.labelingDrawerUi.noBiasBelowSpin.valueChanged.connect(onNoBiasBelowSpin)

        def onSaveAsButton():
            print "save object as?"
            if self.topLevelOperatorView.opCarving.dataIsStorable():
                name, ok = QInputDialog.getText(self, 'Save Object As', 'object name') 
                name = str(name)
                if not ok:
                    return
                objects = self.topLevelOperatorView.opCarving.AllObjectNames[:].wait()
                if name in objects:
                    QMessageBox.critical(self, "Save Object As", "An object with name '%s' already exists.\nPlease choose a different name." % name)
                    return
                self.topLevelOperatorView.opCarving.saveObjectAs(name)
                print "save object as %s" % name
            else:
                msgBox = QMessageBox(self)
                msgBox.setText("The data does not seem fit to be stored.")
                msgBox.setWindowTitle("Problem with Data")
                msgBox.setIcon(2)
                msgBox.exec_()
                print "object not saved due to faulty data."

        self.labelingDrawerUi.saveAs.clicked.connect(onSaveAsButton)

        def onSaveButton():
            if self.topLevelOperatorView.opCarving.dataIsStorable():
                if self.topLevelOperatorView.opCarving.hasCurrentObject():
                    name = self.topLevelOperatorView.opCarving.currentObjectName()
                    self.topLevelOperatorView.opCarving.saveObjectAs( name )
                else:
                    onSaveAsButton()
            else:
                msgBox = QMessageBox(self)
                msgBox.setText("The data does no seem fit to be stored.")
                msgBox.setWindowTitle("Lousy Data")
                msgBox.setIcon(2)
                msgBox.exec_()
                print "object not saved due to faulty data."
        self.labelingDrawerUi.save.clicked.connect(onSaveButton)
        self.labelingDrawerUi.save.setEnabled(False) #initially, the user need to use "Save As"

        def onClearButton():
            self.topLevelOperatorView.opCarving._clear()
            self.topLevelOperatorView.opCarving.clearCurrentLabeling()
            # trigger a re-computation
            self.topLevelOperatorView.opCarving.Trigger.setDirty(slice(None))
        self.labelingDrawerUi.clear.clicked.connect(onClearButton)
        self.labelingDrawerUi.clear.setEnabled(True)
        
        def onShowObjectNames():
            '''show object names and allow user to load/delete them'''
            dialog = uic.loadUi(os.path.join(directory, 'carvingObjectManagement.ui'))
            listOfItems = self.topLevelOperatorView.opCarving.AllObjectNames[:].wait()
            dialog.objectNames.addItems(sorted(listOfItems))
            
            def loadSelection():
                for name in dialog.objectNames.selectedItems():
                    objectname = str(name.text())
                    self.topLevelOperatorView.opCarving.loadObject(objectname)
            
            def deleteSelection():
                for name in dialog.objectNames.selectedItems():
                    objectname = str(name.text())
                    self.topLevelOperatorView.opCarving.deleteObject(objectname)
                    name.setHidden(True)
            
            dialog.loadButton.clicked.connect(loadSelection)
            dialog.deleteButton.clicked.connect(deleteSelection)
            dialog.cancelButton.clicked.connect(dialog.close)
            dialog.exec_()
        
        self.labelingDrawerUi.namesButton.clicked.connect(onShowObjectNames)
        
        def labelBackground():
            self.selectLabel(0)
        def labelObject():
            self.selectLabel(1)

        self._labelControlUi.labelListModel.allowRemove(False)

        bg = QShortcut(QKeySequence("1"), self, member=labelBackground, ambiguousMember=labelBackground)
        mgr.register("Carving", "Select background label", bg)
        fg = QShortcut(QKeySequence("2"), self, member=labelObject, ambiguousMember=labelObject)
        mgr.register("Carving", "Select object label", fg)

        def layerIndexForName(name):
            return self.layerstack.findMatchingIndex(lambda x: x.name == name)

        def addLayerToggleShortcut(layername, shortcut):
            def toggle():
                row = layerIndexForName(layername)
                self.layerstack.selectRow(row)
                layer = self.layerstack[row]
                layer.visible = not layer.visible
                self.viewerControlWidget().layerWidget.setFocus()
            shortcut = QShortcut(QKeySequence(shortcut), self, member=toggle, ambiguousMember=toggle)
            mgr.register("Carving", "Toggle layer %s" % layername, shortcut)

        addLayerToggleShortcut("done", "d")
        addLayerToggleShortcut("segmentation", "s")
        addLayerToggleShortcut("raw", "r")
        addLayerToggleShortcut("pmap", "v")
        addLayerToggleShortcut("hints","h")

        '''
        def updateLayerTimings():
            s = "Layer timings:\n"
            for l in self.layerstack:
                s += "%s: %f sec.\n" % (l.name, l.averageTimePerTile)
            self.labelingDrawerUi.layerTimings.setText(s)
        t = QTimer(self)
        t.setInterval(1*1000) # 10 seconds
        t.start()
        t.timeout.connect(updateLayerTimings)
        '''

        def makeColortable():
            self._doneSegmentationColortable = [QColor(0,0,0,0).rgba()]
            for i in range(254):
                r,g,b = numpy.random.randint(0,255), numpy.random.randint(0,255), numpy.random.randint(0,255)
                self._doneSegmentationColortable.append(QColor(r,g,b).rgba())
            self._doneSegmentationColortable[1:17] = colortables.default16
        makeColortable()
        self._doneSegmentationLayer = None
        def onRandomizeColors():
            if self._doneSegmentationLayer is not None:
                print "randomizing colors ..."
                makeColortable()
                self._doneSegmentationLayer.colorTable = self._doneSegmentationColortable
                if self.render and self._renderMgr.ready:
                    self._update_rendering()
        #self.labelingDrawerUi.randomizeColors.clicked.connect(onRandomizeColors)
        
    def handleEditorRightClick(self, position5d, globalWindowCoordinate):
        names = self.topLevelOperatorView.opCarving.doneObjectNamesForPosition(position5d[1:4])
       
        op = self.topLevelOperatorView.opCarving
        
        menu = QMenu(self)
        menu.addAction("position %d %d %d" % (position5d[1], position5d[2], position5d[3]))
        for name in names:
            menu.addAction("edit %s" % name)
            menu.addAction("delete %s" % name)
            if self.render:
                if name in self._shownObjects3D:
                    menu.addAction("remove %s from 3D view" % name)
                else:
                    menu.addAction("show 3D %s" % name)

        act = menu.exec_(globalWindowCoordinate)
        for name in names:
            if act is not None and act.text() == "edit %s" %name:
                op.loadObject(name)
            elif act is not None and act.text() =="delete %s" % name:
                op.deleteObject(name)
                if self.render and self._renderMgr.ready:
                    self._update_rendering()
            elif act is not None and act.text() == "show 3D %s" % name:
                label = self._renderMgr.addObject()
                self._shownObjects3D[name] = label
                self._update_rendering()
            elif act is not None and act.text() == "remove %s from 3D view" % name:
                label = self._shownObjects3D.pop(name)
                self._renderMgr.removeObject(label)
                self._update_rendering()

    def _update_rendering(self):
        if not self.render:
            return

        op = self.topLevelOperatorView.opCarving
        if not self._renderMgr.ready:
            self._renderMgr.setup(op.MST.value.raw.shape)

        # remove nonexistent objects
        self._shownObjects3D = dict((k, v) for k, v in self._shownObjects3D.iteritems()
                                    if k in op.MST.value.object_lut.keys())

        lut = numpy.zeros(len(op.MST.value.objects.lut), dtype=numpy.int32)
        for name, label in self._shownObjects3D.iteritems():
            objectSupervoxels = op.MST.value.object_lut[name]
            lut[objectSupervoxels] = label

        self._renderMgr.volume = lut[op.MST.value.regionVol]
        self._update_colors()
        self._renderMgr.update()

    def _update_colors(self):
        op = self.topLevelOperatorView.opCarving
        ctable = self._doneSegmentationLayer.colorTable

        for name, label in self._shownObjects3D.iteritems():
            color = QColor(ctable[op.MST.value.object_names[name]])
            color = (color.red() / 255.0, color.green() / 255.0, color.blue() / 255.0)
            self._renderMgr.setColor(label, color)


    def getNextLabelName(self):
        l = len(self._labelControlUi.labelListModel)
        if l == 0:
            return "Background"
        else:
            return "Object"

    def appletDrawers(self):
        return [ ("Carving", self._labelControlUi) ]

    def setupLayers( self ):
        layers = []

        def onButtonsEnabled(slot, roi):
            currObj = self.topLevelOperatorView.opCarving.CurrentObjectName.value
            hasSeg  = self.topLevelOperatorView.opCarving.HasSegmentation.value
            nzLB    = self.topLevelOperatorView.opCarving.opLabeling.NonzeroLabelBlocks[:].wait()[0]
            
            self.labelingDrawerUi.currentObjectLabel.setText("current object: %s" % currObj)
            self.labelingDrawerUi.save.setEnabled(currObj != "" and hasSeg)
            self.labelingDrawerUi.saveAs.setEnabled(currObj == "" and hasSeg)
            #rethink this
            #self.labelingDrawerUi.segment.setEnabled(len(nzLB) > 0)
            #self.labelingDrawerUi.clear.setEnabled(len(nzLB) > 0)
        self.topLevelOperatorView.opCarving.CurrentObjectName.notifyDirty(onButtonsEnabled)
        self.topLevelOperatorView.opCarving.HasSegmentation.notifyDirty(onButtonsEnabled)
        self.topLevelOperatorView.opCarving.opLabeling.NonzeroLabelBlocks.notifyDirty(onButtonsEnabled)
        
        # Labels
        labellayer, labelsrc = self.createLabelLayer(direct=True)
        if labellayer is not None:
            layers.append(labellayer)
            # Tell the editor where to draw label data
            self.editor.setLabelSink(labelsrc)

        #uncertainty
        uncert = self.topLevelOperatorView.opCarving.Uncertainty
        if uncert.ready():
            colortable = []
            for i in range(256-len(colortable)):
                r,g,b,a = i,0,0,i
                colortable.append(QColor(r,g,b,a).rgba())

            layer = ColortableLayer(LazyflowSource(uncert), colortable, direct=True)
            layer.name = "uncertainty"
            layer.visible = True
            layer.opacity = 0.3
            layers.append(layer)

       
        #segmentation 
        seg = self.topLevelOperatorView.opCarving.Segmentation
        
        #seg = self.topLevelOperatorView.opCarving.MST.value.segmentation
        #temp = self._done_lut[self.MST.value.regionVol[sl[1:4]]]
        if seg.ready():
            #source = RelabelingArraySource(seg)
            #source.setRelabeling(numpy.arange(256, dtype=numpy.uint8))
            colortable = [QColor(0,0,0,0).rgba(), QColor(0,0,0,0).rgba(), QColor(0,255,0).rgba()]
            for i in range(256-len(colortable)):
                r,g,b = numpy.random.randint(0,255), numpy.random.randint(0,255), numpy.random.randint(0,255)
                colortable.append(QColor(r,g,b).rgba())

            layer = ColortableLayer(LazyflowSource(seg), colortable, direct=True)
            layer.name = "segmentation"
            layer.visible = True
            layer.opacity = 0.3
            layers.append(layer)
        
        #done 
        done = self.topLevelOperatorView.opCarving.DoneObjects
        if done.ready(): 
            colortable = [QColor(0,0,0,0).rgba(), QColor(0,0,255).rgba()]
            for i in range(254-len(colortable)):
                r,g,b = numpy.random.randint(0,255), numpy.random.randint(0,255), numpy.random.randint(0,255)
                colortable.append(QColor(r,g,b).rgba())
            #have to use lazyflow because it provides dirty signals
            layer = ColortableLayer(LazyflowSource(done), colortable, direct=True)
            layer.name = "done"
            layer.visible = False
            layer.opacity = 0.5
            layers.append(layer)

        #hints
        useLazyflow = True
        ctable = [QColor(0,0,0,0).rgba(), QColor(255,0,0).rgba()]
        ctable.extend( [QColor(255*random.random(), 255*random.random(), 255*random.random()) for x in range(254)] )
        if useLazyflow:
            hints = self.topLevelOperatorView.opCarving.HintOverlay
            layer = ColortableLayer(LazyflowSource(hints), ctable, direct=True)
        else:
            hints = self.topLevelOperatorView.opCarving._hints
            layer = ColortableLayer(ArraySource(hints), ctable, direct=True)
        if not useLazyflow or hints.ready():
            layer.name = "hints"
            layer.visible = False
            layer.opacity = 1.0
            layers.append(layer)
            
        #pmaps
        useLazyflow = True
        pmaps = self.topLevelOperatorView.opCarving._pmap
        if pmaps is not None:
            layer = GrayscaleLayer(ArraySource(pmaps), direct=True)
            layer.name = "pmap"
            layer.visible = False
            layer.opacity = 1.0
            layers.append(layer)

        #done seg
        doneSeg = self.topLevelOperatorView.opCarving.DoneSegmentation
        if doneSeg.ready():
            if self._doneSegmentationLayer is None:
                layer = ColortableLayer(LazyflowSource(doneSeg), self._doneSegmentationColortable, direct=True)
                layer.name = "done seg"
                layer.visible = False
                layer.opacity = 0.5
                self._doneSegmentationLayer = layer
                layers.append(layer)
            else:
                layers.append(self._doneSegmentationLayer)

        #supervoxel
        sv = self.topLevelOperatorView.opCarving.Supervoxels
        if sv.ready():
            for i in range(256):
                r,g,b = numpy.random.randint(0,255), numpy.random.randint(0,255), numpy.random.randint(0,255)
                colortable.append(QColor(r,g,b).rgba())
            layer = ColortableLayer(LazyflowSource(sv), colortable, direct=True)
            layer.name = "supervoxels"
            layer.visible = False
            layer.opacity = 1.0
            layers.append(layer)

        #raw data
        #(here we load the actual raw data from an ArraySource rather than from a LazyflowSource for speed reasons)
        raw5D = self.topLevelOperatorView.RawData.value
        layer = GrayscaleLayer(ArraySource(raw5D), direct=True)
        layer.name = "raw"
        layer.visible = True
        layer.opacity = 1.0
        layers.append(layer)

        return layers
Esempio n. 6
0
class CarvingGui(LabelingGui):
    def __init__(self, parentApplet, topLevelOperatorView, drawerUiPath=None ):
        self.topLevelOperatorView = topLevelOperatorView

        #members
        self._doneSegmentationLayer = None
        self._showSegmentationIn3D = False
        #self._showUncertaintyLayer = False
        #end: members

        labelingSlots = LabelingGui.LabelingSlots()
        labelingSlots.labelInput       = topLevelOperatorView.WriteSeeds
        labelingSlots.labelOutput      = topLevelOperatorView.opLabelArray.Output
        labelingSlots.labelEraserValue = topLevelOperatorView.opLabelArray.EraserLabelValue
        labelingSlots.labelNames       = topLevelOperatorView.LabelNames
        labelingSlots.labelDelete      = topLevelOperatorView.opLabelArray.DeleteLabel
        labelingSlots.maxLabelValue    = topLevelOperatorView.opLabelArray.MaxLabelValue
        
        # We provide our own UI file (which adds an extra control for interactive mode)
        directory = os.path.split(__file__)[0]
        if drawerUiPath is None:
            drawerUiPath = os.path.join(directory, 'carvingDrawer.ui')
        self.dialogdirCOM = os.path.join(directory, 'carvingObjectManagement.ui')
        self.dialogdirSAD = os.path.join(directory, 'saveAsDialog.ui')

        # Add 3DWidget only if the data is 3D
        is_3d = self._is_3d()

        super(CarvingGui, self).__init__(parentApplet, labelingSlots, topLevelOperatorView, drawerUiPath,
                                         is_3d_widget_visible=is_3d)

        self.labelingDrawerUi.currentObjectLabel.setText("<not saved yet>")

        # Init special base class members
        self.minLabelNumber = 2
        self.maxLabelNumber = 2

        mgr = ShortcutManager()
        ActionInfo = ShortcutManager.ActionInfo

        #set up keyboard shortcuts
        mgr.register( "3", ActionInfo( "Carving",
                                       "Run interactive segmentation",
                                       "Run interactive segmentation",
                                       self.labelingDrawerUi.segment.click,
                                       self.labelingDrawerUi.segment,
                                       self.labelingDrawerUi.segment  ) )


        # Disable 3D view by default
        self.render = False
        if is_3d:
            try:
                self._renderMgr = RenderingManager( self.editor.view3d )
                self._shownObjects3D = {}
                self.render = True
            except:
                self.render = False

        # Segmentation is toggled on by default in _after_init, below.
        # (We can't enable it until the layers are all present.)
        self._showSegmentationIn3D = False
        self._segmentation_3d_label = None
                
        self.labelingDrawerUi.segment.clicked.connect(self.onSegmentButton)
        self.labelingDrawerUi.segment.setEnabled(True)

        self.topLevelOperatorView.Segmentation.notifyDirty( bind( self._segmentation_dirty ) )
        self.topLevelOperatorView.HasSegmentation.notifyValueChanged( bind( self._updateGui ) )

        ## uncertainty

        #self.labelingDrawerUi.pushButtonUncertaintyFG.setEnabled(False)
        #self.labelingDrawerUi.pushButtonUncertaintyBG.setEnabled(False)

        #def onUncertaintyFGButton():
        #    logger.debug( "uncertFG button clicked" )
        #    pos = self.topLevelOperatorView.getMaxUncertaintyPos(label=2)
        #    self.editor.posModel.slicingPos = (pos[0], pos[1], pos[2])
        #self.labelingDrawerUi.pushButtonUncertaintyFG.clicked.connect(onUncertaintyFGButton)

        #def onUncertaintyBGButton():
        #    logger.debug( "uncertBG button clicked" )
        #    pos = self.topLevelOperatorView.getMaxUncertaintyPos(label=1)
        #    self.editor.posModel.slicingPos = (pos[0], pos[1], pos[2])
        #self.labelingDrawerUi.pushButtonUncertaintyBG.clicked.connect(onUncertaintyBGButton)

        #def onUncertaintyCombo(value):
        #    if value == 0:
        #        value = "none"
        #        self.labelingDrawerUi.pushButtonUncertaintyFG.setEnabled(False)
        #        self.labelingDrawerUi.pushButtonUncertaintyBG.setEnabled(False)
        #        self._showUncertaintyLayer = False
        #    else:
        #        if value == 1:
        #            value = "localMargin"
        #        elif value == 2:
        #            value = "exchangeCount"
        #        elif value == 3:
        #            value = "gabow"
        #        else:
        #            raise RuntimeError("unhandled case '%r'" % value)
        #        self.labelingDrawerUi.pushButtonUncertaintyFG.setEnabled(True)
        #        self.labelingDrawerUi.pushButtonUncertaintyBG.setEnabled(True)
        #        self._showUncertaintyLayer = True
        #        logger.debug( "uncertainty changed to %r" % value )
        #    self.topLevelOperatorView.UncertaintyType.setValue(value)
        #    self.updateAllLayers() #make sure that an added/deleted uncertainty layer is recognized
        #self.labelingDrawerUi.uncertaintyCombo.currentIndexChanged.connect(onUncertaintyCombo)

        ## background priority
        
        def onBackgroundPrioritySpin(value):
            logger.debug( "background priority changed to %f" % value )
            self.topLevelOperatorView.BackgroundPriority.setValue(value)
        self.labelingDrawerUi.backgroundPrioritySpin.valueChanged.connect(onBackgroundPrioritySpin)

        def onBackgroundPriorityDirty(slot, roi):
            oldValue = self.labelingDrawerUi.backgroundPrioritySpin.value()
            newValue = self.topLevelOperatorView.BackgroundPriority.value
            if  newValue != oldValue:
                self.labelingDrawerUi.backgroundPrioritySpin.setValue(newValue)
        self.topLevelOperatorView.BackgroundPriority.notifyDirty(onBackgroundPriorityDirty)
        
        ## bias
        
        def onNoBiasBelowDirty(slot, roi):
            oldValue = self.labelingDrawerUi.noBiasBelowSpin.value()
            newValue = self.topLevelOperatorView.NoBiasBelow.value
            if  newValue != oldValue:
                self.labelingDrawerUi.noBiasBelowSpin.setValue(newValue)
        self.topLevelOperatorView.NoBiasBelow.notifyDirty(onNoBiasBelowDirty)
        
        def onNoBiasBelowSpin(value):
            logger.debug( "background priority changed to %f" % value )
            self.topLevelOperatorView.NoBiasBelow.setValue(value)
        self.labelingDrawerUi.noBiasBelowSpin.valueChanged.connect(onNoBiasBelowSpin)
        
        ## save

        self.labelingDrawerUi.save.clicked.connect(self.onSaveButton)

        ## clear

        self.labelingDrawerUi.clear.clicked.connect(self._onClearAction)
        
        ## object names
        
        self.labelingDrawerUi.namesButton.clicked.connect(self.onShowObjectNames)
        if hasattr( self.labelingDrawerUi, 'exportAllMeshesButton' ):
            self.labelingDrawerUi.exportAllMeshesButton.clicked.connect(self._exportAllObjectMeshes)

        self.labelingDrawerUi.labelListView.allowDelete = False
        self._labelControlUi.labelListModel.allowRemove(False)

        def layerIndexForName(name):
            return self.layerstack.findMatchingIndex(lambda x: x.name == name)
        
        def addLayerToggleShortcut(layername, shortcut):
            def toggle():
                row = layerIndexForName(layername)
                self.layerstack.selectRow(row)
                layer = self.layerstack[row]
                layer.visible = not layer.visible
                self.viewerControlWidget().layerWidget.setFocus()

            mgr.register(shortcut, ActionInfo( "Carving", 
                                               "Toggle layer %s" % layername, 
                                               "Toggle layer %s" % layername, 
                                               toggle,
                                               self.viewerControlWidget(),
                                               None ) )

        #TODO
        addLayerToggleShortcut("Completed segments (unicolor)", "d")
        addLayerToggleShortcut("Segmentation", "s")
        addLayerToggleShortcut("Input Data", "r")

        def makeColortable():
            self._doneSegmentationColortable = [QColor(0,0,0,0).rgba()]
            for i in range(254):
                r,g,b = numpy.random.randint(0,255), numpy.random.randint(0,255), numpy.random.randint(0,255)
                # ensure colors have sufficient distance to pure red and pure green
                while (255 - r)+g+b<128 or r+(255-g)+b<128:
                    r,g,b = numpy.random.randint(0,255), numpy.random.randint(0,255), numpy.random.randint(0,255)
                self._doneSegmentationColortable.append(QColor(r,g,b).rgba())
            self._doneSegmentationColortable.append(QColor(0,255,0).rgba())
        makeColortable()
        self._updateGui()

    def _is_3d(self):
        tagged_shape = defaultdict(lambda: 1)
        tagged_shape.update(self.topLevelOperatorView.InputData.meta.getTaggedShape())
        is_3d = tagged_shape['x'] > 1 and tagged_shape['y'] > 1 and tagged_shape['z'] > 1
        return is_3d

    def _after_init(self):
        super(CarvingGui, self)._after_init()
        if self.render:
            self._toggleSegmentation3D()
        
    def _updateGui(self):
        self.labelingDrawerUi.save.setEnabled( self.topLevelOperatorView.dataIsStorable() )
        
    def onSegmentButton(self):
        logger.debug( "segment button clicked" )
        self.topLevelOperatorView.Trigger.setDirty(slice(None))
    
    def saveAsDialog(self, name=""):
        '''special functionality: reject names given to other objects'''
        dialog = uic.loadUi(self.dialogdirSAD)
        dialog.lineEdit.setText(name)
        dialog.warning.setVisible(False)
        dialog.Ok.clicked.connect(dialog.accept)
        dialog.Cancel.clicked.connect(dialog.reject)
        listOfItems = self.topLevelOperatorView.AllObjectNames[:].wait()
        dialog.isDisabled = False
        def validate():
            name = dialog.lineEdit.text()
            if name in listOfItems:
                dialog.Ok.setEnabled(False)
                dialog.warning.setVisible(True)
                dialog.isDisabled = True
            elif dialog.isDisabled:
                dialog.Ok.setEnabled(True)
                dialog.warning.setVisible(False)
                dialog.isDisabled = False
        dialog.lineEdit.textChanged.connect(validate)
        result = dialog.exec_()
        if result:
            return str(dialog.lineEdit.text())
    
    def onSaveButton(self):
        logger.info( "save object as?" )
        if self.topLevelOperatorView.dataIsStorable():
            prevName = ""
            if self.topLevelOperatorView.hasCurrentObject():
                prevName = self.topLevelOperatorView.currentObjectName()
            if prevName == "<not saved yet>":
                prevName = ""
            name = self.saveAsDialog(name=prevName)
            if name is None:
                return
            objects = self.topLevelOperatorView.AllObjectNames[:].wait()
            if name in objects and name != prevName:
                QMessageBox.critical(self, "Save Object As", "An object with name '%s' already exists.\nPlease choose a different name." % name)
                return
            self.topLevelOperatorView.saveObjectAs(name)
            logger.info( "save object as %s" % name )
            if prevName != name and prevName != "":
                self.topLevelOperatorView.deleteObject(prevName)
            elif prevName == name:
                self._renderMgr.removeObject(prevName)
                self._renderMgr.invalidateObject(prevName)
                self._shownObjects3D.pop(prevName, None)
        else:
            msgBox = QMessageBox(self)
            msgBox.setText("The data does not seem fit to be stored.")
            msgBox.setWindowTitle("Problem with Data")
            msgBox.setIcon(2)
            msgBox.exec_()
            logger.error( "object not saved due to faulty data." )
    
    def onShowObjectNames(self):
        '''show object names and allow user to load/delete them'''
        dialog = uic.loadUi(self.dialogdirCOM)
        listOfItems = self.topLevelOperatorView.AllObjectNames[:].wait()
        dialog.objectNames.addItems(sorted(listOfItems))
        
        def loadSelection():
            selected = [str(name.text()) for name in dialog.objectNames.selectedItems()]
            dialog.close()
            for objectname in selected: 
                self.topLevelOperatorView.loadObject(objectname)
        
        def deleteSelection():
            items = dialog.objectNames.selectedItems()
            if self.confirmAndDelete([str(name.text()) for name in items]):
                for name in items:
                    name.setHidden(True)
            dialog.close()
        
        dialog.loadButton.clicked.connect(loadSelection)
        dialog.deleteButton.clicked.connect(deleteSelection)
        dialog.cancelButton.clicked.connect(dialog.close)
        dialog.exec_()
    
    def confirmAndDelete(self,namelist):
        logger.info( "confirmAndDelete: {}".format( namelist ) )
        objectlist = "".join("\n  "+str(i) for i in namelist)
        confirmed = QMessageBox.question(self, "Delete Object", \
                    "Do you want to delete these objects?"+objectlist, \
                    QMessageBox.Yes | QMessageBox.Cancel, \
                    defaultButton=QMessageBox.Yes)
            
        if confirmed == QMessageBox.Yes:
            for name in namelist:
                self.topLevelOperatorView.deleteObject(name)
            return True
        return False
    
    def labelingContextMenu(self,names,op,position5d):
        menu = QMenu(self)
        menu.setObjectName("carving_context_menu")
        posItem = menu.addAction("position %d %d %d" % (position5d[1], position5d[2], position5d[3]))
        posItem.setEnabled(False)
        menu.addSeparator()
        for name in names:
            submenu = QMenu(name,menu)
            
            # Load
            loadAction = submenu.addAction("Load %s" % name)
            loadAction.triggered.connect( partial(op.loadObject, name) )
            
            # Delete
            def onDelAction(_name):
                self.confirmAndDelete([_name])
                if self.render and self._renderMgr.ready:
                    self._update_rendering()
            delAction = submenu.addAction("Delete %s" % name)
            delAction.triggered.connect( partial(onDelAction, name) )

            if self.render:
                if name in self._shownObjects3D:
                    # Remove
                    def onRemove3D(_name):
                        label = self._shownObjects3D.pop(_name)
                        self._renderMgr.removeObject(label)
                        self._update_rendering()
                    removeAction = submenu.addAction("Remove %s from 3D view" % name)
                    removeAction.triggered.connect( partial(onRemove3D, name) )
                else:
                    # Show
                    def onShow3D(_name):
                        label = self._renderMgr.addObject()
                        self._shownObjects3D[_name] = label
                        self._update_rendering()
                    showAction = submenu.addAction("Show 3D %s" % name)
                    showAction.triggered.connect( partial(onShow3D, name ) )
            
            # Export mesh

            exportAction = submenu.addAction("Export mesh for %s" % name)
            exportAction.triggered.connect( partial(self._onContextMenuExportMesh, name) )
                        
            menu.addMenu(submenu)

        if names:
            menu.addSeparator()

        menu.addSeparator()
        if self.render:
            showSeg3DAction = menu.addAction( "Show Editing Segmentation in 3D" )
            showSeg3DAction.setCheckable(True)
            showSeg3DAction.setChecked( self._showSegmentationIn3D )
            showSeg3DAction.triggered.connect( self._toggleSegmentation3D )
        
        if op.dataIsStorable():
            menu.addAction("Save object").triggered.connect( self.onSaveButton )
        menu.addAction("Browse objects").triggered.connect( self.onShowObjectNames )
        menu.addAction("Segment").triggered.connect( self.onSegmentButton )
        menu.addAction("Clear").triggered.connect( self._onClearAction )
        return menu

    def _onClearAction(self):
        confirm = QMessageBox.warning(self, "Really Clear?", "Clear all brushtrokes?", QMessageBox.Ok | QMessageBox.Cancel)
        if confirm == QMessageBox.Ok:
            self.topLevelOperatorView.clearCurrentLabeling()

    def _onContextMenuExportMesh(self, _name):
        """
        Export a single object mesh to a user-specified filename.
        """
        recent_dir = PreferencesManager().get( 'carving', 'recent export mesh directory' )
        if recent_dir is None:
            defaultPath = os.path.join( os.path.expanduser('~'), '{}obj'.format(_name) )
        else:
            defaultPath = os.path.join( recent_dir, '{}.obj'.format(_name) )
        filepath, _filter = QFileDialog.getSaveFileName(self, 
                                               "Save meshes for object '{}'".format(_name),
                                               defaultPath,
                                               "OBJ Files (*.obj)")
        if not filepath:
            return
        obj_filepath = str(filepath)
        PreferencesManager().set( 'carving', 'recent export mesh directory', os.path.split(obj_filepath)[0] )
        
        self._exportMeshes([_name], [obj_filepath])

    def _exportAllObjectMeshes(self):
        """
        Export all objects in the project as separate .obj files, stored to a user-specified directory.
        """
        mst = self.topLevelOperatorView.MST.value
        if not list(mst.object_lut.keys()):
            QMessageBox.critical(self, "Can't Export", "You have no saved objets, so there are no meshes to export.")
            return
        
        recent_dir = PreferencesManager().get( 'carving', 'recent export mesh directory' )
        if recent_dir is None:
            defaultPath = os.path.join( os.path.expanduser('~') )
        else:
            defaultPath = os.path.join( recent_dir )
        export_dir = QFileDialog.getExistingDirectory( self, 
                                                       "Select export directory for mesh files",
                                                       defaultPath)
        if not export_dir:
            return
        export_dir = str(export_dir)
        PreferencesManager().set( 'carving', 'recent export mesh directory', export_dir )

        # Get the list of all object names
        object_names = []
        obj_filepaths = []
        for object_name in list(mst.object_lut.keys()):
            object_names.append( object_name )
            obj_filepaths.append( os.path.join( export_dir, "{}.obj".format( object_name ) ) )
        
        if object_names:
            self._exportMeshes( object_names, obj_filepaths )

    def _exportMeshes(self, object_names, obj_filepaths):
        """
        Export a mesh .obj file for each object in the object_names list to the corresponding file name from the obj_filepaths list.
        This function is pseudo-recursive. It works like this:
        1) Pop the first name/file from the args
        2) Kick off the export by launching the export mesh dlg
        3) return from this function to allow the eventloop to resume while the export is running
        4) When the export dlg is finished, create the mesh file
        5) If there are still more items in the object_names list to process, repeat this function.
        """
        # Pop the first object off the list
        object_name = object_names.pop(0)
        obj_filepath = obj_filepaths.pop(0)
        
        # Construct a volume with only this object.
        # We might be tempted to get the object directly from opCarving.DoneObjects, 
        #  but that won't be correct for overlapping objects.
        mst = self.topLevelOperatorView.MST.value
        object_supervoxels = mst.object_lut[object_name]
        object_lut = numpy.zeros(mst.nodeNum+1, dtype=numpy.int32)
        object_lut[object_supervoxels] = 1
        supervoxel_volume = mst.supervoxelUint32
        object_volume = object_lut[supervoxel_volume]

        if len(numpy.unique(object_volume)) <= 1:
            if object_names:
                self._exportMeshes(object_names, obj_filepaths)
            return

        # Run the mesh extractor
        window = MeshGeneratorDialog(self)
        
        def onMeshesComplete(mesh):
            """
            Called when mesh extraction is complete.
            Writes the extracted mesh to an .obj file
            """
            logger.info( "Mesh generation complete." )

            # FIXME: the old comment: Mesh count can sometimes be 0 for the '<not saved yet>' object...
            # FIXME: is this still relevant???
            '''
            mesh_count = len( window.extractor.meshes )
            if mesh_count > 0:
                assert mesh_count == 1, \
                    "Found {} meshes processing object '{}',"\
                    "(only expected 1)".format( mesh_count, object_name )
                mesh = list(window.extractor.meshes.values())[0]
                logger.info( "Saving meshes to {}".format( obj_filepath ) )
    
                # Use VTK to write to a temporary .vtk file
                tmpdir = tempfile.mkdtemp()
                vtkpoly_path = os.path.join(tmpdir, 'meshes.vtk')
                w = vtkPolyDataWriter()
                w.SetFileTypeToASCII()
                w.SetInput(mesh)
                w.SetFileName(vtkpoly_path)
                w.Write()
                
                # Now convert the file to .obj format.
                convertVTPtoOBJ(vtkpoly_path, obj_filepath)
            '''
            logger.info("Saving meshes to {}".format(obj_filepath))
            mesh_to_obj(mesh, obj_filepath, object_name)
            # Cleanup: We don't need the window anymore.
            window.setParent(None)

            # If there are still objects left to process,
            #   start again with the remainder of the list.
            if object_names:
                self._exportMeshes(object_names, obj_filepaths)
            
        window.finished.connect( onMeshesComplete )

        # Kick off the save process and exit to the event loop
        window.show()
        QTimer.singleShot(0, partial(window.run, object_volume))


    def handleEditorRightClick(self, position5d, globalWindowCoordinate):
        names = self.topLevelOperatorView.doneObjectNamesForPosition(position5d[1:4])
        op = self.topLevelOperatorView

        # (Subclasses may override menu)
        menu = self.labelingContextMenu(names,op,position5d)
        if menu is not None:
            menu.exec_(globalWindowCoordinate)

    def _toggleSegmentation3D(self):
        self._showSegmentationIn3D = not self._showSegmentationIn3D
        if self._showSegmentationIn3D:
            self._segmentation_3d_label = self._renderMgr.addObject()
        else:
            self._renderMgr.removeObject(self._segmentation_3d_label)
            self._segmentation_3d_label = None
        self._update_rendering()

    def _segmentation_dirty(self):
        if self.render:
            self._renderMgr.invalidateObject(CURRENT_SEGMENTATION_NAME)
            self._renderMgr.removeObject(CURRENT_SEGMENTATION_NAME)

        self._update_rendering()

    def _update_rendering(self):
        if not self.render:
            return

        op = self.topLevelOperatorView
        if not self._renderMgr.ready:
            shape = op.InputData.meta.shape[1:4]
            self._renderMgr.setup(op.InputData.meta.shape[1:4])

        # remove nonexistent objects
        self._shownObjects3D = dict((k, v) for k, v in self._shownObjects3D.items()
                                    if k in list(op.MST.value.object_lut.keys()))

        lut = numpy.zeros(op.MST.value.nodeNum+1, dtype=numpy.int32)
        label_name_map = {}
        for name, label in self._shownObjects3D.items():
            objectSupervoxels = op.MST.value.object_lut[name]
            lut[objectSupervoxels] = label
            label_name_map[label] = name
            label_name_map[name] = label

        if self._showSegmentationIn3D:
            # Add segmentation as label, which is green
            label_name_map[self._segmentation_3d_label] = CURRENT_SEGMENTATION_NAME
            label_name_map[CURRENT_SEGMENTATION_NAME] = self._segmentation_3d_label
            lut[:] = numpy.where( op.MST.value.getSuperVoxelSeg() == 2, self._segmentation_3d_label, lut )

        self._renderMgr.volume = lut[op.MST.value.supervoxelUint32], label_name_map  # (Advanced indexing)
        self._update_colors()
        self._renderMgr.update()

    def _update_colors(self):
        op = self.topLevelOperatorView
        ctable = self._doneSegmentationLayer.colorTable

        for name, label in self._shownObjects3D.items():
            color = QColor(ctable[op.MST.value.object_names[name]])
            color = (old_div(color.red(), 255.0), old_div(color.green(), 255.0), old_div(color.blue(), 255.0))
            self._renderMgr.setColor(label, color)

        if self._showSegmentationIn3D and self._segmentation_3d_label is not None:
            self._renderMgr.setColor(self._segmentation_3d_label, (0.0, 1.0, 0.0)) # Green

    def _getNext(self, slot, parentFun, transform=None):
        numLabels = self.labelListData.rowCount()
        value = slot.value
        if numLabels < len(value):
            result = value[numLabels]
            if transform is not None:
                result = transform(result)
            return result
        else:
            return parentFun()

    def getNextLabelName(self):
        return self._getNext(self.topLevelOperatorView.LabelNames,
                             super(CarvingGui, self).getNextLabelName)

    def appletDrawers(self):
        return [ ("Carving", self._labelControlUi) ]

    def setupLayers( self ):
        logger.debug( "setupLayers" )
        
        layers = []

        def onButtonsEnabled(slot, roi):
            currObj = self.topLevelOperatorView.CurrentObjectName.value
            hasSeg  = self.topLevelOperatorView.HasSegmentation.value
            
            self.labelingDrawerUi.currentObjectLabel.setText(currObj)
            self.labelingDrawerUi.save.setEnabled(hasSeg)

        self.topLevelOperatorView.CurrentObjectName.notifyDirty(onButtonsEnabled)
        self.topLevelOperatorView.HasSegmentation.notifyDirty(onButtonsEnabled)
        self.topLevelOperatorView.opLabelArray.NonzeroBlocks.notifyDirty(onButtonsEnabled)
        
        # Labels
        labellayer, labelsrc = self.createLabelLayer(direct=True)
        if labellayer is not None:
            labellayer._allowToggleVisible = False
            layers.append(labellayer)
            # Tell the editor where to draw label data
            self.editor.setLabelSink(labelsrc)

        #uncertainty
        #if self._showUncertaintyLayer:
        #    uncert = self.topLevelOperatorView.Uncertainty
        #    if uncert.ready():
        #        colortable = []
        #        for i in range(256-len(colortable)):
        #            r,g,b,a = i,0,0,i
        #            colortable.append(QColor(r,g,b,a).rgba())
        #        layer = ColortableLayer(LazyflowSource(uncert), colortable, direct=True)
        #        layer.name = "Uncertainty"
        #        layer.visible = True
        #        layer.opacity = 0.3
        #        layers.append(layer)
       
        #segmentation 
        seg = self.topLevelOperatorView.Segmentation
        
        #seg = self.topLevelOperatorView.MST.value.segmentation
        #temp = self._done_lut[self.MST.value.supervoxelUint32[sl[1:4]]]
        if seg.ready():
            #source = RelabelingArraySource(seg)
            #source.setRelabeling(numpy.arange(256, dtype=numpy.uint8))
            colortable = [QColor(0,0,0,0).rgba(), QColor(0,0,0,0).rgba(), QColor(0,255,0).rgba()]
            for i in range(256-len(colortable)):
                r,g,b = numpy.random.randint(0,255), numpy.random.randint(0,255), numpy.random.randint(0,255)
                colortable.append(QColor(r,g,b).rgba())

            layer = ColortableLayer(LazyflowSource(seg), colortable, direct=True)
            layer.name = "Segmentation"
            layer.setToolTip("This layer displays the <i>current</i> segmentation. Simply add foreground and background " \
                             "labels, then press <i>Segment</i>.")
            layer.visible = True
            layer.opacity = 0.3
            layers.append(layer)
        
        #done 
        done = self.topLevelOperatorView.DoneObjects
        if done.ready(): 
            colortable = [QColor(0,0,0,0).rgba(), QColor(0,0,255).rgba()]
            #have to use lazyflow because it provides dirty signals
            layer = ColortableLayer(LazyflowSource(done), colortable, direct=True)
            layer.name = "Completed segments (unicolor)"
            layer.setToolTip("In order to keep track of which objects you have already completed, this layer " \
                             "shows <b>all completed object</b> in one color (<b>blue</b>). " \
                             "The reason for only one color is that for finding out which " \
                              "objects to label next, the identity of already completed objects is unimportant " \
                              "and destracting.")
            layer.visible = False
            layer.opacity = 0.5
            layers.append(layer)

        #done seg
        doneSeg = self.topLevelOperatorView.DoneSegmentation
        if doneSeg.ready():
            layer = ColortableLayer(LazyflowSource(doneSeg), self._doneSegmentationColortable, direct=True)
            layer.name = "Completed segments (one color per object)"
            layer.setToolTip("<html>In order to keep track of which objects you have already completed, this layer " \
                             "shows <b>all completed object</b>, each with a random color.</html>")
            layer.visible = False
            layer.opacity = 0.5
            self._doneSegmentationLayer = layer
            layers.append(layer)

        #supervoxel
        sv = self.topLevelOperatorView.Supervoxels
        if sv.ready():
            colortable = []
            for i in range(256):
                r,g,b = numpy.random.randint(0,255), numpy.random.randint(0,255), numpy.random.randint(0,255)
                colortable.append(QColor(r,g,b).rgba())
            layer = ColortableLayer(LazyflowSource(sv), colortable, direct=True)
            layer.name = "Supervoxels"
            layer.setToolTip("<html>This layer shows the partitioning of the input image into <b>supervoxels</b>. The carving " \
                             "algorithm uses these tiny puzzle-piceces to piece together the segmentation of an " \
                             "object. Sometimes, supervoxels are too large and straddle two distinct objects " \
                             "(undersegmentation). In this case, it will be impossible to achieve the desired " \
                             "segmentation. This layer helps you to understand these cases.</html>")
            layer.visible = False
            layer.opacity = 1.0
            layers.append(layer)

        # Visual overlay (just for easier labeling)
        overlaySlot = self.topLevelOperatorView.OverlayData
        if overlaySlot.ready():
            overlay5D = self.topLevelOperatorView.OverlayData.value
            layer = GrayscaleLayer(ArraySource(overlay5D), direct=True)
            layer.visible = True
            layer.name = 'Overlay'
            layer.opacity = 1.0
            # if the flag window_leveling is set the contrast 
            # of the layer is adjustable
            layer.window_leveling = True
            self.labelingDrawerUi.thresToolButton.show()
            layers.append(layer)
            del layer

        inputSlot = self.topLevelOperatorView.InputData
        if inputSlot.ready():
            layer = GrayscaleLayer( LazyflowSource(inputSlot), direct=True )
            layer.name = "Input Data"
            layer.setToolTip("<html>The data originally loaded into ilastik (unprocessed).</html>")
            #layer.visible = not rawSlot.ready()
            layer.visible = True
            layer.opacity = 1.0

            # Window leveling is already active on the Overlay,
            # but if no overlay was provided, then activate window_leveling on the raw data instead.
            if not overlaySlot.ready():
                # if the flag window_leveling is set the contrast 
                # of the layer is adjustable
                layer.window_leveling = True
                self.labelingDrawerUi.thresToolButton.show()

            layers.append(layer)
            del layer

        filteredSlot = self.topLevelOperatorView.FilteredInputData
        if filteredSlot.ready():
            layer = GrayscaleLayer( LazyflowSource(filteredSlot) )
            layer.name = "Filtered Input"
            layer.visible = False
            layer.opacity = 1.0
            layers.append(layer)

        return layers
Esempio n. 7
0
class PixelClassificationGui(LabelingGui):

    ###########################################
    ### AppletGuiInterface Concrete Methods ###
    ###########################################
    def centralWidget(self):
        return self

    def stopAndCleanUp(self):
        for fn in self.__cleanup_fns:
            fn()

        # Base class
        super(PixelClassificationGui, self).stopAndCleanUp()

    def viewerControlWidget(self):
        return self._viewerControlUi

    def menus(self):
        menus = super(PixelClassificationGui, self).menus()

        # For now classifier selection is only available in debug mode
        if ilastik_config.getboolean('ilastik', 'debug'):

            def handleClassifierAction():
                dlg = ClassifierSelectionDlg(self.topLevelOperatorView,
                                             parent=self)
                dlg.exec_()

            advanced_menu = QMenu("Advanced", parent=self)
            classifier_action = advanced_menu.addAction("Classifier...")
            classifier_action.triggered.connect(handleClassifierAction)
            menus += [advanced_menu]

        return menus

    ###########################################
    ###########################################

    def __init__(self, parentApplet, topLevelOperatorView):
        self.parentApplet = parentApplet
        # Tell our base class which slots to monitor
        labelSlots = LabelingGui.LabelingSlots()
        labelSlots.labelInput = topLevelOperatorView.LabelInputs
        labelSlots.labelOutput = topLevelOperatorView.LabelImages
        labelSlots.labelEraserValue = topLevelOperatorView.opLabelPipeline.opLabelArray.eraser
        labelSlots.labelDelete = topLevelOperatorView.opLabelPipeline.DeleteLabel
        labelSlots.labelNames = topLevelOperatorView.LabelNames
        labelSlots.labelsAllowed = topLevelOperatorView.LabelsAllowedFlags

        self.__cleanup_fns = []

        # We provide our own UI file (which adds an extra control for interactive mode)
        labelingDrawerUiPath = os.path.split(
            __file__)[0] + '/labelingDrawer.ui'

        # Base class init
        super(PixelClassificationGui,
              self).__init__(parentApplet, labelSlots, topLevelOperatorView,
                             labelingDrawerUiPath)

        self.topLevelOperatorView = topLevelOperatorView

        self.interactiveModeActive = False
        # Immediately update our interactive state
        self.toggleInteractive(
            not self.topLevelOperatorView.FreezePredictions.value)

        self._currentlySavingPredictions = False

        self.labelingDrawerUi.liveUpdateButton.setEnabled(False)
        self.labelingDrawerUi.liveUpdateButton.setIcon(QIcon(
            ilastikIcons.Play))
        self.labelingDrawerUi.liveUpdateButton.setToolButtonStyle(
            Qt.ToolButtonTextBesideIcon)
        self.labelingDrawerUi.liveUpdateButton.toggled.connect(
            self.toggleInteractive)

        self.topLevelOperatorView.LabelNames.notifyDirty(
            bind(self.handleLabelSelectionChange))
        self.__cleanup_fns.append(
            partial(self.topLevelOperatorView.LabelNames.unregisterDirty,
                    bind(self.handleLabelSelectionChange)))

        self._initShortcuts()

        # FIXME: We MUST NOT enable the render manager by default,
        #        since it will drastically slow down the app for large volumes.
        #        For now, we leave it off by default.
        #        To re-enable rendering, we need to allow the user to render a segmentation
        #        and then initialize the render manager on-the-fly.
        #        (We might want to warn the user if her volume is not small.)
        self.render = False
        self._renderMgr = None
        self._renderedLayers = {}  # (layer name, label number)

        # Always off for now (see note above)
        if self.render:
            try:
                self._renderMgr = RenderingManager(self.editor.view3d)
            except:
                self.render = False

        # toggle interactive mode according to freezePredictions.value
        self.toggleInteractive(
            not self.topLevelOperatorView.FreezePredictions.value)

        def FreezePredDirty():
            self.toggleInteractive(
                not self.topLevelOperatorView.FreezePredictions.value)

        # listen to freezePrediction changes
        self.topLevelOperatorView.FreezePredictions.notifyDirty(
            bind(FreezePredDirty))
        self.__cleanup_fns.append(
            partial(
                self.topLevelOperatorView.FreezePredictions.unregisterDirty,
                bind(FreezePredDirty)))

    def initViewerControlUi(self):
        localDir = os.path.split(__file__)[0]
        self._viewerControlUi = uic.loadUi(
            os.path.join(localDir, "viewerControls.ui"))

        # Connect checkboxes
        def nextCheckState(checkbox):
            checkbox.setChecked(not checkbox.isChecked())

        self._viewerControlUi.checkShowPredictions.nextCheckState = partial(
            nextCheckState, self._viewerControlUi.checkShowPredictions)
        self._viewerControlUi.checkShowSegmentation.nextCheckState = partial(
            nextCheckState, self._viewerControlUi.checkShowSegmentation)

        self._viewerControlUi.checkShowPredictions.clicked.connect(
            self.handleShowPredictionsClicked)
        self._viewerControlUi.checkShowSegmentation.clicked.connect(
            self.handleShowSegmentationClicked)

        # The editor's layerstack is in charge of which layer movement buttons are enabled
        model = self.editor.layerStack
        self._viewerControlUi.viewerControls.setupConnections(model)

    def _initShortcuts(self):
        mgr = ShortcutManager()
        ActionInfo = ShortcutManager.ActionInfo
        shortcutGroupName = "Predictions"

        mgr.register(
            "p",
            ActionInfo(shortcutGroupName, "Toggle Prediction",
                       "Toggle Prediction Layer Visibility",
                       self._viewerControlUi.checkShowPredictions.click,
                       self._viewerControlUi.checkShowPredictions,
                       self._viewerControlUi.checkShowPredictions))

        mgr.register(
            "s",
            ActionInfo(shortcutGroupName, "Toggle Segmentaton",
                       "Toggle Segmentaton Layer Visibility",
                       self._viewerControlUi.checkShowSegmentation.click,
                       self._viewerControlUi.checkShowSegmentation,
                       self._viewerControlUi.checkShowSegmentation))

        mgr.register(
            "l",
            ActionInfo(shortcutGroupName, "Live Prediction",
                       "Toggle Live Prediction Mode",
                       self.labelingDrawerUi.liveUpdateButton.toggle,
                       self.labelingDrawerUi.liveUpdateButton,
                       self.labelingDrawerUi.liveUpdateButton))

    def _setup_contexts(self, layer):
        def callback(pos, clayer=layer):
            name = clayer.name
            if name in self._renderedLayers:
                label = self._renderedLayers.pop(name)
                self._renderMgr.removeObject(label)
                self._update_rendering()
            else:
                label = self._renderMgr.addObject()
                self._renderedLayers[clayer.name] = label
                self._update_rendering()

        if self.render:
            layer.contexts.append(('Toggle 3D rendering', callback))

    def setupLayers(self):
        """
        Called by our base class when one of our data slots has changed.
        This function creates a layer for each slot we want displayed in the volume editor.
        """
        # Base class provides the label layer.
        layers = super(PixelClassificationGui, self).setupLayers()

        ActionInfo = ShortcutManager.ActionInfo

        # Add the uncertainty estimate layer
        uncertaintySlot = self.topLevelOperatorView.UncertaintyEstimate
        if uncertaintySlot.ready():
            uncertaintySrc = LazyflowSource(uncertaintySlot)
            uncertaintyLayer = AlphaModulatedLayer(uncertaintySrc,
                                                   tintColor=QColor(Qt.cyan),
                                                   range=(0.0, 1.0),
                                                   normalize=(0.0, 1.0))
            uncertaintyLayer.name = "Uncertainty"
            uncertaintyLayer.visible = False
            uncertaintyLayer.opacity = 1.0
            uncertaintyLayer.shortcutRegistration = (
                "u",
                ActionInfo("Prediction Layers", "Uncertainty",
                           "Show/Hide Uncertainty",
                           uncertaintyLayer.toggleVisible,
                           self.viewerControlWidget(), uncertaintyLayer))
            layers.append(uncertaintyLayer)

        labels = self.labelListData

        # Add each of the segmentations
        for channel, segmentationSlot in enumerate(
                self.topLevelOperatorView.SegmentationChannels):
            if segmentationSlot.ready() and channel < len(labels):
                ref_label = labels[channel]
                segsrc = LazyflowSource(segmentationSlot)
                segLayer = AlphaModulatedLayer(segsrc,
                                               tintColor=ref_label.pmapColor(),
                                               range=(0.0, 1.0),
                                               normalize=(0.0, 1.0))

                segLayer.opacity = 1
                segLayer.visible = False  #self.labelingDrawerUi.liveUpdateButton.isChecked()
                segLayer.visibleChanged.connect(
                    self.updateShowSegmentationCheckbox)

                def setLayerColor(c, segLayer_=segLayer, initializing=False):
                    if not initializing and segLayer_ not in self.layerstack:
                        # This layer has been removed from the layerstack already.
                        # Don't touch it.
                        return
                    segLayer_.tintColor = c
                    self._update_rendering()

                def setSegLayerName(n, segLayer_=segLayer, initializing=False):
                    if not initializing and segLayer_ not in self.layerstack:
                        # This layer has been removed from the layerstack already.
                        # Don't touch it.
                        return
                    oldname = segLayer_.name
                    newName = "Segmentation (%s)" % n
                    segLayer_.name = newName
                    if not self.render:
                        return
                    if oldname in self._renderedLayers:
                        label = self._renderedLayers.pop(oldname)
                        self._renderedLayers[newName] = label

                setSegLayerName(ref_label.name, initializing=True)

                ref_label.pmapColorChanged.connect(setLayerColor)
                ref_label.nameChanged.connect(setSegLayerName)
                #check if layer is 3d before adding the "Toggle 3D" option
                #this check is done this way to match the VolumeRenderer, in
                #case different 3d-axistags should be rendered like t-x-y
                #_axiskeys = segmentationSlot.meta.getAxisKeys()
                if len(segmentationSlot.meta.shape) == 4:
                    #the Renderer will cut out the last shape-dimension, so
                    #we're checking for 4 dimensions
                    self._setup_contexts(segLayer)
                layers.append(segLayer)

        # Add each of the predictions
        for channel, predictionSlot in enumerate(
                self.topLevelOperatorView.PredictionProbabilityChannels):
            if predictionSlot.ready() and channel < len(labels):
                ref_label = labels[channel]
                predictsrc = LazyflowSource(predictionSlot)
                predictLayer = AlphaModulatedLayer(
                    predictsrc,
                    tintColor=ref_label.pmapColor(),
                    range=(0.0, 1.0),
                    normalize=(0.0, 1.0))
                predictLayer.opacity = 0.25
                predictLayer.visible = self.labelingDrawerUi.liveUpdateButton.isChecked(
                )
                predictLayer.visibleChanged.connect(
                    self.updateShowPredictionCheckbox)

                def setLayerColor(c,
                                  predictLayer_=predictLayer,
                                  initializing=False):
                    if not initializing and predictLayer_ not in self.layerstack:
                        # This layer has been removed from the layerstack already.
                        # Don't touch it.
                        return
                    predictLayer_.tintColor = c

                def setPredLayerName(n,
                                     predictLayer_=predictLayer,
                                     initializing=False):
                    if not initializing and predictLayer_ not in self.layerstack:
                        # This layer has been removed from the layerstack already.
                        # Don't touch it.
                        return
                    newName = "Prediction for %s" % n
                    predictLayer_.name = newName

                setPredLayerName(ref_label.name, initializing=True)
                ref_label.pmapColorChanged.connect(setLayerColor)
                ref_label.nameChanged.connect(setPredLayerName)
                layers.append(predictLayer)

        # Add the raw data last (on the bottom)
        inputDataSlot = self.topLevelOperatorView.InputImages
        if inputDataSlot.ready():
            inputLayer = self.createStandardLayerFromSlot(inputDataSlot)
            inputLayer.name = "Input Data"
            inputLayer.visible = True
            inputLayer.opacity = 1.0
            # the flag window_leveling is used to determine if the contrast
            # of the layer is adjustable
            if isinstance(inputLayer, GrayscaleLayer):
                inputLayer.window_leveling = True
            else:
                inputLayer.window_leveling = False

            def toggleTopToBottom():
                index = self.layerstack.layerIndex(inputLayer)
                self.layerstack.selectRow(index)
                if index == 0:
                    self.layerstack.moveSelectedToBottom()
                else:
                    self.layerstack.moveSelectedToTop()

            inputLayer.shortcutRegistration = ("i",
                                               ActionInfo(
                                                   "Prediction Layers",
                                                   "Bring Input To Top/Bottom",
                                                   "Bring Input To Top/Bottom",
                                                   toggleTopToBottom,
                                                   self.viewerControlWidget(),
                                                   inputLayer))
            layers.append(inputLayer)

            # The thresholding button can only be used if the data is displayed as grayscale.
            if inputLayer.window_leveling:
                self.labelingDrawerUi.thresToolButton.show()
            else:
                self.labelingDrawerUi.thresToolButton.hide()

        self.handleLabelSelectionChange()
        return layers

    def toggleInteractive(self, checked):
        logger.debug("toggling interactive mode to '%r'" % checked)

        if checked == True:
            if not self.topLevelOperatorView.FeatureImages.ready() \
            or self.topLevelOperatorView.FeatureImages.meta.shape==None:
                self.labelingDrawerUi.liveUpdateButton.setChecked(False)
                mexBox = QMessageBox()
                mexBox.setText("There are no features selected ")
                mexBox.exec_()
                return

        # If we're changing modes, enable/disable our controls and other applets accordingly
        if self.interactiveModeActive != checked:
            if checked:
                self.labelingDrawerUi.labelListView.allowDelete = False
                self.labelingDrawerUi.AddLabelButton.setEnabled(False)
            else:
                self.labelingDrawerUi.labelListView.allowDelete = True
                self.labelingDrawerUi.AddLabelButton.setEnabled(True)
        self.interactiveModeActive = checked

        self.topLevelOperatorView.FreezePredictions.setValue(not checked)
        self.labelingDrawerUi.liveUpdateButton.setChecked(checked)
        # Auto-set the "show predictions" state according to what the user just clicked.
        if checked:
            self._viewerControlUi.checkShowPredictions.setChecked(True)
            self.handleShowPredictionsClicked()

        # Notify the workflow that some applets may have changed state now.
        # (For example, the downstream pixel classification applet can
        #  be used now that there are features selected)
        self.parentApplet.appletStateUpdateRequested.emit()

    @pyqtSlot()
    def handleShowPredictionsClicked(self):
        checked = self._viewerControlUi.checkShowPredictions.isChecked()
        for layer in self.layerstack:
            if "Prediction" in layer.name:
                layer.visible = checked

    @pyqtSlot()
    def handleShowSegmentationClicked(self):
        checked = self._viewerControlUi.checkShowSegmentation.isChecked()
        for layer in self.layerstack:
            if "Segmentation" in layer.name:
                layer.visible = checked

    @pyqtSlot()
    def updateShowPredictionCheckbox(self):
        predictLayerCount = 0
        visibleCount = 0
        for layer in self.layerstack:
            if "Prediction" in layer.name:
                predictLayerCount += 1
                if layer.visible:
                    visibleCount += 1

        if visibleCount == 0:
            self._viewerControlUi.checkShowPredictions.setCheckState(
                Qt.Unchecked)
        elif predictLayerCount == visibleCount:
            self._viewerControlUi.checkShowPredictions.setCheckState(
                Qt.Checked)
        else:
            self._viewerControlUi.checkShowPredictions.setCheckState(
                Qt.PartiallyChecked)

    @pyqtSlot()
    def updateShowSegmentationCheckbox(self):
        segLayerCount = 0
        visibleCount = 0
        for layer in self.layerstack:
            if "Segmentation" in layer.name:
                segLayerCount += 1
                if layer.visible:
                    visibleCount += 1

        if visibleCount == 0:
            self._viewerControlUi.checkShowSegmentation.setCheckState(
                Qt.Unchecked)
        elif segLayerCount == visibleCount:
            self._viewerControlUi.checkShowSegmentation.setCheckState(
                Qt.Checked)
        else:
            self._viewerControlUi.checkShowSegmentation.setCheckState(
                Qt.PartiallyChecked)

    @pyqtSlot()
    @threadRouted
    def handleLabelSelectionChange(self):
        enabled = False
        if self.topLevelOperatorView.LabelNames.ready():
            enabled = True
            enabled &= len(self.topLevelOperatorView.LabelNames.value) >= 2
            enabled &= numpy.all(
                numpy.asarray(self.topLevelOperatorView.CachedFeatureImages.
                              meta.shape) > 0)
            # FIXME: also check that each label has scribbles?

        if not enabled:
            self.labelingDrawerUi.liveUpdateButton.setChecked(False)
            self._viewerControlUi.checkShowPredictions.setChecked(False)
            self._viewerControlUi.checkShowSegmentation.setChecked(False)
            self.handleShowPredictionsClicked()
            self.handleShowSegmentationClicked()

        self.labelingDrawerUi.liveUpdateButton.setEnabled(enabled)
        self._viewerControlUi.checkShowPredictions.setEnabled(enabled)
        self._viewerControlUi.checkShowSegmentation.setEnabled(enabled)

    def _getNext(self, slot, parentFun, transform=None):
        numLabels = self.labelListData.rowCount()
        value = slot.value
        if numLabels < len(value):
            result = value[numLabels]
            if transform is not None:
                result = transform(result)
            return result
        else:
            return parentFun()

    def _onLabelChanged(self, parentFun, mapf, slot):
        parentFun()
        new = map(mapf, self.labelListData)
        old = slot.value
        slot.setValue(_listReplace(old, new))

    def _onLabelRemoved(self, parent, start, end):
        # Call the base class to update the operator.
        super(PixelClassificationGui, self)._onLabelRemoved(parent, start, end)

        # Keep colors in sync with names
        # (If we deleted a name, delete its corresponding colors, too.)
        op = self.topLevelOperatorView
        if len(op.PmapColors.value) > len(op.LabelNames.value):
            for slot in (op.LabelColors, op.PmapColors):
                value = slot.value
                value.pop(start)
                # Force dirty propagation even though the list id is unchanged.
                slot.setValue(value, check_changed=False)

    def getNextLabelName(self):
        return self._getNext(
            self.topLevelOperatorView.LabelNames,
            super(PixelClassificationGui, self).getNextLabelName)

    def getNextLabelColor(self):
        return self._getNext(
            self.topLevelOperatorView.LabelColors,
            super(PixelClassificationGui, self).getNextLabelColor,
            lambda x: QColor(*x))

    def getNextPmapColor(self):
        return self._getNext(
            self.topLevelOperatorView.PmapColors,
            super(PixelClassificationGui, self).getNextPmapColor,
            lambda x: QColor(*x))

    def onLabelNameChanged(self):
        self._onLabelChanged(
            super(PixelClassificationGui, self).onLabelNameChanged,
            lambda l: l.name, self.topLevelOperatorView.LabelNames)

    def onLabelColorChanged(self):
        self._onLabelChanged(
            super(PixelClassificationGui, self).onLabelColorChanged, lambda l:
            (l.brushColor().red(), l.brushColor().green(), l.brushColor().blue(
            )), self.topLevelOperatorView.LabelColors)

    def onPmapColorChanged(self):
        self._onLabelChanged(
            super(PixelClassificationGui, self).onPmapColorChanged, lambda l:
            (l.pmapColor().red(), l.pmapColor().green(), l.pmapColor().blue()),
            self.topLevelOperatorView.PmapColors)

    def _update_rendering(self):
        if not self.render:
            return
        shape = self.topLevelOperatorView.InputImages.meta.shape[1:4]
        if len(shape) != 5:
            #this might be a 2D image, no need for updating any 3D stuff
            return

        time = self.editor.posModel.slicingPos5D[0]
        if not self._renderMgr.ready:
            self._renderMgr.setup(shape)

        layernames = set(layer.name for layer in self.layerstack)
        self._renderedLayers = dict(
            (k, v) for k, v in self._renderedLayers.iteritems()
            if k in layernames)

        newvolume = numpy.zeros(shape, dtype=numpy.uint8)
        for layer in self.layerstack:
            try:
                label = self._renderedLayers[layer.name]
            except KeyError:
                continue
            for ds in layer.datasources:
                vol = ds.dataSlot.value[time, ..., 0]
                indices = numpy.where(vol != 0)
                newvolume[indices] = label

        self._renderMgr.volume = newvolume
        self._update_colors()
        self._renderMgr.update()

    def _update_colors(self):
        for layer in self.layerstack:
            try:
                label = self._renderedLayers[layer.name]
            except KeyError:
                continue
            color = layer.tintColor
            color = (color.red() / 255.0, color.green() / 255.0,
                     color.blue() / 255.0)
            self._renderMgr.setColor(label, color)
Esempio n. 8
0
class PixelClassificationGui(LabelingGui):

    ###########################################
    ### AppletGuiInterface Concrete Methods ###
    ###########################################
    def centralWidget( self ):
        return self

    def stopAndCleanUp(self):
        for fn in self.__cleanup_fns:
            fn()

        # Base class
        super(PixelClassificationGui, self).stopAndCleanUp()

    def viewerControlWidget(self):
        return self._viewerControlUi

    def menus( self ):
        menus = super( PixelClassificationGui, self ).menus()

        # For now classifier selection is only available in debug mode
        if ilastik_config.getboolean('ilastik', 'debug'):
            def handleClassifierAction():
                dlg = ClassifierSelectionDlg(self.topLevelOperatorView, parent=self)
                dlg.exec_()
            
            advanced_menu = QMenu("Advanced", parent=self)
            classifier_action = advanced_menu.addAction("Classifier...")
            classifier_action.triggered.connect( handleClassifierAction )
            menus += [advanced_menu]

        return menus

    ###########################################
    ###########################################

    def __init__(self, parentApplet, topLevelOperatorView ):
        self.parentApplet = parentApplet
        # Tell our base class which slots to monitor
        labelSlots = LabelingGui.LabelingSlots()
        labelSlots.labelInput = topLevelOperatorView.LabelInputs
        labelSlots.labelOutput = topLevelOperatorView.LabelImages
        labelSlots.labelEraserValue = topLevelOperatorView.opLabelPipeline.opLabelArray.eraser
        labelSlots.labelDelete = topLevelOperatorView.opLabelPipeline.DeleteLabel
        labelSlots.labelNames = topLevelOperatorView.LabelNames
        labelSlots.labelsAllowed = topLevelOperatorView.LabelsAllowedFlags

        self.__cleanup_fns = []

        # We provide our own UI file (which adds an extra control for interactive mode)
        labelingDrawerUiPath = os.path.split(__file__)[0] + '/labelingDrawer.ui'

        # Base class init
        super(PixelClassificationGui, self).__init__( parentApplet, labelSlots, topLevelOperatorView, labelingDrawerUiPath )
        
        self.topLevelOperatorView = topLevelOperatorView

        self.interactiveModeActive = False
        # Immediately update our interactive state
        self.toggleInteractive( not self.topLevelOperatorView.FreezePredictions.value )

        self._currentlySavingPredictions = False

        self.labelingDrawerUi.liveUpdateButton.setEnabled(False)
        self.labelingDrawerUi.liveUpdateButton.setIcon( QIcon(ilastikIcons.Play) )
        self.labelingDrawerUi.liveUpdateButton.setToolButtonStyle(Qt.ToolButtonTextBesideIcon)
        self.labelingDrawerUi.liveUpdateButton.toggled.connect( self.toggleInteractive )

        self.topLevelOperatorView.LabelNames.notifyDirty( bind(self.handleLabelSelectionChange) )
        self.__cleanup_fns.append( partial( self.topLevelOperatorView.LabelNames.unregisterDirty, bind(self.handleLabelSelectionChange) ) )
        
        self._initShortcuts()

        # FIXME: We MUST NOT enable the render manager by default,
        #        since it will drastically slow down the app for large volumes.
        #        For now, we leave it off by default.
        #        To re-enable rendering, we need to allow the user to render a segmentation 
        #        and then initialize the render manager on-the-fly. 
        #        (We might want to warn the user if her volume is not small.)
        self.render = False
        self._renderMgr = None
        self._renderedLayers = {} # (layer name, label number)
        
        # Always off for now (see note above)
        if self.render:
            try:
                self._renderMgr = RenderingManager( self.editor.view3d )
            except:
                self.render = False

        # toggle interactive mode according to freezePredictions.value
        self.toggleInteractive(not self.topLevelOperatorView.FreezePredictions.value)
        def FreezePredDirty():
            self.toggleInteractive(not self.topLevelOperatorView.FreezePredictions.value)
        # listen to freezePrediction changes
        self.topLevelOperatorView.FreezePredictions.notifyDirty( bind(FreezePredDirty) )
        self.__cleanup_fns.append( partial( self.topLevelOperatorView.FreezePredictions.unregisterDirty, bind(FreezePredDirty) ) )

    def initViewerControlUi(self):
        localDir = os.path.split(__file__)[0]
        self._viewerControlUi = uic.loadUi( os.path.join( localDir, "viewerControls.ui" ) )

        # Connect checkboxes
        def nextCheckState(checkbox):
            checkbox.setChecked( not checkbox.isChecked() )
        self._viewerControlUi.checkShowPredictions.nextCheckState = partial(nextCheckState, self._viewerControlUi.checkShowPredictions)
        self._viewerControlUi.checkShowSegmentation.nextCheckState = partial(nextCheckState, self._viewerControlUi.checkShowSegmentation)

        self._viewerControlUi.checkShowPredictions.clicked.connect( self.handleShowPredictionsClicked )
        self._viewerControlUi.checkShowSegmentation.clicked.connect( self.handleShowSegmentationClicked )

        # The editor's layerstack is in charge of which layer movement buttons are enabled
        model = self.editor.layerStack
        self._viewerControlUi.viewerControls.setupConnections(model)
       
    def _initShortcuts(self):
        mgr = ShortcutManager()
        ActionInfo = ShortcutManager.ActionInfo
        shortcutGroupName = "Predictions"

        mgr.register( "p", ActionInfo( shortcutGroupName,
                                       "Toggle Prediction",
                                       "Toggle Prediction Layer Visibility",
                                       self._viewerControlUi.checkShowPredictions.click,
                                       self._viewerControlUi.checkShowPredictions,
                                       self._viewerControlUi.checkShowPredictions ) )

        mgr.register( "s", ActionInfo( shortcutGroupName,
                                       "Toggle Segmentaton",
                                       "Toggle Segmentaton Layer Visibility",
                                       self._viewerControlUi.checkShowSegmentation.click,
                                       self._viewerControlUi.checkShowSegmentation,
                                       self._viewerControlUi.checkShowSegmentation ) )

        mgr.register( "l", ActionInfo( shortcutGroupName,
                                       "Live Prediction",
                                       "Toggle Live Prediction Mode",
                                       self.labelingDrawerUi.liveUpdateButton.toggle,
                                       self.labelingDrawerUi.liveUpdateButton,
                                       self.labelingDrawerUi.liveUpdateButton ) )

    def _setup_contexts(self, layer):
        def callback(pos, clayer=layer):
            name = clayer.name
            if name in self._renderedLayers:
                label = self._renderedLayers.pop(name)
                self._renderMgr.removeObject(label)
                self._update_rendering()
            else:
                label = self._renderMgr.addObject()
                self._renderedLayers[clayer.name] = label
                self._update_rendering()

        if self.render:
            layer.contexts.append(('Toggle 3D rendering', callback))

    def setupLayers(self):
        """
        Called by our base class when one of our data slots has changed.
        This function creates a layer for each slot we want displayed in the volume editor.
        """
        # Base class provides the label layer.
        layers = super(PixelClassificationGui, self).setupLayers()

        ActionInfo = ShortcutManager.ActionInfo

        # Add the uncertainty estimate layer
        uncertaintySlot = self.topLevelOperatorView.UncertaintyEstimate
        if uncertaintySlot.ready():
            uncertaintySrc = LazyflowSource(uncertaintySlot)
            uncertaintyLayer = AlphaModulatedLayer( uncertaintySrc,
                                                    tintColor=QColor( Qt.cyan ),
                                                    range=(0.0, 1.0),
                                                    normalize=(0.0, 1.0) )
            uncertaintyLayer.name = "Uncertainty"
            uncertaintyLayer.visible = False
            uncertaintyLayer.opacity = 1.0
            uncertaintyLayer.shortcutRegistration = ( "u", ActionInfo( "Prediction Layers",
                                                                       "Uncertainty",
                                                                       "Show/Hide Uncertainty",
                                                                       uncertaintyLayer.toggleVisible,
                                                                       self.viewerControlWidget(),
                                                                       uncertaintyLayer ) )
            layers.append(uncertaintyLayer)

        labels = self.labelListData

        # Add each of the segmentations
        for channel, segmentationSlot in enumerate(self.topLevelOperatorView.SegmentationChannels):
            if segmentationSlot.ready() and channel < len(labels):
                ref_label = labels[channel]
                segsrc = LazyflowSource(segmentationSlot)
                segLayer = AlphaModulatedLayer( segsrc,
                                                tintColor=ref_label.pmapColor(),
                                                range=(0.0, 1.0),
                                                normalize=(0.0, 1.0) )

                segLayer.opacity = 1
                segLayer.visible = False #self.labelingDrawerUi.liveUpdateButton.isChecked()
                segLayer.visibleChanged.connect(self.updateShowSegmentationCheckbox)

                def setLayerColor(c, segLayer_=segLayer, initializing=False):
                    if not initializing and segLayer_ not in self.layerstack:
                        # This layer has been removed from the layerstack already.
                        # Don't touch it.
                        return
                    segLayer.tintColor = c
                    self._update_rendering()

                def setSegLayerName(n, segLayer_=segLayer, initializing=False):
                    if not initializing and segLayer_ not in self.layerstack:
                        # This layer has been removed from the layerstack already.
                        # Don't touch it.
                        return
                    oldname = segLayer.name
                    newName = "Segmentation (%s)" % n
                    segLayer.name = newName
                    if not self.render:
                        return
                    if oldname in self._renderedLayers:
                        label = self._renderedLayers.pop(oldname)
                        self._renderedLayers[newName] = label

                setSegLayerName(ref_label.name, initializing=True)

                ref_label.pmapColorChanged.connect(setLayerColor)
                ref_label.nameChanged.connect(setSegLayerName)
                #check if layer is 3d before adding the "Toggle 3D" option
                #this check is done this way to match the VolumeRenderer, in
                #case different 3d-axistags should be rendered like t-x-y
                #_axiskeys = segmentationSlot.meta.getAxisKeys()
                if len(segmentationSlot.meta.shape) == 4:
                    #the Renderer will cut out the last shape-dimension, so
                    #we're checking for 4 dimensions
                    self._setup_contexts(segLayer)
                layers.append(segLayer)
        
        # Add each of the predictions
        for channel, predictionSlot in enumerate(self.topLevelOperatorView.PredictionProbabilityChannels):
            if predictionSlot.ready() and channel < len(labels):
                ref_label = labels[channel]
                predictsrc = LazyflowSource(predictionSlot)
                predictLayer = AlphaModulatedLayer( predictsrc,
                                                    tintColor=ref_label.pmapColor(),
                                                    range=(0.0, 1.0),
                                                    normalize=(0.0, 1.0) )
                predictLayer.opacity = 0.25
                predictLayer.visible = self.labelingDrawerUi.liveUpdateButton.isChecked()
                predictLayer.visibleChanged.connect(self.updateShowPredictionCheckbox)

                def setLayerColor(c, predictLayer_=predictLayer, initializing=False):
                    if not initializing and predictLayer_ not in self.layerstack:
                        # This layer has been removed from the layerstack already.
                        # Don't touch it.
                        return
                    predictLayer_.tintColor = c

                def setPredLayerName(n, predictLayer_=predictLayer, initializing=False):
                    if not initializing and predictLayer_ not in self.layerstack:
                        # This layer has been removed from the layerstack already.
                        # Don't touch it.
                        return
                    newName = "Prediction for %s" % n
                    predictLayer_.name = newName

                setPredLayerName(ref_label.name, initializing=True)
                ref_label.pmapColorChanged.connect(setLayerColor)
                ref_label.nameChanged.connect(setPredLayerName)
                layers.append(predictLayer)

        # Add the raw data last (on the bottom)
        inputDataSlot = self.topLevelOperatorView.InputImages
        if inputDataSlot.ready():
            inputLayer = self.createStandardLayerFromSlot( inputDataSlot )
            inputLayer.name = "Input Data"
            inputLayer.visible = True
            inputLayer.opacity = 1.0
            # the flag window_leveling is used to determine if the contrast 
            # of the layer is adjustable
            if isinstance( inputLayer, GrayscaleLayer ):
                inputLayer.window_leveling = True
            else:
                inputLayer.window_leveling = False

            def toggleTopToBottom():
                index = self.layerstack.layerIndex( inputLayer )
                self.layerstack.selectRow( index )
                if index == 0:
                    self.layerstack.moveSelectedToBottom()
                else:
                    self.layerstack.moveSelectedToTop()

            inputLayer.shortcutRegistration = ( "i", ActionInfo( "Prediction Layers",
                                                                 "Bring Input To Top/Bottom",
                                                                 "Bring Input To Top/Bottom",
                                                                 toggleTopToBottom,
                                                                 self.viewerControlWidget(),
                                                                 inputLayer ) )
            layers.append(inputLayer)
            
            # The thresholding button can only be used if the data is displayed as grayscale.
            if inputLayer.window_leveling:
                self.labelingDrawerUi.thresToolButton.show()
            else:
                self.labelingDrawerUi.thresToolButton.hide()
        
        self.handleLabelSelectionChange()
        return layers

    def toggleInteractive(self, checked):
        logger.debug("toggling interactive mode to '%r'" % checked)

        if checked==True:
            if not self.topLevelOperatorView.FeatureImages.ready() \
            or self.topLevelOperatorView.FeatureImages.meta.shape==None:
                self.labelingDrawerUi.liveUpdateButton.setChecked(False)
                mexBox=QMessageBox()
                mexBox.setText("There are no features selected ")
                mexBox.exec_()
                return

        # If we're changing modes, enable/disable our controls and other applets accordingly
        if self.interactiveModeActive != checked:
            if checked:
                self.labelingDrawerUi.labelListView.allowDelete = False
                self.labelingDrawerUi.AddLabelButton.setEnabled( False )
            else:
                self.labelingDrawerUi.labelListView.allowDelete = True
                self.labelingDrawerUi.AddLabelButton.setEnabled( True )
        self.interactiveModeActive = checked

        self.topLevelOperatorView.FreezePredictions.setValue( not checked )
        self.labelingDrawerUi.liveUpdateButton.setChecked(checked)
        # Auto-set the "show predictions" state according to what the user just clicked.
        if checked:
            self._viewerControlUi.checkShowPredictions.setChecked( True )
            self.handleShowPredictionsClicked()

        # Notify the workflow that some applets may have changed state now.
        # (For example, the downstream pixel classification applet can 
        #  be used now that there are features selected)
        self.parentApplet.appletStateUpdateRequested.emit()

    @pyqtSlot()
    def handleShowPredictionsClicked(self):
        checked = self._viewerControlUi.checkShowPredictions.isChecked()
        for layer in self.layerstack:
            if "Prediction" in layer.name:
                layer.visible = checked

    @pyqtSlot()
    def handleShowSegmentationClicked(self):
        checked = self._viewerControlUi.checkShowSegmentation.isChecked()
        for layer in self.layerstack:
            if "Segmentation" in layer.name:
                layer.visible = checked

    @pyqtSlot()
    def updateShowPredictionCheckbox(self):
        predictLayerCount = 0
        visibleCount = 0
        for layer in self.layerstack:
            if "Prediction" in layer.name:
                predictLayerCount += 1
                if layer.visible:
                    visibleCount += 1

        if visibleCount == 0:
            self._viewerControlUi.checkShowPredictions.setCheckState(Qt.Unchecked)
        elif predictLayerCount == visibleCount:
            self._viewerControlUi.checkShowPredictions.setCheckState(Qt.Checked)
        else:
            self._viewerControlUi.checkShowPredictions.setCheckState(Qt.PartiallyChecked)

    @pyqtSlot()
    def updateShowSegmentationCheckbox(self):
        segLayerCount = 0
        visibleCount = 0
        for layer in self.layerstack:
            if "Segmentation" in layer.name:
                segLayerCount += 1
                if layer.visible:
                    visibleCount += 1

        if visibleCount == 0:
            self._viewerControlUi.checkShowSegmentation.setCheckState(Qt.Unchecked)
        elif segLayerCount == visibleCount:
            self._viewerControlUi.checkShowSegmentation.setCheckState(Qt.Checked)
        else:
            self._viewerControlUi.checkShowSegmentation.setCheckState(Qt.PartiallyChecked)

    @pyqtSlot()
    @threadRouted
    def handleLabelSelectionChange(self):
        enabled = False
        if self.topLevelOperatorView.LabelNames.ready():
            enabled = True
            enabled &= len(self.topLevelOperatorView.LabelNames.value) >= 2
            enabled &= numpy.all(numpy.asarray(self.topLevelOperatorView.CachedFeatureImages.meta.shape) > 0)
            # FIXME: also check that each label has scribbles?
        
        if not enabled:
            self.labelingDrawerUi.liveUpdateButton.setChecked(False)
            self._viewerControlUi.checkShowPredictions.setChecked(False)
            self._viewerControlUi.checkShowSegmentation.setChecked(False)
            self.handleShowPredictionsClicked()
            self.handleShowSegmentationClicked()

        self.labelingDrawerUi.liveUpdateButton.setEnabled(enabled)
        self._viewerControlUi.checkShowPredictions.setEnabled(enabled)
        self._viewerControlUi.checkShowSegmentation.setEnabled(enabled)

    def _getNext(self, slot, parentFun, transform=None):
        numLabels = self.labelListData.rowCount()
        value = slot.value
        if numLabels < len(value):
            result = value[numLabels]
            if transform is not None:
                result = transform(result)
            return result
        else:
            return parentFun()

    def _onLabelChanged(self, parentFun, mapf, slot):
        parentFun()
        new = map(mapf, self.labelListData)
        old = slot.value
        slot.setValue(_listReplace(old, new))

    def _onLabelRemoved(self, parent, start, end):
        # Call the base class to update the operator.
        super(PixelClassificationGui, self)._onLabelRemoved(parent, start, end)

        # Keep colors in sync with names
        # (If we deleted a name, delete its corresponding colors, too.)
        op = self.topLevelOperatorView
        if len(op.PmapColors.value) > len(op.LabelNames.value):
            for slot in (op.LabelColors, op.PmapColors):
                value = slot.value
                value.pop(start)
                # Force dirty propagation even though the list id is unchanged.
                slot.setValue(value, check_changed=False)

    def getNextLabelName(self):
        return self._getNext(self.topLevelOperatorView.LabelNames,
                             super(PixelClassificationGui, self).getNextLabelName)

    def getNextLabelColor(self):
        return self._getNext(
            self.topLevelOperatorView.LabelColors,
            super(PixelClassificationGui, self).getNextLabelColor,
            lambda x: QColor(*x)
        )

    def getNextPmapColor(self):
        return self._getNext(
            self.topLevelOperatorView.PmapColors,
            super(PixelClassificationGui, self).getNextPmapColor,
            lambda x: QColor(*x)
        )

    def onLabelNameChanged(self):
        self._onLabelChanged(super(PixelClassificationGui, self).onLabelNameChanged,
                             lambda l: l.name,
                             self.topLevelOperatorView.LabelNames)

    def onLabelColorChanged(self):
        self._onLabelChanged(super(PixelClassificationGui, self).onLabelColorChanged,
                             lambda l: (l.brushColor().red(),
                                        l.brushColor().green(),
                                        l.brushColor().blue()),
                             self.topLevelOperatorView.LabelColors)


    def onPmapColorChanged(self):
        self._onLabelChanged(super(PixelClassificationGui, self).onPmapColorChanged,
                             lambda l: (l.pmapColor().red(),
                                        l.pmapColor().green(),
                                        l.pmapColor().blue()),
                             self.topLevelOperatorView.PmapColors)

    def _update_rendering(self):
        if not self.render:
            return
        shape = self.topLevelOperatorView.InputImages.meta.shape[1:4]
        if len(shape) != 5:
            #this might be a 2D image, no need for updating any 3D stuff 
            return
        
        time = self.editor.posModel.slicingPos5D[0]
        if not self._renderMgr.ready:
            self._renderMgr.setup(shape)

        layernames = set(layer.name for layer in self.layerstack)
        self._renderedLayers = dict((k, v) for k, v in self._renderedLayers.iteritems()
                                if k in layernames)

        newvolume = numpy.zeros(shape, dtype=numpy.uint8)
        for layer in self.layerstack:
            try:
                label = self._renderedLayers[layer.name]
            except KeyError:
                continue
            for ds in layer.datasources:
                vol = ds.dataSlot.value[time, ..., 0]
                indices = numpy.where(vol != 0)
                newvolume[indices] = label

        self._renderMgr.volume = newvolume
        self._update_colors()
        self._renderMgr.update()

    def _update_colors(self):
        for layer in self.layerstack:
            try:
                label = self._renderedLayers[layer.name]
            except KeyError:
                continue
            color = layer.tintColor
            color = (color.red() / 255.0, color.green() / 255.0, color.blue() / 255.0)
            self._renderMgr.setColor(label, color)
Esempio n. 9
0
class CarvingGui(LabelingGui):
    def __init__(self, parentApplet, topLevelOperatorView, drawerUiPath=None ):
        self.topLevelOperatorView = topLevelOperatorView

        #members
        self._doneSegmentationLayer = None
        self._showSegmentationIn3D = False
        self._showUncertaintyLayer = False
        #end: members

        labelingSlots = LabelingGui.LabelingSlots()
        labelingSlots.labelInput       = topLevelOperatorView.WriteSeeds
        labelingSlots.labelOutput      = topLevelOperatorView.opLabelArray.Output
        labelingSlots.labelEraserValue = topLevelOperatorView.opLabelArray.EraserLabelValue
        labelingSlots.labelNames       = topLevelOperatorView.LabelNames
        labelingSlots.labelDelete      = topLevelOperatorView.opLabelArray.DeleteLabel
        labelingSlots.maxLabelValue    = topLevelOperatorView.opLabelArray.MaxLabelValue
        labelingSlots.labelsAllowed    = topLevelOperatorView.LabelsAllowed
        
        # We provide our own UI file (which adds an extra control for interactive mode)
        directory = os.path.split(__file__)[0]
        if drawerUiPath is None:
            drawerUiPath = os.path.join(directory, 'carvingDrawer.ui')
        self.dialogdirCOM = os.path.join(directory, 'carvingObjectManagement.ui')
        self.dialogdirSAD = os.path.join(directory, 'saveAsDialog.ui')

        rawInputSlot = topLevelOperatorView.RawData        
        super(CarvingGui, self).__init__(parentApplet, labelingSlots, topLevelOperatorView, drawerUiPath, rawInputSlot)
        
        self.labelingDrawerUi.currentObjectLabel.setText("<not saved yet>")

        # Init special base class members
        self.minLabelNumber = 2
        self.maxLabelNumber = 2
        
        mgr = ShortcutManager()
        ActionInfo = ShortcutManager.ActionInfo
        
        #set up keyboard shortcuts
        mgr.register( "3", ActionInfo( "Carving", 
                                       "Run interactive segmentation", 
                                       "Run interactive segmentation", 
                                       self.labelingDrawerUi.segment.click,
                                       self.labelingDrawerUi.segment,
                                       self.labelingDrawerUi.segment  ) )
        
        try:
            self.render = True
            self._shownObjects3D = {}
            self._renderMgr = RenderingManager(
                renderer=self.editor.view3d.qvtk.renderer,
                qvtk=self.editor.view3d.qvtk)
        except:
            self.render = False

        # Segmentation is toggled on by default in _after_init, below.
        # (We can't enable it until the layers are all present.)
        self._showSegmentationIn3D = False
        self._segmentation_3d_label = None
                
        self.labelingDrawerUi.segment.clicked.connect(self.onSegmentButton)
        self.labelingDrawerUi.segment.setEnabled(True)

        self.topLevelOperatorView.Segmentation.notifyDirty( bind( self._update_rendering ) )
        self.topLevelOperatorView.HasSegmentation.notifyValueChanged( bind( self._updateGui ) )

        ## uncertainty

        self.labelingDrawerUi.pushButtonUncertaintyFG.setEnabled(False)
        self.labelingDrawerUi.pushButtonUncertaintyBG.setEnabled(False)

        def onUncertaintyFGButton():
            logger.debug( "uncertFG button clicked" )
            pos = self.topLevelOperatorView.getMaxUncertaintyPos(label=2)
            self.editor.posModel.slicingPos = (pos[0], pos[1], pos[2])
        self.labelingDrawerUi.pushButtonUncertaintyFG.clicked.connect(onUncertaintyFGButton)

        def onUncertaintyBGButton():
            logger.debug( "uncertBG button clicked" )
            pos = self.topLevelOperatorView.getMaxUncertaintyPos(label=1)
            self.editor.posModel.slicingPos = (pos[0], pos[1], pos[2])
        self.labelingDrawerUi.pushButtonUncertaintyBG.clicked.connect(onUncertaintyBGButton)

        def onUncertaintyCombo(value):
            if value == 0:
                value = "none"
                self.labelingDrawerUi.pushButtonUncertaintyFG.setEnabled(False)
                self.labelingDrawerUi.pushButtonUncertaintyBG.setEnabled(False)
                self._showUncertaintyLayer = False
            else:
                if value == 1:
                    value = "localMargin"
                elif value == 2:
                    value = "exchangeCount"
                elif value == 3:
                    value = "gabow"
                else:
                    raise RuntimeError("unhandled case '%r'" % value)
                self.labelingDrawerUi.pushButtonUncertaintyFG.setEnabled(True)
                self.labelingDrawerUi.pushButtonUncertaintyBG.setEnabled(True)
                self._showUncertaintyLayer = True
                logger.debug( "uncertainty changed to %r" % value )
            self.topLevelOperatorView.UncertaintyType.setValue(value)
            self.updateAllLayers() #make sure that an added/deleted uncertainty layer is recognized
        self.labelingDrawerUi.uncertaintyCombo.currentIndexChanged.connect(onUncertaintyCombo)

        ## background priority
        
        def onBackgroundPrioritySpin(value):
            logger.debug( "background priority changed to %f" % value )
            self.topLevelOperatorView.BackgroundPriority.setValue(value)
        self.labelingDrawerUi.backgroundPrioritySpin.valueChanged.connect(onBackgroundPrioritySpin)

        def onBackgroundPriorityDirty(slot, roi):
            oldValue = self.labelingDrawerUi.backgroundPrioritySpin.value()
            newValue = self.topLevelOperatorView.BackgroundPriority.value
            if  newValue != oldValue:
                self.labelingDrawerUi.backgroundPrioritySpin.setValue(newValue)
        self.topLevelOperatorView.BackgroundPriority.notifyDirty(onBackgroundPriorityDirty)
        
        ## bias
        
        def onNoBiasBelowDirty(slot, roi):
            oldValue = self.labelingDrawerUi.noBiasBelowSpin.value()
            newValue = self.topLevelOperatorView.NoBiasBelow.value
            if  newValue != oldValue:
                self.labelingDrawerUi.noBiasBelowSpin.setValue(newValue)
        self.topLevelOperatorView.NoBiasBelow.notifyDirty(onNoBiasBelowDirty)
        
        def onNoBiasBelowSpin(value):
            logger.debug( "background priority changed to %f" % value )
            self.topLevelOperatorView.NoBiasBelow.setValue(value)
        self.labelingDrawerUi.noBiasBelowSpin.valueChanged.connect(onNoBiasBelowSpin)
        
        ## save

        self.labelingDrawerUi.save.clicked.connect(self.onSaveButton)

        ## clear

        self.labelingDrawerUi.clear.clicked.connect(self.topLevelOperatorView.clearCurrentLabeling)
        
        ## object names
        
        self.labelingDrawerUi.namesButton.clicked.connect(self.onShowObjectNames)
        
        def labelBackground():
            self.selectLabel(0)
        def labelObject():
            self.selectLabel(1)

        self._labelControlUi.labelListModel.allowRemove(False)

        bgToolTipObject = LabelListModel.EntryToolTipAdapter(self._labelControlUi.labelListModel, 0)
        mgr.register( "1", ActionInfo( "Carving", 
                                       "Select background label", 
                                       "Select background label", 
                                       labelBackground,
                                       self.viewerControlWidget(),
                                       bgToolTipObject ) )

        fgToolTipObject = LabelListModel.EntryToolTipAdapter(self._labelControlUi.labelListModel, 1)
        mgr.register( "2", ActionInfo( "Carving", 
                                       "Select object label", 
                                       "Select object label", 
                                       labelObject,
                                       self.viewerControlWidget(),
                                       fgToolTipObject ) )

        def layerIndexForName(name):
            return self.layerstack.findMatchingIndex(lambda x: x.name == name)
        
        def addLayerToggleShortcut(layername, shortcut):
            def toggle():
                row = layerIndexForName(layername)
                self.layerstack.selectRow(row)
                layer = self.layerstack[row]
                layer.visible = not layer.visible
                self.viewerControlWidget().layerWidget.setFocus()

            mgr.register(shortcut, ActionInfo( "Carving", 
                                               "Toggle layer %s" % layername, 
                                               "Toggle layer %s" % layername, 
                                               toggle,
                                               self.viewerControlWidget(),
                                               None ) )

        #TODO
        addLayerToggleShortcut("Completed segments (unicolor)", "d")
        addLayerToggleShortcut("Segmentation", "s")
        addLayerToggleShortcut("Input Data", "r")

        '''
        def updateLayerTimings():
            s = "Layer timings:\n"
            for l in self.layerstack:
                s += "%s: %f sec.\n" % (l.name, l.averageTimePerTile)
            self.labelingDrawerUi.layerTimings.setText(s)
        t = QTimer(self)
        t.setInterval(1*1000) # 10 seconds
        t.start()
        t.timeout.connect(updateLayerTimings)
        '''

        def makeColortable():
            self._doneSegmentationColortable = [QColor(0,0,0,0).rgba()]
            for i in range(254):
                r,g,b = numpy.random.randint(0,255), numpy.random.randint(0,255), numpy.random.randint(0,255)
                # ensure colors have sufficient distance to pure red and pure green
                while (255 - r)+g+b<128 or r+(255-g)+b<128:
                    r,g,b = numpy.random.randint(0,255), numpy.random.randint(0,255), numpy.random.randint(0,255)
                self._doneSegmentationColortable.append(QColor(r,g,b).rgba())
            self._doneSegmentationColortable.append(QColor(0,255,0).rgba())
        makeColortable()
        def onRandomizeColors():
            if self._doneSegmentationLayer is not None:
                logger.debug( "randomizing colors ..." )
                makeColortable()
                self._doneSegmentationLayer.colorTable = self._doneSegmentationColortable
                if self.render and self._renderMgr.ready:
                    self._update_rendering()
        #self.labelingDrawerUi.randomizeColors.clicked.connect(onRandomizeColors)
        self._updateGui()
    
    def _after_init(self):
        super(CarvingGui, self)._after_init()
        if self.render:self._toggleSegmentation3D()
        
        
    def _updateGui(self):
        self.labelingDrawerUi.save.setEnabled( self.topLevelOperatorView.dataIsStorable() )
        
    def onSegmentButton(self):
        logger.debug( "segment button clicked" )
        self.topLevelOperatorView.Trigger.setDirty(slice(None))
    
    def saveAsDialog(self, name=""):
        '''special functionality: reject names given to other objects'''
        dialog = uic.loadUi(self.dialogdirSAD)
        dialog.lineEdit.setText(name)
        dialog.warning.setVisible(False)
        dialog.Ok.clicked.connect(dialog.accept)
        dialog.Cancel.clicked.connect(dialog.reject)
        listOfItems = self.topLevelOperatorView.AllObjectNames[:].wait()
        dialog.isDisabled = False
        def validate():
            name = dialog.lineEdit.text()
            if name in listOfItems:
                dialog.Ok.setEnabled(False)
                dialog.warning.setVisible(True)
                dialog.isDisabled = True
            elif dialog.isDisabled:
                dialog.Ok.setEnabled(True)
                dialog.warning.setVisible(False)
                dialog.isDisabled = False
        dialog.lineEdit.textChanged.connect(validate)
        result = dialog.exec_()
        if result:
            return str(dialog.lineEdit.text())
    
    def onSaveButton(self):
        logger.info( "save object as?" )
        if self.topLevelOperatorView.dataIsStorable():
            prevName = ""
            if self.topLevelOperatorView.hasCurrentObject():
                prevName = self.topLevelOperatorView.currentObjectName()
            if prevName == "<not saved yet>":
                prevName = ""
            name = self.saveAsDialog(name=prevName)
            if name is None:
                return
            objects = self.topLevelOperatorView.AllObjectNames[:].wait()
            if name in objects and name != prevName:
                QMessageBox.critical(self, "Save Object As", "An object with name '%s' already exists.\nPlease choose a different name." % name)
                return
            self.topLevelOperatorView.saveObjectAs(name)
            logger.info( "save object as %s" % name )
            if prevName != name and prevName != "":
                self.topLevelOperatorView.deleteObject(prevName)
        else:
            msgBox = QMessageBox(self)
            msgBox.setText("The data does not seem fit to be stored.")
            msgBox.setWindowTitle("Problem with Data")
            msgBox.setIcon(2)
            msgBox.exec_()
            logger.error( "object not saved due to faulty data." )
    
    def onShowObjectNames(self):
        '''show object names and allow user to load/delete them'''
        dialog = uic.loadUi(self.dialogdirCOM)
        listOfItems = self.topLevelOperatorView.AllObjectNames[:].wait()
        dialog.objectNames.addItems(sorted(listOfItems))
        
        def loadSelection():
            selected = [str(name.text()) for name in dialog.objectNames.selectedItems()]
            dialog.close()
            for objectname in selected: 
                objectname = str(name.text())
                self.topLevelOperatorView.loadObject(objectname)
        
        def deleteSelection():
            items = dialog.objectNames.selectedItems()
            if self.confirmAndDelete([str(name.text()) for name in items]):
                for name in items:
                    name.setHidden(True)
            dialog.close()
        
        dialog.loadButton.clicked.connect(loadSelection)
        dialog.deleteButton.clicked.connect(deleteSelection)
        dialog.cancelButton.clicked.connect(dialog.close)
        dialog.exec_()
    
    def confirmAndDelete(self,namelist):
        logger.info( "confirmAndDelete: {}".format( namelist ) )
        objectlist = "".join("\n  "+str(i) for i in namelist)
        confirmed = QMessageBox.question(self, "Delete Object", \
                    "Do you want to delete these objects?"+objectlist, \
                    QMessageBox.Yes | QMessageBox.Cancel, \
                    defaultButton=QMessageBox.Yes)
            
        if confirmed == QMessageBox.Yes:
            for name in namelist:
                self.topLevelOperatorView.deleteObject(name)
            return True
        return False
    
    def labelingContextMenu(self,names,op,position5d):
        menu = QMenu(self)
        menu.setObjectName("carving_context_menu")
        posItem = menu.addAction("position %d %d %d" % (position5d[1], position5d[2], position5d[3]))
        posItem.setEnabled(False)
        menu.addSeparator()
        for name in names:
            submenu = QMenu(name,menu)
            
            # Load
            loadAction = submenu.addAction("Load %s" % name)
            loadAction.triggered.connect( partial(op.loadObject, name) )
            
            # Delete
            def onDelAction(_name):
                self.confirmAndDelete([_name])
                if self.render and self._renderMgr.ready:
                    self._update_rendering()
            delAction = submenu.addAction("Delete %s" % name)
            delAction.triggered.connect( partial(onDelAction, name) )

            if self.render:
                if name in self._shownObjects3D:
                    # Remove
                    def onRemove3D(_name):
                        label = self._shownObjects3D.pop(_name)
                        self._renderMgr.removeObject(label)
                        self._update_rendering()
                    removeAction = submenu.addAction("Remove %s from 3D view" % name)
                    removeAction.triggered.connect( partial(onRemove3D, name) )
                else:
                    # Show
                    def onShow3D(_name):
                        label = self._renderMgr.addObject()
                        self._shownObjects3D[_name] = label
                        self._update_rendering()
                    showAction = submenu.addAction("Show 3D %s" % name)
                    showAction.triggered.connect( partial(onShow3D, name ) )
                        
            menu.addMenu(submenu)

        if names:
            menu.addSeparator()

        menu.addSeparator()
        if self.render:
            showSeg3DAction = menu.addAction( "Show Editing Segmentation in 3D" )
            showSeg3DAction.setCheckable(True)
            showSeg3DAction.setChecked( self._showSegmentationIn3D )
            showSeg3DAction.triggered.connect( self._toggleSegmentation3D )
        
        if op.dataIsStorable():
            menu.addAction("Save object").triggered.connect( self.onSaveButton )
        menu.addAction("Browse objects").triggered.connect( self.onShowObjectNames )
        menu.addAction("Segment").triggered.connect( self.onSegmentButton )
        menu.addAction("Clear").triggered.connect( self.topLevelOperatorView.clearCurrentLabeling )
        return menu
    
    def handleEditorRightClick(self, position5d, globalWindowCoordinate):
        names = self.topLevelOperatorView.doneObjectNamesForPosition(position5d[1:4])
        op = self.topLevelOperatorView

        # (Subclasses may override menu)
        menu = self.labelingContextMenu(names,op,position5d)
        if menu is not None:
            menu.exec_(globalWindowCoordinate)

    def _toggleSegmentation3D(self):
        self._showSegmentationIn3D = not self._showSegmentationIn3D
        if self._showSegmentationIn3D:
            self._segmentation_3d_label = self._renderMgr.addObject()
        else:
            self._renderMgr.removeObject(self._segmentation_3d_label)
            self._segmentation_3d_label = None
        self._update_rendering()
    
    def _update_rendering(self):
        if not self.render:
            return

        op = self.topLevelOperatorView
        if not self._renderMgr.ready:
            self._renderMgr.setup(op.InputData.meta.shape[1:4])

        # remove nonexistent objects
        self._shownObjects3D = dict((k, v) for k, v in self._shownObjects3D.iteritems()
                                    if k in op.MST.value.object_lut.keys())

        lut = numpy.zeros(len(op.MST.value.objects.lut), dtype=numpy.int32)
        for name, label in self._shownObjects3D.iteritems():
            objectSupervoxels = op.MST.value.object_lut[name]
            lut[objectSupervoxels] = label

        if self._showSegmentationIn3D:
            # Add segmentation as label, which is green
            lut[:] = numpy.where( op.MST.value.segmentation.lut == 2, self._segmentation_3d_label, lut )
                    
        self._renderMgr.volume = lut[op.MST.value.regionVol] # (Advanced indexing)
        self._update_colors()
        self._renderMgr.update()

    def _update_colors(self):
        op = self.topLevelOperatorView
        ctable = self._doneSegmentationLayer.colorTable

        for name, label in self._shownObjects3D.iteritems():
            color = QColor(ctable[op.MST.value.object_names[name]])
            color = (color.red() / 255.0, color.green() / 255.0, color.blue() / 255.0)
            self._renderMgr.setColor(label, color)

        if self._showSegmentationIn3D and self._segmentation_3d_label is not None:
            self._renderMgr.setColor(self._segmentation_3d_label, (0.0, 1.0, 0.0)) # Green

    def _getNext(self, slot, parentFun, transform=None):
        numLabels = self.labelListData.rowCount()
        value = slot.value
        if numLabels < len(value):
            result = value[numLabels]
            if transform is not None:
                result = transform(result)
            return result
        else:
            return parentFun()

    def getNextLabelName(self):
        return self._getNext(self.topLevelOperatorView.LabelNames,
                             super(CarvingGui, self).getNextLabelName)

    def appletDrawers(self):
        return [ ("Carving", self._labelControlUi) ]

    def setupLayers( self ):
        logger.debug( "setupLayers" )
        
        layers = []

        def onButtonsEnabled(slot, roi):
            currObj = self.topLevelOperatorView.CurrentObjectName.value
            hasSeg  = self.topLevelOperatorView.HasSegmentation.value
            
            self.labelingDrawerUi.currentObjectLabel.setText(currObj)
            self.labelingDrawerUi.save.setEnabled(hasSeg)

        self.topLevelOperatorView.CurrentObjectName.notifyDirty(onButtonsEnabled)
        self.topLevelOperatorView.HasSegmentation.notifyDirty(onButtonsEnabled)
        self.topLevelOperatorView.opLabelArray.NonzeroBlocks.notifyDirty(onButtonsEnabled)
        
        # Labels
        labellayer, labelsrc = self.createLabelLayer(direct=True)
        if labellayer is not None:
            labellayer._allowToggleVisible = False
            layers.append(labellayer)
            # Tell the editor where to draw label data
            self.editor.setLabelSink(labelsrc)

        #uncertainty
        if self._showUncertaintyLayer:
            uncert = self.topLevelOperatorView.Uncertainty
            if uncert.ready():
                colortable = []
                for i in range(256-len(colortable)):
                    r,g,b,a = i,0,0,i
                    colortable.append(QColor(r,g,b,a).rgba())
    
                layer = ColortableLayer(LazyflowSource(uncert), colortable, direct=True)
                layer.name = "Uncertainty"
                layer.visible = True
                layer.opacity = 0.3
                layers.append(layer)
       
        #segmentation 
        seg = self.topLevelOperatorView.Segmentation
        
        #seg = self.topLevelOperatorView.MST.value.segmentation
        #temp = self._done_lut[self.MST.value.regionVol[sl[1:4]]]
        if seg.ready():
            #source = RelabelingArraySource(seg)
            #source.setRelabeling(numpy.arange(256, dtype=numpy.uint8))
            colortable = [QColor(0,0,0,0).rgba(), QColor(0,0,0,0).rgba(), QColor(0,255,0).rgba()]
            for i in range(256-len(colortable)):
                r,g,b = numpy.random.randint(0,255), numpy.random.randint(0,255), numpy.random.randint(0,255)
                colortable.append(QColor(r,g,b).rgba())

            layer = ColortableLayer(LazyflowSource(seg), colortable, direct=True)
            layer.name = "Segmentation"
            layer.setToolTip("This layer displays the <i>current</i> segmentation. Simply add foreground and background " \
                             "labels, then press <i>Segment</i>.")
            layer.visible = True
            layer.opacity = 0.3
            layers.append(layer)
        
        #done 
        done = self.topLevelOperatorView.DoneObjects
        if done.ready(): 
            colortable = [QColor(0,0,0,0).rgba(), QColor(0,0,255).rgba()]
            #have to use lazyflow because it provides dirty signals
            layer = ColortableLayer(LazyflowSource(done), colortable, direct=True)
            layer.name = "Completed segments (unicolor)"
            layer.setToolTip("In order to keep track of which objects you have already completed, this layer " \
                             "shows <b>all completed object</b> in one color (<b>blue</b>). " \
                             "The reason for only one color is that for finding out which " \
                              "objects to label next, the identity of already completed objects is unimportant " \
                              "and destracting.")
            layer.visible = False
            layer.opacity = 0.5
            layers.append(layer)

        #done seg
        doneSeg = self.topLevelOperatorView.DoneSegmentation
        if doneSeg.ready():
            layer = ColortableLayer(LazyflowSource(doneSeg), self._doneSegmentationColortable, direct=True)
            layer.name = "Completed segments (one color per object)"
            layer.setToolTip("<html>In order to keep track of which objects you have already completed, this layer " \
                             "shows <b>all completed object</b>, each with a random color.</html>")
            layer.visible = False
            layer.opacity = 0.5
            self._doneSegmentationLayer = layer
            layers.append(layer)

        #supervoxel
        sv = self.topLevelOperatorView.Supervoxels
        if sv.ready():
            colortable = []
            for i in range(256):
                r,g,b = numpy.random.randint(0,255), numpy.random.randint(0,255), numpy.random.randint(0,255)
                colortable.append(QColor(r,g,b).rgba())
            layer = ColortableLayer(LazyflowSource(sv), colortable, direct=True)
            layer.name = "Supervoxels"
            layer.setToolTip("<html>This layer shows the partitioning of the input image into <b>supervoxels</b>. The carving " \
                             "algorithm uses these tiny puzzle-piceces to piece together the segmentation of an " \
                             "object. Sometimes, supervoxels are too large and straddle two distinct objects " \
                             "(undersegmentation). In this case, it will be impossible to achieve the desired " \
                             "segmentation. This layer helps you to understand these cases.</html>")
            layer.visible = False
            layer.opacity = 1.0
            layers.append(layer)

        #raw data
        '''
        rawSlot = self.topLevelOperatorView.RawData
        if rawSlot.ready():
            raw5D = self.topLevelOperatorView.RawData.value
            layer = GrayscaleLayer(ArraySource(raw5D), direct=True)
            #layer = GrayscaleLayer( LazyflowSource(rawSlot) )
            layer.visible = True
            layer.name = 'raw'
            layer.opacity = 1.0
            layers.append(layer)
        '''

        inputSlot = self.topLevelOperatorView.InputData
        if inputSlot.ready():
            layer = GrayscaleLayer( LazyflowSource(inputSlot), direct=True )
            layer.name = "Input Data"
            layer.setToolTip("<html>The data originally loaded into ilastik (unprocessed).</html>")
            #layer.visible = not rawSlot.ready()
            layer.visible = True
            layer.opacity = 1.0
            layers.append(layer)

        filteredSlot = self.topLevelOperatorView.FilteredInputData
        if filteredSlot.ready():
            layer = GrayscaleLayer( LazyflowSource(filteredSlot) )
            layer.name = "Filtered Input"
            layer.visible = False
            layer.opacity = 1.0
            layers.append(layer)

        return layers
Esempio n. 10
0
class CarvingGui(LabelingGui):
    def __init__(self, topLevelOperatorView, drawerUiPath=None):
        self.topLevelOperatorView = topLevelOperatorView

        #members
        self._doneSegmentationLayer = None
        self._showSegmentationIn3D = False
        self._showUncertaintyLayer = False
        #end: members

        labelingSlots = LabelingGui.LabelingSlots()
        labelingSlots.labelInput = topLevelOperatorView.WriteSeeds
        labelingSlots.labelOutput = topLevelOperatorView.opLabelArray.Output
        labelingSlots.labelEraserValue = topLevelOperatorView.opLabelArray.EraserLabelValue
        labelingSlots.LabelNames = topLevelOperatorView.LabelNames
        labelingSlots.labelDelete = topLevelOperatorView.opLabelArray.DeleteLabel
        labelingSlots.maxLabelValue = topLevelOperatorView.opLabelArray.MaxLabelValue
        labelingSlots.labelsAllowed = topLevelOperatorView.LabelsAllowed

        # We provide our own UI file (which adds an extra control for interactive mode)
        directory = os.path.split(__file__)[0]
        if drawerUiPath is None:
            drawerUiPath = os.path.join(directory, 'carvingDrawer.ui')
        self.dialogdirCOM = os.path.join(directory,
                                         'carvingObjectManagement.ui')
        self.dialogdirSAD = os.path.join(directory, 'saveAsDialog.ui')

        rawInputSlot = topLevelOperatorView.RawData
        super(CarvingGui, self).__init__(labelingSlots, topLevelOperatorView,
                                         drawerUiPath, rawInputSlot)

        self.labelingDrawerUi.currentObjectLabel.setText("<not saved yet>")

        # Init special base class members
        self.minLabelNumber = 2
        self.maxLabelNumber = 2

        mgr = ShortcutManager()

        #set up keyboard shortcuts
        segmentShortcut = QShortcut(
            QKeySequence("3"),
            self,
            member=self.labelingDrawerUi.segment.click,
            ambiguousMember=self.labelingDrawerUi.segment.click)
        mgr.register("Carving", "Run interactive segmentation",
                     segmentShortcut, self.labelingDrawerUi.segment)

        try:
            self.render = True
            self._shownObjects3D = {}
            self._renderMgr = RenderingManager(
                renderer=self.editor.view3d.qvtk.renderer,
                qvtk=self.editor.view3d.qvtk)
        except:
            self.render = False

        # Segmentation is toggled on by default in _after_init, below.
        # (We can't enable it until the layers are all present.)
        self._showSegmentationIn3D = False
        self._segmentation_3d_label = None

        self.labelingDrawerUi.segment.clicked.connect(self.onSegmentButton)
        self.labelingDrawerUi.segment.setEnabled(True)

        self.topLevelOperatorView.Segmentation.notifyDirty(
            bind(self._update_rendering))
        self.topLevelOperatorView.HasSegmentation.notifyValueChanged(
            bind(self._updateGui))

        ## uncertainty

        self.labelingDrawerUi.pushButtonUncertaintyFG.setEnabled(False)
        self.labelingDrawerUi.pushButtonUncertaintyBG.setEnabled(False)

        def onUncertaintyFGButton():
            print "uncertFG button clicked"
            pos = self.topLevelOperatorView.getMaxUncertaintyPos(label=2)
            self.editor.posModel.slicingPos = (pos[0], pos[1], pos[2])

        self.labelingDrawerUi.pushButtonUncertaintyFG.clicked.connect(
            onUncertaintyFGButton)

        def onUncertaintyBGButton():
            print "uncertBG button clicked"
            pos = self.topLevelOperatorView.getMaxUncertaintyPos(label=1)
            self.editor.posModel.slicingPos = (pos[0], pos[1], pos[2])

        self.labelingDrawerUi.pushButtonUncertaintyBG.clicked.connect(
            onUncertaintyBGButton)

        def onUncertaintyCombo(value):
            if value == 0:
                value = "none"
                self.labelingDrawerUi.pushButtonUncertaintyFG.setEnabled(False)
                self.labelingDrawerUi.pushButtonUncertaintyBG.setEnabled(False)
                self._showUncertaintyLayer = False
            else:
                if value == 1:
                    value = "localMargin"
                elif value == 2:
                    value = "exchangeCount"
                elif value == 3:
                    value = "gabow"
                else:
                    raise RuntimeError("unhandled case '%r'" % value)
                self.labelingDrawerUi.pushButtonUncertaintyFG.setEnabled(True)
                self.labelingDrawerUi.pushButtonUncertaintyBG.setEnabled(True)
                self._showUncertaintyLayer = True
                print "uncertainty changed to %r" % value
            self.topLevelOperatorView.UncertaintyType.setValue(value)
            self.updateAllLayers(
            )  #make sure that an added/deleted uncertainty layer is recognized

        self.labelingDrawerUi.uncertaintyCombo.currentIndexChanged.connect(
            onUncertaintyCombo)

        ## background priority

        def onBackgroundPrioritySpin(value):
            print "background priority changed to %f" % value
            self.topLevelOperatorView.BackgroundPriority.setValue(value)

        self.labelingDrawerUi.backgroundPrioritySpin.valueChanged.connect(
            onBackgroundPrioritySpin)

        def onBackgroundPriorityDirty(slot, roi):
            oldValue = self.labelingDrawerUi.backgroundPrioritySpin.value()
            newValue = self.topLevelOperatorView.BackgroundPriority.value
            if newValue != oldValue:
                self.labelingDrawerUi.backgroundPrioritySpin.setValue(newValue)

        self.topLevelOperatorView.BackgroundPriority.notifyDirty(
            onBackgroundPriorityDirty)

        ## bias

        def onNoBiasBelowDirty(slot, roi):
            oldValue = self.labelingDrawerUi.noBiasBelowSpin.value()
            newValue = self.topLevelOperatorView.NoBiasBelow.value
            if newValue != oldValue:
                self.labelingDrawerUi.noBiasBelowSpin.setValue(newValue)

        self.topLevelOperatorView.NoBiasBelow.notifyDirty(onNoBiasBelowDirty)

        def onNoBiasBelowSpin(value):
            print "background priority changed to %f" % value
            self.topLevelOperatorView.NoBiasBelow.setValue(value)

        self.labelingDrawerUi.noBiasBelowSpin.valueChanged.connect(
            onNoBiasBelowSpin)

        ## save

        self.labelingDrawerUi.save.clicked.connect(self.onSaveButton)

        ## clear

        self.labelingDrawerUi.clear.clicked.connect(
            self.topLevelOperatorView.clearCurrentLabeling)

        ## object names

        self.labelingDrawerUi.namesButton.clicked.connect(
            self.onShowObjectNames)

        def labelBackground():
            self.selectLabel(0)

        def labelObject():
            self.selectLabel(1)

        self._labelControlUi.labelListModel.allowRemove(False)

        bg = QShortcut(QKeySequence("1"),
                       self,
                       member=labelBackground,
                       ambiguousMember=labelBackground)
        bgToolTipObject = LabelListModel.EntryToolTipAdapter(
            self._labelControlUi.labelListModel, 0)
        mgr.register("Carving", "Select background label", bg, bgToolTipObject)
        fg = QShortcut(QKeySequence("2"),
                       self,
                       member=labelObject,
                       ambiguousMember=labelObject)
        fgToolTipObject = LabelListModel.EntryToolTipAdapter(
            self._labelControlUi.labelListModel, 1)
        mgr.register("Carving", "Select object label", fg, fgToolTipObject)

        def layerIndexForName(name):
            return self.layerstack.findMatchingIndex(lambda x: x.name == name)

        def addLayerToggleShortcut(layername, shortcut):
            def toggle():
                row = layerIndexForName(layername)
                self.layerstack.selectRow(row)
                layer = self.layerstack[row]
                layer.visible = not layer.visible
                self.viewerControlWidget().layerWidget.setFocus()

            shortcut = QShortcut(QKeySequence(shortcut),
                                 self,
                                 member=toggle,
                                 ambiguousMember=toggle)
            mgr.register("Carving", "Toggle layer %s" % layername, shortcut)

        #TODO
        addLayerToggleShortcut("done", "d")
        addLayerToggleShortcut("segmentation", "s")
        addLayerToggleShortcut("raw", "r")
        addLayerToggleShortcut("pmap", "v")
        addLayerToggleShortcut("hints", "t")
        '''
        def updateLayerTimings():
            s = "Layer timings:\n"
            for l in self.layerstack:
                s += "%s: %f sec.\n" % (l.name, l.averageTimePerTile)
            self.labelingDrawerUi.layerTimings.setText(s)
        t = QTimer(self)
        t.setInterval(1*1000) # 10 seconds
        t.start()
        t.timeout.connect(updateLayerTimings)
        '''

        def makeColortable():
            self._doneSegmentationColortable = [QColor(0, 0, 0, 0).rgba()]
            for i in range(254):
                r, g, b = numpy.random.randint(0, 255), numpy.random.randint(
                    0, 255), numpy.random.randint(0, 255)
                # ensure colors have sufficient distance to pure red and pure green
                while (255 - r) + g + b < 128 or r + (255 - g) + b < 128:
                    r, g, b = numpy.random.randint(
                        0, 255), numpy.random.randint(
                            0, 255), numpy.random.randint(0, 255)
                self._doneSegmentationColortable.append(QColor(r, g, b).rgba())
            self._doneSegmentationColortable.append(QColor(0, 255, 0).rgba())

        makeColortable()

        def onRandomizeColors():
            if self._doneSegmentationLayer is not None:
                print "randomizing colors ..."
                makeColortable()
                self._doneSegmentationLayer.colorTable = self._doneSegmentationColortable
                if self.render and self._renderMgr.ready:
                    self._update_rendering()

        #self.labelingDrawerUi.randomizeColors.clicked.connect(onRandomizeColors)
        self._updateGui()

    def _after_init(self):
        super(CarvingGui, self)._after_init()
        if self.render: self._toggleSegmentation3D()

    def _updateGui(self):
        self.labelingDrawerUi.save.setEnabled(
            self.topLevelOperatorView.dataIsStorable())

    def onSegmentButton(self):
        print "segment button clicked"
        self.topLevelOperatorView.Trigger.setDirty(slice(None))

    def saveAsDialog(self, name=""):
        '''special functionality: reject names given to other objects'''
        dialog = uic.loadUi(self.dialogdirSAD)
        dialog.lineEdit.setText(name)
        dialog.warning.setVisible(False)
        dialog.Ok.clicked.connect(dialog.accept)
        dialog.Cancel.clicked.connect(dialog.reject)
        listOfItems = self.topLevelOperatorView.AllObjectNames[:].wait()
        dialog.isDisabled = False

        def validate():
            name = dialog.lineEdit.text()
            if name in listOfItems:
                dialog.Ok.setEnabled(False)
                dialog.warning.setVisible(True)
                dialog.isDisabled = True
            elif dialog.isDisabled:
                dialog.Ok.setEnabled(True)
                dialog.warning.setVisible(False)
                dialog.isDisabled = False

        dialog.lineEdit.textChanged.connect(validate)
        result = dialog.exec_()
        if result:
            return str(dialog.lineEdit.text())

    def onSaveButton(self):
        print "save object as?"
        if self.topLevelOperatorView.dataIsStorable():
            prevName = ""
            if self.topLevelOperatorView.hasCurrentObject():
                prevName = self.topLevelOperatorView.currentObjectName()
            if prevName == "<not saved yet>":
                prevName = ""
            name = self.saveAsDialog(name=prevName)
            if name is None:
                return
            objects = self.topLevelOperatorView.AllObjectNames[:].wait()
            if name in objects and name != prevName:
                QMessageBox.critical(
                    self, "Save Object As",
                    "An object with name '%s' already exists.\nPlease choose a different name."
                    % name)
                return
            self.topLevelOperatorView.saveObjectAs(name)
            print "save object as %s" % name
            if prevName != name and prevName != "":
                self.topLevelOperatorView.deleteObject(prevName)
        else:
            msgBox = QMessageBox(self)
            msgBox.setText("The data does not seem fit to be stored.")
            msgBox.setWindowTitle("Problem with Data")
            msgBox.setIcon(2)
            msgBox.exec_()
            print "object not saved due to faulty data."

    def onShowObjectNames(self):
        '''show object names and allow user to load/delete them'''
        dialog = uic.loadUi(self.dialogdirCOM)
        listOfItems = self.topLevelOperatorView.AllObjectNames[:].wait()
        dialog.objectNames.addItems(sorted(listOfItems))

        def loadSelection():
            selected = [
                str(name.text())
                for name in dialog.objectNames.selectedItems()
            ]
            dialog.close()
            for objectname in selected:
                objectname = str(name.text())
                self.topLevelOperatorView.loadObject(objectname)

        def deleteSelection():
            items = dialog.objectNames.selectedItems()
            if self.confirmAndDelete([str(name.text()) for name in items]):
                for name in items:
                    name.setHidden(True)
            dialog.close()

        dialog.loadButton.clicked.connect(loadSelection)
        dialog.deleteButton.clicked.connect(deleteSelection)
        dialog.cancelButton.clicked.connect(dialog.close)
        dialog.exec_()

    def confirmAndDelete(self, namelist):
        print namelist
        objectlist = "".join("\n  " + str(i) for i in namelist)
        confirmed = QMessageBox.question(self, "Delete Object", \
                    "Do you want to delete these objects?"+objectlist, \
                    QMessageBox.Yes | QMessageBox.Cancel, \
                    defaultButton=QMessageBox.Yes)

        if confirmed == QMessageBox.Yes:
            for name in namelist:
                self.topLevelOperatorView.deleteObject(name)
            return True
        return False

    def labelingContextMenu(self, names, op, position5d):
        menu = QMenu(self)
        menu.setObjectName("carving_context_menu")
        posItem = menu.addAction("position %d %d %d" %
                                 (position5d[1], position5d[2], position5d[3]))
        posItem.setEnabled(False)
        menu.addSeparator()
        for name in names:
            submenu = QMenu(name, menu)

            # Load
            loadAction = submenu.addAction("Load %s" % name)
            loadAction.triggered.connect(partial(op.loadObject, name))

            # Delete
            def onDelAction(_name):
                self.confirmAndDelete([_name])
                if self.render and self._renderMgr.ready:
                    self._update_rendering()

            delAction = submenu.addAction("Delete %s" % name)
            delAction.triggered.connect(partial(onDelAction, name))

            if self.render:
                if name in self._shownObjects3D:
                    # Remove
                    def onRemove3D(_name):
                        label = self._shownObjects3D.pop(_name)
                        self._renderMgr.removeObject(label)
                        self._update_rendering()

                    removeAction = submenu.addAction("Remove %s from 3D view" %
                                                     name)
                    removeAction.triggered.connect(partial(onRemove3D, name))
                else:
                    # Show
                    def onShow3D(_name):
                        label = self._renderMgr.addObject()
                        self._shownObjects3D[_name] = label
                        self._update_rendering()

                    showAction = submenu.addAction("Show 3D %s" % name)
                    showAction.triggered.connect(partial(onShow3D, name))

            menu.addMenu(submenu)

        if names:
            menu.addSeparator()

        menu.addSeparator()
        if self.render:
            showSeg3DAction = menu.addAction("Show Editing Segmentation in 3D")
            showSeg3DAction.setCheckable(True)
            showSeg3DAction.setChecked(self._showSegmentationIn3D)
            showSeg3DAction.triggered.connect(self._toggleSegmentation3D)

        if op.dataIsStorable():
            menu.addAction("Save object").triggered.connect(self.onSaveButton)
        menu.addAction("Browse objects").triggered.connect(
            self.onShowObjectNames)
        menu.addAction("Segment").triggered.connect(self.onSegmentButton)
        menu.addAction("Clear").triggered.connect(
            self.topLevelOperatorView.clearCurrentLabeling)
        return menu

    def handleEditorRightClick(self, position5d, globalWindowCoordinate):
        names = self.topLevelOperatorView.doneObjectNamesForPosition(
            position5d[1:4])
        op = self.topLevelOperatorView

        # (Subclasses may override menu)
        menu = self.labelingContextMenu(names, op, position5d)
        if menu is not None:
            menu.exec_(globalWindowCoordinate)

    def _toggleSegmentation3D(self):
        self._showSegmentationIn3D = not self._showSegmentationIn3D
        if self._showSegmentationIn3D:
            self._segmentation_3d_label = self._renderMgr.addObject()
        else:
            self._renderMgr.removeObject(self._segmentation_3d_label)
            self._segmentation_3d_label = None
        self._update_rendering()

    def _update_rendering(self):
        if not self.render:
            return

        op = self.topLevelOperatorView
        if not self._renderMgr.ready:
            self._renderMgr.setup(op.InputData.meta.shape[1:4])

        # remove nonexistent objects
        self._shownObjects3D = dict(
            (k, v) for k, v in self._shownObjects3D.iteritems()
            if k in op.MST.value.object_lut.keys())

        lut = numpy.zeros(len(op.MST.value.objects.lut), dtype=numpy.int32)
        for name, label in self._shownObjects3D.iteritems():
            objectSupervoxels = op.MST.value.object_lut[name]
            lut[objectSupervoxels] = label

        if self._showSegmentationIn3D:
            # Add segmentation as label, which is green
            lut[:] = numpy.where(op.MST.value.segmentation.lut == 2,
                                 self._segmentation_3d_label, lut)

        self._renderMgr.volume = lut[
            op.MST.value.regionVol]  # (Advanced indexing)
        self._update_colors()
        self._renderMgr.update()

    def _update_colors(self):
        op = self.topLevelOperatorView
        ctable = self._doneSegmentationLayer.colorTable

        for name, label in self._shownObjects3D.iteritems():
            color = QColor(ctable[op.MST.value.object_names[name]])
            color = (color.red() / 255.0, color.green() / 255.0,
                     color.blue() / 255.0)
            self._renderMgr.setColor(label, color)

        if self._showSegmentationIn3D and self._segmentation_3d_label is not None:
            self._renderMgr.setColor(self._segmentation_3d_label,
                                     (0.0, 1.0, 0.0))  # Green

    def _getNext(self, slot, parentFun, transform=None):
        numLabels = self.labelListData.rowCount()
        value = slot.value
        if numLabels < len(value):
            result = value[numLabels]
            if transform is not None:
                result = transform(result)
            return result
        else:
            return parentFun()

    def getNextLabelName(self):
        return self._getNext(self.topLevelOperatorView.LabelNames,
                             super(CarvingGui, self).getNextLabelName)

    def appletDrawers(self):
        return [("Carving", self._labelControlUi)]

    def setupLayers(self):
        print "setupLayers"

        layers = []

        def onButtonsEnabled(slot, roi):
            currObj = self.topLevelOperatorView.CurrentObjectName.value
            hasSeg = self.topLevelOperatorView.HasSegmentation.value

            self.labelingDrawerUi.currentObjectLabel.setText(currObj)
            self.labelingDrawerUi.save.setEnabled(hasSeg)

        self.topLevelOperatorView.CurrentObjectName.notifyDirty(
            onButtonsEnabled)
        self.topLevelOperatorView.HasSegmentation.notifyDirty(onButtonsEnabled)
        self.topLevelOperatorView.opLabelArray.NonzeroBlocks.notifyDirty(
            onButtonsEnabled)

        # Labels
        labellayer, labelsrc = self.createLabelLayer(direct=True)
        if labellayer is not None:
            labellayer._allowToggleVisible = False
            layers.append(labellayer)
            # Tell the editor where to draw label data
            self.editor.setLabelSink(labelsrc)

        #uncertainty
        if self._showUncertaintyLayer:
            uncert = self.topLevelOperatorView.Uncertainty
            if uncert.ready():
                colortable = []
                for i in range(256 - len(colortable)):
                    r, g, b, a = i, 0, 0, i
                    colortable.append(QColor(r, g, b, a).rgba())

                layer = ColortableLayer(LazyflowSource(uncert),
                                        colortable,
                                        direct=True)
                layer.name = "Uncertainty"
                layer.visible = True
                layer.opacity = 0.3
                layers.append(layer)

        #segmentation
        seg = self.topLevelOperatorView.Segmentation

        #seg = self.topLevelOperatorView.MST.value.segmentation
        #temp = self._done_lut[self.MST.value.regionVol[sl[1:4]]]
        if seg.ready():
            #source = RelabelingArraySource(seg)
            #source.setRelabeling(numpy.arange(256, dtype=numpy.uint8))
            colortable = [
                QColor(0, 0, 0, 0).rgba(),
                QColor(0, 0, 0, 0).rgba(),
                QColor(0, 255, 0).rgba()
            ]
            for i in range(256 - len(colortable)):
                r, g, b = numpy.random.randint(0, 255), numpy.random.randint(
                    0, 255), numpy.random.randint(0, 255)
                colortable.append(QColor(r, g, b).rgba())

            layer = ColortableLayer(LazyflowSource(seg),
                                    colortable,
                                    direct=True)
            layer.name = "Segmentation"
            layer.setToolTip("This layer displays the <i>current</i> segmentation. Simply add foreground and background " \
                             "labels, then press <i>Segment</i>.")
            layer.visible = True
            layer.opacity = 0.3
            layers.append(layer)

        #done
        done = self.topLevelOperatorView.DoneObjects
        if done.ready():
            colortable = [QColor(0, 0, 0, 0).rgba(), QColor(0, 0, 255).rgba()]
            for i in range(254 - len(colortable)):
                r, g, b = numpy.random.randint(0, 255), numpy.random.randint(
                    0, 255), numpy.random.randint(0, 255)
                # ensure colors have sufficient distance to pure red and pure green
                while (255 - r) + g + b < 128 or r + (255 - g) + b < 128:
                    r, g, b = numpy.random.randint(
                        0, 255), numpy.random.randint(
                            0, 255), numpy.random.randint(0, 255)
                colortable.append(QColor(r, g, b).rgba())
            #have to use lazyflow because it provides dirty signals
            layer = ColortableLayer(LazyflowSource(done),
                                    colortable,
                                    direct=True)
            layer.name = "Completed segments (unicolor)"
            layer.setToolTip("In order to keep track of which objects you have already completed, this layer " \
                             "shows <b>all completed object</b> in one color (<b>blue</b>). " \
                             "The reason for only one color is that for finding out which " \
                              "objects to label next, the identity of already completed objects is unimportant " \
                              "and destracting.")
            layer.visible = False
            layer.opacity = 0.5
            layers.append(layer)

        #hints
        '''
        useLazyflow = True
        ctable = [QColor(0,0,0,0).rgba(), QColor(255,0,0).rgba()]
        ctable.extend( [QColor(255*random.random(), 255*random.random(), 255*random.random()) for x in range(254)] )
        if useLazyflow:
            hints = self.topLevelOperatorView.HintOverlay
            layer = ColortableLayer(LazyflowSource(hints), ctable, direct=True)
        else:
            hints = self.topLevelOperatorView._hints
            layer = ColortableLayer(ArraySource(hints), ctable, direct=True)
        if not useLazyflow or hints.ready():
            layer.name = "hints"
            layer.visible = False
            layer.opacity = 1.0
            layers.append(layer)
        '''
        '''
        #pmaps
        useLazyflow = True
        pmaps = self.topLevelOperatorView._pmap
        if pmaps is not None:
            layer = GrayscaleLayer(ArraySource(pmaps), direct=True)
            layer.name = "pmap"
            layer.visible = False
            layer.opacity = 1.0
            layers.append(layer)
        '''

        #done seg
        doneSeg = self.topLevelOperatorView.DoneSegmentation
        if doneSeg.ready():
            if self._doneSegmentationLayer is None:
                layer = ColortableLayer(LazyflowSource(doneSeg),
                                        self._doneSegmentationColortable,
                                        direct=True)
                layer.name = "Completed segments (one color per object)"
                layer.setToolTip("<html>In order to keep track of which objects you have already completed, this layer " \
                                 "shows <b>all completed object</b>, each with a random color.</html>")
                layer.visible = False
                layer.opacity = 0.5
                self._doneSegmentationLayer = layer
                layers.append(layer)
            else:
                layers.append(self._doneSegmentationLayer)

        #supervoxel
        sv = self.topLevelOperatorView.Supervoxels
        if sv.ready():
            for i in range(256):
                r, g, b = numpy.random.randint(0, 255), numpy.random.randint(
                    0, 255), numpy.random.randint(0, 255)
                colortable.append(QColor(r, g, b).rgba())
            layer = ColortableLayer(LazyflowSource(sv),
                                    colortable,
                                    direct=True)
            layer.name = "Supervoxels"
            layer.setToolTip("<html>This layer shows the partitioning of the input image into <b>supervoxels</b>. The carving " \
                             "algorithm uses these tiny puzzle-piceces to piece together the segmentation of an " \
                             "object. Sometimes, supervoxels are too large and straddle two distinct objects " \
                             "(undersegmentation). In this case, it will be impossible to achieve the desired " \
                             "segmentation. This layer helps you to understand these cases.</html>")
            layer.visible = False
            layer.opacity = 1.0
            layers.append(layer)

        #raw data
        '''
        rawSlot = self.topLevelOperatorView.RawData
        if rawSlot.ready():
            raw5D = self.topLevelOperatorView.RawData.value
            layer = GrayscaleLayer(ArraySource(raw5D), direct=True)
            #layer = GrayscaleLayer( LazyflowSource(rawSlot) )
            layer.visible = True
            layer.name = 'raw'
            layer.opacity = 1.0
            layers.append(layer)
        '''

        inputSlot = self.topLevelOperatorView.InputData
        if inputSlot.ready():
            layer = GrayscaleLayer(LazyflowSource(inputSlot), direct=True)
            layer.name = "Input Data"
            layer.setToolTip(
                "<html>The data originally loaded into ilastik (unprocessed).</html>"
            )
            #layer.visible = not rawSlot.ready()
            layer.visible = True
            layer.opacity = 1.0
            layers.append(layer)

        filteredSlot = self.topLevelOperatorView.FilteredInputData
        if filteredSlot.ready():
            layer = GrayscaleLayer(LazyflowSource(filteredSlot))
            layer.name = "Filtered Input"
            layer.visible = False
            layer.opacity = 1.0
            layers.append(layer)

        return layers
class PixelClassificationGui(LabelingGui):

    ###########################################
    ### AppletGuiInterface Concrete Methods ###
    ###########################################
    def centralWidget( self ):
        return self

    def stopAndCleanUp(self):
        # Base class first
        super(PixelClassificationGui, self).stopAndCleanUp()

        # Ensure that we are NOT in interactive mode
        self.labelingDrawerUi.liveUpdateButton.setChecked(False)
        self._viewerControlUi.checkShowPredictions.setChecked(False)
        self._viewerControlUi.checkShowSegmentation.setChecked(False)
        self.toggleInteractive(False)

        for fn in self.__cleanup_fns:
            fn()

    def viewerControlWidget(self):
        return self._viewerControlUi

    ###########################################
    ###########################################

    def __init__(self, topLevelOperatorView, shellRequestSignal, guiControlSignal, predictionSerializer ):
        # Tell our base class which slots to monitor
        labelSlots = LabelingGui.LabelingSlots()
        labelSlots.labelInput = topLevelOperatorView.LabelInputs
        labelSlots.labelOutput = topLevelOperatorView.LabelImages
        labelSlots.labelEraserValue = topLevelOperatorView.opLabelPipeline.opLabelArray.eraser
        labelSlots.labelDelete = topLevelOperatorView.opLabelPipeline.opLabelArray.deleteLabel
        labelSlots.maxLabelValue = topLevelOperatorView.MaxLabelValue
        labelSlots.labelsAllowed = topLevelOperatorView.LabelsAllowedFlags
        labelSlots.LabelNames = topLevelOperatorView.LabelNames

        self.__cleanup_fns = []

        # We provide our own UI file (which adds an extra control for interactive mode)
        labelingDrawerUiPath = os.path.split(__file__)[0] + '/labelingDrawer.ui'

        # Base class init
        super(PixelClassificationGui, self).__init__( labelSlots, topLevelOperatorView, labelingDrawerUiPath )
        
        self.topLevelOperatorView = topLevelOperatorView
        self.shellRequestSignal = shellRequestSignal
        self.guiControlSignal = guiControlSignal
        self.predictionSerializer = predictionSerializer

        self.interactiveModeActive = False
        # Immediately update our interactive state
        self.toggleInteractive( not self.topLevelOperatorView.FreezePredictions.value )

        self._currentlySavingPredictions = False

        self.labelingDrawerUi.liveUpdateButton.setEnabled(False)
        self.labelingDrawerUi.liveUpdateButton.setIcon( QIcon(ilastikIcons.Play) )
        self.labelingDrawerUi.liveUpdateButton.setToolButtonStyle(Qt.ToolButtonTextBesideIcon)
        self.labelingDrawerUi.liveUpdateButton.toggled.connect( self.toggleInteractive )

        self.topLevelOperatorView.MaxLabelValue.notifyDirty( bind(self.handleLabelSelectionChange) )
        self.__cleanup_fns.append( partial( self.topLevelOperatorView.MaxLabelValue.unregisterDirty, bind(self.handleLabelSelectionChange) ) )
        
        self._initShortcuts()

        try:
            self.render = True
            self._renderedLayers = {} # (layer name, label number)
            self._renderMgr = RenderingManager(
                renderer=self.editor.view3d.qvtk.renderer,
                qvtk=self.editor.view3d.qvtk)
        except:
            self.render = False

        # toggle interactive mode according to freezePredictions.value
        self.toggleInteractive(not self.topLevelOperatorView.FreezePredictions.value)
        def FreezePredDirty():
            self.toggleInteractive(not self.topLevelOperatorView.FreezePredictions.value)
        # listen to freezePrediction changes
        self.topLevelOperatorView.FreezePredictions.notifyDirty( bind(FreezePredDirty) )
        self.__cleanup_fns.append( partial( self.topLevelOperatorView.FreezePredictions.unregisterDirty, bind(FreezePredDirty) ) )

    def initViewerControlUi(self):
        localDir = os.path.split(__file__)[0]
        self._viewerControlUi = uic.loadUi( os.path.join( localDir, "viewerControls.ui" ) )

        # Connect checkboxes
        def nextCheckState(checkbox):
            checkbox.setChecked( not checkbox.isChecked() )
        self._viewerControlUi.checkShowPredictions.nextCheckState = partial(nextCheckState, self._viewerControlUi.checkShowPredictions)
        self._viewerControlUi.checkShowSegmentation.nextCheckState = partial(nextCheckState, self._viewerControlUi.checkShowSegmentation)

        self._viewerControlUi.checkShowPredictions.clicked.connect( self.handleShowPredictionsClicked )
        self._viewerControlUi.checkShowSegmentation.clicked.connect( self.handleShowSegmentationClicked )

        # The editor's layerstack is in charge of which layer movement buttons are enabled
        model = self.editor.layerStack
        self._viewerControlUi.viewerControls.setupConnections(model)
       
    def _initShortcuts(self):
        mgr = ShortcutManager()
        shortcutGroupName = "Predictions"

        togglePredictions = QShortcut( QKeySequence("p"), self, member=self._viewerControlUi.checkShowPredictions.click )
        mgr.register( shortcutGroupName,
                      "Toggle Prediction Layer Visibility",
                      togglePredictions,
                      self._viewerControlUi.checkShowPredictions )

        toggleSegmentation = QShortcut( QKeySequence("s"), self, member=self._viewerControlUi.checkShowSegmentation.click )
        mgr.register( shortcutGroupName,
                      "Toggle Segmentaton Layer Visibility",
                      toggleSegmentation,
                      self._viewerControlUi.checkShowSegmentation )

        toggleLivePredict = QShortcut( QKeySequence("l"), self, member=self.labelingDrawerUi.liveUpdateButton.toggle )
        mgr.register( shortcutGroupName,
                      "Toggle Live Prediction Mode",
                      toggleLivePredict,
                      self.labelingDrawerUi.liveUpdateButton )

    def _setup_contexts(self, layer):
        def callback(pos, clayer=layer):
            name = clayer.name
            if name in self._renderedLayers:
                label = self._renderedLayers.pop(name)
                self._renderMgr.removeObject(label)
                self._update_rendering()
            else:
                label = self._renderMgr.addObject()
                self._renderedLayers[clayer.name] = label
                self._update_rendering()

        if self.render:
            layer.contexts.append(('Toggle 3D rendering', callback))

    def setupLayers(self):
        """
        Called by our base class when one of our data slots has changed.
        This function creates a layer for each slot we want displayed in the volume editor.
        """
        # Base class provides the label layer.
        layers = super(PixelClassificationGui, self).setupLayers()

        # Add the uncertainty estimate layer
        uncertaintySlot = self.topLevelOperatorView.UncertaintyEstimate
        if uncertaintySlot.ready():
            uncertaintySrc = LazyflowSource(uncertaintySlot)
            uncertaintyLayer = AlphaModulatedLayer( uncertaintySrc,
                                                    tintColor=QColor( Qt.cyan ),
                                                    range=(0.0, 1.0),
                                                    normalize=(0.0, 1.0) )
            uncertaintyLayer.name = "Uncertainty"
            uncertaintyLayer.visible = False
            uncertaintyLayer.opacity = 1.0
            uncertaintyLayer.shortcutRegistration = (
                "Prediction Layers",
                "Show/Hide Uncertainty",
                QShortcut( QKeySequence("u"), self.viewerControlWidget(), uncertaintyLayer.toggleVisible ),
                uncertaintyLayer )
            layers.append(uncertaintyLayer)

        labels = self.labelListData

        # Add each of the segmentations
        for channel, segmentationSlot in enumerate(self.topLevelOperatorView.SegmentationChannels):
            if segmentationSlot.ready() and channel < len(labels):
                ref_label = labels[channel]
                segsrc = LazyflowSource(segmentationSlot)
                segLayer = AlphaModulatedLayer( segsrc,
                                                tintColor=ref_label.pmapColor(),
                                                range=(0.0, 1.0),
                                                normalize=(0.0, 1.0) )

                segLayer.opacity = 1
                segLayer.visible = False #self.labelingDrawerUi.liveUpdateButton.isChecked()
                segLayer.visibleChanged.connect(self.updateShowSegmentationCheckbox)

                def setLayerColor(c, segLayer=segLayer):
                    segLayer.tintColor = c
                    self._update_rendering()

                def setSegLayerName(n, segLayer=segLayer):
                    oldname = segLayer.name
                    newName = "Segmentation (%s)" % n
                    segLayer.name = newName
                    if not self.render:
                        return
                    if oldname in self._renderedLayers:
                        label = self._renderedLayers.pop(oldname)
                        self._renderedLayers[newName] = label

                setSegLayerName(ref_label.name)

                ref_label.pmapColorChanged.connect(setLayerColor)
                ref_label.nameChanged.connect(setSegLayerName)
                #check if layer is 3d before adding the "Toggle 3D" option
                #this check is done this way to match the VolumeRenderer, in
                #case different 3d-axistags should be rendered like t-x-y
                #_axiskeys = segmentationSlot.meta.getAxisKeys()
                if len(segmentationSlot.meta.shape) == 4:
                    #the Renderer will cut out the last shape-dimension, so
                    #we're checking for 4 dimensions
                    self._setup_contexts(segLayer)
                layers.append(segLayer)
        
        # Add each of the predictions
        for channel, predictionSlot in enumerate(self.topLevelOperatorView.PredictionProbabilityChannels):
            if predictionSlot.ready() and channel < len(labels):
                ref_label = labels[channel]
                predictsrc = LazyflowSource(predictionSlot)
                predictLayer = AlphaModulatedLayer( predictsrc,
                                                    tintColor=ref_label.pmapColor(),
                                                    range=(0.0, 1.0),
                                                    normalize=(0.0, 1.0) )
                predictLayer.opacity = 0.25
                predictLayer.visible = self.labelingDrawerUi.liveUpdateButton.isChecked()
                predictLayer.visibleChanged.connect(self.updateShowPredictionCheckbox)

                def setLayerColor(c, predictLayer=predictLayer):
                    predictLayer.tintColor = c

                def setPredLayerName(n, predictLayer=predictLayer):
                    newName = "Prediction for %s" % n
                    predictLayer.name = newName

                setPredLayerName(ref_label.name)
                ref_label.pmapColorChanged.connect(setLayerColor)
                ref_label.nameChanged.connect(setPredLayerName)
                layers.append(predictLayer)

        # Add the raw data last (on the bottom)
        inputDataSlot = self.topLevelOperatorView.InputImages
        if inputDataSlot.ready():
            inputLayer = self.createStandardLayerFromSlot( inputDataSlot )
            inputLayer.name = "Input Data"
            inputLayer.visible = True
            inputLayer.opacity = 1.0

            def toggleTopToBottom():
                index = self.layerstack.layerIndex( inputLayer )
                self.layerstack.selectRow( index )
                if index == 0:
                    self.layerstack.moveSelectedToBottom()
                else:
                    self.layerstack.moveSelectedToTop()

            inputLayer.shortcutRegistration = (
                "Prediction Layers",
                "Bring Input To Top/Bottom",
                QShortcut( QKeySequence("i"), self.viewerControlWidget(), toggleTopToBottom),
                inputLayer )
            layers.append(inputLayer)
        
        self.handleLabelSelectionChange()
        return layers

    def toggleInteractive(self, checked):
        logger.debug("toggling interactive mode to '%r'" % checked)

        if checked==True:
            if not self.topLevelOperatorView.FeatureImages.ready() \
            or self.topLevelOperatorView.FeatureImages.meta.shape==None:
                self.labelingDrawerUi.liveUpdateButton.setChecked(False)
                mexBox=QMessageBox()
                mexBox.setText("There are no features selected ")
                mexBox.exec_()
                return

        # If we're changing modes, enable/disable our controls and other applets accordingly
        if self.interactiveModeActive != checked:
            if checked:
                self.labelingDrawerUi.labelListView.allowDelete = False
                self.labelingDrawerUi.AddLabelButton.setEnabled( False )
                self.guiControlSignal.emit( ControlCommand.DisableUpstream )
            else:
                self.labelingDrawerUi.labelListView.allowDelete = True
                self.labelingDrawerUi.AddLabelButton.setEnabled( True )
                self.guiControlSignal.emit( ControlCommand.Pop )
        self.interactiveModeActive = checked

        self.topLevelOperatorView.FreezePredictions.setValue( not checked )
        self.labelingDrawerUi.liveUpdateButton.setChecked(checked)
        # Auto-set the "show predictions" state according to what the user just clicked.
        if checked:
            self._viewerControlUi.checkShowPredictions.setChecked( True )
            self.handleShowPredictionsClicked()

    @pyqtSlot()
    def handleShowPredictionsClicked(self):
        checked = self._viewerControlUi.checkShowPredictions.isChecked()
        for layer in self.layerstack:
            if "Prediction" in layer.name:
                layer.visible = checked

    @pyqtSlot()
    def handleShowSegmentationClicked(self):
        checked = self._viewerControlUi.checkShowSegmentation.isChecked()
        for layer in self.layerstack:
            if "Segmentation" in layer.name:
                layer.visible = checked

    @pyqtSlot()
    def updateShowPredictionCheckbox(self):
        predictLayerCount = 0
        visibleCount = 0
        for layer in self.layerstack:
            if "Prediction" in layer.name:
                predictLayerCount += 1
                if layer.visible:
                    visibleCount += 1

        if visibleCount == 0:
            self._viewerControlUi.checkShowPredictions.setCheckState(Qt.Unchecked)
        elif predictLayerCount == visibleCount:
            self._viewerControlUi.checkShowPredictions.setCheckState(Qt.Checked)
        else:
            self._viewerControlUi.checkShowPredictions.setCheckState(Qt.PartiallyChecked)

    @pyqtSlot()
    def updateShowSegmentationCheckbox(self):
        segLayerCount = 0
        visibleCount = 0
        for layer in self.layerstack:
            if "Segmentation" in layer.name:
                segLayerCount += 1
                if layer.visible:
                    visibleCount += 1

        if visibleCount == 0:
            self._viewerControlUi.checkShowSegmentation.setCheckState(Qt.Unchecked)
        elif segLayerCount == visibleCount:
            self._viewerControlUi.checkShowSegmentation.setCheckState(Qt.Checked)
        else:
            self._viewerControlUi.checkShowSegmentation.setCheckState(Qt.PartiallyChecked)

    @pyqtSlot()
    @threadRouted
    def handleLabelSelectionChange(self):
        enabled = False
        if self.topLevelOperatorView.MaxLabelValue.ready():
            enabled = True
            enabled &= self.topLevelOperatorView.MaxLabelValue.value >= 2
            enabled &= numpy.all(numpy.asarray(self.topLevelOperatorView.CachedFeatureImages.meta.shape) > 0)
            # FIXME: also check that each label has scribbles?
        
        if not enabled:
            self.labelingDrawerUi.liveUpdateButton.setChecked(False)
            self._viewerControlUi.checkShowPredictions.setChecked(False)
            self._viewerControlUi.checkShowSegmentation.setChecked(False)
            self.handleShowPredictionsClicked()
            self.handleShowSegmentationClicked()

        self.labelingDrawerUi.liveUpdateButton.setEnabled(enabled)
        self._viewerControlUi.checkShowPredictions.setEnabled(enabled)
        self._viewerControlUi.checkShowSegmentation.setEnabled(enabled)

    def _getNext(self, slot, parentFun, transform=None):
        numLabels = self.labelListData.rowCount()
        value = slot.value
        if numLabels < len(value):
            result = value[numLabels]
            if transform is not None:
                result = transform(result)
            return result
        else:
            return parentFun()

    def _onLabelChanged(self, parentFun, mapf, slot):
        parentFun()
        new = map(mapf, self.labelListData)
        old = slot.value
        slot.setValue(_listReplace(old, new))

    def _onLabelRemoved(self, parent, start, end):
        # Update the label names/colors BEFORE calling the base class,
        #  which will update the operator and expects the 
        #  label names list to be correct.
        op = self.topLevelOperatorView
        for slot in (op.LabelNames, op.LabelColors, op.PmapColors):
            value = slot.value
            value.pop(start)
            slot.setValue(value)
        
        # Call the base class to update the operator.
        super(PixelClassificationGui, self)._onLabelRemoved(parent, start, end)

    def getNextLabelName(self):
        return self._getNext(self.topLevelOperatorView.LabelNames,
                             super(PixelClassificationGui, self).getNextLabelName)

    def getNextLabelColor(self):
        return self._getNext(
            self.topLevelOperatorView.LabelColors,
            super(PixelClassificationGui, self).getNextLabelColor,
            lambda x: QColor(*x)
        )

    def getNextPmapColor(self):
        return self._getNext(
            self.topLevelOperatorView.PmapColors,
            super(PixelClassificationGui, self).getNextPmapColor,
            lambda x: QColor(*x)
        )

    def onLabelNameChanged(self):
        self._onLabelChanged(super(PixelClassificationGui, self).onLabelNameChanged,
                             lambda l: l.name,
                             self.topLevelOperatorView.LabelNames)

    def onLabelColorChanged(self):
        self._onLabelChanged(super(PixelClassificationGui, self).onLabelColorChanged,
                             lambda l: (l.brushColor().red(),
                                        l.brushColor().green(),
                                        l.brushColor().blue()),
                             self.topLevelOperatorView.LabelColors)


    def onPmapColorChanged(self):
        self._onLabelChanged(super(PixelClassificationGui, self).onPmapColorChanged,
                             lambda l: (l.pmapColor().red(),
                                        l.pmapColor().green(),
                                        l.pmapColor().blue()),
                             self.topLevelOperatorView.PmapColors)

    def _update_rendering(self):
        if not self.render:
            return
        shape = self.topLevelOperatorView.InputImages.meta.shape[1:4]
        if len(shape) != 5:
            #this might be a 2D image, no need for updating any 3D stuff 
            return
        
        time = self.editor.posModel.slicingPos5D[0]
        if not self._renderMgr.ready:
            self._renderMgr.setup(shape)

        layernames = set(layer.name for layer in self.layerstack)
        self._renderedLayers = dict((k, v) for k, v in self._renderedLayers.iteritems()
                                if k in layernames)

        newvolume = numpy.zeros(shape, dtype=numpy.uint8)
        for layer in self.layerstack:
            try:
                label = self._renderedLayers[layer.name]
            except KeyError:
                continue
            for ds in layer.datasources:
                vol = ds.dataSlot.value[time, ..., 0]
                indices = numpy.where(vol != 0)
                newvolume[indices] = label

        self._renderMgr.volume = newvolume
        self._update_colors()
        self._renderMgr.update()

    def _update_colors(self):
        for layer in self.layerstack:
            try:
                label = self._renderedLayers[layer.name]
            except KeyError:
                continue
            color = layer.tintColor
            color = (color.red() / 255.0, color.green() / 255.0, color.blue() / 255.0)
            self._renderMgr.setColor(label, color)
Esempio n. 12
0
class CarvingGui(LabelingGui):
    def __init__(self, topLevelOperatorView, drawerUiPath=None ):
        self.topLevelOperatorView = topLevelOperatorView

        labelingSlots = LabelingGui.LabelingSlots()
        labelingSlots.labelInput = topLevelOperatorView.WriteSeeds
        labelingSlots.labelOutput = topLevelOperatorView.opLabelArray.Output
        labelingSlots.labelEraserValue = topLevelOperatorView.opLabelArray.EraserLabelValue
        labelingSlots.labelDelete = topLevelOperatorView.opLabelArray.DeleteLabel
        labelingSlots.maxLabelValue = topLevelOperatorView.opLabelArray.MaxLabelValue
        labelingSlots.labelsAllowed = topLevelOperatorView.LabelsAllowed

        # We provide our own UI file (which adds an extra control for interactive mode)
        directory = os.path.split(__file__)[0]
        if drawerUiPath is None:
            drawerUiPath = os.path.join(directory, 'carvingDrawer.ui')
        self.dialogdirCOM = os.path.join(directory, 'carvingObjectManagement.ui')
        self.dialogdirSAD = os.path.join(directory, 'saveAsDialog.ui')

        rawInputSlot = topLevelOperatorView.RawData        
        super(CarvingGui, self).__init__(labelingSlots, topLevelOperatorView, drawerUiPath, rawInputSlot)

        # Init special base class members
        self.minLabelNumber = 2
        self.maxLabelNumber = 2
        
        mgr = ShortcutManager()
        
        #set up keyboard shortcuts
        segmentShortcut = QShortcut(QKeySequence("3"), self, member=self.labelingDrawerUi.segment.click,
                                    ambiguousMember=self.labelingDrawerUi.segment.click)
        mgr.register("Carving", "Run interactive segmentation", segmentShortcut, self.labelingDrawerUi.segment)
        

        self._doneSegmentationLayer = None
        self._showSegmentationIn3D = False
        
        #volume rendering
        try:
            self.render = True
            self._shownObjects3D = {}
            self._renderMgr = RenderingManager(
                renderer=self.editor.view3d.qvtk.renderer,
                qvtk=self.editor.view3d.qvtk)
        except:
            self.render = False

        
        self.labelingDrawerUi.segment.clicked.connect(self.onSegmentButton)
        self.labelingDrawerUi.segment.setEnabled(True)

        self.topLevelOperatorView.Segmentation.notifyDirty( bind( self._update_rendering ) )

        def onUncertaintyFGButton():
            print "uncertFG button clicked"
            pos = self.topLevelOperatorView.getMaxUncertaintyPos(label=2)
            self.editor.posModel.slicingPos = (pos[0], pos[1], pos[2])
        self.labelingDrawerUi.pushButtonUncertaintyFG.clicked.connect(onUncertaintyFGButton)
        self.labelingDrawerUi.pushButtonUncertaintyFG.setEnabled(True)

        def onUncertaintyBGButton():
            print "uncertBG button clicked"
            pos = self.topLevelOperatorView.getMaxUncertaintyPos(label=1)
            self.editor.posModel.slicingPos = (pos[0], pos[1], pos[2])
        self.labelingDrawerUi.pushButtonUncertaintyBG.clicked.connect(onUncertaintyBGButton)
        self.labelingDrawerUi.pushButtonUncertaintyBG.setEnabled(True)


        def onBackgroundPrioritySpin(value):
            print "background priority changed to %f" % value
            self.topLevelOperatorView.BackgroundPriority.setValue(value)
        self.labelingDrawerUi.backgroundPrioritySpin.valueChanged.connect(onBackgroundPrioritySpin)

        def onuncertaintyCombo(value):
            if value == 0:
                value = "none"
            if value == 1:
                value = "localMargin"
            if value == 2:
                value = "exchangeCount"
            if value == 3:
                value = "gabow"
            print "uncertainty changed to %r" % value
            self.topLevelOperatorView.UncertaintyType.setValue(value)
        self.labelingDrawerUi.uncertaintyCombo.currentIndexChanged.connect(onuncertaintyCombo)

        def onBackgroundPriorityDirty(slot, roi):
            oldValue = self.labelingDrawerUi.backgroundPrioritySpin.value()
            newValue = self.topLevelOperatorView.BackgroundPriority.value
            if  newValue != oldValue:
                self.labelingDrawerUi.backgroundPrioritySpin.setValue(newValue)
        self.topLevelOperatorView.BackgroundPriority.notifyDirty(onBackgroundPriorityDirty)
        
        def onNoBiasBelowDirty(slot, roi):
            oldValue = self.labelingDrawerUi.noBiasBelowSpin.value()
            newValue = self.topLevelOperatorView.NoBiasBelow.value
            if  newValue != oldValue:
                self.labelingDrawerUi.noBiasBelowSpin.setValue(newValue)
        self.topLevelOperatorView.NoBiasBelow.notifyDirty(onNoBiasBelowDirty)
        
        def onNoBiasBelowSpin(value):
            print "background priority changed to %f" % value
            self.topLevelOperatorView.NoBiasBelow.setValue(value)
        self.labelingDrawerUi.noBiasBelowSpin.valueChanged.connect(onNoBiasBelowSpin)

        self.labelingDrawerUi.saveAs.clicked.connect(self.onSaveAsButton)
        self.labelingDrawerUi.save.clicked.connect(self.onSaveButton)
        self.labelingDrawerUi.save.setEnabled(False) #initially, the user need to use "Save As"

        self.labelingDrawerUi.clear.clicked.connect(self.topLevelOperatorView.clearCurrentLabeling)
        self.labelingDrawerUi.clear.setEnabled(True)
        
        self.labelingDrawerUi.namesButton.clicked.connect(self.onShowObjectNames)
        
        def labelBackground():
            self.selectLabel(0)
        def labelObject():
            self.selectLabel(1)

        self._labelControlUi.labelListModel.allowRemove(False)

        bg = QShortcut(QKeySequence("1"), self, member=labelBackground, ambiguousMember=labelBackground)
        bgToolTipObject = LabelListModel.EntryToolTipAdapter(self._labelControlUi.labelListModel, 0)
        mgr.register("Carving", "Select background label", bg, bgToolTipObject)
        fg = QShortcut(QKeySequence("2"), self, member=labelObject, ambiguousMember=labelObject)
        fgToolTipObject = LabelListModel.EntryToolTipAdapter(self._labelControlUi.labelListModel, 1)
        mgr.register("Carving", "Select object label", fg, fgToolTipObject)

        def layerIndexForName(name):
            return self.layerstack.findMatchingIndex(lambda x: x.name == name)

        def addLayerToggleShortcut(layername, shortcut):
            def toggle():
                row = layerIndexForName(layername)
                self.layerstack.selectRow(row)
                layer = self.layerstack[row]
                layer.visible = not layer.visible
                self.viewerControlWidget().layerWidget.setFocus()
            shortcut = QShortcut(QKeySequence(shortcut), self, member=toggle, ambiguousMember=toggle)
            mgr.register("Carving", "Toggle layer %s" % layername, shortcut)

        addLayerToggleShortcut("done", "d")
        addLayerToggleShortcut("segmentation", "s")
        addLayerToggleShortcut("raw", "r")
        addLayerToggleShortcut("pmap", "v")
        addLayerToggleShortcut("hints","t")

        '''
        def updateLayerTimings():
            s = "Layer timings:\n"
            for l in self.layerstack:
                s += "%s: %f sec.\n" % (l.name, l.averageTimePerTile)
            self.labelingDrawerUi.layerTimings.setText(s)
        t = QTimer(self)
        t.setInterval(1*1000) # 10 seconds
        t.start()
        t.timeout.connect(updateLayerTimings)
        '''

        def makeColortable():
            self._doneSegmentationColortable = [QColor(0,0,0,0).rgba()]
            for i in range(254):
                r,g,b = numpy.random.randint(0,255), numpy.random.randint(0,255), numpy.random.randint(0,255)
                self._doneSegmentationColortable.append(QColor(r,g,b).rgba())
            self._doneSegmentationColortable[1:17] = colortables.default16
            self._doneSegmentationColortable.append(QColor(0,255,0).rgba())
        makeColortable()
        self._doneSegmentationLayer = None
        def onRandomizeColors():
            if self._doneSegmentationLayer is not None:
                print "randomizing colors ..."
                makeColortable()
                self._doneSegmentationLayer.colorTable = self._doneSegmentationColortable
                if self.render and self._renderMgr.ready:
                    self._update_rendering()
        #self.labelingDrawerUi.randomizeColors.clicked.connect(onRandomizeColors)
        
    def onSegmentButton(self):
        print "segment button clicked"
        self.topLevelOperatorView.Trigger.setDirty(slice(None))
    
    def saveAsDialog(self):
        '''special functionality: reject names given to other objects'''
        dialog = uic.loadUi(self.dialogdirSAD)
        dialog.warning.setVisible(False)
        dialog.Ok.clicked.connect(dialog.accept)
        dialog.Cancel.clicked.connect(dialog.reject)
        listOfItems = self.topLevelOperatorView.AllObjectNames[:].wait()
        dialog.isDisabled = False
        def validate():
            name = dialog.lineEdit.text()
            if name in listOfItems:
                dialog.Ok.setEnabled(False)
                dialog.warning.setVisible(True)
                dialog.isDisabled = True
            elif dialog.isDisabled:
                dialog.Ok.setEnabled(True)
                dialog.warning.setVisible(False)
                dialog.isDisabled = False
        dialog.lineEdit.textChanged.connect(validate)
        result = dialog.exec_()
        if result:
            return str(dialog.lineEdit.text())
    
    def onSaveAsButton(self):
        print "save object as?"
        if self.topLevelOperatorView.dataIsStorable():
            name = self.saveAsDialog()
            if name is None:
                return
            objects = self.topLevelOperatorView.AllObjectNames[:].wait()
            if name in objects:
                QMessageBox.critical(self, "Save Object As", "An object with name '%s' already exists.\nPlease choose a different name." % name)
                return
            self.topLevelOperatorView.saveObjectAs(name)
            print "save object as %s" % name
        else:
            msgBox = QMessageBox(self)
            msgBox.setText("The data does not seem fit to be stored.")
            msgBox.setWindowTitle("Problem with Data")
            msgBox.setIcon(2)
            msgBox.exec_()
            print "object not saved due to faulty data."
    
    def onSaveButton(self):
        if self.topLevelOperatorView.dataIsStorable():
            if self.topLevelOperatorView.hasCurrentObject():
                name = self.topLevelOperatorView.currentObjectName()
                self.topLevelOperatorView.saveObjectAs( name )
            else:
                self.onSaveAsButton()
        else:
            msgBox = QMessageBox(self)
            msgBox.setText("The data does no seem fit to be stored.")
            msgBox.setWindowTitle("Lousy Data")
            msgBox.setIcon(2)
            msgBox.exec_()
            print "object not saved due to faulty data."
    
    def onShowObjectNames(self):
        '''show object names and allow user to load/delete them'''
        dialog = uic.loadUi(self.dialogdirCOM)
        listOfItems = self.topLevelOperatorView.AllObjectNames[:].wait()
        dialog.objectNames.addItems(sorted(listOfItems))
        
        def loadSelection():
            for name in dialog.objectNames.selectedItems():
                objectname = str(name.text())
                self.topLevelOperatorView.loadObject(objectname)
        
        def deleteSelection():
            items = dialog.objectNames.selectedItems()
            if self.confirmAndDelete([str(name.text()) for name in items]):
                for name in items:
                    name.setHidden(True)
        
        dialog.loadButton.clicked.connect(loadSelection)
        dialog.deleteButton.clicked.connect(deleteSelection)
        dialog.cancelButton.clicked.connect(dialog.close)
        dialog.exec_()
    
    def confirmAndDelete(self,namelist):
        print namelist
        objectlist = "".join("\n  "+str(i) for i in namelist)
        confirmed = QMessageBox.question(self, "Delete Object", \
                    "Do you want to delete these objects?"+objectlist, \
                    QMessageBox.Yes | QMessageBox.Cancel, \
                    defaultButton=QMessageBox.Yes)
            
        if confirmed == QMessageBox.Yes:
            for name in namelist:
                self.topLevelOperatorView.deleteObject(name)
            return True
        return False
    
    def labelingContextMenu(self,names,op,position5d):
        menu = QMenu(self)
        menu.setObjectName("carving_context_menu")
        posItem = menu.addAction("position %d %d %d" % (position5d[1], position5d[2], position5d[3]))
        posItem.setEnabled(False)
        menu.addSeparator()
        for name in names:
            submenu = QMenu(name,menu)
            
            # Load
            loadAction = submenu.addAction("Load %s" % name)
            loadAction.triggered.connect( partial(op.loadObject, name) )
            
            # Delete
            def onDelAction(_name):
                self.confirmAndDelete([_name])
                if self.render and self._renderMgr.ready:
                    self._update_rendering()
            delAction = submenu.addAction("Delete %s" % name)
            delAction.triggered.connect( partial(onDelAction, name) )

            if self.render:
                if name in self._shownObjects3D:
                    # Remove
                    def onRemove3D(_name):
                        label = self._shownObjects3D.pop(_name)
                        self._renderMgr.removeObject(label)
                        self._update_rendering()
                    removeAction = submenu.addAction("Remove %s from 3D view" % name)
                    removeAction.triggered.connect( partial(onRemove3D, name) )
                else:
                    # Show
                    def onShow3D(_name):
                        label = self._renderMgr.addObject()
                        self._shownObjects3D[_name] = label
                        self._update_rendering()
                    showAction = submenu.addAction("Show 3D %s" % name)
                    showAction.triggered.connect( partial(onShow3D, name ) )
                        
            menu.addMenu(submenu)

        if names:
            menu.addSeparator()

        menu.addSeparator()
        showSeg3DAction = menu.addAction( "Show Editing Segmentation in 3D" )
        showSeg3DAction.setCheckable(True)
        showSeg3DAction.setChecked( self._showSegmentationIn3D )
        showSeg3DAction.triggered.connect( self._toggleSegmentation3D )
        
        if op.dataIsStorable():
            menu.addAction("Save objects").triggered.connect( self.onSegmentButton )
        menu.addAction("Browse objects").triggered.connect( self.topLevelOperatorView.clearCurrentLabeling )
        menu.addAction("Segment").triggered.connect( self.onShowObjectNames )
        menu.addAction("Clear").triggered.connect( self.onSaveAsButton )
        return menu
    
    def handleEditorRightClick(self, position5d, globalWindowCoordinate):
        names = self.topLevelOperatorView.doneObjectNamesForPosition(position5d[1:4])
        op = self.topLevelOperatorView

        # (Subclasses may override menu)
        menu = self.labelingContextMenu(names,op,position5d)
        if menu is not None:
            menu.exec_(globalWindowCoordinate)

    def _toggleSegmentation3D(self):
        self._showSegmentationIn3D = not self._showSegmentationIn3D
        if self._showSegmentationIn3D:
            self._segmentation_3d_label = self._renderMgr.addObject()
        else:
            self._renderMgr.removeObject(self._segmentation_3d_label)
        self._update_rendering()
    
    def _update_rendering(self):
        if not self.render:
            return

        op = self.topLevelOperatorView
        if not self._renderMgr.ready:
            self._renderMgr.setup(op.InputData.meta.shape[1:4])

        # remove nonexistent objects
        self._shownObjects3D = dict((k, v) for k, v in self._shownObjects3D.iteritems()
                                    if k in op.MST.value.object_lut.keys())

        lut = numpy.zeros(len(op.MST.value.objects.lut), dtype=numpy.int32)
        for name, label in self._shownObjects3D.iteritems():
            objectSupervoxels = op.MST.value.object_lut[name]
            lut[objectSupervoxels] = label

        if self._showSegmentationIn3D:
            # Add segmentation as label, which is green
            lut[:] = numpy.where( op.MST.value.segmentation.lut == 2, self._segmentation_3d_label, lut )
                    
        self._renderMgr.volume = lut[op.MST.value.regionVol] # (Advanced indexing)
        self._update_colors()
        self._renderMgr.update()

    def _update_colors(self):
        op = self.topLevelOperatorView
        ctable = self._doneSegmentationLayer.colorTable

        for name, label in self._shownObjects3D.iteritems():
            color = QColor(ctable[op.MST.value.object_names[name]])
            color = (color.red() / 255.0, color.green() / 255.0, color.blue() / 255.0)
            self._renderMgr.setColor(label, color)

        if self._showSegmentationIn3D:
            self._renderMgr.setColor(self._segmentation_3d_label, (0.0, 1.0, 0.0)) # Green


    def getNextLabelName(self):
        l = len(self._labelControlUi.labelListModel)
        if l == 0:
            return "Background"
        else:
            return "Object"

    def appletDrawers(self):
        return [ ("Carving", self._labelControlUi) ]

    def setupLayers( self ):
        layers = []

        def onButtonsEnabled(slot, roi):
            currObj = self.topLevelOperatorView.CurrentObjectName.value
            hasSeg  = self.topLevelOperatorView.HasSegmentation.value
            #nzLB    = self.topLevelOperatorView.opLabeling.NonzeroLabelBlocks[:].wait()[0]
            
            self.labelingDrawerUi.currentObjectLabel.setText("current object: %s" % currObj)
            self.labelingDrawerUi.save.setEnabled(currObj != "" and hasSeg)
            self.labelingDrawerUi.saveAs.setEnabled(currObj == "" and hasSeg)
            #rethink this
            #self.labelingDrawerUi.segment.setEnabled(len(nzLB) > 0)
            #self.labelingDrawerUi.clear.setEnabled(len(nzLB) > 0)
        self.topLevelOperatorView.CurrentObjectName.notifyDirty(onButtonsEnabled)
        self.topLevelOperatorView.HasSegmentation.notifyDirty(onButtonsEnabled)
        self.topLevelOperatorView.opLabelArray.NonzeroBlocks.notifyDirty(onButtonsEnabled)
        
        # Labels
        labellayer, labelsrc = self.createLabelLayer(direct=True)
        if labellayer is not None:
            layers.append(labellayer)
            # Tell the editor where to draw label data
            self.editor.setLabelSink(labelsrc)

        #uncertainty
        uncert = self.topLevelOperatorView.Uncertainty
        if uncert.ready():
            colortable = []
            for i in range(256-len(colortable)):
                r,g,b,a = i,0,0,i
                colortable.append(QColor(r,g,b,a).rgba())

            layer = ColortableLayer(LazyflowSource(uncert), colortable, direct=True)
            layer.name = "uncertainty"
            layer.visible = True
            layer.opacity = 0.3
            layers.append(layer)

       
        #segmentation 
        seg = self.topLevelOperatorView.Segmentation
        
        #seg = self.topLevelOperatorView.MST.value.segmentation
        #temp = self._done_lut[self.MST.value.regionVol[sl[1:4]]]
        if seg.ready():
            #source = RelabelingArraySource(seg)
            #source.setRelabeling(numpy.arange(256, dtype=numpy.uint8))
            colortable = [QColor(0,0,0,0).rgba(), QColor(0,0,0,0).rgba(), QColor(0,255,0).rgba()]
            for i in range(256-len(colortable)):
                r,g,b = numpy.random.randint(0,255), numpy.random.randint(0,255), numpy.random.randint(0,255)
                colortable.append(QColor(r,g,b).rgba())

            layer = ColortableLayer(LazyflowSource(seg), colortable, direct=True)
            layer.name = "segmentation"
            layer.visible = True
            layer.opacity = 0.3
            layers.append(layer)
        
        #done 
        done = self.topLevelOperatorView.DoneObjects
        if done.ready(): 
            colortable = [QColor(0,0,0,0).rgba(), QColor(0,0,255).rgba()]
            for i in range(254-len(colortable)):
                r,g,b = numpy.random.randint(0,255), numpy.random.randint(0,255), numpy.random.randint(0,255)
                colortable.append(QColor(r,g,b).rgba())
            #have to use lazyflow because it provides dirty signals
            layer = ColortableLayer(LazyflowSource(done), colortable, direct=True)
            layer.name = "done"
            layer.visible = False
            layer.opacity = 0.5
            layers.append(layer)

        #hints
        useLazyflow = True
        ctable = [QColor(0,0,0,0).rgba(), QColor(255,0,0).rgba()]
        ctable.extend( [QColor(255*random.random(), 255*random.random(), 255*random.random()) for x in range(254)] )
        if useLazyflow:
            hints = self.topLevelOperatorView.HintOverlay
            layer = ColortableLayer(LazyflowSource(hints), ctable, direct=True)
        else:
            hints = self.topLevelOperatorView._hints
            layer = ColortableLayer(ArraySource(hints), ctable, direct=True)
        if not useLazyflow or hints.ready():
            layer.name = "hints"
            layer.visible = False
            layer.opacity = 1.0
            layers.append(layer)
            
        #pmaps
        useLazyflow = True
        pmaps = self.topLevelOperatorView._pmap
        if pmaps is not None:
            layer = GrayscaleLayer(ArraySource(pmaps), direct=True)
            layer.name = "pmap"
            layer.visible = False
            layer.opacity = 1.0
            layers.append(layer)

        #done seg
        doneSeg = self.topLevelOperatorView.DoneSegmentation
        if doneSeg.ready():
            if self._doneSegmentationLayer is None:
                layer = ColortableLayer(LazyflowSource(doneSeg), self._doneSegmentationColortable, direct=True)
                layer.name = "done seg"
                layer.visible = False
                layer.opacity = 0.5
                self._doneSegmentationLayer = layer
                layers.append(layer)
            else:
                layers.append(self._doneSegmentationLayer)

        #supervoxel
        sv = self.topLevelOperatorView.Supervoxels
        if sv.ready():
            for i in range(256):
                r,g,b = numpy.random.randint(0,255), numpy.random.randint(0,255), numpy.random.randint(0,255)
                colortable.append(QColor(r,g,b).rgba())
            layer = ColortableLayer(LazyflowSource(sv), colortable, direct=True)
            layer.name = "supervoxels"
            layer.visible = False
            layer.opacity = 1.0
            layers.append(layer)

        #raw data
        rawSlot = self.topLevelOperatorView.RawData
        if rawSlot.ready():
            raw5D = self.topLevelOperatorView.RawData.value
            layer = GrayscaleLayer(ArraySource(raw5D), direct=True)
            #layer = GrayscaleLayer( LazyflowSource(rawSlot) )
            layer.name = "raw"
            layer.visible = True
            layer.opacity = 1.0
            layers.append(layer)

        inputSlot = self.topLevelOperatorView.InputData
        if inputSlot.ready():
            layer = GrayscaleLayer( LazyflowSource(inputSlot), direct=True )
            layer.name = "input"
            layer.visible = not rawSlot.ready()
            layer.opacity = 1.0
            layers.append(layer)

        filteredSlot = self.topLevelOperatorView.FilteredInputData
        if filteredSlot.ready():
            layer = GrayscaleLayer( LazyflowSource(filteredSlot) )
            layer.name = "filtered input"
            layer.visible = False
            layer.opacity = 1.0
            layers.append(layer)

        return layers
Esempio n. 13
0
class GraphCutsGui(LabelingGui):
    """
    """
    
    ###########################################
    ### AppletGuiInterface Concrete Methods ###
    ###########################################

    def appletDrawer(self):
        return self.getAppletDrawerUi()

    def viewerControlWidget(self):
        return self._viewerControlUi


    ###########################################
    ###########################################
    
    @traceLogged(traceLogger)
    def __init__(self, topLevelOperatorView):
        """
        """
        with Tracer(traceLogger):
            self.topLevelOperatorView = topLevelOperatorView
            super(GraphCutsGui, self).__init__(self.topLevelOperatorView)

        # Tell our base class which slots to monitor
        labelSlots = LabelingGui.LabelingSlots()
        labelSlots.labelInput = topLevelOperatorView.LabelInputs
        labelSlots.labelOutput = topLevelOperatorView.LabelImages
        labelSlots.labelEraserValue = topLevelOperatorView.opLabelPipeline.opLabelArray.eraser
        labelSlots.labelDelete = topLevelOperatorView.opLabelPipeline.opLabelArray.deleteLabel

        # Base class init
        super(GraphCutsGui, self).__init__( labelSlots, topLevelOperatorView, labelingDrawerUiPath )
        
        self.topLevelOperatorView = topLevelOperatorView
        self.interactiveModeActive = False
        self._currentlySavingPredictions = False

        self.labelingDrawerUi.savePredictionsButton.clicked.connect(self.onSavePredictionsButtonClicked)
        self.labelingDrawerUi.savePredictionsButton.setIcon( QIcon(ilastikIcons.Save) )
        
        self.labelingDrawerUi.liveUpdateButton.setEnabled(False)
        self.labelingDrawerUi.liveUpdateButton.setIcon( QIcon(ilastikIcons.Play) )
        self.labelingDrawerUi.liveUpdateButton.setToolButtonStyle(Qt.ToolButtonTextBesideIcon)
        self.labelingDrawerUi.liveUpdateButton.toggled.connect( self.toggleInteractive )

        self.topLevelOperatorView.MaxLabelValue.notifyDirty( bind(self.handleLabelSelectionChange) )
        
        self._initShortcuts()

        try:
            self.render = True
            self._renderedLayers = {} # (layer name, label number)
            self._renderMgr = RenderingManager(
                renderer=self.editor.view3d.qvtk.renderer,
                qvtk=self.editor.view3d.qvtk)
        except:
            self.render = False
            
    def initAppletDrawerUi(self):
        with Tracer(traceLogger):
            # Load the ui file (find it in our own directory)
            localDir = os.path.split(__file__)[0]
            self._drawer = uic.loadUi(localDir+"/drawer.ui")
            
            layout = QVBoxLayout()
            layout.setSpacing(0)
            self._drawer.setLayout( layout )
    
            thresholdWidget = ThresholdingWidget(self)
            thresholdWidget.valueChanged.connect( self.handleThresholdGuiValuesChanged )
            self._drawer.layout().addWidget( thresholdWidget )
            self._drawer.layout().addSpacerItem( QSpacerItem(0,0,vPolicy=QSizePolicy.Expanding) )
            
            def updateDrawerFromOperator():
                minValue, maxValue = (0,255)

                if self.topLevelOperatorView.MinValue.ready():
                    minValue = self.topLevelOperatorView.MinValue.value
                if self.topLevelOperatorView.MaxValue.ready():
                    maxValue = self.topLevelOperatorView.MaxValue.value

                thresholdWidget.setValue(minValue, maxValue)
                
            self.topLevelOperatorView.MinValue.notifyDirty( bind(updateDrawerFromOperator) )
            self.topLevelOperatorView.MaxValue.notifyDirty( bind(updateDrawerFromOperator) )
            updateDrawerFromOperator()
            

    @traceLogged(traceLogger)
    def initViewerControlUi(self):
        localDir = os.path.split(__file__)[0]
        self._viewerControlUi = uic.loadUi( os.path.join( localDir, "viewerControls.ui" ) )

        # Connect checkboxes
        def nextCheckState(checkbox):
            checkbox.setChecked( not checkbox.isChecked() )
        self._viewerControlUi.checkShowPredictions.nextCheckState = partial(nextCheckState, self._viewerControlUi.checkShowPredictions)
        self._viewerControlUi.checkShowSegmentation.nextCheckState = partial(nextCheckState, self._viewerControlUi.checkShowSegmentation)

        self._viewerControlUi.checkShowPredictions.clicked.connect( self.handleShowPredictionsClicked )
        self._viewerControlUi.checkShowSegmentation.clicked.connect( self.handleShowSegmentationClicked )

        # The editor's layerstack is in charge of which layer movement buttons are enabled
        model = self.editor.layerStack
        self._viewerControlUi.viewerControls.setupConnections(model)
       
    def _initShortcuts(self):
        mgr = ShortcutManager()
        shortcutGroupName = "Predictions"

        togglePredictions = QShortcut( QKeySequence("p"), self, member=self._viewerControlUi.checkShowPredictions.click )
        mgr.register( shortcutGroupName,
                      "Toggle Prediction Layer Visibility",
                      togglePredictions,
                      self._viewerControlUi.checkShowPredictions )

        toggleSegmentation = QShortcut( QKeySequence("s"), self, member=self._viewerControlUi.checkShowSegmentation.click )
        mgr.register( shortcutGroupName,
                      "Toggle Segmentaton Layer Visibility",
                      toggleSegmentation,
                      self._viewerControlUi.checkShowSegmentation )

        toggleLivePredict = QShortcut( QKeySequence("l"), self, member=self.labelingDrawerUi.liveUpdateButton.toggle )
        mgr.register( shortcutGroupName,
                      "Toggle Live Prediction Mode",
                      toggleLivePredict,
                      self.labelingDrawerUi.liveUpdateButton )

    def _setup_contexts(self, layer):
        def callback(pos, clayer=layer):
            name = clayer.name
            if name in self._renderedLayers:
                label = self._renderedLayers.pop(name)
                self._renderMgr.removeObject(label)
                self._update_rendering()
            else:
                label = self._renderMgr.addObject()
                self._renderedLayers[clayer.name] = label
                self._update_rendering()

        if self.render:
            layer.contexts.append(('Toggle 3D rendering', callback))

    def handleThresholdGuiValuesChanged(self, minVal, maxVal):
        with Tracer(traceLogger):
            self.topLevelOperatorView.MinValue.setValue(minVal)
            self.topLevelOperatorView.MaxValue.setValue(maxVal)
    
    def getAppletDrawerUi(self):
        return self._drawer
    
    def setupLayers(self):
        """
        Called by our base class when one of our data slots has changed.
        This function creates a layer for each slot we want displayed in the volume editor.
        """
        # Base class provides the label layer.
        layers = super(GraphCutsGui, self).setupLayers()
        labels = self.labelListData

        # Add each of the segmentations
        for channel, segmentationSlot in enumerate(self.topLevelOperatorView.SegmentationChannels):
            if segmentationSlot.ready() and channel < len(labels):
                ref_label = labels[channel]
                segsrc = LazyflowSource(segmentationSlot)
                segLayer = AlphaModulatedLayer( segsrc,
                                                tintColor=ref_label.pmapColor(),
                                                range=(0.0, 1.0),
                                                normalize=(0.0, 1.0) )

                segLayer.opacity = 1
                segLayer.visible = self.labelingDrawerUi.liveUpdateButton.isChecked()
                segLayer.visibleChanged.connect(self.updateShowSegmentationCheckbox)

                def setLayerColor(c, segLayer=segLayer):
                    segLayer.tintColor = c
                    self._update_rendering()

                def setSegLayerName(n, segLayer=segLayer):
                    oldname = segLayer.name
                    newName = "Segmentation (%s)" % n
                    segLayer.name = newName
                    if not self.render:
                        return
                    if oldname in self._renderedLayers:
                        label = self._renderedLayers.pop(oldname)
                        self._renderedLayers[newName] = label

                setSegLayerName(ref_label.name)

                ref_label.pmapColorChanged.connect(setLayerColor)
                ref_label.nameChanged.connect(setSegLayerName)
                #check if layer is 3d before adding the "Toggle 3D" option
                #this check is done this way to match the VolumeRenderer, in
                #case different 3d-axistags should be rendered like t-x-y
                #_axiskeys = segmentationSlot.meta.getAxisKeys()
                if len(segmentationSlot.meta.shape) == 4:
                    #the Renderer will cut out the last shape-dimension, so
                    #we're checking for 4 dimensions
                    self._setup_contexts(segLayer)
                layers.append(segLayer)
        
        # Add each of the predictions
        for channel, predictionSlot in enumerate(self.topLevelOperatorView.PredictionProbabilityChannels):
            if predictionSlot.ready() and channel < len(labels):
                ref_label = labels[channel]
                predictsrc = LazyflowSource(predictionSlot)
                predictLayer = AlphaModulatedLayer( predictsrc,
                                                    tintColor=ref_label.pmapColor(),
                                                    range=(0.0, 1.0),
                                                    normalize=(0.0, 1.0) )
                predictLayer.opacity = 0.25
                predictLayer.visible = self.labelingDrawerUi.liveUpdateButton.isChecked()
                predictLayer.visibleChanged.connect(self.updateShowPredictionCheckbox)

                def setLayerColor(c, predictLayer=predictLayer):
                    predictLayer.tintColor = c

                def setPredLayerName(n, predictLayer=predictLayer):
                    newName = "Prediction for %s" % n
                    predictLayer.name = newName

                setPredLayerName(ref_label.name)
                ref_label.pmapColorChanged.connect(setLayerColor)
                ref_label.nameChanged.connect(setPredLayerName)
                layers.append(predictLayer)

        # Add the raw data last (on the bottom)
        inputDataSlot = self.topLevelOperatorView.InputImages
        if inputDataSlot.ready():
            inputLayer = self.createStandardLayerFromSlot( inputDataSlot )
            inputLayer.name = "Input Data"
            inputLayer.visible = True
            inputLayer.opacity = 1.0

            def toggleTopToBottom():
                index = self.layerstack.layerIndex( inputLayer )
                self.layerstack.selectRow( index )
                if index == 0:
                    self.layerstack.moveSelectedToBottom()
                else:
                    self.layerstack.moveSelectedToTop()

            inputLayer.shortcutRegistration = (
                "Prediction Layers",
                "Bring Input To Top/Bottom",
                QShortcut( QKeySequence("i"), self.viewerControlWidget(), toggleTopToBottom),
                inputLayer )
            layers.append(inputLayer)
        
        self.handleLabelSelectionChange()
        return layers

    @traceLogged(traceLogger)
    def toggleInteractive(self, checked):
        """
        If enable
        """
        logger.debug("toggling interactive mode to '%r'" % checked)

        if checked==True:
            if not self.topLevelOperatorView.FeatureImages.ready() \
            or self.topLevelOperatorView.FeatureImages.meta.shape==None:
                self.labelingDrawerUi.liveUpdateButton.setChecked(False)
                mexBox=QMessageBox()
                mexBox.setText("There are no features selected ")
                mexBox.exec_()
                return

        self.labelingDrawerUi.savePredictionsButton.setEnabled(not checked)
        self.topLevelOperatorView.FreezePredictions.setValue( not checked )

        # Auto-set the "show predictions" state according to what the user just clicked.
        if checked:
            self._viewerControlUi.checkShowPredictions.setChecked( True )
            self.handleShowPredictionsClicked()

        # If we're changing modes, enable/disable our controls and other applets accordingly
        if self.interactiveModeActive != checked:
            if checked:
                self.labelingDrawerUi.labelListView.allowDelete = False
                self.labelingDrawerUi.AddLabelButton.setEnabled( False )
            else:
                self.labelingDrawerUi.labelListView.allowDelete = True
                self.labelingDrawerUi.AddLabelButton.setEnabled( True )
        self.interactiveModeActive = checked

    @pyqtSlot()
    @traceLogged(traceLogger)
    def handleShowPredictionsClicked(self):
        checked = self._viewerControlUi.checkShowPredictions.isChecked()
        for layer in self.layerstack:
            if "Prediction" in layer.name:
                layer.visible = checked

    @pyqtSlot()
    @traceLogged(traceLogger)
    def handleShowSegmentationClicked(self):
        checked = self._viewerControlUi.checkShowSegmentation.isChecked()
        for layer in self.layerstack:
            if "Segmentation" in layer.name:
                layer.visible = checked

    @pyqtSlot()
    @traceLogged(traceLogger)
    def updateShowPredictionCheckbox(self):
        predictLayerCount = 0
        visibleCount = 0
        for layer in self.layerstack:
            if "Prediction" in layer.name:
                predictLayerCount += 1
                if layer.visible:
                    visibleCount += 1

        if visibleCount == 0:
            self._viewerControlUi.checkShowPredictions.setCheckState(Qt.Unchecked)
        elif predictLayerCount == visibleCount:
            self._viewerControlUi.checkShowPredictions.setCheckState(Qt.Checked)
        else:
            self._viewerControlUi.checkShowPredictions.setCheckState(Qt.PartiallyChecked)

    @pyqtSlot()
    @traceLogged(traceLogger)
    def updateShowSegmentationCheckbox(self):
        segLayerCount = 0
        visibleCount = 0
        for layer in self.layerstack:
            if "Segmentation" in layer.name:
                segLayerCount += 1
                if layer.visible:
                    visibleCount += 1

        if visibleCount == 0:
            self._viewerControlUi.checkShowSegmentation.setCheckState(Qt.Unchecked)
        elif segLayerCount == visibleCount:
            self._viewerControlUi.checkShowSegmentation.setCheckState(Qt.Checked)
        else:
            self._viewerControlUi.checkShowSegmentation.setCheckState(Qt.PartiallyChecked)

    @pyqtSlot()
    @threadRouted
    @traceLogged(traceLogger)
    def handleLabelSelectionChange(self):
        enabled = False
        if self.topLevelOperatorView.MaxLabelValue.ready():
            enabled = True
            enabled &= self.topLevelOperatorView.MaxLabelValue.value >= 2
            enabled &= numpy.all(numpy.asarray(self.topLevelOperatorView.CachedFeatureImages.meta.shape) > 0)
            # FIXME: also check that each label has scribbles?
        
        self.labelingDrawerUi.savePredictionsButton.setEnabled(enabled)
        self.labelingDrawerUi.liveUpdateButton.setEnabled(enabled)
        self._viewerControlUi.checkShowPredictions.setEnabled(enabled)
        self._viewerControlUi.checkShowSegmentation.setEnabled(enabled)

    @pyqtSlot()
    @traceLogged(traceLogger)
    def onSavePredictionsButtonClicked(self):
        """
        The user clicked "Train and Predict".
        Handle this event by asking the topLevelOperatorView for a prediction over the entire output region.
        """
        # The button does double-duty as a cancel button while predictions are being stored
        if self._currentlySavingPredictions:
            self.predictionSerializer.cancel()
        else:
            # Compute new predictions as needed
            predictionsFrozen = self.topLevelOperatorView.FreezePredictions.value
            self.topLevelOperatorView.FreezePredictions.setValue(False)
            self._currentlySavingPredictions = True

            originalButtonText = "Full Volume Predict and Save"
            self.labelingDrawerUi.savePredictionsButton.setText("Cancel Full Predict")
            
            # Make sure the user can't paint anything while the computation is in progress.
            self._changeInteractionMode(Tool.Navigation)

            @traceLogged(traceLogger)
            def saveThreadFunc():
                logger.info("Starting full volume save...")
                # Disable all other applets
                self.guiControlSignal.emit( ControlCommand.DisableUpstream )
                self.guiControlSignal.emit( ControlCommand.DisableDownstream )

                def disableAllInWidgetButName(widget, exceptName):
                    for child in widget.children():
                        if child.findChild( QPushButton, exceptName) is None:
                            child.setEnabled(False)
                        else:
                            disableAllInWidgetButName(child, exceptName)

                # Disable everything in our drawer *except* the cancel button
                disableAllInWidgetButName(self.labelingDrawerUi, "savePredictionsButton")

                # But allow the user to cancel the save
                self.labelingDrawerUi.savePredictionsButton.setEnabled(True)

                # First, do a regular save.
                # During a regular save, predictions are not saved to the project file.
                # (It takes too much time if the user only needs the classifier.)
                self.shellRequestSignal.emit( ShellRequest.RequestSave )

                # Enable prediction storage and ask the shell to save the project again.
                # (This way the second save will occupy the whole progress bar.)
                self.predictionSerializer.predictionStorageEnabled = True
                self.shellRequestSignal.emit( ShellRequest.RequestSave )
                self.predictionSerializer.predictionStorageEnabled = False

                # Restore original states (must use events for UI calls)
                self.thunkEventHandler.post(self.labelingDrawerUi.savePredictionsButton.setText, originalButtonText)
                self.topLevelOperatorView.FreezePredictions.setValue(predictionsFrozen)
                self._currentlySavingPredictions = False

                # Re-enable our controls
                def enableAll(widget):
                    for child in widget.children():
                        if isinstance( child, QWidget ):
                            child.setEnabled(True)
                            enableAll(child)
                enableAll(self.labelingDrawerUi)

                # Re-enable all other applets
                self.guiControlSignal.emit( ControlCommand.Pop )
                self.guiControlSignal.emit( ControlCommand.Pop )
                logger.info("Finished full volume save.")

            saveThread = threading.Thread(target=saveThreadFunc)
            saveThread.start()

    def _getNext(self, slot, parentFun, transform=None):
        numLabels = self.labelListData.rowCount()
        value = slot.value
        if numLabels < len(value):
            result = value[numLabels]
            if transform is not None:
                result = transform(result)
            return result
        else:
            return parentFun()

    def _onLabelChanged(self, parentFun, mapf, slot):
        parentFun()
        new = map(mapf, self.labelListData)
        old = slot.value
        slot.setValue(_listReplace(old, new))

    def _onLabelRemoved(self, parent, start, end):
        super(GraphCutsGui, self)._onLabelRemoved(parent, start, end)
        op = self.topLevelOperatorView
        for slot in (op.LabelNames, op.LabelColors, op.PmapColors):
            value = slot.value
            value.pop(start)
            slot.setValue(value)

    def getNextLabelName(self):
        return self._getNext(self.topLevelOperatorView.LabelNames,
                             super(GraphCutsGui, self).getNextLabelName)

    def getNextLabelColor(self):
        return self._getNext(
            self.topLevelOperatorView.LabelColors,
            super(GraphCutsGui, self).getNextLabelColor,
            lambda x: QColor(*x)
        )

    def getNextPmapColor(self):
        return self._getNext(
            self.topLevelOperatorView.PmapColors,
            super(GraphCutsGui, self).getNextPmapColor,
            lambda x: QColor(*x)
        )

    def onLabelNameChanged(self):
        self._onLabelChanged(super(GraphCutsGui, self).onLabelNameChanged,
                             lambda l: l.name,
                             self.topLevelOperatorView.LabelNames)

    def onLabelColorChanged(self):
        self._onLabelChanged(super(GraphCutsGui, self).onLabelColorChanged,
                             lambda l: (l.brushColor().red(),
                                        l.brushColor().green(),
                                        l.brushColor().blue()),
                             self.topLevelOperatorView.LabelColors)


    def onPmapColorChanged(self):
        self._onLabelChanged(super(GraphCutsGui, self).onPmapColorChanged,
                             lambda l: (l.pmapColor().red(),
                                        l.pmapColor().green(),
                                        l.pmapColor().blue()),
                             self.topLevelOperatorView.PmapColors)

    def _update_rendering(self):
        if not self.render:
            return
        shape = self.topLevelOperatorView.InputImages.meta.shape[1:4]
        time = self.editor.posModel.slicingPos5D[0]
        if not self._renderMgr.ready:
            self._renderMgr.setup(shape)

        layernames = set(layer.name for layer in self.layerstack)
        self._renderedLayers = dict((k, v) for k, v in self._renderedLayers.iteritems()
                                if k in layernames)

        newvolume = numpy.zeros(shape, dtype=numpy.uint8)
        for layer in self.layerstack:
            try:
                label = self._renderedLayers[layer.name]
            except KeyError:
                continue
            for ds in layer.datasources:
                vol = ds.dataSlot.value[time, ..., 0]
                indices = numpy.where(vol != 0)
                newvolume[indices] = label

        self._renderMgr.volume = newvolume
        self._update_colors()
        self._renderMgr.update()

    def _update_colors(self):
        for layer in self.layerstack:
            try:
                label = self._renderedLayers[layer.name]
            except KeyError:
                continue
            color = layer.tintColor
            color = (color.red() / 255.0, color.green() / 255.0, color.blue() / 255.0)
            self._renderMgr.setColor(label, color)
Esempio n. 14
0
class PixelClassificationGui(LabelingGui):

    ###########################################
    ### AppletGuiInterface Concrete Methods ###
    ###########################################
    def centralWidget(self):
        return self

    def stopAndCleanUp(self):
        # Base class first
        super(PixelClassificationGui, self).stopAndCleanUp()

        # Ensure that we are NOT in interactive mode
        self.labelingDrawerUi.liveUpdateButton.setChecked(False)
        self._viewerControlUi.checkShowPredictions.setChecked(False)
        self._viewerControlUi.checkShowSegmentation.setChecked(False)
        self.toggleInteractive(False)

        for fn in self.__cleanup_fns:
            fn()

    def viewerControlWidget(self):
        return self._viewerControlUi

    ###########################################
    ###########################################

    def __init__(self, topLevelOperatorView, shellRequestSignal,
                 guiControlSignal, predictionSerializer):
        # Tell our base class which slots to monitor
        labelSlots = LabelingGui.LabelingSlots()
        labelSlots.labelInput = topLevelOperatorView.LabelInputs
        labelSlots.labelOutput = topLevelOperatorView.LabelImages
        labelSlots.labelEraserValue = topLevelOperatorView.opLabelPipeline.opLabelArray.eraser
        labelSlots.labelDelete = topLevelOperatorView.opLabelPipeline.opLabelArray.deleteLabel
        labelSlots.maxLabelValue = topLevelOperatorView.MaxLabelValue
        labelSlots.labelsAllowed = topLevelOperatorView.LabelsAllowedFlags
        labelSlots.LabelNames = topLevelOperatorView.LabelNames

        self.__cleanup_fns = []

        # We provide our own UI file (which adds an extra control for interactive mode)
        labelingDrawerUiPath = os.path.split(
            __file__)[0] + '/labelingDrawer.ui'

        # Base class init
        super(PixelClassificationGui,
              self).__init__(labelSlots, topLevelOperatorView,
                             labelingDrawerUiPath)

        self.topLevelOperatorView = topLevelOperatorView
        self.shellRequestSignal = shellRequestSignal
        self.guiControlSignal = guiControlSignal
        self.predictionSerializer = predictionSerializer

        self.interactiveModeActive = False
        # Immediately update our interactive state
        self.toggleInteractive(
            not self.topLevelOperatorView.FreezePredictions.value)

        self._currentlySavingPredictions = False

        self.labelingDrawerUi.liveUpdateButton.setEnabled(False)
        self.labelingDrawerUi.liveUpdateButton.setIcon(QIcon(
            ilastikIcons.Play))
        self.labelingDrawerUi.liveUpdateButton.setToolButtonStyle(
            Qt.ToolButtonTextBesideIcon)
        self.labelingDrawerUi.liveUpdateButton.toggled.connect(
            self.toggleInteractive)

        self.topLevelOperatorView.MaxLabelValue.notifyDirty(
            bind(self.handleLabelSelectionChange))
        self.__cleanup_fns.append(
            partial(self.topLevelOperatorView.MaxLabelValue.unregisterDirty,
                    bind(self.handleLabelSelectionChange)))

        self._initShortcuts()

        try:
            self.render = True
            self._renderedLayers = {}  # (layer name, label number)
            self._renderMgr = RenderingManager(
                renderer=self.editor.view3d.qvtk.renderer,
                qvtk=self.editor.view3d.qvtk)
        except:
            self.render = False

        # toggle interactive mode according to freezePredictions.value
        self.toggleInteractive(
            not self.topLevelOperatorView.FreezePredictions.value)

        def FreezePredDirty():
            self.toggleInteractive(
                not self.topLevelOperatorView.FreezePredictions.value)

        # listen to freezePrediction changes
        self.topLevelOperatorView.FreezePredictions.notifyDirty(
            bind(FreezePredDirty))
        self.__cleanup_fns.append(
            partial(
                self.topLevelOperatorView.FreezePredictions.unregisterDirty,
                bind(FreezePredDirty)))

    def initViewerControlUi(self):
        localDir = os.path.split(__file__)[0]
        self._viewerControlUi = uic.loadUi(
            os.path.join(localDir, "viewerControls.ui"))

        # Connect checkboxes
        def nextCheckState(checkbox):
            checkbox.setChecked(not checkbox.isChecked())

        self._viewerControlUi.checkShowPredictions.nextCheckState = partial(
            nextCheckState, self._viewerControlUi.checkShowPredictions)
        self._viewerControlUi.checkShowSegmentation.nextCheckState = partial(
            nextCheckState, self._viewerControlUi.checkShowSegmentation)

        self._viewerControlUi.checkShowPredictions.clicked.connect(
            self.handleShowPredictionsClicked)
        self._viewerControlUi.checkShowSegmentation.clicked.connect(
            self.handleShowSegmentationClicked)

        # The editor's layerstack is in charge of which layer movement buttons are enabled
        model = self.editor.layerStack
        self._viewerControlUi.viewerControls.setupConnections(model)

    def _initShortcuts(self):
        mgr = ShortcutManager()
        shortcutGroupName = "Predictions"

        togglePredictions = QShortcut(
            QKeySequence("p"),
            self,
            member=self._viewerControlUi.checkShowPredictions.click)
        mgr.register(shortcutGroupName, "Toggle Prediction Layer Visibility",
                     togglePredictions,
                     self._viewerControlUi.checkShowPredictions)

        toggleSegmentation = QShortcut(
            QKeySequence("s"),
            self,
            member=self._viewerControlUi.checkShowSegmentation.click)
        mgr.register(shortcutGroupName, "Toggle Segmentaton Layer Visibility",
                     toggleSegmentation,
                     self._viewerControlUi.checkShowSegmentation)

        toggleLivePredict = QShortcut(
            QKeySequence("l"),
            self,
            member=self.labelingDrawerUi.liveUpdateButton.toggle)
        mgr.register(shortcutGroupName, "Toggle Live Prediction Mode",
                     toggleLivePredict, self.labelingDrawerUi.liveUpdateButton)

    def _setup_contexts(self, layer):
        def callback(pos, clayer=layer):
            name = clayer.name
            if name in self._renderedLayers:
                label = self._renderedLayers.pop(name)
                self._renderMgr.removeObject(label)
                self._update_rendering()
            else:
                label = self._renderMgr.addObject()
                self._renderedLayers[clayer.name] = label
                self._update_rendering()

        if self.render:
            layer.contexts.append(('Toggle 3D rendering', callback))

    def setupLayers(self):
        """
        Called by our base class when one of our data slots has changed.
        This function creates a layer for each slot we want displayed in the volume editor.
        """
        # Base class provides the label layer.
        layers = super(PixelClassificationGui, self).setupLayers()

        # Add the uncertainty estimate layer
        uncertaintySlot = self.topLevelOperatorView.UncertaintyEstimate
        if uncertaintySlot.ready():
            uncertaintySrc = LazyflowSource(uncertaintySlot)
            uncertaintyLayer = AlphaModulatedLayer(uncertaintySrc,
                                                   tintColor=QColor(Qt.cyan),
                                                   range=(0.0, 1.0),
                                                   normalize=(0.0, 1.0))
            uncertaintyLayer.name = "Uncertainty"
            uncertaintyLayer.visible = False
            uncertaintyLayer.opacity = 1.0
            uncertaintyLayer.shortcutRegistration = (
                "Prediction Layers", "Show/Hide Uncertainty",
                QShortcut(QKeySequence("u"), self.viewerControlWidget(),
                          uncertaintyLayer.toggleVisible), uncertaintyLayer)
            layers.append(uncertaintyLayer)

        labels = self.labelListData

        # Add each of the segmentations
        for channel, segmentationSlot in enumerate(
                self.topLevelOperatorView.SegmentationChannels):
            if segmentationSlot.ready() and channel < len(labels):
                ref_label = labels[channel]
                segsrc = LazyflowSource(segmentationSlot)
                segLayer = AlphaModulatedLayer(segsrc,
                                               tintColor=ref_label.pmapColor(),
                                               range=(0.0, 1.0),
                                               normalize=(0.0, 1.0))

                segLayer.opacity = 1
                segLayer.visible = False  #self.labelingDrawerUi.liveUpdateButton.isChecked()
                segLayer.visibleChanged.connect(
                    self.updateShowSegmentationCheckbox)

                def setLayerColor(c, segLayer=segLayer):
                    segLayer.tintColor = c
                    self._update_rendering()

                def setSegLayerName(n, segLayer=segLayer):
                    oldname = segLayer.name
                    newName = "Segmentation (%s)" % n
                    segLayer.name = newName
                    if not self.render:
                        return
                    if oldname in self._renderedLayers:
                        label = self._renderedLayers.pop(oldname)
                        self._renderedLayers[newName] = label

                setSegLayerName(ref_label.name)

                ref_label.pmapColorChanged.connect(setLayerColor)
                ref_label.nameChanged.connect(setSegLayerName)
                #check if layer is 3d before adding the "Toggle 3D" option
                #this check is done this way to match the VolumeRenderer, in
                #case different 3d-axistags should be rendered like t-x-y
                #_axiskeys = segmentationSlot.meta.getAxisKeys()
                if len(segmentationSlot.meta.shape) == 4:
                    #the Renderer will cut out the last shape-dimension, so
                    #we're checking for 4 dimensions
                    self._setup_contexts(segLayer)
                layers.append(segLayer)

        # Add each of the predictions
        for channel, predictionSlot in enumerate(
                self.topLevelOperatorView.PredictionProbabilityChannels):
            if predictionSlot.ready() and channel < len(labels):
                ref_label = labels[channel]
                predictsrc = LazyflowSource(predictionSlot)
                predictLayer = AlphaModulatedLayer(
                    predictsrc,
                    tintColor=ref_label.pmapColor(),
                    range=(0.0, 1.0),
                    normalize=(0.0, 1.0))
                predictLayer.opacity = 0.25
                predictLayer.visible = self.labelingDrawerUi.liveUpdateButton.isChecked(
                )
                predictLayer.visibleChanged.connect(
                    self.updateShowPredictionCheckbox)

                def setLayerColor(c, predictLayer=predictLayer):
                    predictLayer.tintColor = c

                def setPredLayerName(n, predictLayer=predictLayer):
                    newName = "Prediction for %s" % n
                    predictLayer.name = newName

                setPredLayerName(ref_label.name)
                ref_label.pmapColorChanged.connect(setLayerColor)
                ref_label.nameChanged.connect(setPredLayerName)
                layers.append(predictLayer)

        # Add the raw data last (on the bottom)
        inputDataSlot = self.topLevelOperatorView.InputImages
        if inputDataSlot.ready():
            inputLayer = self.createStandardLayerFromSlot(inputDataSlot)
            inputLayer.name = "Input Data"
            inputLayer.visible = True
            inputLayer.opacity = 1.0

            def toggleTopToBottom():
                index = self.layerstack.layerIndex(inputLayer)
                self.layerstack.selectRow(index)
                if index == 0:
                    self.layerstack.moveSelectedToBottom()
                else:
                    self.layerstack.moveSelectedToTop()

            inputLayer.shortcutRegistration = ("Prediction Layers",
                                               "Bring Input To Top/Bottom",
                                               QShortcut(
                                                   QKeySequence("i"),
                                                   self.viewerControlWidget(),
                                                   toggleTopToBottom),
                                               inputLayer)
            layers.append(inputLayer)

        self.handleLabelSelectionChange()
        return layers

    def toggleInteractive(self, checked):
        logger.debug("toggling interactive mode to '%r'" % checked)

        if checked == True:
            if not self.topLevelOperatorView.FeatureImages.ready() \
            or self.topLevelOperatorView.FeatureImages.meta.shape==None:
                self.labelingDrawerUi.liveUpdateButton.setChecked(False)
                mexBox = QMessageBox()
                mexBox.setText("There are no features selected ")
                mexBox.exec_()
                return

        # If we're changing modes, enable/disable our controls and other applets accordingly
        if self.interactiveModeActive != checked:
            if checked:
                self.labelingDrawerUi.labelListView.allowDelete = False
                self.labelingDrawerUi.AddLabelButton.setEnabled(False)
                self.guiControlSignal.emit(ControlCommand.DisableUpstream)
            else:
                self.labelingDrawerUi.labelListView.allowDelete = True
                self.labelingDrawerUi.AddLabelButton.setEnabled(True)
                self.guiControlSignal.emit(ControlCommand.Pop)
        self.interactiveModeActive = checked

        self.topLevelOperatorView.FreezePredictions.setValue(not checked)
        self.labelingDrawerUi.liveUpdateButton.setChecked(checked)
        # Auto-set the "show predictions" state according to what the user just clicked.
        if checked:
            self._viewerControlUi.checkShowPredictions.setChecked(True)
            self.handleShowPredictionsClicked()

    @pyqtSlot()
    def handleShowPredictionsClicked(self):
        checked = self._viewerControlUi.checkShowPredictions.isChecked()
        for layer in self.layerstack:
            if "Prediction" in layer.name:
                layer.visible = checked

    @pyqtSlot()
    def handleShowSegmentationClicked(self):
        checked = self._viewerControlUi.checkShowSegmentation.isChecked()
        for layer in self.layerstack:
            if "Segmentation" in layer.name:
                layer.visible = checked

    @pyqtSlot()
    def updateShowPredictionCheckbox(self):
        predictLayerCount = 0
        visibleCount = 0
        for layer in self.layerstack:
            if "Prediction" in layer.name:
                predictLayerCount += 1
                if layer.visible:
                    visibleCount += 1

        if visibleCount == 0:
            self._viewerControlUi.checkShowPredictions.setCheckState(
                Qt.Unchecked)
        elif predictLayerCount == visibleCount:
            self._viewerControlUi.checkShowPredictions.setCheckState(
                Qt.Checked)
        else:
            self._viewerControlUi.checkShowPredictions.setCheckState(
                Qt.PartiallyChecked)

    @pyqtSlot()
    def updateShowSegmentationCheckbox(self):
        segLayerCount = 0
        visibleCount = 0
        for layer in self.layerstack:
            if "Segmentation" in layer.name:
                segLayerCount += 1
                if layer.visible:
                    visibleCount += 1

        if visibleCount == 0:
            self._viewerControlUi.checkShowSegmentation.setCheckState(
                Qt.Unchecked)
        elif segLayerCount == visibleCount:
            self._viewerControlUi.checkShowSegmentation.setCheckState(
                Qt.Checked)
        else:
            self._viewerControlUi.checkShowSegmentation.setCheckState(
                Qt.PartiallyChecked)

    @pyqtSlot()
    @threadRouted
    def handleLabelSelectionChange(self):
        enabled = False
        if self.topLevelOperatorView.MaxLabelValue.ready():
            enabled = True
            enabled &= self.topLevelOperatorView.MaxLabelValue.value >= 2
            enabled &= numpy.all(
                numpy.asarray(self.topLevelOperatorView.CachedFeatureImages.
                              meta.shape) > 0)
            # FIXME: also check that each label has scribbles?

        if not enabled:
            self.labelingDrawerUi.liveUpdateButton.setChecked(False)
            self._viewerControlUi.checkShowPredictions.setChecked(False)
            self._viewerControlUi.checkShowSegmentation.setChecked(False)
            self.handleShowPredictionsClicked()
            self.handleShowSegmentationClicked()

        self.labelingDrawerUi.liveUpdateButton.setEnabled(enabled)
        self._viewerControlUi.checkShowPredictions.setEnabled(enabled)
        self._viewerControlUi.checkShowSegmentation.setEnabled(enabled)

    def _getNext(self, slot, parentFun, transform=None):
        numLabels = self.labelListData.rowCount()
        value = slot.value
        if numLabels < len(value):
            result = value[numLabels]
            if transform is not None:
                result = transform(result)
            return result
        else:
            return parentFun()

    def _onLabelChanged(self, parentFun, mapf, slot):
        parentFun()
        new = map(mapf, self.labelListData)
        old = slot.value
        slot.setValue(_listReplace(old, new))

    def _onLabelRemoved(self, parent, start, end):
        # Update the label names/colors BEFORE calling the base class,
        #  which will update the operator and expects the
        #  label names list to be correct.
        op = self.topLevelOperatorView
        for slot in (op.LabelNames, op.LabelColors, op.PmapColors):
            value = slot.value
            value.pop(start)
            slot.setValue(value)

        # Call the base class to update the operator.
        super(PixelClassificationGui, self)._onLabelRemoved(parent, start, end)

    def getNextLabelName(self):
        return self._getNext(
            self.topLevelOperatorView.LabelNames,
            super(PixelClassificationGui, self).getNextLabelName)

    def getNextLabelColor(self):
        return self._getNext(
            self.topLevelOperatorView.LabelColors,
            super(PixelClassificationGui, self).getNextLabelColor,
            lambda x: QColor(*x))

    def getNextPmapColor(self):
        return self._getNext(
            self.topLevelOperatorView.PmapColors,
            super(PixelClassificationGui, self).getNextPmapColor,
            lambda x: QColor(*x))

    def onLabelNameChanged(self):
        self._onLabelChanged(
            super(PixelClassificationGui, self).onLabelNameChanged,
            lambda l: l.name, self.topLevelOperatorView.LabelNames)

    def onLabelColorChanged(self):
        self._onLabelChanged(
            super(PixelClassificationGui, self).onLabelColorChanged, lambda l:
            (l.brushColor().red(), l.brushColor().green(), l.brushColor().blue(
            )), self.topLevelOperatorView.LabelColors)

    def onPmapColorChanged(self):
        self._onLabelChanged(
            super(PixelClassificationGui, self).onPmapColorChanged, lambda l:
            (l.pmapColor().red(), l.pmapColor().green(), l.pmapColor().blue()),
            self.topLevelOperatorView.PmapColors)

    def _update_rendering(self):
        if not self.render:
            return
        shape = self.topLevelOperatorView.InputImages.meta.shape[1:4]
        if len(shape) != 5:
            #this might be a 2D image, no need for updating any 3D stuff
            return

        time = self.editor.posModel.slicingPos5D[0]
        if not self._renderMgr.ready:
            self._renderMgr.setup(shape)

        layernames = set(layer.name for layer in self.layerstack)
        self._renderedLayers = dict(
            (k, v) for k, v in self._renderedLayers.iteritems()
            if k in layernames)

        newvolume = numpy.zeros(shape, dtype=numpy.uint8)
        for layer in self.layerstack:
            try:
                label = self._renderedLayers[layer.name]
            except KeyError:
                continue
            for ds in layer.datasources:
                vol = ds.dataSlot.value[time, ..., 0]
                indices = numpy.where(vol != 0)
                newvolume[indices] = label

        self._renderMgr.volume = newvolume
        self._update_colors()
        self._renderMgr.update()

    def _update_colors(self):
        for layer in self.layerstack:
            try:
                label = self._renderedLayers[layer.name]
            except KeyError:
                continue
            color = layer.tintColor
            color = (color.red() / 255.0, color.green() / 255.0,
                     color.blue() / 255.0)
            self._renderMgr.setColor(label, color)