Esempio n. 1
0
class MainWindow(QMainWindow, Ui4_MainWindow):
    def __init__(self, parent=None):
        super(MainWindow, self).__init__(parent)
        self.setupUi(self)
        self.OpenVTK()
        self.update()

    def update(self):
        self.width = self.width()
        self.height = self.height()
        self.frame.setGeometry(self.width / 2, 0, self.width / 2, self.height)

    def OpenVTK(self):
        self.vtkWidget = QVTKRenderWindowInteractor(self.frame)
        self.vl = Qt.QVBoxLayout()
        self.vl.addWidget(self.vtkWidget)

        self.ren = vtk.vtkRenderer()
        self.vtkWidget.GetRenderWindow().AddRenderer(self.ren)
        self.iren = self.vtkWidget.GetRenderWindow().GetInteractor()
        self.iren.GetInteractorStyle().SetCurrentStyleToTrackballCamera()

        # Create source
        # source = vtk.vtkSphereSource()
        # source.SetCenter(0, 0, 0)
        # source.SetRadius(5.0)

        source = vtk.vtkCubeSource()
        source.SetXLength(1500)
        source.SetYLength(1500)
        source.SetZLength(100)
        source.SetCenter(0, 0, 0)

        source8 = vtk.vtkCubeSource()
        source8.SetXLength(10)
        source8.SetYLength(100)
        source8.SetZLength(100)
        source8.SetCenter(0, 0, 0)

        source2 = vtk.vtkCubeSource()
        source2.SetXLength(25)
        source2.SetYLength(200)
        source2.SetZLength(550)
        source2.SetCenter(0, 0, 0)

        source3 = vtk.vtkCubeSource()
        source3.SetXLength(1500)
        source3.SetYLength(25)
        source3.SetZLength(150)
        source2.SetCenter(0, 0, 0)

        source4 = vtk.vtkCubeSource()
        source4.SetXLength(200)
        source4.SetYLength(25)
        source4.SetZLength(550)
        source4.SetCenter(0, 0, 0)

        source5 = vtk.vtkCylinderSource()
        source5.SetRadius(70)
        source5.SetHeight(350)
        source5.SetCenter(0, 0, 0)
        source5.SetResolution(100)

        source6 = vtk.vtkCylinderSource()
        source6.SetRadius(20)
        source6.SetHeight(100)
        source6.SetCenter(0, 0, 0)
        source6.SetResolution(100)

        source7 = vtk.vtkTriangle()
        points = vtk.vtkPoints()
        points.InsertNextPoint(100.0, 0.0, 0.0)
        points.InsertNextPoint(0.0, 0.0, 0.0)
        points.InsertNextPoint(0.0, 0.0, 100.0)

        source7.GetPointIds().SetId(0, 0)
        source7.GetPointIds().SetId(1, 1)
        source7.GetPointIds().SetId(2, 2)

        triangles = vtk.vtkCellArray()
        triangles.InsertNextCell(source7)

        trianglePolyData = vtk.vtkPolyData()
        trianglePolyData.SetPoints(points)
        trianglePolyData.SetPolys(triangles)

        #mapperBase8 = vtk.vtkPolyDataMapper()
        #mapperBase8.SetInputData(trianglePolyData)

        filename = "/Users/steve/PycharmProjects/vtktest/CylinderHead-stl/CylinderHead-binary.stl"
        reader = vtk.vtkSTLReader()
        reader.SetFileName(filename)
        reader.Update()

        # Create a mapper
        #mapperSTL = vtk.vtkPolyDataMapper()
        #mapperSTL.SetInputConnection(reader.GetOutputPort())
        mapperBase = vtk.vtkPolyDataMapper()
        mapperBase.SetInputConnection(source.GetOutputPort())
        mapperBase2 = vtk.vtkPolyDataMapper()
        mapperBase2.SetInputConnection(source2.GetOutputPort())
        mapperBase3 = vtk.vtkPolyDataMapper()
        mapperBase3.SetInputConnection(source2.GetOutputPort())
        mapperBase4 = vtk.vtkPolyDataMapper()
        mapperBase4.SetInputConnection(source3.GetOutputPort())
        mapperBase5 = vtk.vtkPolyDataMapper()
        mapperBase5.SetInputConnection(source4.GetOutputPort())
        mapperBase6 = vtk.vtkPolyDataMapper()
        mapperBase6.SetInputConnection(source5.GetOutputPort())
        mapperBase7 = vtk.vtkPolyDataMapper()
        mapperBase7.SetInputConnection(source6.GetOutputPort())
        mapperBase8 = vtk.vtkPolyDataMapper()
        mapperBase8.SetInputConnection(source8.GetOutputPort())

        BaseActor8 = vtk.vtkActor()
        BaseActor8.SetMapper(mapperBase8)
        # BaseActor8.GetProperty().SetColor(1.0, 215 / 255, 0)
        #BaseActor8.RotateX(-45)
        # BaseActor8.SetPosition(0,2000,500)
        BaseActor8.GetProperty().SetColor(0, 0, 1)

        # Create an actor
        #STLactor = vtk.vtkActor()
        #STLactor.SetMapper(mapperSTL)
        BaseActor = vtk.vtkActor()
        BaseActor.SetMapper(mapperBase)
        #BaseActor.SetPosition((0, 0, 0))
        BaseActor.GetProperty().SetColor(1.0, 0, 0)
        #STLactor.SetScale(5,5,5)
        BaseActor2 = vtk.vtkActor()
        BaseActor2.SetMapper(mapperBase2)
        BaseActor2.GetProperty().SetColor(0, 0, 1.0)
        BaseActor3 = vtk.vtkActor()
        BaseActor3.SetMapper(mapperBase3)
        BaseActor3.GetProperty().SetColor(0, 0, 1.0)
        BaseActor4 = vtk.vtkActor()
        BaseActor4.SetMapper(mapperBase4)
        BaseActor4.GetProperty().SetColor(0, 0, 1.0)
        BaseActor5 = vtk.vtkActor()
        BaseActor5.SetMapper(mapperBase5)
        BaseActor5.GetProperty().SetColor(1.0, 1.0, 0)
        BaseActor6 = vtk.vtkActor()
        BaseActor6.SetMapper(mapperBase6)
        BaseActor6.GetProperty().SetColor(1.0, 69 / 255, 0)
        BaseActor6.RotateX(90)
        BaseActor7 = vtk.vtkActor()
        BaseActor7.SetMapper(mapperBase7)
        BaseActor7.GetProperty().SetColor(1.0, 215 / 255, 0)
        BaseActor7.RotateX(90)

        #self.ren.AddActor(STLactor)
        self.ren.AddActor(BaseActor)
        self.ren.AddActor(BaseActor2)
        self.ren.AddActor(BaseActor3)
        self.ren.AddActor(BaseActor4)
        self.ren.AddActor(BaseActor5)
        self.ren.AddActor(BaseActor6)
        self.ren.AddActor(BaseActor7)
        self.ren.AddActor(BaseActor8)

        self.ren.ResetCamera()

        self.frame.setLayout(self.vl)
        # self.setCentralWidget(self.frame)

        self.show()
        self.iren.Initialize()
        # put timer event here

        cbBase2 = vtkTimerCallback()
        cbBase3 = vtkTimerCallback()
        cbBase4 = vtkTimerCallback()
        cbBase5 = vtkTimerCallback()
        cbBase6 = vtkTimerCallback()
        cbBase7 = vtkTimerCallback()
        cbBase8 = vtkTimerCallback()
        cbBase2.actor = BaseActor2
        cbBase3.actor = BaseActor3
        cbBase4.actor = BaseActor4
        cbBase5.actor = BaseActor5
        cbBase6.actor = BaseActor6
        cbBase7.actor = BaseActor7
        cbBase8.actor = BaseActor8

        #cbSTL = vtkTimerCallback()
        #cbSTL.actor = STLactor

        def getfiles():
            try:
                file = QFileDialog.getOpenFileName(self, 'Single File',
                                                   QtCore.QDir.rootPath(),
                                                   '*.stl')
                if file == None:
                    file = "/Users/steve/PycharmProjects/vtktest/CylinderHead-stl/CylinderHead-binary.stl"
                str = file[0]
                reader.SetFileName(str)
                reader.Update()
                #mapper.SetInputConnection(reader.GetOutputPort())
                #STLactor.SetMapper(mapper)
            except Exception as e:
                print("Exception in method")
                print(e)

        def stopR():
            cbBase8.timerVar3 = 0

        def startR():
            cbBase8.timerVar3 = 1

        def stopY():
            cbBase2.timerVar1 = 1
            cbBase3.timerVar1 = 1
            cbBase4.timerVar1 = 1
            cbBase5.timerVar1 = 1
            cbBase6.timerVar1 = 1
            cbBase7.timerVar1 = 1
            cbBase8.timerVar1 = 1

        def forwardY():
            cbBase2.timerVar1 = 0
            cbBase3.timerVar1 = 0
            cbBase4.timerVar1 = 0
            cbBase5.timerVar1 = 0
            cbBase6.timerVar1 = 0
            cbBase7.timerVar1 = 0
            cbBase2.count = 3
            cbBase3.count = 3
            cbBase4.count = 3
            cbBase5.count = 3
            cbBase6.count = 3
            cbBase7.count = 3
            cbBase8.count = 3

        def backY():
            cbBase2.timerVar1 = 0
            cbBase3.timerVar1 = 0
            cbBase4.timerVar1 = 0
            cbBase5.timerVar1 = 0
            cbBase6.timerVar1 = 0
            cbBase7.timerVar1 = 0
            cbBase8.timerVar1 = 0
            cbBase2.count = 1
            cbBase3.count = 1
            cbBase4.count = 1
            cbBase5.count = 1
            cbBase6.count = 1
            cbBase7.count = 1
            cbBase8.count = 1

        def upZ():
            cbBase5.timerVar2 = 0
            cbBase5.count2 = 1
            cbBase6.timerVar2 = 0
            cbBase6.count2 = 1
            cbBase7.timerVar2 = 0
            cbBase7.count2 = 1
            cbBase8.timerVar2 = 0
            cbBase8.count2 = 1

        def downZ():
            cbBase5.timerVar2 = 0
            cbBase5.count2 = 3
            cbBase6.timerVar2 = 0
            cbBase6.count2 = 3
            cbBase7.timerVar2 = 0
            cbBase7.count2 = 3
            cbBase8.timerVar2 = 0
            cbBase8.count2 = 3

        def stopZ():
            cbBase5.timerVar2 = 1
            cbBase5.count2 = 2
            cbBase6.timerVar2 = 1
            cbBase6.count2 = 3
            cbBase7.timerVar2 = 1
            cbBase7.count2 = 3
            cbBase8.timerVar2 = 1
            cbBase8.count2 = 3

        #self.vtkWidget.AddObserver('TimerEvent', cbSTL.execute2)
        self.vtkWidget.AddObserver('TimerEvent', cbBase2.execute3)
        self.vtkWidget.AddObserver('TimerEvent', cbBase3.execute4)
        self.vtkWidget.AddObserver('TimerEvent', cbBase4.execute)
        self.vtkWidget.AddObserver('TimerEvent', cbBase5.execute5)
        self.vtkWidget.AddObserver('TimerEvent', cbBase5.execute6)
        self.vtkWidget.AddObserver('TimerEvent', cbBase6.execute7)
        self.vtkWidget.AddObserver('TimerEvent', cbBase6.execute8)
        self.vtkWidget.AddObserver('TimerEvent', cbBase7.execute9)
        self.vtkWidget.AddObserver('TimerEvent', cbBase7.execute10)
        self.vtkWidget.AddObserver('TimerEvent', cbBase8.execute11)
        self.vtkWidget.AddObserver('TimerEvent', cbBase8.execute12)
        self.vtkWidget.AddObserver('TimerEvent', cbBase8.execute2)

        self.stopRotationButton.clicked.connect(stopR)
        self.startRotationButton.clicked.connect(startR)
        try:
            self.STLButton.clicked.connect(getfiles)
        except Exception as ex:
            print("Exception in button call")
            print(ex)

        self.stopYButton.clicked.connect(stopY)
        self.forwardYButton.clicked.connect(forwardY)
        self.backYButton.clicked.connect(backY)
        self.stopZButton.clicked.connect(stopZ)
        self.upZButton.clicked.connect(upZ)
        self.downZButton.clicked.connect(downZ)

        self.vtkWidget.CreateRepeatingTimer(1)

        self.iren.Start()
class MainWindow(qtw.QMainWindow):

  def __init__(self,
               input_file,
               gaussian,
               radius,
               thresh,
               zoom,
               zSlice,
               brightness,
               window_size,
               *args, **kwargs):
    """MainWindow constructor"""
    super().__init__(*args, **kwargs)

    # Window setup
    self.resize(window_size[0],window_size[1])
    self.title = "Qt Viewer for Lesion Augmentation"

    self.statusBar().showMessage("Welcome.",8000)

    # Capture defaults
    self.gaussian = gaussian
    self.radius = radius
    self.thresh = thresh
    self.zoom = zoom
    self.brightness = brightness
    self.shape_dic = None
    self.lesion_dic = {}
    self.thresholdArray = None
    self.imageArray = None
    self.zSlice = 100
    self.shape = None
    self.crop = None
    self.colorWindow = 1000
    self.colorLevel = 500
    # Initialize the window
    self.initUI()

    # Set up some VTK pipeline classes
    self.reader = None
    self.gauss = vtk.vtkImageGaussianSmooth()
    self.lesion = vtk.vtkImageData()
    self.threshold = vtk.vtkImageThreshold()
    self.mapToColors = vtk.vtkImageMapToColors()

    self.imageViewer = vtk.vtkImageViewer2()

    self.resizeImage = vtk.vtkImageResize()
    self.resizeSeg = vtk.vtkImageResize()

    self.contourRep = vtk.vtkOrientedGlyphContourRepresentation()
    self.contourWidget = vtk.vtkContourWidget()
    self.placer = vtk.vtkImageActorPointPlacer()

    self.polyData = None

    self.origmapper = vtk.vtkImageMapper()#vtkImageSliceMapper()#
    self.mapper = vtk.vtkImageMapper()
    self.stencilmapper = vtk.vtkPolyDataMapper()

    self.origactor = vtk.vtkActor2D() #vtkImageActor()
    self.actor = vtk.vtkActor2D()
    self.stencilactor = vtk.vtkActor()
    # Take inputs from command line. Only use these if there is an input file specified
    if (input_file != None):
      if (not os.path.exists(input_file)):
        qtw.QMessageBox.warning(self, "Error", "Invalid input file.")
        return

      self.createPipeline(input_file)
      self.statusBar().showMessage("Loading file " + input_file,4000)
      self.changeSigma(gaussian)
      self.changeRadius(radius)
      self.changeThreshold(thresh)
      self.changeBrightness(brightness)
      self.changeSlice(zSlice)

  def initUI(self):
    ########################################
    # Create Widgets
    ########################################

    self.loadPushButton = qtw.QPushButton(
      "Load Image",
      self,
      objectName = "loadPushButton",
      shortcut=qtg.QKeySequence("Ctrl+f")
    )
    self.sigmaSpinBox = qtw.QDoubleSpinBox(
      self,
      objectName = "sigmaSpinBox",
      value=self.gaussian,
      decimals=1,
      maximum=20.0,
      minimum=0.1,
      singleStep=0.1,
      keyboardTracking=False
    )
    self.radiusSpinBox = qtw.QSpinBox(
      self,
      objectName = "radiusSpinBox",
      value=self.radius,
      maximum=20,
      minimum=1,
      singleStep=1,
      keyboardTracking=False
    )

    self.threshSpinBox = qtw.QSpinBox(
      self,
      objectName = "threshSpinBox",
      value=self.thresh,
      maximum=3000,
      minimum=-3000,
      singleStep=5,
      keyboardTracking=False
    )

    self.brightnessSpinBox = qtw.QSpinBox(
      self,
      objectName = "brightnessSpinBox",
      value=self.brightness,
      maximum=3000,
      minimum=-3000,
      singleStep=5,
      keyboardTracking=False
    )

    self.sliceSpinBox = qtw.QSpinBox(
      self,
      objectName = "sliceSpinBox",
      value=self.zSlice,
      maximum=3000,
      minimum=-3000,
      singleStep=1,
      keyboardTracking=False
    )

    self.levelSpinBox = qtw.QSpinBox(
      self,
      objectName = "levelSpinBox",
      value=self.colorLevel,
      maximum=5000,
      minimum=-3000,
      singleStep=50,
      keyboardTracking=False
    )

    self.windowSpinBox = qtw.QSpinBox(
      self,
      objectName = "windowSpinBox",
      value=self.colorWindow,
      maximum=5000,
      minimum=-3000,
      singleStep=50,
      keyboardTracking=False
    )

    self.lesionPushButton = qtw.QPushButton(
      "Add Lesion",
      self,
      objectName = "lesionPushButton",
      shortcut=qtg.QKeySequence("Ctrl+l")
    )

    self.savePushButton = qtw.QPushButton(
      "Save Lesion",
      self,
      objectName = "savePushButton",
      shortcut=qtg.QKeySequence("Ctrl+l")
    )

    self.deletePushButton = qtw.QPushButton(
      "Delete Lesion",
      self,
      objectName = "deletePushButton",
      shortcut=qtg.QKeySequence("Ctrl+l")
    )

    self.resetPushButton = qtw.QPushButton(
      "Reset Contour",
      self,
      objectName = "resetPushButton",
      shortcut=qtg.QKeySequence("Ctrl+l")
    )

    self.saveCropPushButton = qtw.QPushButton(
      "Save Cropped Region",
      self,
      objectName = "saveCropPushButton",
      shortcut=qtg.QKeySequence("Ctrl+l")
    )

    self.confirmCropPushButton = qtw.QPushButton(
      "Confirm Position of Cropped Contour",
      self,
      objectName = "confirmCropPushButton",
      shortcut=qtg.QKeySequence("Ctrl+l")
    )

    self.addCropPushButton = qtw.QPushButton(
      "Add Cropped Pixels to Current Slice",
      self,
      objectName = "addCropPushButton",
      shortcut=qtg.QKeySequence("Ctrl+l")
    )
    # Create the menu options --------------------------------------------------------------------
    menubar = qtw.QMenuBar()
    self.setMenuBar(menubar)
    menubar.setNativeMenuBar(False)

    file_menu = menubar.addMenu("File")
    open_action = file_menu.addAction("Open Image")
    file_menu.addSeparator()
    about_action = file_menu.addAction("About")
    quit_action = file_menu.addAction("Quit")

    # Lay out the GUI ----------------------------------------------------------------------------
    self.mainGroupBox = qtw.QGroupBox("Image Controls")
    self.mainGroupBox.setLayout(qtw.QGridLayout())#QHBoxLayout())

    self.controlsGroupBox = qtw.QGroupBox("Thresholding controls")
    self.controlsGroupBox.setLayout(qtw.QVBoxLayout())
    self.controlsFormLayout = qtw.QFormLayout()
    self.controlsFormLayout.addRow("Sigma",self.sigmaSpinBox)
    self.controlsFormLayout.addRow("Radius",self.radiusSpinBox)
    self.controlsFormLayout.addRow("Global Threshold",self.threshSpinBox)
    self.controlsFormLayout.addRow("Lesion Brightness",self.brightnessSpinBox)
    self.controlsFormLayout.addRow("Slice Index", self.sliceSpinBox)
    self.controlsFormLayout.addRow("Color Level", self.levelSpinBox)
    self.controlsFormLayout.addRow("Color Window", self.windowSpinBox)
    self.controlsGroupBox.layout().addLayout(self.controlsFormLayout)

    self.mainGroupBox.layout().addWidget(self.loadPushButton)
    self.mainGroupBox.layout().addWidget(self.lesionPushButton)
    self.mainGroupBox.layout().addWidget(self.savePushButton)
    self.mainGroupBox.layout().addWidget(self.deletePushButton)
    self.mainGroupBox.layout().addWidget(self.resetPushButton)
    self.mainGroupBox.layout().addWidget(self.controlsGroupBox)

    self.mainGroupBox.layout().addWidget(self.saveCropPushButton)
    self.mainGroupBox.layout().addWidget(self.confirmCropPushButton)
    self.mainGroupBox.layout().addWidget(self.addCropPushButton)


    # Assemble the side control panel and put it in a QPanel widget ------------------------------
    self.panel = qtw.QVBoxLayout()
    self.panel.addWidget(self.mainGroupBox)
    self.panelWidget = qtw.QFrame()
    self.panelWidget.setLayout(self.panel)

    # Create the VTK rendering window ------------------------------------------------------------
    self.vtkWidget = QVTKRenderWindowInteractor()
    self.vtkWidget.AddObserver("ExitEvent", lambda o, e, a=self: a.quit())
    #self.vtkWidget.AddObserver("MouseWheelForwardEvent", self.wheelForward)
    #self.vtkWidget.AddObserver("MouseWheelBackwardEvent", self.wheelBackward)
    #self.vtkWidget.AddObserver("KeyPressEvent", self.keyPressEvent)

    # Create main layout and add VTK window and control panel
    self.mainLayout = qtw.QHBoxLayout()
    self.mainLayout.addWidget(self.vtkWidget,4)
    self.mainLayout.addWidget(self.panelWidget,1)

    self.frame = qtw.QFrame()
    self.frame.setLayout(self.mainLayout)
    self.setCentralWidget(self.frame)

    self.setWindowTitle(self.title)
    self.centreWindow()

    # Set size policies --------------------------------------------------------------------------
    self.sigmaSpinBox.setMinimumSize(70,20)
    self.radiusSpinBox.setMinimumSize(70,20)
    self.threshSpinBox.setMinimumSize(70,20)
    self.brightnessSpinBox.setMinimumSize(70,20)
    self.sliceSpinBox.setMinimumSize(70,20)
    self.levelSpinBox.setMinimumSize(70,20)
    self.windowSpinBox.setMinimumSize(70,20)

    self.mainGroupBox.setMaximumSize(1000,1000)

    self.vtkWidget.setSizePolicy(
      qtw.QSizePolicy.MinimumExpanding,
      qtw.QSizePolicy.MinimumExpanding
    )

    self.mainGroupBox.setSizePolicy(
      qtw.QSizePolicy.Maximum,
      qtw.QSizePolicy.Maximum
    )

    # Connect signals and slots ------------------------------------------------------------------
    self.loadPushButton.clicked.connect(self.openFile)
    self.sigmaSpinBox.valueChanged.connect(lambda s: self.changeSigma(s))
    self.radiusSpinBox.valueChanged.connect(lambda s: self.changeRadius(s))
    self.threshSpinBox.valueChanged.connect(lambda s: self.changeThreshold(s))
    self.brightnessSpinBox.valueChanged.connect(lambda s: self.changeBrightness(s))
    self.sliceSpinBox.valueChanged.connect(lambda s: self.changeSlice(s))
    self.windowSpinBox.valueChanged.connect(lambda s: self.changeWindow(s))
    self.levelSpinBox.valueChanged.connect(lambda s: self.changeLevel(s))

    self.lesionPushButton.clicked.connect(self.addLesion)
    self.savePushButton.clicked.connect(self.saveLesion)
    self.deletePushButton.clicked.connect(self.deleteLesion)
    self.resetPushButton.clicked.connect(self.resetContour)

    self.saveCropPushButton.clicked.connect(self.saveCrop)
    self.confirmCropPushButton.clicked.connect(self.confirmCrop)
    self.addCropPushButton.clicked.connect(self.addCrop)
    self.initRenderWindow()

    # Menu actions
    open_action.triggered.connect(self.openFile)
    about_action.triggered.connect(self.about)
    quit_action.triggered.connect(self.quit)

    self.pipe = None

    # End main UI code
    self.show()

    ########################################
    # Define methods for controlling GUI
    ########################################

  def centreWindow(self):
    qr = self.frameGeometry()
    cp = qtw.QDesktopWidget().availableGeometry().center()
    qr.moveCenter(cp)
    self.move(qr.topLeft())

  def initRenderWindow(self):
    # Create renderer
    self.renderer = vtk.vtkRenderer()
    self.renderer.SetBackground((.2, .2, .2)) # grey

    # Create interactor
    self.renWin = self.vtkWidget.GetRenderWindow()
    self.renWin.AddRenderer(self.renderer)
    self.iren = self.renWin.GetInteractor()
    self.iren.RemoveObservers("LeftButtonPressEvent")
    #self.iren.SetInteractorStyle(vtk.vtkInteractorStyleImage().SetInteractionModeToImage3D())

    #self.iren.SetInteractorStyle(vtk.vtkInteractorStyleTerrain())
    #self.iren.SetInteractorStyle(vtk.vtkInteractionModeToImage)
    # Initialize
    self.iren.Initialize()
    self.iren.Start()


  def refreshRenderWindow(self):
    self.renWin.Render()
    self.iren.Render()

  def createPipeline(self, _filename):
    # Read in the file
    if _filename.lower().endswith('.nii'):
      self.reader = vtk.vtkNIFTIImageReader()
      self.reader.SetFileName(_filename)
    elif _filename.lower().endswith('.nii.gz'):
      self.reader = vtk.vtkNIFTIImageReader()
      self.reader.SetFileName(_filename)
    elif _filename.lower().endswith('.dcm'):
      self.reader = vtk.vtkDICOMImageReader()
      self.reader.SetDirectoryName(os.path.dirname(_filename))
    elif os.path.isdir(_filename):
      self.reader = vtk.vtkDICOMImageReader()
      self.reader.SetDirectoryName(_filename)

    if self.reader is None:
        os.sys.exit("[ERROR] Cannot find reader for file \"{}\"".format(self.filename))
    self.reader.Update()

    #global threshold
    self.threshold.SetInputConnection(self.reader.GetOutputPort())
    self.threshold.ThresholdByUpper(self.thresh)
    self.threshold.SetInValue(1)
    self.threshold.SetOutValue(0)
    self.threshold.Update()

    #self.read_dictionary()
    self.convertImageData()
    self.convertThresholdData()
    self.initializeLesion()

    #Gaussian smoothing
    self.gauss.SetStandardDeviation(self.gaussian, self.gaussian, self.gaussian)
    self.gauss.SetRadiusFactors(self.radius, self.radius, self.radius)
    self.gauss.SetInputData(self.lesion)

    lookupTable = vtk.vtkLookupTable()
    lookupTable.SetNumberOfTableValues(2)
    lookupTable.SetRange(0.0,1.0)
    lookupTable.SetTableValue( 0, 0.0, 0.0, 0.0, 0.0 ) #label outRangeValue is transparent
    lookupTable.SetTableValue( 1, 0.0, 1.0, 0.0, 1.0 )  #label inRangeValue is opaque and green
    lookupTable.Build()

    self.mapToColors.SetLookupTable(lookupTable)
    self.mapToColors.PassAlphaToOutputOn()
    self.mapToColors.SetInputConnection(self.threshold.GetOutputPort())

    #Resize segmented image
    self.resizeSeg.SetInputConnection(self.mapToColors.GetOutputPort())
    self.resizeSeg.SetResizeMethodToMagnificationFactors()
    self.resizeSeg.SetMagnificationFactors(self.zoom,self.zoom,0)

    # Resize original image
    self.resizeImage.SetInputConnection(self.gauss.GetOutputPort())
    self.resizeImage.SetResizeMethodToMagnificationFactors()
    self.resizeImage.SetMagnificationFactors(self.zoom,self.zoom,0)

    #
    self.imageViewer.SetInputConnection(self.resizeImage.GetOutputPort())
    #self.imageViewer.SetSliceOrientationToXZ()
    self.imageViewer.SetRenderWindow(self.renWin)
    self.imageViewer.SetSlice(self.zSlice)
    self.imageViewer.Render()

    # pklace contour on plane aligned with image actor
    self.placer.SetImageActor(self.imageViewer.GetImageActor())
    self.contourRep.SetPointPlacer(self.placer)
    self.contourRep.GetProperty().SetColor(0,1,0)
    self.contourRep.GetLinesProperty().SetColor(0,1,0)

    #contour widget
    self.contourWidget.SetRepresentation(self.contourRep)
    self.contourWidget.SetInteractor(self.iren)
    self.contourWidget.FollowCursorOn()
    self.contourWidget.SetContinuousDraw(1)
    self.contourWidget.SetEnabled(True)
    self.contourWidget.ProcessEventsOn()
    self.contourWidget.CloseLoop()

    self.refreshRenderWindow()

    return

  def read_dictionary(self):
    # read predefined dictionary of shapes for
    file = open('../Image_Datasets/shape_dic.pkl', 'rb')
    shape_dic = pickle.load(file)
    file.close()
    self.shape_dic = shape_dic
    return

  def resetContour(self):
    #self.contourRep.ClearAllNodes()
    #self.contourWidget.ResetAction()
    self.contourWidget.Initialize()
    self.refreshRenderWindow()
    return


  def convertImageData(self):

    # convert vtk image data to numpy array
    self.reader.Update()
    imageData = self.reader.GetOutput()
    imageDataArray = imageData.GetPointData().GetScalars()
    imageArray = vtk_to_numpy(imageDataArray).reshape(imageData.GetDimensions(), order='F')
    self.imageArray = imageArray
    return

  def convertThresholdData(self):
    # convert thresholding mask for analysis
    self.threshold.Update()
    thresholdData = self.threshold.GetOutput()
    thresholdDataArray = thresholdData.GetPointData().GetScalars()
    thresholdArray = vtk_to_numpy(thresholdDataArray).reshape(thresholdData.GetDimensions(), order='F')
    self.thresholdArray = thresholdArray

    return

  def initializeLesion(self):
    self.reader.Update()
    self.lesion.CopyStructure(self.reader.GetOutput())
    self.lesion.GetPointData().SetScalars(numpy_to_vtk(num_array=
                                    self.imageArray.ravel(order='F'),
                                    deep=True))

  def saveCrop(self):
    global nodes
    state = self.contourWidget.GetWidgetState()
    dim = self.reader.GetOutput().GetDimensions()
    self.zSlice = self.imageViewer.GetSlice()
    if state == 0:
        self.statusBar().showMessage(f"Draw lesion before saving",4000)
    elif state == 2:
        self.polyData = self.contourRep.GetContourRepresentationAsPolyData()
        nodes = vtk_to_numpy(self.polyData.GetPoints().GetData())
        mask = np.zeros([dim[0],dim[1]])
        contour = np.floor(nodes).astype(int)
        mask[contour[:,0], contour[:,1]] = 1
        binary  = ndimage.morphology.binary_fill_holes(mask)
        self.shape = binary
        plt.figure()
        plt.imshow(binary)
        plt.gca().invert_yaxis()
        plt.show()

        crop = self.imageArray[:,:,self.zSlice][binary]
        self.crop = crop
    return

  def confirmCrop(self):
    global nodes
    state = self.contourWidget.GetWidgetState()
    dim = self.reader.GetOutput().GetDimensions()
    if state == 0:
        self.statusBar().showMessage(f"Draw lesion before saving",4000)
    elif state == 2:
        self.polyData = self.contourRep.GetContourRepresentationAsPolyData()
        nodes = vtk_to_numpy(self.polyData.GetPoints().GetData())
        mask = np.zeros([dim[0],dim[1]])
        contour = np.floor(nodes).astype(int)
        mask[contour[:,0], contour[:,1]] = 1
        binary  = ndimage.morphology.binary_fill_holes(mask)
        self.shape = binary
        plt.figure()
        plt.imshow(binary)
        plt.gca().invert_yaxis()
        plt.show()

  def addCrop(self):
    state = self.contourWidget.GetWidgetState()
    if state == 0:
        self.statusBar().showMessage(f"Draw and save lesion before adding to slice",4000)
    else:
        #add lesion to current zSlice
        self.zSlice = self.imageViewer.GetSlice()
        lesionDataArray = self.lesion.GetPointData().GetScalars()
        lesionArray = vtk_to_numpy(lesionDataArray).reshape(self.lesion.GetDimensions(), order='F')


        img_slice = lesionArray[:,:,self.zSlice]
        if self.shape.shape[0] != img_slice.shape[0]:
            self.shape = np.transpose(self.shape)

        idx = min([img_slice[self.shape].shape, self.crop.shape])[0]
        #normalize cropped value
        MAX = img_slice.max()
        MIN = img_slice.min()
        #img_slice[self.shape][:idx] = self.crop[:idx]
        if img_slice[self.shape].shape < self.crop.shape:
            img_slice[self.shape]= self.crop[:idx]#* (MAX - MIN) / MAX
        else:
            img_slice[self.shape][:idx] = self.crop
        self.lesion.CopyStructure(self.reader.GetOutput())
        self.lesion.GetPointData().SetScalars(numpy_to_vtk(num_array=
                                        lesionArray.ravel(order='F'),
                                        deep=True))
        #self.lesion_dic[self.zSlice] = self.shape
    self.refreshRenderWindow()


  def saveLesion(self):

    global nodes
    state = self.contourWidget.GetWidgetState()

    dim = self.reader.GetOutput().GetDimensions()
    print(dim)
    if state == 0:
        self.statusBar().showMessage(f"Draw lesion before saving",4000)
    elif state == 2:
        print(dim)
        #self.contourRep.SetWorldTolerance(0.00001)
        self.polyData = self.contourRep.GetContourRepresentationAsPolyData()
        print(self.contourRep.GetWorldTolerance())
        
        nodes = vtk_to_numpy(self.polyData.GetPoints().GetData())

        #print(nodes)
        mask = np.zeros([dim[0],dim[1]])
        contour = np.floor(nodes).astype(int)
        mask[contour[:,0], contour[:,1]] = 1
        binary  = ndimage.morphology.binary_fill_holes(mask)
        print(contour)
        
        plt.figure()
        plt.imshow(binary)
        plt.gca().invert_yaxis()
        plt.show()
        self.shape = binary
        self.statusBar().showMessage(f"Lesion saved to dictionary",4000)
    return

  def deleteLesion(self):
    self.zSlice = self.imageViewer.GetSlice()
    if self.zSlice in self.lesion_dic:
        self.threshold.Update()

        lesionDataArray = self.lesion.GetPointData().GetScalars()
        lesionArray = vtk_to_numpy(lesionDataArray).reshape(self.lesion.GetDimensions(), order='F')
        lesionArray[:,:,self.zSlice] = self.imageArray[:,:,self.zSlice]
        self.lesion.GetPointData().SetScalars(numpy_to_vtk(num_array=lesionArray.ravel(order='F'),
                                    deep=True))
    self.refreshRenderWindow()

    return


  def addLesion(self):
    state = self.contourWidget.GetWidgetState()
    if state == 0:
        self.statusBar().showMessage(f"Draw and save lesion before adding to slice",4000)
    else:
        #add lesion to current zSlice
        self.zSlice = self.imageViewer.GetSlice()
        lesionDataArray = self.lesion.GetPointData().GetScalars()
        lesionArray = vtk_to_numpy(lesionDataArray).reshape(self.lesion.GetDimensions(), order='F')


        img_slice = lesionArray[:,:,self.zSlice]
        if self.shape.shape[0] != img_slice.shape[0]:
            self.shape = np.transpose(self.shape)

        img_slice[self.shape] = self.brightness

        self.lesion.CopyStructure(self.reader.GetOutput())
        self.lesion.GetPointData().SetScalars(numpy_to_vtk(num_array=
                                        lesionArray.ravel(order='F'),
                                        deep=True))
        self.lesion_dic[self.zSlice] = self.shape
    self.refreshRenderWindow()



  def addLesion1(self):
    state = self.contourWidget.GetWidgetState()
    if state == 0:
        self.statusBar().showMessage(f"Draw and save lesion before adding to slice",4000)
    else:
        #add lesion to current zSlice
        self.zSlice = self.imageViewer.GetSlice()
        lesionDataArray = self.lesion.GetPointData().GetScalars()
        lesionArray = vtk_to_numpy(lesionDataArray).reshape(self.lesion.GetDimensions(), order='F')


        img_slice = lesionArray[:,:,self.zSlice]
        if self.shape.shape[0] != img_slice.shape[0]:
            self.shape = np.transpose(self.shape)

        img_slice[self.shape] = self.brightness

        self.lesion.CopyStructure(self.reader.GetOutput())
        self.lesion.GetPointData().SetScalars(numpy_to_vtk(num_array=
                                        lesionArray.ravel(order='F'),
                                        deep=True))
        self.lesion_dic[self.zSlice] = self.shape
    self.refreshRenderWindow()




  def adjustBrightness(self):

    self.zSlice = self.imageViewer.GetSlice()
    lesionDataArray = self.lesion.GetPointData().GetScalars()
    lesionArray = vtk_to_numpy(lesionDataArray).reshape(self.lesion.GetDimensions(), order='F')


    if self.zSlice in self.lesion_dic:

        self.shape = self.lesion_dic[self.zSlice]
        #x1,x2,y1,y2 = lesionParams[-4:]
        #image_cut = lesionParams[0]
        #random_shape = lesionParams[1]
        #imageData = self.lesion
        #imageDataArray = imageData.GetPointData().GetScalars()
        #imageArray = vtk_to_numpy(imageDataArray).reshape(imageData.GetDimensions(), order='F')
        img_slice = lesionArray[:,:,self.zSlice]

        img_slice[self.shape] = self.brightness

        #imageArray[x1:x2,y1:y2,zSlice] = image_cut

        self.lesion.GetPointData().SetScalars(numpy_to_vtk(num_array=lesionArray.ravel(order='F'),
                                    deep=True))

    self.refreshRenderWindow()
    return

  def changeSigma(self, _value):
    self.gauss.SetStandardDeviation(_value, _value, _value)
    self.statusBar().showMessage(f"Changing standard deviation to {_value}",4000)
    self.refreshRenderWindow()
    return

  def changeRadius(self, _value):
    self.gauss.SetRadiusFactors(_value, _value, _value)
    self.statusBar().showMessage(f"Changing radius to {_value}",4000)
    self.refreshRenderWindow()
    return

  def changeThreshold(self, _value):
    self.threshold.ThresholdByUpper(_value)
    self.statusBar().showMessage(f"Changing threshold to {_value}",4000)
    self.refreshRenderWindow()
    return

  def changeBrightness(self, _value):
    self.brightness = _value
    self.adjustBrightness()
    self.statusBar().showMessage(f"Changing lesion brightness to {_value}",4000)
    self.refreshRenderWindow()

    return
  def changeSlice(self, _value):
    self.zSlice = _value
    self.imageViewer.SetSlice(self.zSlice)

    self.statusBar().showMessage(f"Changing zSlice to {_value}",4000)
    self.refreshRenderWindow()

    return
  def changeLevel(self, _value):
    self.colorLevel = _value
    self.imageViewer.SetColorLevel(self.colorLevel)
    self.refreshRenderWindow()
    return

  def changeWindow(self, _value):
    self.colorWindow = _value
    self.imageViewer.SetColorWindow(self.colorWindow)
    self.refreshRenderWindow()
    return

  def validExtension(self, extension):
    if (extension == ".nii" or \
        extension == ".dcm" or \
        extension == ".gz"):
      return True
    else:
      return False

  def openFile(self):
    self.statusBar().showMessage("Load image types (.nii, .dcm)",4000)
    filename, _ = qtw.QFileDialog.getOpenFileName(
      self,
      "Select a 3D image file to open…",
      qtc.QDir.homePath(),
      "Nifti Files (*.nii) ;;DICOM Files (*.dcm) ;;All Files (*)",
      "All Files (*)",
      qtw.QFileDialog.DontUseNativeDialog |
      qtw.QFileDialog.DontResolveSymlinks
    )

    if filename:
      _,ext = os.path.splitext(filename)
      if not (self.validExtension(ext.lower())):
        qtw.QMessageBox.warning(self, "Error", "Invalid file type.")
        return

      self.createPipeline(filename)
      self.statusBar().showMessage("Loading file " + filename,4000)
    return

  def quit(self):
    reply = qtw.QMessageBox.question(self, "Message",
      "Are you sure you want to quit?", qtw.QMessageBox.Yes |
      qtw.QMessageBox.No, qtw.QMessageBox.Yes)
    if reply == qtw.QMessageBox.Yes:
      exit(0)

  def about(self):
    about = qtw.QMessageBox(self)
    about.setText("QtBasic 1.0")
    about.setInformativeText("Copyright (C) 2021\nBone Imaging Laboratory\nAll rights reserved.\[email protected]")
    about.setStandardButtons(qtw.QMessageBox.Ok | qtw.QMessageBox.Cancel)
    about.exec_()
Esempio n. 3
0
class MainWindow(qtw.QMainWindow):
    def __init__(self, input_file, gaussian, radius, isosurface, window_size,
                 *args, **kwargs):
        """MainWindow constructor"""
        super().__init__(*args, **kwargs)

        # Window setup
        self.resize(window_size[0], window_size[1])
        self.title = "Basic Qt Viewer for MDSC 689.03"

        self.statusBar().showMessage("Welcome.", 8000)

        # Capture defaults
        self.gaussian = gaussian
        self.radius = radius
        self.isosurface = isosurface

        # Initialize the window
        self.initUI()

        # Set up some VTK pipeline classes
        self.reader = None
        self.gauss = vtk.vtkImageGaussianSmooth()
        self.marchingCubes = vtk.vtkImageMarchingCubes()
        self.mapper = vtk.vtkPolyDataMapper()
        self.actor = vtk.vtkActor()

        # Take inputs from command line. Only use these if there is an input file specified
        if (input_file != None):
            if (not os.path.exists(input_file)):
                qtw.QMessageBox.warning(self, "Error", "Invalid input file.")
                return

            #_,ext = os.path.splitext(input_file)
            #if not (self.validExtension(ext.lower())):
            #  qtw.QMessageBox.warning(self, "Error", "Invalid file type.")
            #  return

            self.createPipeline(input_file)
            self.statusBar().showMessage("Loading file " + input_file, 4000)
            self.changeSigma(gaussian)
            self.changeRadius(radius)
            self.changeIsosurface(isosurface)

    def initUI(self):
        ########################################
        # Create Widgets
        ########################################

        self.loadPushButton = qtw.QPushButton(
            "Load Image",
            self,
            objectName="loadPushButton",
            shortcut=qtg.QKeySequence("Ctrl+f"))
        self.sigmaSpinBox = qtw.QDoubleSpinBox(self,
                                               objectName="sigmaSpinBox",
                                               value=self.gaussian,
                                               decimals=1,
                                               maximum=20.0,
                                               minimum=0.1,
                                               singleStep=0.1,
                                               keyboardTracking=False)
        self.radiusSpinBox = qtw.QSpinBox(self,
                                          objectName="radiusSpinBox",
                                          value=self.radius,
                                          maximum=20,
                                          minimum=1,
                                          singleStep=1,
                                          keyboardTracking=False)
        self.isosurfaceSpinBox = qtw.QSpinBox(self,
                                              objectName="isosurfaceSpinBox",
                                              value=self.isosurface,
                                              maximum=32768,
                                              minimum=0,
                                              singleStep=1,
                                              keyboardTracking=False)

        # Create the menu options --------------------------------------------------------------------
        menubar = qtw.QMenuBar()
        self.setMenuBar(menubar)
        menubar.setNativeMenuBar(False)

        file_menu = menubar.addMenu("File")
        open_action = file_menu.addAction("Open Image")
        file_menu.addSeparator()
        about_action = file_menu.addAction("About")
        quit_action = file_menu.addAction("Quit")

        # Lay out the GUI ----------------------------------------------------------------------------
        self.mainGroupBox = qtw.QGroupBox("Image Controls")
        self.mainGroupBox.setLayout(qtw.QHBoxLayout())

        self.isosurfaceGroupBox = qtw.QGroupBox("Isosurface controls")
        self.isosurfaceGroupBox.setLayout(qtw.QVBoxLayout())
        self.isosurfaceFormLayout = qtw.QFormLayout()
        self.isosurfaceFormLayout.addRow("Sigma", self.sigmaSpinBox)
        self.isosurfaceFormLayout.addRow("Radius", self.radiusSpinBox)
        self.isosurfaceFormLayout.addRow("Isosurface", self.isosurfaceSpinBox)
        self.isosurfaceGroupBox.layout().addLayout(self.isosurfaceFormLayout)

        self.mainGroupBox.layout().addWidget(self.loadPushButton)
        self.mainGroupBox.layout().addWidget(self.isosurfaceGroupBox)

        # Assemble the side control panel and put it in a QPanel widget ------------------------------
        self.panel = qtw.QVBoxLayout()
        self.panel.addWidget(self.mainGroupBox)
        self.panelWidget = qtw.QFrame()
        self.panelWidget.setLayout(self.panel)

        # Create the VTK rendering window ------------------------------------------------------------
        self.vtkWidget = QVTKRenderWindowInteractor()
        self.vtkWidget.AddObserver("ExitEvent", lambda o, e, a=self: a.quit())
        self.vtkWidget.AddObserver("KeyReleaseEvent", self.keyEventDetected)
        self.vtkWidget.AddObserver("LeftButtonPressEvent",
                                   self.mouseEventDetected)

        # Create main layout and add VTK window and control panel
        self.mainLayout = qtw.QHBoxLayout()
        self.mainLayout.addWidget(self.vtkWidget, 4)
        self.mainLayout.addWidget(self.panelWidget, 1)

        self.frame = qtw.QFrame()
        self.frame.setLayout(self.mainLayout)
        self.setCentralWidget(self.frame)

        self.setWindowTitle(self.title)
        self.centreWindow()

        # Set size policies --------------------------------------------------------------------------
        self.sigmaSpinBox.setMinimumSize(70, 20)
        self.radiusSpinBox.setMinimumSize(70, 20)
        self.isosurfaceSpinBox.setMinimumSize(70, 20)

        self.mainGroupBox.setMaximumSize(1000, 200)

        self.vtkWidget.setSizePolicy(qtw.QSizePolicy.MinimumExpanding,
                                     qtw.QSizePolicy.MinimumExpanding)

        self.mainGroupBox.setSizePolicy(qtw.QSizePolicy.Maximum,
                                        qtw.QSizePolicy.Maximum)

        # Connect signals and slots ------------------------------------------------------------------
        self.loadPushButton.clicked.connect(self.openFile)
        self.sigmaSpinBox.valueChanged.connect(lambda s: self.changeSigma(s))
        self.radiusSpinBox.valueChanged.connect(lambda s: self.changeRadius(s))
        self.isosurfaceSpinBox.valueChanged.connect(
            lambda s: self.changeIsosurface(s))

        self.initRenderWindow()

        # Menu actions
        open_action.triggered.connect(self.openFile)
        about_action.triggered.connect(self.about)
        quit_action.triggered.connect(self.quit)

        self.pipe = None

        # End main UI code
        self.show()

        ########################################
        # Define methods for controlling GUI
        ########################################

    def centreWindow(self):
        qr = self.frameGeometry()
        cp = qtw.QDesktopWidget().availableGeometry().center()
        qr.moveCenter(cp)
        self.move(qr.topLeft())

    def initRenderWindow(self):
        # Create renderer
        self.renderer = vtk.vtkRenderer()
        self.renderer.SetBackground(
            (0.000, 0.000, 204.0 / 255.0))  # Scanco blue

        # Create interactor
        self.renWin = self.vtkWidget.GetRenderWindow()
        self.renWin.AddRenderer(self.renderer)
        self.iren = self.renWin.GetInteractor()
        self.iren.SetInteractorStyle(vtk.vtkInteractorStyleTrackballCamera())

        # Initialize
        self.iren.Initialize()
        self.iren.Start()

    def createPipeline(self, _filename):
        # Read in the file
        if _filename.lower().endswith('.nii'):
            self.reader = vtk.vtkNIFTIImageReader()
            self.reader.SetFileName(_filename)
        elif _filename.lower().endswith('.nii.gz'):
            self.reader = vtk.vtkNIFTIImageReader()
            self.reader.SetFileName(_filename)
        elif _filename.lower().endswith('.dcm'):
            self.reader = vtk.vtkDICOMImageReader()
            self.reader.SetDirectoryName(os.path.dirname(_filename))
        elif os.path.isdir(_filename):
            self.reader = vtk.vtkDICOMImageReader()
            self.reader.SetDirectoryName(_filename)

        if self.reader is None:
            os.sys.exit("[ERROR] Cannot find reader for file \"{}\"".format(
                self.filename))
        self.reader.Update()

        # Gaussian smoothing
        self.gauss.SetStandardDeviation(self.gaussian, self.gaussian,
                                        self.gaussian)
        self.gauss.SetRadiusFactors(self.radius, self.radius, self.radius)
        self.gauss.SetInputConnection(self.reader.GetOutputPort())

        # Marching Cubes
        self.marchingCubes.SetInputConnection(self.gauss.GetOutputPort())
        self.marchingCubes.ComputeGradientsOn()
        self.marchingCubes.ComputeNormalsOn()
        self.marchingCubes.ComputeScalarsOff()
        self.marchingCubes.SetNumberOfContours(1)
        self.marchingCubes.SetValue(0, self.isosurface)

        # Set mapper for image data
        self.mapper.SetInputConnection(self.marchingCubes.GetOutputPort())

        # Actor
        self.actor.SetMapper(self.mapper)
        self.actor.GetProperty().SetColor((0.890, 0.855, 0.788))
        self.renderer.AddActor(self.actor)

        self.refreshRenderWindow()
        return

    def changeSigma(self, _value):
        self.gauss.SetStandardDeviation(_value, _value, _value)
        self.statusBar().showMessage(
            f"Changing standard deviation to {_value}", 4000)
        self.refreshRenderWindow()
        return

    def changeRadius(self, _value):
        self.gauss.SetRadiusFactors(_value, _value, _value)
        self.statusBar().showMessage(f"Changing radius to {_value}", 4000)
        self.refreshRenderWindow()
        return

    def changeIsosurface(self, _value):
        self.marchingCubes.SetValue(0, _value)
        self.statusBar().showMessage(f"Changing isosurface to {_value}", 4000)
        self.refreshRenderWindow()
        return

    def keyEventDetected(self, obj, event):
        key = self.vtkWidget.GetKeySym()
        print("key press – clicked " + key + "!")
        return

    def mouseEventDetected(self, obj, event):
        print("mouse press – click!")
        return

    def validExtension(self, extension):
        if (extension == ".nii" or \
            extension == ".dcm"):
            return True
        else:
            return False

    def openFile(self):
        self.statusBar().showMessage("Load image types (.nii, .dcm)", 4000)
        filename, _ = qtw.QFileDialog.getOpenFileName(
            self, "Select a 3D image file to open…", qtc.QDir.homePath(),
            "Nifti Files (*.nii) ;;DICOM Files (*.dcm) ;;All Files (*)",
            "All Files (*)", qtw.QFileDialog.DontUseNativeDialog
            | qtw.QFileDialog.DontResolveSymlinks)

        if filename:
            _, ext = os.path.splitext(filename)
            if not (self.validExtension(ext.lower())):
                qtw.QMessageBox.warning(self, "Error", "Invalid file type.")
                return

            self.createPipeline(filename)
            self.statusBar().showMessage("Loading file " + filename, 4000)
        return

    def quit(self):
        reply = qtw.QMessageBox.question(
            self, "Message", "Are you sure you want to quit?",
            qtw.QMessageBox.Yes | qtw.QMessageBox.No, qtw.QMessageBox.Yes)
        if reply == qtw.QMessageBox.Yes:
            exit(0)

    def about(self):
        about = qtw.QMessageBox(self)
        about.setText("blQtBasic 1.0")
        about.setInformativeText(
            "Copyright (C) 2021\nBone Imaging Laboratory\nAll rights reserved.\[email protected]"
        )
        about.setStandardButtons(qtw.QMessageBox.Ok | qtw.QMessageBox.Cancel)
        about.exec_()
Esempio n. 4
0
class MeshWindow(PluginWindowBase):
    """
    Window for displaying the mesh
    """

    fileLoaded = QtCore.pyqtSignal(object)
    boundsChanged = QtCore.pyqtSignal(list)

    SIDESET_CLR = QtGui.QColor(255, 173, 79)
    SIDESET_EDGE_CLR = QtGui.QColor(26, 26, 102)
    SIDESET_EDGE_WIDTH = 5
    NODESET_CLR = QtGui.QColor(168, 91, 2)

    SELECTION_CLR = QtGui.QColor(255, 173, 79)
    SELECTION_EDGE_CLR = QtGui.QColor(179, 95, 0)

    SHADED = 0
    SHADED_WITH_EDGES = 1
    HIDDEN_EDGES_REMOVED = 2
    TRANSLUENT = 3

    MODE_SELECT_BLOCKS = 0
    MODE_SELECT_CELLS = 1
    MODE_SELECT_POINTS = 2

    COLOR_PROFILE_DEFAULT = 0
    COLOR_PROFILE_LIGHT = 1
    COLOR_PROFILE_DARK = 2

    def __init__(self, plugin):
        super().__init__(plugin)
        self._load_thread = None
        self._progress = None
        self._file_name = None
        self._file_watcher = QtCore.QFileSystemWatcher()
        self._selected_block = None

        self.setupWidgets()
        self.setupMenuBar()
        self.updateWindowTitle()
        self.updateMenuBar()
        self.loadColorProfiles()

        self.setAcceptDrops(True)

        self.connectSignals()
        self.setupVtk()
        self.setColorProfile()

        self._vtk_interactor.Initialize()
        self._vtk_interactor.Start()

        self._setupOrientationMarker()
        self._setupCubeAxesActor()

        self.clear()
        self.show()

        self._update_timer = QtCore.QTimer()
        self._update_timer.timeout.connect(self.onUpdateWindow)
        self._update_timer.start(250)

        QtCore.QTimer.singleShot(1, self._updateViewModeLocation)

    def setupWidgets(self):
        self._vtk_widget = QVTKRenderWindowInteractor(self)
        self._vtk_renderer = vtk.vtkRenderer()
        self._vtk_widget.GetRenderWindow().AddRenderer(self._vtk_renderer)

        self.setCentralWidget(self._vtk_widget)

        self._info_window = InfoWindow(self.plugin, self)
        self._info_window.show()

        self.setupViewModeWidget(self)
        self.setupFileChangedNotificationWidget()
        self._selected_mesh_ent_info = SelectedMeshEntityInfoWidget(self)
        self._selected_mesh_ent_info.setVisible(False)

        self.deselect_sc = QtWidgets.QShortcut(
            QtGui.QKeySequence(QtCore.Qt.Key_Space), self)
        self.deselect_sc.activated.connect(self.onDeselect)

        self.setupExplodeWidgets()

    def setupViewModeWidget(self, frame):
        self._view_menu = QtWidgets.QMenu()
        self._shaded_action = self._view_menu.addAction("Shaded")
        self._shaded_action.setCheckable(True)
        self._shaded_action.setShortcut("Ctrl+1")
        self._shaded_w_edges_action = self._view_menu.addAction(
            "Shaded with edges")
        self._shaded_w_edges_action.setCheckable(True)
        self._shaded_w_edges_action.setShortcut("Ctrl+2")
        self._hidden_edges_removed_action = self._view_menu.addAction(
            "Hidden edges removed")
        self._hidden_edges_removed_action.setCheckable(True)
        self._hidden_edges_removed_action.setShortcut("Ctrl+3")
        self._transluent_action = self._view_menu.addAction("Transluent")
        self._transluent_action.setCheckable(True)
        self._transluent_action.setShortcut("Ctrl+4")
        self._shaded_w_edges_action.setChecked(True)
        self._render_mode = self.SHADED_WITH_EDGES

        self._visual_repr = QtWidgets.QActionGroup(self._view_menu)
        self._visual_repr.addAction(self._shaded_action)
        self._visual_repr.addAction(self._shaded_w_edges_action)
        self._visual_repr.addAction(self._hidden_edges_removed_action)
        self._visual_repr.addAction(self._transluent_action)
        self._visual_repr.setExclusive(True)

        self._view_menu.addSeparator()
        self._perspective_action = self._view_menu.addAction("Perspective")
        self._perspective_action.setCheckable(True)
        self._perspective_action.setChecked(True)

        self._view_menu.addSeparator()
        self._ori_marker_action = self._view_menu.addAction(
            "Orientation marker")
        self._ori_marker_action.setCheckable(True)
        self._ori_marker_action.setChecked(True)

        self._shaded_action.triggered.connect(self.onShadedTriggered)
        self._shaded_w_edges_action.triggered.connect(
            self.onShadedWithEdgesTriggered)
        self._hidden_edges_removed_action.triggered.connect(
            self.onHiddenEdgesRemovedTriggered)
        self._transluent_action.triggered.connect(self.onTransluentTriggered)
        self._perspective_action.toggled.connect(self.onPerspectiveToggled)
        self._ori_marker_action.toggled.connect(
            self.onOrientationmarkerVisibilityChanged)

        self._view_mode = QtWidgets.QPushButton(frame)
        self._view_mode.setFixedSize(60, 32)
        self._view_mode.setIcon(Assets().icons['render-mode'])
        self._view_mode.setMenu(self._view_menu)
        self._view_mode.show()

    def setupFileChangedNotificationWidget(self):
        self._file_changed_notification = FileChangedNotificationWidget(self)
        self._file_changed_notification.setVisible(False)
        self._file_changed_notification.reload.connect(self.onReloadFile)

    def setupMenuBar(self):
        file_menu = self._menubar.addMenu("File")
        self._new_action = file_menu.addAction("New", self.onNewFile, "Ctrl+N")
        self._open_action = file_menu.addAction("Open", self.onOpenFile,
                                                "Ctrl+O")
        self._recent_menu = file_menu.addMenu("Open Recent")
        self.buildRecentFilesMenu()
        file_menu.addSeparator()
        export_menu = file_menu.addMenu("Export as...")
        self.setupExportMenu(export_menu)
        file_menu.addSeparator()
        self._close_action = file_menu.addAction("Close", self.onClose,
                                                 "Ctrl+W")

        view_menu = self._menubar.addMenu("View")
        view_menu.addAction(self._shaded_action)
        view_menu.addAction(self._shaded_w_edges_action)
        view_menu.addAction(self._hidden_edges_removed_action)
        view_menu.addAction(self._transluent_action)
        view_menu.addSeparator()
        self._view_info_wnd_action = view_menu.addAction(
            "Info window", self.onViewInfoWindow)
        self._view_info_wnd_action.setCheckable(True)
        color_profile_menu = view_menu.addMenu("Color profile")
        self.setupColorProfileMenu(color_profile_menu)

        tools_menu = self._menubar.addMenu("Tools")
        self.setupSelectModeMenu(tools_menu)
        self._tools_explode_action = tools_menu.addAction(
            "Explode", self.onToolsExplode)

    def setupExportMenu(self, menu):
        menu.addAction("PNG...", self.onExportAsPng)
        menu.addAction("JPG...", self.onExportAsJpg)

    def setupColorProfileMenu(self, menu):
        self._color_profile_action_group = QtWidgets.QActionGroup(self)
        self._color_profile_id = self.plugin.settings.value(
            "color_profile", self.COLOR_PROFILE_DEFAULT)
        color_profiles = [{
            'name': 'Default',
            'id': self.COLOR_PROFILE_DEFAULT
        }, {
            'name': 'Light',
            'id': self.COLOR_PROFILE_LIGHT
        }, {
            'name': 'Dark',
            'id': self.COLOR_PROFILE_DARK
        }]
        for cp in color_profiles:
            name = cp['name']
            id = cp['id']
            action = menu.addAction(name)
            action.setCheckable(True)
            action.setData(id)
            self._color_profile_action_group.addAction(action)
            if id == self._color_profile_id:
                action.setChecked(True)

        self._color_profile_action_group.triggered.connect(
            self.onColorProfileTriggered)

    def setupSelectModeMenu(self, tools_menu):
        select_menu = tools_menu.addMenu("Select mode")
        self._mode_select_action_group = QtWidgets.QActionGroup(self)
        self._select_mode = self.plugin.settings.value("tools/select_mode",
                                                       self.MODE_SELECT_BLOCKS)
        mode_actions = [{
            'name': 'Blocks',
            'mode': self.MODE_SELECT_BLOCKS
        }, {
            'name': 'Cells',
            'mode': self.MODE_SELECT_CELLS
        }, {
            'name': 'Points',
            'mode': self.MODE_SELECT_POINTS
        }]
        for ma in mode_actions:
            name = ma['name']
            mode = ma['mode']
            action = select_menu.addAction(name)
            action.setCheckable(True)
            action.setData(mode)
            self._mode_select_action_group.addAction(action)
            if mode == self._select_mode:
                action.setChecked(True)

        self._mode_select_action_group.triggered.connect(
            self.onSelectModeTriggered)

    def setupExplodeWidgets(self):
        self._explode = ExplodeWidget(self)
        self._explode.valueChanged.connect(self.onExplodeValueChanged)
        self._explode.setVisible(False)

    def updateMenuBar(self):
        self._view_info_wnd_action.setChecked(self._info_window.isVisible())
        self._tools_explode_action.setEnabled(self._file_name is not None)

    def connectSignals(self):
        self.fileLoaded.connect(self._info_window.onFileLoaded)
        self.boundsChanged.connect(self._info_window.onBoundsChanged)
        self._info_window.blockVisibilityChanged.connect(
            self.onBlockVisibilityChanged)
        self._info_window.blockColorChanged.connect(self.onBlockColorChanged)
        self._info_window.blockSelectionChanged.connect(
            self.onBlockSelectionChanged)
        self._info_window.sidesetVisibilityChanged.connect(
            self.onSidesetVisibilityChanged)
        self._info_window.sidesetSelectionChanged.connect(
            self.onSidesetSelectionChanged)
        self._info_window.nodesetVisibilityChanged.connect(
            self.onNodesetVisibilityChanged)
        self._info_window.nodesetSelectionChanged.connect(
            self.onNodesetSelectionChanged)
        self._info_window.dimensionsStateChanged.connect(
            self.onCubeAxisVisibilityChanged)
        self._file_watcher.fileChanged.connect(self.onFileChanged)

    def setupVtk(self):
        self._vtk_render_window = self._vtk_widget.GetRenderWindow()
        self._vtk_interactor = self._vtk_render_window.GetInteractor()

        self._vtk_interactor.SetInteractorStyle(OtterInteractorStyle3D(self))

        # TODO: set background from preferences/templates
        self._vtk_renderer.SetGradientBackground(True)
        # set anti-aliasing on
        self._vtk_renderer.SetUseFXAA(True)
        self._vtk_render_window.SetMultiSamples(1)

        self._vtk_widget.AddObserver('StartInteractionEvent',
                                     self.onStartInteraction)
        self._vtk_widget.AddObserver('EndInteractionEvent',
                                     self.onEndInteraction)

    def resizeEvent(self, event):
        super().resizeEvent(event)
        self._updateViewModeLocation()

    def getRenderWindowWidth(self):
        return self.geometry().width()

    def _updateViewModeLocation(self):
        width = self.getRenderWindowWidth()
        self._view_mode.move(width - 5 - self._view_mode.width(), 10)

    def onStartInteraction(self, obj, event):
        pass

    def onEndInteraction(self, obj, event):
        pass

    def dragEnterEvent(self, event):
        if event.mimeData().hasUrls():
            event.accept()
        else:
            event.ignore()

    def dropEvent(self, event):
        if event.mimeData().hasUrls():
            event.setDropAction(QtCore.Qt.CopyAction)
            event.accept()

            file_names = []
            for url in event.mimeData().urls():
                file_names.append(url.toLocalFile())
            if len(file_names) > 0:
                self.loadFile(file_names[0])
        else:
            event.ignore()

    def clear(self):
        self._blocks = {}
        self._side_sets = {}
        self._node_sets = {}
        self._vtk_renderer.RemoveAllViewProps()

        watched_files = self._file_watcher.files()
        for file in watched_files:
            self._file_watcher.removePath(file)

        self._selection = None

    def loadFile(self, file_name):
        self.clear()

        self._progress = QtWidgets.QProgressDialog(
            "Loading {}...".format(os.path.basename(file_name)), None, 0, 0,
            self)
        self._progress.setWindowModality(QtCore.Qt.WindowModal)
        self._progress.setMinimumDuration(0)
        self._progress.show()

        self._load_thread = LoadThread(file_name)
        self._load_thread.finished.connect(self.onLoadFinished)
        self._load_thread.start(QtCore.QThread.IdlePriority)

    def onLoadFinished(self):
        reader = self._load_thread.getReader()

        self._addBlocks()
        self._addSidesets()
        self._addNodeSets()

        gmin = QtGui.QVector3D(float('inf'), float('inf'), float('inf'))
        gmax = QtGui.QVector3D(float('-inf'), float('-inf'), float('-inf'))
        for block in self._blocks.values():
            bmin, bmax = block.bounds
            gmin = common.point_min(bmin, gmin)
            gmax = common.point_max(bmax, gmax)
        bnds = [gmin.x(), gmax.x(), gmin.y(), gmax.y(), gmin.z(), gmax.z()]

        self._com = common.centerOfBounds(bnds)
        self._cube_axes_actor.SetBounds(*bnds)
        self._vtk_renderer.AddViewProp(self._cube_axes_actor)

        params = {
            'blocks': reader.getBlocks(),
            'sidesets': reader.getSideSets(),
            'nodesets': reader.getNodeSets(),
            'total_elems': reader.getTotalNumberOfElements(),
            'total_nodes': reader.getTotalNumberOfNodes()
        }
        self.fileLoaded.emit(params)
        self.boundsChanged.emit(bnds)

        self._file_name = reader.getFileName()
        self.updateWindowTitle()
        self.addToRecentFiles(self._file_name)
        self._file_watcher.addPath(self._file_name)
        self._file_changed_notification.setFileName(self._file_name)

        self._selection = Selection(self._geometry.GetOutput())
        self._setSelectionProperties(self._selection)
        self._vtk_renderer.AddActor(self._selection.getActor())

        self._progress.hide()
        self._progress = None

        self.updateMenuBar()

        if reader.getDimensionality() == 3:
            style = OtterInteractorStyle3D(self)
        else:
            style = OtterInteractorStyle2D(self)
        self._vtk_interactor.SetInteractorStyle(style)

        camera = self._vtk_renderer.GetActiveCamera()
        focal_point = camera.GetFocalPoint()
        camera.SetPosition(focal_point[0], focal_point[1], 1)
        camera.SetRoll(0)
        self._vtk_renderer.ResetCamera()

    def _addBlocks(self):
        camera = self._vtk_renderer.GetActiveCamera()
        reader = self._load_thread.getReader()

        for index, binfo in enumerate(reader.getBlocks()):
            eb = vtk.vtkExtractBlock()
            eb.SetInputConnection(reader.getVtkOutputPort())
            eb.AddIndex(binfo.multiblock_index)
            eb.Update()

            block = BlockObject(eb, camera)
            self._setBlockProperties(block)
            self._blocks[binfo.number] = block

            self._vtk_renderer.AddViewProp(block.actor)
            self._vtk_renderer.AddViewProp(block.silhouette_actor)
            # FIXME: make this work with multiple blocks
            self._geometry = block.geometry

    def _addSidesets(self):
        reader = self._load_thread.getReader()

        for index, finfo in enumerate(reader.getSideSets()):
            eb = vtk.vtkExtractBlock()
            eb.SetInputConnection(reader.getVtkOutputPort())
            eb.AddIndex(finfo.multiblock_index)
            eb.Update()

            sideset = SideSetObject(eb)
            self._side_sets[finfo.number] = sideset
            self._vtk_renderer.AddViewProp(sideset.actor)
            self._setSideSetProperties(sideset)

    def _addNodeSets(self):
        reader = self._load_thread.getReader()

        for index, ninfo in enumerate(reader.getNodeSets()):
            eb = vtk.vtkExtractBlock()
            eb.SetInputConnection(reader.getVtkOutputPort())
            eb.AddIndex(ninfo.multiblock_index)
            eb.Update()

            nodeset = NodeSetObject(eb)
            self._node_sets[ninfo.number] = nodeset
            self._vtk_renderer.AddViewProp(nodeset.actor)
            self._setNodeSetProperties(nodeset)

    def _setupCubeAxesActor(self):
        self._cube_axes_actor = vtk.vtkCubeAxesActor()
        self._cube_axes_actor.VisibilityOff()
        self._cube_axes_actor.SetCamera(self._vtk_renderer.GetActiveCamera())
        self._cube_axes_actor.SetGridLineLocation(
            vtk.vtkCubeAxesActor.VTK_GRID_LINES_ALL)
        self._cube_axes_actor.SetFlyMode(
            vtk.vtkCubeAxesActor.VTK_FLY_OUTER_EDGES)

    def _getBlock(self, block_id):
        return self._blocks[block_id]

    def _getSideSet(self, sideset_id):
        return self._side_sets[sideset_id]

    def _getNodeSet(self, nodeset_id):
        return self._node_sets[nodeset_id]

    def renderMode(self):
        return self._render_mode

    def _setupOrientationMarker(self):
        axes = vtk.vtkAxesActor()
        self._ori_marker = vtk.vtkOrientationMarkerWidget()
        self._ori_marker.SetOrientationMarker(axes)
        self._ori_marker.SetViewport(0.8, 0, 1.0, 0.2)
        self._ori_marker.SetInteractor(self._vtk_interactor)
        self._ori_marker.SetEnabled(1)
        self._ori_marker.SetInteractive(False)

    def onBlockVisibilityChanged(self, block_id, visible):
        block = self._getBlock(block_id)
        block.setVisible(visible)
        if (self.renderMode() == self.HIDDEN_EDGES_REMOVED
                or self.renderMode() == self.TRANSLUENT):
            block.setSilhouetteVisible(block.visible)
        else:
            block.setSilhouetteVisible(False)

    def onBlockColorChanged(self, block_id, qcolor):
        clr = [qcolor.redF(), qcolor.greenF(), qcolor.blueF()]
        block = self._getBlock(block_id)
        block.setColor(clr)

        property = block.property
        if self.renderMode() == self.HIDDEN_EDGES_REMOVED:
            property.SetColor([1, 1, 1])
        else:
            property.SetColor(clr)

    def onSidesetVisibilityChanged(self, sideset_id, visible):
        sideset = self._getSideSet(sideset_id)
        sideset.setVisible(visible)

    def onNodesetVisibilityChanged(self, nodeset_id, visible):
        nodeset = self._getNodeSet(nodeset_id)
        nodeset.setVisible(visible)

    def onCubeAxisVisibilityChanged(self, visible):
        if visible:
            self._cube_axes_actor.VisibilityOn()
        else:
            self._cube_axes_actor.VisibilityOff()

    def onOrientationmarkerVisibilityChanged(self, visible):
        if visible:
            self._ori_marker.EnabledOn()
        else:
            self._ori_marker.EnabledOff()

    def onOpenFile(self):
        file_name, f = QtWidgets.QFileDialog.getOpenFileName(
            self, 'Open File', "", "ExodusII files (*.e *.exo);;"
            "HDF5 PETSc files (*.h5);;"
            "VTK Unstructured Grid files (*.vtk)")
        if file_name:
            self.loadFile(file_name)

    def onNewFile(self):
        self.clear()
        self.fileLoaded.emit(None)
        self.boundsChanged.emit([])
        self._file_name = None
        self.updateWindowTitle()

    def updateWindowTitle(self):
        if self._file_name is None:
            self.setWindowTitle("Mesh Inspector")
        else:
            self.setWindowTitle("Mesh Inspector \u2014 {}".format(
                os.path.basename(self._file_name)))

    def onShadedTriggered(self, checked):
        self._render_mode = self.SHADED
        for block_id, block in self._blocks.items():
            selected = self._selected_block == block_id
            self._setBlockProperties(block, selected)
            block.setSilhouetteVisible(False)
        for sideset in self._side_sets.values():
            self._setSideSetProperties(sideset)

    def onShadedWithEdgesTriggered(self, checked):
        self._render_mode = self.SHADED_WITH_EDGES
        for block_id, block in self._blocks.items():
            selected = self._selected_block == block_id
            self._setBlockProperties(block, selected)
            block.setSilhouetteVisible(False)
        for sideset in self._side_sets.values():
            self._setSideSetProperties(sideset)

    def onHiddenEdgesRemovedTriggered(self, checked):
        self._render_mode = self.HIDDEN_EDGES_REMOVED
        for block_id, block in self._blocks.items():
            selected = self._selected_block == block_id
            self._setBlockProperties(block, selected)
            block.setSilhouetteVisible(block.visible)
        for sideset in self._side_sets.values():
            self._setSideSetProperties(sideset)

    def onTransluentTriggered(self, checked):
        self._render_mode = self.TRANSLUENT
        for block_id, block in self._blocks.items():
            selected = self._selected_block == block_id
            self._setBlockProperties(block, selected)
            block.setSilhouetteVisible(block.visible)
        for sideset in self._side_sets.values():
            self._setSideSetProperties(sideset)

    def onPerspectiveToggled(self, checked):
        if checked:
            camera = self._vtk_renderer.GetActiveCamera()
            camera.ParallelProjectionOff()
        else:
            camera = self._vtk_renderer.GetActiveCamera()
            camera.ParallelProjectionOn()

    def _setSelectedBlockProperties(self, block):
        property = block.property
        if self.renderMode() == self.SHADED:
            property.SetColor(common.qcolor2vtk(self.SIDESET_CLR))
            property.SetOpacity(1.0)
            property.SetEdgeVisibility(False)
        elif self.renderMode() == self.SHADED_WITH_EDGES:
            property.SetColor(common.qcolor2vtk(self.SIDESET_CLR))
            property.SetOpacity(1.0)
            property.SetEdgeVisibility(True)
            property.SetEdgeColor(common.qcolor2vtk(self.SIDESET_EDGE_CLR))
            property.SetLineWidth(2)
        elif self.renderMode() == self.HIDDEN_EDGES_REMOVED:
            property.SetColor(common.qcolor2vtk(self.SIDESET_CLR))
            property.SetOpacity(1.0)
            property.SetEdgeVisibility(False)
        elif self.renderMode() == self.TRANSLUENT:
            property.SetColor(common.qcolor2vtk(self.SIDESET_CLR))
            property.SetOpacity(0.33)
            property.SetEdgeVisibility(False)

    def _setDeselectedBlockProperties(self, block):
        property = block.property
        if self.renderMode() == self.SHADED:
            property.SetColor(block.color)
            property.SetOpacity(1.0)
            property.SetEdgeVisibility(False)
        elif self.renderMode() == self.SHADED_WITH_EDGES:
            property.SetColor(block.color)
            property.SetOpacity(1.0)
            property.SetEdgeVisibility(True)
            property.SetEdgeColor(common.qcolor2vtk(self.SIDESET_EDGE_CLR))
            property.SetLineWidth(2)
        elif self.renderMode() == self.HIDDEN_EDGES_REMOVED:
            property.SetColor([1, 1, 1])
            property.SetOpacity(1.0)
            property.SetEdgeVisibility(False)
        elif self.renderMode() == self.TRANSLUENT:
            property.SetColor(block.color)
            property.SetOpacity(0.33)
            property.SetEdgeVisibility(False)

    def _setBlockProperties(self, block, selected=False):
        property = block.property
        property.SetAmbient(0.4)
        property.SetDiffuse(0.6)
        if selected:
            self._setSelectedBlockProperties(block)
        else:
            self._setDeselectedBlockProperties(block)

    def _setSideSetProperties(self, sideset):
        property = sideset.property
        if self.renderMode() == self.SHADED:
            property.SetColor(common.qcolor2vtk(self.SIDESET_CLR))
            property.SetEdgeVisibility(False)
            property.SetEdgeColor(common.qcolor2vtk(self.SIDESET_EDGE_CLR))
            property.SetLineWidth(self.SIDESET_EDGE_WIDTH)
            property.LightingOff()
        elif self.renderMode() == self.SHADED_WITH_EDGES:
            property.SetColor(common.qcolor2vtk(self.SIDESET_CLR))
            property.SetEdgeVisibility(False)
            property.SetEdgeColor(common.qcolor2vtk(self.SIDESET_EDGE_CLR))
            property.SetLineWidth(self.SIDESET_EDGE_WIDTH)
            property.LightingOff()
        elif self.renderMode() == self.HIDDEN_EDGES_REMOVED:
            property.SetColor(common.qcolor2vtk(self.SIDESET_CLR))
            property.SetEdgeVisibility(False)
            property.LightingOff()
        elif self.renderMode() == self.TRANSLUENT:
            property.SetColor(common.qcolor2vtk(self.SIDESET_CLR))
            property.SetEdgeVisibility(False)
            property.LightingOff()

    def _setNodeSetProperties(self, nodeset):
        property = nodeset.property
        property.SetRepresentationToPoints()
        property.SetRenderPointsAsSpheres(True)
        property.SetVertexVisibility(True)
        property.SetEdgeVisibility(False)
        property.SetPointSize(10)
        property.SetColor(common.qcolor2vtk(self.NODESET_CLR))
        property.SetOpacity(1)
        property.SetAmbient(1)
        property.SetDiffuse(0)

    def _setSelectionProperties(self, selection):
        actor = selection.getActor()
        property = actor.GetProperty()
        if self._select_mode == self.MODE_SELECT_CELLS:
            property.SetRepresentationToSurface()
            property.SetRenderPointsAsSpheres(False)
            property.SetVertexVisibility(False)
            property.SetPointSize(0)
            property.EdgeVisibilityOn()
            property.SetColor(common.qcolor2vtk(self.SELECTION_CLR))
            property.SetLineWidth(7)
            property.SetEdgeColor(common.qcolor2vtk(self.SELECTION_EDGE_CLR))
            property.SetOpacity(0.5)
            property.SetAmbient(1)
            property.SetDiffuse(0)
        elif self._select_mode == self.MODE_SELECT_POINTS:
            property.SetRepresentationToPoints()
            property.SetRenderPointsAsSpheres(True)
            property.SetVertexVisibility(True)
            property.SetEdgeVisibility(False)
            property.SetPointSize(15)
            property.SetColor(common.qcolor2vtk(self.SELECTION_CLR))
            property.SetOpacity(1)
            property.SetAmbient(1)
            property.SetDiffuse(0)

    def onUpdateWindow(self):
        self._vtk_render_window.Render()

    def event(self, e):
        if e.type() == LoadFileEvent.TYPE:
            self.loadFile(e.fileName())
            return True
        else:
            return super().event(e)

    def closeEvent(self, event):
        self.plugin.settings.setValue("tools/select_mode", self._select_mode)
        self.plugin.settings.setValue("color_profile", self._color_profile_id)
        super().closeEvent(event)

    def onFileChanged(self, path):
        if path not in self._file_watcher.files():
            self._file_watcher.addPath(path)
        self.showFileChangedNotification()

    def showFileChangedNotification(self):
        self._file_changed_notification.adjustSize()
        width = self.getRenderWindowWidth()
        left = (width - self._file_changed_notification.width()) / 2
        top = 10
        self._file_changed_notification.setGeometry(
            left, top, self._file_changed_notification.width(),
            self._file_changed_notification.height())
        self._file_changed_notification.show()

    def onReloadFile(self):
        self.loadFile(self._file_name)

    def _showSelectedMeshEntity(self):
        self._selected_mesh_ent_info.adjustSize()

        wnd_geom = self.geometry()
        widget_geom = self._selected_mesh_ent_info.geometry()

        self._selected_mesh_ent_info.move(
            wnd_geom.width() - widget_geom.width() - 10,
            wnd_geom.height() - widget_geom.height() - 5)
        self._selected_mesh_ent_info.show()

    def onBlockSelectionChanged(self, block_id):
        self._deselectBlocks()
        if block_id in self._blocks:
            block = self._getBlock(block_id)
            self._selected_mesh_ent_info.setBlockInfo(block_id, block.info)
            self._showSelectedMeshEntity()
            self._selected_block = block_id
            self._setBlockProperties(block, selected=True)
        else:
            self._selected_mesh_ent_info.hide()

    def _deselectBlocks(self):
        blk_id = self._selected_block
        if blk_id is not None:
            block = self._getBlock(blk_id)
            self._setBlockProperties(block, selected=False)
            self._selected_block = None

    def onSidesetSelectionChanged(self, sideset_id):
        if sideset_id in self._side_sets:
            ss = self._side_sets[sideset_id]
            self._selected_mesh_ent_info.setSidesetInfo(sideset_id, ss.info)
            self._showSelectedMeshEntity()
        else:
            self._selected_mesh_ent_info.hide()

    def onNodesetSelectionChanged(self, nodeset_id):
        if nodeset_id in self._node_sets:
            ns = self._node_sets[nodeset_id]
            self._selected_mesh_ent_info.setNodesetInfo(nodeset_id, ns.info)
            self._showSelectedMeshEntity()
        else:
            self._selected_mesh_ent_info.hide()

    def _blockActorToId(self, actor):
        # TODO: when we start to have 1000s of actors, this should be an
        # inverse dictionary from 'actor' to 'block_id'
        for blk_id, block in self._blocks.items():
            if block.actor == actor:
                return blk_id
        return None

    def _selectBlock(self, pt):
        picker = vtk.vtkPropPicker()
        if picker.PickProp(pt.x(), pt.y(), self._vtk_renderer):
            actor = picker.GetViewProp()
            blk_id = self._blockActorToId(actor)
            self.onBlockSelectionChanged(blk_id)

    def _buildCellInfo(self, cell):
        nfo = {'type': cell.GetCellType()}
        return nfo

    def _selectCell(self, pt):
        picker = vtk.vtkCellPicker()
        if picker.Pick(pt.x(), pt.y(), 0, self._vtk_renderer):
            cell_id = picker.GetCellId()
            self._selection.selectCell(cell_id)
            self._setSelectionProperties(self._selection)

            unstr_grid = self._selection.get()
            cell = unstr_grid.GetCell(0)
            nfo = self._buildCellInfo(cell)
            self._selected_mesh_ent_info.setCellInfo(cell_id, nfo)
            self._showSelectedMeshEntity()

    def _buildPointInfo(self, points):
        coords = points.GetPoint(0)
        nfo = {'coords': coords}
        return nfo

    def _selectPoint(self, pt):
        picker = vtk.vtkPointPicker()
        if picker.Pick(pt.x(), pt.y(), 0, self._vtk_renderer):
            point_id = picker.GetPointId()
            self._selection.selectPoint(point_id)
            self._setSelectionProperties(self._selection)

            unstr_grid = self._selection.get()
            points = unstr_grid.GetPoints()
            nfo = self._buildPointInfo(points)
            self._selected_mesh_ent_info.setPointInfo(point_id, nfo)
            self._showSelectedMeshEntity()

    def onClicked(self, pt):
        self.onDeselect()
        if self._select_mode == self.MODE_SELECT_BLOCKS:
            self._selectBlock(pt)
        elif self._select_mode == self.MODE_SELECT_CELLS:
            self._selectCell(pt)
        elif self._select_mode == self.MODE_SELECT_POINTS:
            self._selectPoint(pt)

    def onViewInfoWindow(self):
        if self._info_window.isVisible():
            self._info_window.hide()
        else:
            self._info_window.show()
        self.updateMenuBar()

    def onSelectModeTriggered(self, action):
        action.setChecked(True)
        self._select_mode = action.data()

    def onDeselect(self):
        if self._selection is not None:
            self.onBlockSelectionChanged(None)
            self._selection.clear()

    def onColorProfileTriggered(self, action):
        action.setChecked(True)
        self._color_profile_id = action.data()
        self.setColorProfile()

    def setColorProfile(self):
        if self._color_profile_id in self._color_profiles:
            profile = self._color_profiles[self._color_profile_id]
        else:
            profile = self._color_profiles[self.COLOR_PROFILE_DEFAULT]

        bkgnd = common.rgb2vtk(profile['bkgnd'])
        self._vtk_renderer.SetBackground(bkgnd)
        self._vtk_renderer.SetBackground2(bkgnd)

    def loadColorProfiles(self):
        # TODO: load the profiles via import and iterating over files in
        # 'color_profile' folder

        self._color_profiles = {}
        self._color_profiles[self.COLOR_PROFILE_DEFAULT] = default.profile
        self._color_profiles[self.COLOR_PROFILE_LIGHT] = light.profile
        self._color_profiles[self.COLOR_PROFILE_DARK] = dark.profile

    def getFileName(self, window_title, name_filter, default_suffix):
        dialog = QtWidgets.QFileDialog()
        dialog.setWindowTitle(window_title)
        dialog.setNameFilter(name_filter)
        dialog.setFileMode(QtWidgets.QFileDialog.AnyFile)
        dialog.setAcceptMode(QtWidgets.QFileDialog.AcceptSave)
        dialog.setDefaultSuffix(default_suffix)

        if dialog.exec_() == QtWidgets.QDialog.Accepted:
            return str(dialog.selectedFiles()[0])
        return None

    def onExportAsPng(self):
        file_name = self.getFileName('Export to PNG', 'PNG files (*.png)',
                                     'png')
        if file_name:
            windowToImageFilter = vtk.vtkWindowToImageFilter()
            windowToImageFilter.SetInput(self._vtk_render_window)
            windowToImageFilter.SetInputBufferTypeToRGBA()
            windowToImageFilter.ReadFrontBufferOff()
            windowToImageFilter.Update()

            writer = vtk.vtkPNGWriter()
            writer.SetFileName(file_name)
            writer.SetInputConnection(windowToImageFilter.GetOutputPort())
            writer.Write()

    def onExportAsJpg(self):
        file_name = self.getFileName('Export to JPG', 'JPG files (*.jpg)',
                                     'jpg')
        if file_name:
            windowToImageFilter = vtk.vtkWindowToImageFilter()
            windowToImageFilter.SetInput(self._vtk_render_window)
            windowToImageFilter.ReadFrontBufferOff()
            windowToImageFilter.Update()

            writer = vtk.vtkJPEGWriter()
            writer.SetFileName(file_name)
            writer.SetInputConnection(windowToImageFilter.GetOutputPort())
            writer.Write()

    def onToolsExplode(self):
        self._explode.adjustSize()
        render_win_geom = self.geometry()
        left = (render_win_geom.width() - self._explode.width()) / 2
        top = render_win_geom.height() - self._explode.height() - 10
        self._explode.setGeometry(left, top, self._explode.width(),
                                  self._explode.height())
        self._explode.show()

    def onExplodeValueChanged(self, value):
        dist = value / self._explode.range()
        for blk_id, block in self._blocks.items():
            blk_com = block.cob
            cntr = QtGui.QVector3D(self._com[0], self._com[1], self._com[2])
            blk_cntr = QtGui.QVector3D(blk_com[0], blk_com[1], blk_com[2])
            dir = blk_cntr - cntr
            dir.normalize()
            dir = -dist * dir
            pos = [dir.x(), dir.y(), dir.z()]
            block.actor.SetPosition(pos)
Esempio n. 5
0
class E_VolumeRenderingWidget(QWidget):
    def __init__(self, parent = None):
        super(E_VolumeRenderingWidget, self).__init__(parent)
        self.setMaximumWidth(300)

        self.mainLayout = QVBoxLayout()
        self.mainLayout.setSpacing(0)        
        self.setLayout(self.mainLayout)

        self.m_widget = QVTKRenderWindowInteractor();
        self.m_widget.setMaximumHeight(100)
        self.m_widget.AddObserver('MouseMoveEvent', self.onMouseMove)
        self.m_widget.AddObserver('LeftButtonPressEvent', self.onLeftDown, 1.0)
        self.m_widget.AddObserver('LeftButtonReleaseEvent', self.onLeftUp, -1.0)
        self.m_widget.hide()
        self.m_bClicked = False

        self.m_view = vtk.vtkContextView()
        self.m_histogramChart  = vtk.vtkChartXY()

        self.Initialize()

    def SetManager(self, Mgr):
        self.Mgr = Mgr


        # #TEST function
        # colorFunc = self.Mgr.VolumeMgr.m_colorFunctions[0]
        # self.onChangeIndex()

    def Initialize(self):        

        #CTF Controller
        # comboTitle = QLabel("Volume CTF")
        # comboTitle.setStyleSheet("QLabel {border: 1px solid gray;border-radius: 2px;background-color: white;margin: 0px 0px 0px 0px;}");        
        # self.addWidget(comboTitle)

        #Add ComboBox
        self.combo = QComboBox()
        self.combo.addItem("white")
        self.combo.addItem("SKIN")
        self.combo.addItem("BONE")
        self.combo.addItem("Binary Voxel")
        self.combo.currentIndexChanged.connect( self.onChangeIndex )

        self.addWidget(self.combo)

        #OTF Controller
        onOffVolumeOTF = QCheckBox("Volume OTF")
        onOffVolumeOTF.setCheckState(0)
        onOffVolumeOTF.stateChanged.connect(self.onVolumeOTFState)
        self.addWidget(onOffVolumeOTF)


        #Initialize Histogram
        self.addWidget(self.m_widget)
        self.m_view.SetRenderWindow(self.m_widget.GetRenderWindow())
        self.m_view.GetRenderer().SetBackground(0.0, 0.0, 0.0)
        self.m_view.GetScene().AddItem(self.m_histogramChart)

        #Initialize Chart
        self.m_histogramChart.ForceAxesToBoundsOn()
        self.m_histogramChart.SetAutoAxes(False)
        self.m_histogramChart.SetAutoSize(True)
        self.m_histogramChart.SetHiddenAxisBorder(0)
        self.m_histogramChart.GetAxis(0).SetVisible(False)
        self.m_histogramChart.GetAxis(1).SetVisible(False)
        self.m_histogramChart.SetActionToButton(1, -1)



    def GetCurrentColorIndex(self):
        return self.combo.currentIndex()


    def addWidget(self, widget):
        self.mainLayout.addWidget(widget)

    def Redraw(self):
        self.m_view.Update()
        self.m_view.Render()


    def onChangeIndex(self, idx, Update = True):

        #Update Preset Function
        self.Mgr.VolumeMgr.SetPresetFunctions(idx)

        #Plot CTF
        colorFunc = self.Mgr.VolumeMgr.m_colorFunction
        self.m_histogramChart.ClearPlots()
        colorPlot = vtk.vtkColorTransferFunctionItem()
        colorPlot.SetColorTransferFunction(colorFunc)
        self.m_histogramChart.AddPlot(colorPlot)

        #Plot OTF
        opacityFunc = self.Mgr.VolumeMgr.m_opacityFunction
        opacPlot = vtk.vtkPiecewiseFunctionItem()
        opacPlot.SetPiecewiseFunction(opacityFunc)
        self.m_histogramChart.AddPlot(opacPlot)

        opacityPoint = vtk.vtkPiecewiseControlPointsItem()
        opacityPoint.SetPiecewiseFunction(opacityFunc)
        opacityPoint.SetWidth(10.0)
        self.m_histogramChart.AddPlot(opacityPoint)

        #Recalculate Bounds
        sRange = self.Mgr.VolumeMgr.m_scalarRange
        self.m_histogramChart.GetAxis(vtk.vtkAxis.BOTTOM).SetRange(sRange[0], sRange[1])
        self.m_histogramChart.GetAxis(vtk.vtkAxis.BOTTOM).Update()

        #redraw Plot
        self.Redraw()

        if Update:
            self.Mgr.Redraw()

    def onVolumeOTFState(self, state):
        if state == 2:
            self.m_widget.show()
        else:
            self.m_widget.hide()


    def onLeftDown(self, obj, event):
        self.m_bClicked = True;

    def onLeftUp(self, obj, event):
        self.m_bClicked = False;
        self.Mgr.Redraw()

    def onMouseMove(self, obj, event):
        if self.m_bClicked == True:
            self.Mgr.Redraw()
Esempio n. 6
0
class XYZviewer(QtWidgets.QFrame):
    pickedPointSignal = QtCore.pyqtSignal(int)
    def __init__(self, parent, dataPath):
        super(XYZviewer,self).__init__(parent)
        self.interactor = QVTKRenderWindowInteractor(self)
        self.layout = QtWidgets.QHBoxLayout()
        self.layout.addWidget(self.interactor)
        self.layout.setContentsMargins(0,0,0,0)
        self.setLayout(self.layout)
        self.pointCloud = VtkPointCloud()
        self.pcdCollection=[]
        self.actors = []
        self.pickedID=[]
        self.e=ErrorObserver()
        self.interactor.AddObserver("ErrorEvent",self.e)
        if self.e.ErrorOccurred():
            print(e.ErrorMessage)
        #self.load_data(loadPath)
        if dataPath != None:
            self.add_data(dataPath)
    # Renderer
        self.renderer = vtk.vtkRenderer()
        self.cubeAxesActor = vtk.vtkCubeAxesActor()
        self.setCubeAxesActor()
        self.cubeAxesActor.SetBounds(0,100,0,100,0,100)
        self.renderer.AddActor(self.cubeAxesActor)
    # Scalar Bar
        self.scalarBarActor = vtk.vtkScalarBarActor()
        self.setScalarBar()
        self.renderer.AddActor(self.scalarBarActor)
    #renderer.SetBackground(.2, .3, .4)
        #colors=vtk.vtkNamedColors()
        #colors.SetColor("BkgColor",[179,204,255,255])
        #renderer.SetBackground(colors.GetColor3d("BkgColor"))
        self.pointCloud.setLUTRange(0,10)
        #cam=self.renderer.GetActiveCamera()
        #cam.Azimuth(-45)
        #cam.Elevation(0)
        #cam.Roll(90)
        #cam.SetViewUp(0,0,1)
        #cam.SetPosition(0,1,0)
        #cam.SetParallelProjection(0)
        #cam.Elevation(-10)
        #self.renderer.SetActiveCamera(cam)
        #self.renderer.ResetCamera()
        #renderer.SetLayer(1)
     
    # Render Window
        renderWindow = self.interactor.GetRenderWindow()
        #renderWindow = vtk.vtkRenderWindow()
        #print(renderWindow)
        #renderWindow.SetNumberOfLayers(2)
        renderWindow.AddRenderer(self.renderer)
        #renderWindow.AddRenderer(self.addLogo())
        
        
    # Interactor
        #renderWindowInteractor = vtk.vtkRenderWindowInteractor()
        self.interactor.SetRenderWindow(renderWindow)
        self.interactor.SetInteractorStyle(vtk.vtkInteractorStyleTerrain())
    # Scalar Bar
        #self.addScalarBar(self.pointCloud.getLUT())
        
        #renderWindow.SetInteractor(self.interactor)
    # Logo
        #self.addLogo()
        self.renderer.ResetCamera()
    # Begin Interaction
        renderWindow.Render()
        renderWindow.SetWindowName("XYZ Data Viewer:"+ "xyz")
        self.interactor.Start()
        #renderWindowInteractor.Start()
    # Pack to class
        #self.renderer=renderer
        #self.interactor=interactor
        
                
    def start(self):
        self.interactor.Start()
    def load_data(self,filename):
        print("start viewer")
        data = genfromtxt(filename,dtype=float,usecols=[0,1,2])
        #print("generate xyz: ",data[0][2])
        #return
        for k in range(size(data,0)):
            point = data[k]
            self.pointCloud.addPoint(point)
    def addLogo(self):
        imgReader = vtk.vtkPNGReader()
        imgReader.SetFileName("benano.png")
        imgReader.Update()
        #print(imgReader.GetOutput())
        imgActor = vtk.vtkImageActor()
        imgActor.SetInputData(imgReader.GetOutput())
        background_renderer = vtk.vtkRenderer()
        background_renderer.SetLayer(0)
        background_renderer.InteractiveOff()
        background_renderer.AddActor(imgActor)
        return background_renderer
    def setScalarBar(self):
        lut=self.pointCloud.getLUT()
        scalarBar=self.scalarBarActor
        scalarBar.SetOrientationToVertical()
        scalarBar.SetLookupTable(lut)
        scalarBar.SetBarRatio(0.12)
        scalarBar.SetTitleRatio(0.12)
        scalarBar.SetMaximumWidthInPixels(60)
        scalarBar.SetMaximumHeightInPixels(300)
        #print(self.scalarBar.GetProperty().SetDisplayLocationToBackground())
        #self.scalarBar.SetDisplayPosition(750,250)
        scalarBar.SetDisplayPosition(60,400)
        textP = vtk.vtkTextProperty()
        textP.SetFontSize(10)
        scalarBar.SetLabelTextProperty(textP)
        scalarBar.SetTitleTextProperty(textP)
        scalarBar.SetNumberOfLabels(8)
        scalarBar.SetLabelFormat("%-#6.3f")#輸出格式
        #self.scalarBarWidget = vtk.vtkScalarBarWidget()
        #self.scalarBarWidget.SetInteractor(self.interactor)
        #self.scalarBarWidget.SetScalarBarActor(self.scalarBar)
        #self.scalarBarWidget.On()
        self.scalarBarActor=scalarBar
    def setCubeAxesActor(self):
        cubeAxesActor = self.cubeAxesActor
        #設定軸上下限
        bounds = self.pointCloud.getBounds()
        cubeAxesActor.SetBounds(bounds)
        #將RENDER CAMERA指定給軸
        cubeAxesActor.SetCamera(self.renderer.GetActiveCamera())
        #設定標題與標籤文字顏色
        cubeAxesActor.GetTitleTextProperty(0).SetColor(0.5,0.5,0.5)
        cubeAxesActor.GetLabelTextProperty(0).SetColor(0.5,0.5,0.5)
        cubeAxesActor.GetTitleTextProperty(1).SetColor(0.5,0.5,0.5)
        cubeAxesActor.GetLabelTextProperty(1).SetColor(0.5,0.5,0.5)
        cubeAxesActor.GetTitleTextProperty(2).SetColor(0.5,0.5,0.5)
        cubeAxesActor.GetLabelTextProperty(2).SetColor(0.5,0.5,0.5)
        #設定坐標軸線寬
        cubeAxesActor.GetXAxesLinesProperty().SetLineWidth(0.5)
        cubeAxesActor.GetYAxesLinesProperty().SetLineWidth(0.5)
        cubeAxesActor.GetZAxesLinesProperty().SetLineWidth(0.5)
        #開啟網格線
        cubeAxesActor.DrawXGridlinesOn()
        cubeAxesActor.DrawYGridlinesOn()
        cubeAxesActor.DrawZGridlinesOn()
        #內部網格線不畫
        cubeAxesActor.SetDrawXInnerGridlines(0)
        cubeAxesActor.SetDrawYInnerGridlines(0)
        cubeAxesActor.SetDrawZInnerGridlines(0)
        #網格線顏色
        cubeAxesActor.GetXAxesGridlinesProperty().SetColor(0.5,0.5,0.5)
        cubeAxesActor.GetYAxesGridlinesProperty().SetColor(0.5,0.5,0.5)
        cubeAxesActor.GetZAxesGridlinesProperty().SetColor(0.5,0.5,0.5)
        #控制軸的繪製方式(外,最近,最遠,靜態最近,靜態外)
        cubeAxesActor.SetFlyMode(4)
        #設定刻度線的位置(內,外,兩側)
        cubeAxesActor.SetTickLocation(1)
        #網格線樣式(所有,最近,最遠)
        cubeAxesActor.SetGridLineLocation(2)
        cubeAxesActor.XAxisMinorTickVisibilityOff()
        cubeAxesActor.YAxisMinorTickVisibilityOff()
        cubeAxesActor.ZAxisMinorTickVisibilityOn()
        self.cubeAxesActor=cubeAxesActor
    def add_axisWidget(self):
        axes = vtk.vtkAxesActor()
        axisWidget = vtk.vtkOrientationMarkerWidget()
        axisWidget.SetOutlineColor(0.9,0.5,0.1)
        axisWidget.SetOrientationMarker(axes)
        iren = self.interactor.GetRenderWindow().GetInteractor()
        axisWidget.SetInteractor(iren)
        axisWidget.SetViewport(0,0,0.4,0.4)
        axisWidget.EnabledOn()
        axisWidget.InteractiveOn()
    def add_newData(self,pcd):
        
        '''
        print("generate xyz")
        for k in range(size(data,0)):
            point = data[k] #20*(random.rand(3)-0.5)
            pcd.addPoint(point)
        self.renderer.AddActor(pcd.vtkActor)
        '''
        self.pointCloud=pcd
        self.addActor()
    def addActor(self):
        """
        self.pcdCollection.append(self.xyzLoader.pcd)
        print("Current pcd count: ", len(self.pcdCollection))
        #self.actors.append(self.pcdCollection[-1].vtkActor)
        #create each actor from xyz collection
        for i in self.pcdCollection:
            self.renderer.AddActor(i.vtkActor)
            #print(i.vtkActor)
        """
        self.removeAll()
        isMesh = False
        isDelaunay3D=False
        isSurfRecon=1
        if isMesh:
            self.pointCloud.generateMesh()
            #self.renderer.AddActor(self.pointCloud.vtkActor)
            self.mainActor=self.pointCloud.boundaryActor
        elif isDelaunay3D:
            self.mainActor=self.pointCloud.delaunay3D()
        elif isSurfRecon:
            self.mainActor=self.pointCloud.surfaceRecon()
        else:
            self.mainActor=self.pointCloud.vtkActor
        self.renderer.AddActor(self.mainActor)
        self.setCubeAxesActor()
        self.renderer.AddActor(self.cubeAxesActor)
        self.setScalarBar()
        self.renderer.AddActor(self.scalarBarActor)
        
        self.renderer.ResetCamera()
        self.refresh_renderer()
        cam = self.renderer.GetActiveCamera()
        self.oriMatrix = cam.GetExplicitProjectionTransformMatrix()
    def removeAll(self):
        actors = self.renderer.GetActors()
        #print(actors)
        for i in actors:
            self.renderer.RemoveActor(i)
        for i in range(len(self.pcdCollection)):
            #print(i)
            del self.pcdCollection[-1]
        #print(len(self.pcdCollection))
    def reset_Camera(self):
        print(self.oriMatrix)
        center_x,center_y,center_z=self.mainActor.GetCenter()
        cam = self.renderer.GetActiveCamera()
        cam.SetPosition(center_x,center_y,center_z+1)
        cam.SetViewUp(0,1,0)
        self.renderer.ResetCamera()
        self.refresh_renderer()
    def setCameraTop(self):
        center_x,center_y,center_z=self.mainActor.GetCenter()
        cam=self.renderer.GetActiveCamera()
        cam.SetPosition(center_x+1,center_y,center_z)
        cam.SetViewUp(0,0,1)
        #cam.Azimuth(180)
        print(cam.GetPosition())
        #self.renderer.SetActiveCamera(cam)
        self.renderer.ResetCamera()
        self.refresh_renderer()
    def setCameraLeft(self):
        self.renderer.ResetCamera()
        cam=self.renderer.GetActiveCamera()
        #cam.SetPosition(0,0,0)
        #cam.SetViewUp(0,1,0)
        cam.Azimuth(-10)
        #self.renderer.SetActiveCamera(cam)
        self.renderer.ResetCamera()
        self.refresh_renderer()
    def setCameraRight(self):
        self.renderer.ResetCamera()
        cam=self.renderer.GetActiveCamera()
        #cam.SetPosition(0,0,0)
        #cam.SetViewUp(0,1,0)
        cam.Azimuth(10)
        #self.renderer.SetActiveCamera(cam)
        self.renderer.ResetCamera()
        self.refresh_renderer()
    def refresh_renderer(self):
        #self.renderer.ResetCamera()
        renderWindow = self.interactor.GetRenderWindow()
        renderWindow.Render()
    def applyTransform(self,x,y,z):
        center_x,center_y,center_z=self.mainActor.GetCenter()
        w = vtk.vtkTransform()
        #w.Translate(-center_x,-center_y,-center_z)
        #vtk not auto change type from string to double
        w.RotateX(float(x))
        w.RotateY(float(y))
        w.RotateZ(float(z))
        self.mainActor.SetUserTransform(w)
        self.refresh_renderer()
    def setParallelCamera(self,state):
        cam = self.renderer.GetActiveCamera()
        cam.SetParallelProjection(state)
        self.renderer.ResetCamera()
        self.refresh_renderer()
    def setPickerMode(self,state):
        import utilities.pointPicker as pStyle
        print(pStyle)
        print(state)
        if state==2:
            self.interactor.SetInteractorStyle(pStyle.testStyle(self.emitPickedPoint))
        else:
            self.interactor.SetInteractorStyle(vtk.vtkInteractorStyleTerrain())
    def emitPickedPoint(self,int):
        self.pickedID.append(int)
        x,y,z=self.pointCloud.vtkPoints.GetPoint(int)
        print("emit:",int,x,y,z)
        self.pickedPointSignal.emit(int)
        sphereSource = vtk.vtkSphereSource()
        sphereSource.SetCenter(x,y,z)
        sphereSource.SetRadius(1)
        sphereSource.SetThetaResolution(10)
        mapper = vtk.vtkPolyDataMapper()
        mapper.SetInputConnection(sphereSource.GetOutputPort())
        actor = vtk.vtkActor()
        actor.SetMapper(mapper)
        actor.GetProperty().SetColor(1,0,0)
        #actor.GetProperty().SetRepresentationToWireframe()
        print(actor)
        actors = self.renderer.GetActors()
        print(actors)
        self.renderer.AddActor(actor)
        self.refresh_renderer()
        if len(self.pickedID)>3:
            self.drawKochanekSpline(self.pickedID)
    def drawParametricSpline(self,IDList):
        points = vtk.vtkPoints()
        for i in IDList:
            p=self.pointCloud.vtkPoints.GetPoint(i)
            points.InsertNextPoint(p)
        spline = vtk.vtkParametricSpline()
        spline.SetPoints(points)
        functionSource = vtk.vtkParametricFunctionSource()
        functionSource.SetParametricFunction(spline)
        functionSource.Update()
        mapper = vtk.vtkPolyDataMapper()
        mapper.SetInputConnection(functionSource.GetOutputPort())
        actor = vtk.vtkActor()
        actor.SetMapper(mapper)
        self.renderer.AddActor(actor)
        self.refresh_renderer()
    def drawKochanekSpline(self,IDList):
        points = vtk.vtkPoints()
        for i in IDList:
            p=self.pointCloud.vtkPoints.GetPoint(i)
            points.InsertNextPoint(p)
        xSpline = vtk.vtkKochanekSpline()
        ySpline = vtk.vtkKochanekSpline()
        zSpline = vtk.vtkKochanekSpline()
        spline = vtk.vtkParametricSpline()
        spline.SetXSpline(xSpline)
        spline.SetYSpline(ySpline)
        spline.SetZSpline(zSpline)
        spline.SetPoints(points)
        functionSource = vtk.vtkParametricFunctionSource()
        functionSource.SetParametricFunction(spline)
        functionSource.Update()
        mapper = vtk.vtkPolyDataMapper()
        mapper.SetInputConnection(functionSource.GetOutputPort())
        actor = vtk.vtkActor()
        actor.SetMapper(mapper)
        self.renderer.AddActor(actor)
        self.refresh_renderer()
    def SurfaceReconstruction(self):
        pointSource=vtk.vtkProgrammableSource()
        def readPoints():
            output = pointSource.GetPolyDataOutput()
            points = vtk.vtkPoints()
            output.SetPoints(points)
            for i in IDList:
                p=self.pointCloud.vtkPoints.GetPoint(i)
                points.InsertNextPoint(p)
        pointSource.SetExecuteMethod(readPoints)
        surf = vtk.vtkSurfaceReconstructionFilter()
        surf.SetInputConnection(pointSource.GetOutputPort())
        cf = vtk.vtkContourFilter()
        cf.SetInputConnection(surf.GetOutputPort())
        cf.SetValue(0,0)
        mapper = vtk.vtkPolyDataMapper()
        mapper.SetInputConnection(cf.GetOutputPort())
        mapper.ScalarVisibilityOff()
        actor = vtk.vtkActor()
        actor.SetMapper(mapper)
        actor.GetProperty().SetDiffuseColor(1,0.3882,0.2784)
        actor.GetProperty().SetSpecularColor(1,1,1)
        actor.GetProperty().SetSpecular(.4)
        actor.GetProperty().SetSpecularPower(50)
        self.renderer.AddActor(actor)
        self.refresh_renderer()
Esempio n. 7
0
class MainWindow(qtw.QMainWindow):
    def __init__(self, input_file, gaussian, radius, isosurface, window_size,
                 *args, **kwargs):
        """MainWindow constructor"""
        super().__init__(*args, **kwargs)

        # Window setup
        self.resize(window_size[0], window_size[1])
        self.title = "Bone Imaging Laboratory – Viewer"
        self.iconPath = os.path.join(os.getcwd(), "qtviewer")
        self.iconPath = os.path.abspath(os.path.join(self.iconPath,
                                                     "icon.png"))

        self.statusBar().showMessage("Welcome.", 8000)

        # Initialize the window
        self.initUI()

        # Take inputs from command line. Only use these if there is an input file specified
        if (input_file != None):
            if (not os.path.exists(input_file)):
                qtw.QMessageBox.warning(self, "Error", "Invalid input file.")
                return

            _, ext = os.path.splitext(input_file)
            if not (self.validExtension(ext.lower())):
                qtw.QMessageBox.warning(self, "Error", "Invalid file type.")
                return

            self.createPipeline(input_file, "in1")
            self.statusBar().showMessage("Loading file " + input_file, 4000)
            self.changeSigma(gaussian, "in1")
            self.changeRadius(radius, "in1")
            self.changeIsosurface(isosurface, "in1")
            self.updateGUI()

    def initUI(self):
        ########################################
        # Create Widgets
        ########################################

        # Fixed image (in1)
        self.in1_loadPushButton = qtw.QPushButton(
            "Load Image 1",
            self,
            objectName="in1_loadPushButton",
            shortcut=qtg.QKeySequence("Ctrl+f"))
        self.in1_pickableCheckBox = qtw.QCheckBox(
            "Pickable",
            self,
            objectName="in1_pickableCheckBox",
            checkable=True,
            checked=True)
        self.in1_visibilityCheckBox = qtw.QCheckBox(
            "Visibility",
            self,
            objectName="in1_visibilityCheckBox",
            checkable=True,
            checked=True)
        self.in1_filenameLabel = qtw.QLabel("", self)

        self.in1_sigmaSpinBox = qtw.QDoubleSpinBox(
            self,
            objectName="in1_sigmaSpinBox",
            value=1.2,
            decimals=1,
            maximum=20.0,
            minimum=0.1,
            singleStep=0.1,
            keyboardTracking=False)
        self.in1_radiusSpinBox = qtw.QSpinBox(self,
                                              objectName="in1_radiusSpinBox",
                                              value=2,
                                              maximum=20,
                                              minimum=1,
                                              singleStep=1,
                                              keyboardTracking=False)
        self.in1_isosurfaceSpinBox = qtw.QSpinBox(
            self,
            objectName="in1_isosurfaceSpinBox",
            value=0,
            maximum=32768,
            minimum=0,
            singleStep=1,
            keyboardTracking=False)

        # Moving image (in2)
        self.in2_loadPushButton = qtw.QPushButton(
            "Load Image 2",
            self,
            objectName="in2_loadPushButton",
            shortcut=qtg.QKeySequence("Ctrl+m"))
        self.in2_pickableCheckBox = qtw.QCheckBox(
            "Pickable",
            self,
            objectName="in2_pickableCheckBox",
            checkable=True,
            checked=False)
        self.in2_visibilityCheckBox = qtw.QCheckBox(
            "Visibility",
            self,
            objectName="in2_visibilityCheckBox",
            checkable=True,
            checked=True)
        self.in2_filenameLabel = qtw.QLabel("", self)

        self.in2_sigmaSpinBox = qtw.QDoubleSpinBox(
            self,
            objectName="in2_sigmaSpinBox",
            value=1.2,
            decimals=1,
            maximum=20.0,
            minimum=0.1,
            singleStep=0.1,
            keyboardTracking=False)
        self.in2_radiusSpinBox = qtw.QSpinBox(self,
                                              objectName="in2_radiusSpinBox",
                                              value=2,
                                              maximum=20,
                                              minimum=1,
                                              singleStep=1,
                                              keyboardTracking=False)
        self.in2_isosurfaceSpinBox = qtw.QSpinBox(
            self,
            objectName="in2_isosurfaceSpinBox",
            value=0,
            maximum=32768,
            minimum=0,
            singleStep=1,
            keyboardTracking=False)

        # Camera controls
        self.rollCameraPushButton = qtw.QPushButton(
            "Roll", self, objectName="rollCameraPushButton")
        self.elevationCameraPushButton = qtw.QPushButton(
            "Elevation", self, objectName="elevationCameraPushButton")
        self.azimuthCameraPushButton = qtw.QPushButton(
            "Azimuth", self, objectName="azimuthCameraPushButton")
        self.incrementCameraSpinBox = qtw.QSpinBox(
            self,
            objectName="incrementCameraSpinBox",
            value=90,
            maximum=90,
            minimum=-90,
            singleStep=10)

        # Landmark transform
        self.landmarkTransformPushButton = qtw.QPushButton(
            "Landmark",
            self,
            objectName="Landmark Transform",
            enabled=False,
            shortcut=qtg.QKeySequence("Ctrl+l"))

        # ICP transform
        self.icpTransformPushButton = qtw.QPushButton(
            "ICP",
            self,
            objectName="ICP Transform",
            enabled=True,
            shortcut=qtg.QKeySequence("Ctrl+i"))

        # Reset transform
        self.resetTransformPushButton = qtw.QPushButton(
            "Reset",
            self,
            objectName="Reset Transform",
            enabled=True,
            shortcut=qtg.QKeySequence("Ctrl+r"))

        self.viewTransformCheckBox = qtw.QCheckBox("Toggle View",
                                                   self,
                                                   objectName="View Transform",
                                                   checkable=True,
                                                   checked=True)

        self.in1_points_count_label = qtw.QLabel("in1_points_label",
                                                 self,
                                                 text="Points")
        self.in2_points_count_label = qtw.QLabel("in2_points_label",
                                                 self,
                                                 text="Points")
        self.in1_points_count = qtw.QLCDNumber(
            self, intValue=0, segmentStyle=qtw.QLCDNumber.Flat)
        self.in2_points_count = qtw.QLCDNumber(
            self, intValue=0, segmentStyle=qtw.QLCDNumber.Flat)

        # Vectors
        self.rot1 = qtw.QLabel("rot1", self, text="  0.000")
        self.rot2 = qtw.QLabel("rot2", self, text="  0.000")
        self.rot3 = qtw.QLabel("rot3", self, text="  0.000")
        self.trans1 = qtw.QLabel("trans1", self, text="  0.000")
        self.trans2 = qtw.QLabel("trans2", self, text="  0.000")
        self.trans3 = qtw.QLabel("trans3", self, text="  0.000")
        self.rotLabel = qtw.QLabel("R:", self)
        self.transLabel = qtw.QLabel("T:", self)

        # 4x4 matrix
        self.mat11 = qtw.QLabel("mat11", self, text="  0.000")
        self.mat12 = qtw.QLabel("mat12", self, text="  0.000")
        self.mat13 = qtw.QLabel("mat13", self, text="  0.000")
        self.mat14 = qtw.QLabel("mat14", self, text="  0.000")
        self.mat21 = qtw.QLabel("mat21", self, text="  0.000")
        self.mat22 = qtw.QLabel("mat22", self, text="  0.000")
        self.mat23 = qtw.QLabel("mat23", self, text="  0.000")
        self.mat24 = qtw.QLabel("mat24", self, text="  0.000")
        self.mat31 = qtw.QLabel("mat31", self, text="  0.000")
        self.mat32 = qtw.QLabel("mat32", self, text="  0.000")
        self.mat33 = qtw.QLabel("mat33", self, text="  0.000")
        self.mat34 = qtw.QLabel("mat34", self, text="  0.000")
        self.mat41 = qtw.QLabel("mat41", self, text="  0.000")
        self.mat42 = qtw.QLabel("mat42", self, text="  0.000")
        self.mat43 = qtw.QLabel("mat43", self, text="  0.000")
        self.mat44 = qtw.QLabel("mat44", self, text="  0.000")

        self.updateMatrixGUI(vtk.vtkMatrix4x4())

        # Log window
        self.log_window = qtw.QTextEdit(self,
                                        objectName="log_window",
                                        acceptRichText=False,
                                        lineWrapMode=qtw.QTextEdit.NoWrap,
                                        lineWrapColumnOrWidth=80,
                                        placeholderText="Ready...")

        # Create the menu options --------------------------------------------------------------------
        menubar = qtw.QMenuBar()
        self.setMenuBar(menubar)
        menubar.setNativeMenuBar(False)

        file_menu = menubar.addMenu("File")
        open_in1_action = file_menu.addAction("Open Image1")
        open_in2_action = file_menu.addAction("Open Image2")
        file_menu.addSeparator()
        about_action = file_menu.addAction("About")
        quit_action = file_menu.addAction("Quit")

        save_menu = menubar.addMenu("Save")
        save_log_action = save_menu.addAction("Log Window")
        save_menu.addSeparator()
        save_points_action = save_menu.addAction("Points")
        save_transform_matrix_action = save_menu.addAction("Transform Matrix")
        save_transform_vector_action = save_menu.addAction("Transform Vector")
        save_menu.addSeparator()
        save_extrusion_action = save_menu.addAction("Extruded Shape")

        ########################################
        # Layouts
        ########################################

        # Fixed Image (in1) --------------------------------------------------------------------------
        self.in1_mainGroupBox = qtw.QGroupBox("Fixed Image (in1)")
        self.in1_mainGroupBox.setLayout(qtw.QHBoxLayout())

        self.in1_pointWidget = qtw.QFrame()
        self.in1_pointWidget.setLayout(qtw.QHBoxLayout())
        self.in1_pointWidget.layout().addWidget(self.in1_points_count_label)
        self.in1_pointWidget.layout().addWidget(self.in1_points_count)

        self.in1_loadGridLayout = qtw.QGridLayout()
        self.in1_loadGridLayout.addWidget(self.in1_loadPushButton, 0, 0, 1, 1)
        self.in1_loadGridLayout.addWidget(self.in1_pickableCheckBox, 0, 1, 1,
                                          1)
        self.in1_loadGridLayout.addWidget(self.in1_pointWidget, 1, 0, 1, 1)
        self.in1_loadGridLayout.addWidget(self.in1_visibilityCheckBox, 1, 1, 1,
                                          1)
        self.in1_loadGridLayout.addWidget(self.in1_filenameLabel, 2, 0, 1, 4)

        self.in1_isosurfaceGroupBox = qtw.QGroupBox("Isosurface controls")
        self.in1_isosurfaceGroupBox.setLayout(qtw.QVBoxLayout())
        self.in1_isosurfaceFormLayout = qtw.QFormLayout()
        self.in1_isosurfaceFormLayout.addRow("Sigma", self.in1_sigmaSpinBox)
        self.in1_isosurfaceFormLayout.addRow("Radius", self.in1_radiusSpinBox)
        self.in1_isosurfaceFormLayout.addRow("Isosurface",
                                             self.in1_isosurfaceSpinBox)
        self.in1_isosurfaceGroupBox.layout().addLayout(
            self.in1_isosurfaceFormLayout)

        self.in1_mainGroupBox.layout().addLayout(self.in1_loadGridLayout)
        self.in1_mainGroupBox.layout().addWidget(self.in1_isosurfaceGroupBox)

        # Moving Image (in2) -------------------------------------------------------------------------
        self.in2_mainGroupBox = qtw.QGroupBox("Moving Image (in2)")
        self.in2_mainGroupBox.setLayout(qtw.QHBoxLayout())

        self.in2_pointWidget = qtw.QFrame()
        self.in2_pointWidget.setLayout(qtw.QHBoxLayout())
        self.in2_pointWidget.layout().addWidget(self.in2_points_count_label)
        self.in2_pointWidget.layout().addWidget(self.in2_points_count)

        self.in2_loadGridLayout = qtw.QGridLayout()
        self.in2_loadGridLayout.addWidget(self.in2_loadPushButton, 0, 0, 1, 1)
        self.in2_loadGridLayout.addWidget(self.in2_pickableCheckBox, 0, 1, 1,
                                          1)
        self.in2_loadGridLayout.addWidget(self.in2_pointWidget, 1, 0, 1, 1)
        self.in2_loadGridLayout.addWidget(self.in2_visibilityCheckBox, 1, 1, 1,
                                          1)
        self.in2_loadGridLayout.addWidget(self.in2_filenameLabel, 2, 0, 1, 4)

        self.in2_isosurfaceGroupBox = qtw.QGroupBox("Isosurface controls")
        self.in2_isosurfaceGroupBox.setLayout(qtw.QVBoxLayout())
        self.in2_isosurfaceFormLayout = qtw.QFormLayout()
        self.in2_isosurfaceFormLayout.addRow("Sigma", self.in2_sigmaSpinBox)
        self.in2_isosurfaceFormLayout.addRow("Radius", self.in2_radiusSpinBox)
        self.in2_isosurfaceFormLayout.addRow("Isosurface",
                                             self.in2_isosurfaceSpinBox)
        self.in2_isosurfaceGroupBox.layout().addLayout(
            self.in2_isosurfaceFormLayout)

        self.in2_mainGroupBox.layout().addLayout(self.in2_loadGridLayout)
        self.in2_mainGroupBox.layout().addWidget(self.in2_isosurfaceGroupBox)

        # Camera controls panel ----------------------------------------------------------------------
        self.cameraControlsGroupBox = qtw.QGroupBox("Camera")
        self.cameraControlsGroupBox.setLayout(qtw.QHBoxLayout())
        self.cameraControlsGroupBox.layout().addWidget(
            self.rollCameraPushButton)
        self.cameraControlsGroupBox.layout().addWidget(
            self.elevationCameraPushButton)
        self.cameraControlsGroupBox.layout().addWidget(
            self.azimuthCameraPushButton)
        self.incrementCameraFormLayout = qtw.QFormLayout()
        self.incrementCameraFormLayout.addRow("Increment",
                                              self.incrementCameraSpinBox)
        self.cameraControlsGroupBox.layout().addLayout(
            self.incrementCameraFormLayout)

        # Transform panel ----------------------------------------------------------------------------
        self.transformPanelGroupBox = qtw.QGroupBox("Transform")
        self.transformPanelGroupBox.setLayout(qtw.QHBoxLayout())

        # Transform actions
        self.transformLayout = qtw.QGridLayout()
        self.transformLayout.addWidget(self.landmarkTransformPushButton, 0, 0,
                                       1, 1)
        self.transformLayout.addWidget(self.icpTransformPushButton, 0, 1, 1, 1)
        self.transformLayout.addWidget(self.resetTransformPushButton, 1, 0, 1,
                                       1)
        self.transformLayout.addWidget(self.viewTransformCheckBox, 1, 1, 1, 1)

        # Vectors display
        self.vectorsLayout = qtw.QGridLayout()
        self.vectorsLayout.addWidget(self.rotLabel, 0, 0, 1, 1)
        self.vectorsLayout.addWidget(self.rot1, 0, 1, 1, 1)
        self.vectorsLayout.addWidget(self.rot2, 0, 2, 1, 1)
        self.vectorsLayout.addWidget(self.rot3, 0, 3, 1, 1)
        self.vectorsLayout.addWidget(self.transLabel, 1, 0, 1, 1)
        self.vectorsLayout.addWidget(self.trans1, 1, 1, 1, 1)
        self.vectorsLayout.addWidget(self.trans2, 1, 2, 1, 1)
        self.vectorsLayout.addWidget(self.trans3, 1, 3, 1, 1)

        self.vectorsGroupBox = qtw.QGroupBox()
        self.vectorsGroupBox.setLayout(self.vectorsLayout)
        self.transformLayout.addWidget(self.vectorsGroupBox, 2, 0, 2, 2)

        # Matrix display
        self.matrixLayout = qtw.QGridLayout()
        self.matrixLayout.addWidget(self.mat11, 0, 0, 1, 1)
        self.matrixLayout.addWidget(self.mat12, 0, 1, 1, 1)
        self.matrixLayout.addWidget(self.mat13, 0, 2, 1, 1)
        self.matrixLayout.addWidget(self.mat14, 0, 3, 1, 1)
        self.matrixLayout.addWidget(self.mat21, 1, 0, 1, 1)
        self.matrixLayout.addWidget(self.mat22, 1, 1, 1, 1)
        self.matrixLayout.addWidget(self.mat23, 1, 2, 1, 1)
        self.matrixLayout.addWidget(self.mat24, 1, 3, 1, 1)
        self.matrixLayout.addWidget(self.mat31, 2, 0, 1, 1)
        self.matrixLayout.addWidget(self.mat32, 2, 1, 1, 1)
        self.matrixLayout.addWidget(self.mat33, 2, 2, 1, 1)
        self.matrixLayout.addWidget(self.mat34, 2, 3, 1, 1)
        self.matrixLayout.addWidget(self.mat41, 3, 0, 1, 1)
        self.matrixLayout.addWidget(self.mat42, 3, 1, 1, 1)
        self.matrixLayout.addWidget(self.mat43, 3, 2, 1, 1)
        self.matrixLayout.addWidget(self.mat44, 3, 3, 1, 1)

        self.matrixGroupBox = qtw.QGroupBox("Matrix4x4 (in2)")
        self.matrixGroupBox.setLayout(self.matrixLayout)

        self.transformPanelGroupBox.layout().addLayout(self.transformLayout)
        self.transformPanelGroupBox.layout().addWidget(self.matrixGroupBox)

        # Set up the log window ----------------------------------------------------------------------
        font = qtg.QFont("Courier")
        font.setStyleHint(qtg.QFont.TypeWriter)
        font.setWeight(25)
        self.log_window.setTextColor(qtg.QColor("blue"))
        self.log_window.setCurrentFont(font)

        # Add logo -----------------------------------------------------------------------------------
        logo = qtg.QPixmap(':/bonelab/icon.png')
        self.logoLabel = qtw.QLabel("", self)
        self.logoLabel.setPixmap(logo)

        # Assemble the side control panel and put it in a QPanel widget ------------------------------
        self.panel = qtw.QVBoxLayout()
        self.panel.addWidget(self.in1_mainGroupBox)
        self.panel.addWidget(self.in2_mainGroupBox)
        self.panel.addWidget(self.cameraControlsGroupBox)
        self.panel.addWidget(self.transformPanelGroupBox)
        self.panel.addWidget(self.log_window)
        self.panel.addWidget(self.logoLabel, alignment=qtc.Qt.AlignRight)
        self.panelWidget = qtw.QFrame()
        self.panelWidget.setLayout(self.panel)

        # Create the VTK rendering window ------------------------------------------------------------
        self.vtkWidget = QVTKRenderWindowInteractor()
        self.vtkWidget.AddObserver("ExitEvent", lambda o, e, a=self: a.quit())
        #self.vtkWidget.AddObserver("KeyReleaseEvent", self.keyEventDetected)
        self.vtkWidget.AddObserver("LeftButtonReleaseEvent",
                                   self.mouseEventDetected)

        # Create main layout and add VTK window and control panel
        self.mainLayout = qtw.QHBoxLayout()
        self.mainLayout.addWidget(self.vtkWidget, 4)
        self.mainLayout.addWidget(self.panelWidget, 2)

        self.frame = qtw.QFrame()
        self.frame.setLayout(self.mainLayout)
        self.setCentralWidget(self.frame)

        self.setWindowTitle(self.title)
        self.centreWindow()
        #print(self.iconPath)
        self.setWindowIcon(qtg.QIcon(self.iconPath))

        self.cp = ColourPalette()

        # Set size policies --------------------------------------------------------------------------
        self.in1_sigmaSpinBox.setMinimumSize(70, 20)
        self.in1_radiusSpinBox.setMinimumSize(70, 20)
        self.in1_isosurfaceSpinBox.setMinimumSize(70, 20)

        self.in2_sigmaSpinBox.setMinimumSize(70, 20)
        self.in2_radiusSpinBox.setMinimumSize(70, 20)
        self.in2_isosurfaceSpinBox.setMinimumSize(70, 20)

        self.in1_mainGroupBox.setMaximumSize(1000, 1000)
        self.in2_mainGroupBox.setMaximumSize(1000, 1000)
        self.transformPanelGroupBox.setMaximumSize(1000, 1000)

        self.rotLabel.setMaximumSize(15, 20)
        self.transLabel.setMaximumSize(15, 20)

        self.vtkWidget.setSizePolicy(qtw.QSizePolicy.MinimumExpanding,
                                     qtw.QSizePolicy.MinimumExpanding)

        self.in1_points_count.setMaximumSize(50, 30)
        self.in2_points_count.setMaximumSize(50, 30)

        self.in1_mainGroupBox.setSizePolicy(qtw.QSizePolicy.Maximum,
                                            qtw.QSizePolicy.Maximum)

        self.in2_mainGroupBox.setSizePolicy(qtw.QSizePolicy.Maximum,
                                            qtw.QSizePolicy.Maximum)

        self.transformPanelGroupBox.setSizePolicy(qtw.QSizePolicy.Maximum,
                                                  qtw.QSizePolicy.Maximum)

        self.log_window.setSizePolicy(qtw.QSizePolicy.MinimumExpanding,
                                      qtw.QSizePolicy.MinimumExpanding)

        # Connect signals and slots ------------------------------------------------------------------
        # The use of the lambda function is so that the same slot can be used for either pipeline
        self.in1_loadPushButton.clicked.connect(lambda: self.openFile("in1"))
        self.in1_pickableCheckBox.stateChanged.connect(
            lambda s: self.togglePickable(s, "in1"))
        self.in1_visibilityCheckBox.stateChanged.connect(
            lambda s: self.toggleVisibility(s, "in1"))
        self.in1_sigmaSpinBox.valueChanged.connect(
            lambda s: self.changeSigma(s, "in1"))
        self.in1_radiusSpinBox.valueChanged.connect(
            lambda s: self.changeRadius(s, "in1"))
        self.in1_isosurfaceSpinBox.valueChanged.connect(
            lambda s: self.changeIsosurface(s, "in1"))

        self.in2_loadPushButton.clicked.connect(lambda: self.openFile("in2"))
        self.in2_pickableCheckBox.stateChanged.connect(
            lambda s: self.togglePickable(s, "in2"))
        self.in2_visibilityCheckBox.stateChanged.connect(
            lambda s: self.toggleVisibility(s, "in2"))
        self.in2_sigmaSpinBox.valueChanged.connect(
            lambda s: self.changeSigma(s, "in2"))
        self.in2_radiusSpinBox.valueChanged.connect(
            lambda s: self.changeRadius(s, "in2"))
        self.in2_isosurfaceSpinBox.valueChanged.connect(
            lambda s: self.changeIsosurface(s, "in2"))

        self.initRenderWindow()
        self.in2_mainGroupBox.setEnabled(
            False)  # Force user to load in1 before in2

        # Camera
        self.rollCameraPushButton.clicked.connect(
            lambda: self.updateCamera("roll"))
        self.elevationCameraPushButton.clicked.connect(
            lambda: self.updateCamera("elevation"))
        self.azimuthCameraPushButton.clicked.connect(
            lambda: self.updateCamera("azimuth"))

        # Transform
        self.landmarkTransformPushButton.clicked.connect(
            self.applyLandmarkTransform)
        self.icpTransformPushButton.clicked.connect(self.applyICPTransform)
        self.resetTransformPushButton.clicked.connect(self.resetTransform)
        self.viewTransformCheckBox.stateChanged.connect(
            self.toggleTransformApplied)

        # Menu actions
        open_in1_action.triggered.connect(lambda: self.openFile("in1"))
        open_in2_action.triggered.connect(lambda: self.openFile("in2"))
        about_action.triggered.connect(self.about)
        quit_action.triggered.connect(self.quit)

        save_log_action.triggered.connect(self.saveLogFile)
        save_points_action.triggered.connect(self.savePointsFile)
        save_transform_matrix_action.triggered.connect(
            lambda: self.saveTransformFile("matrix"))
        save_transform_vector_action.triggered.connect(
            lambda: self.saveTransformFile("vector"))
        save_extrusion_action.triggered.connect(self.extrudeFromPoints)

        # Variables for managing the VTK pipelines
        self.in1_pipe = None
        self.in2_pipe = None

        # End main UI code
        self.show()

    def centreWindow(self):
        qr = self.frameGeometry()
        cp = qtw.QDesktopWidget().availableGeometry().center()
        qr.moveCenter(cp)
        self.move(qr.topLeft())

    def initRenderWindow(self):
        # Create renderer
        self.renderer = vtk.vtkRenderer()
        self.renderer.SetBackground(self.cp.getColour("background2"))

        # Create interactor
        self.renWin = self.vtkWidget.GetRenderWindow()
        self.renWin.AddRenderer(self.renderer)
        self.iren = self.renWin.GetInteractor()

        self.pickerstyle = MyInteractorStyle()
        self.pickerstyle.AddObserver("UpdateEvent", self.keyEventDetected)
        #self.pickerstyle.SetCurrentStyleToTrackballCamera()
        self.iren.SetInteractorStyle(self.pickerstyle)
        #print(self.iren.GetInteractorStyle().GetClassName())

        # Initialize
        self.iren.Initialize()
        self.iren.Start()

    def refreshRenderWindow(self):
        self.renWin.Render()
        self.renderer.ResetCamera()
        self.iren.Render()

    def setPipelineAttributesFromGUI(self, pipeline):
        if (pipeline == "in1"):
            if self.in1_pickableCheckBox.isChecked():
                self.in1_pipe.setActorPickable(1)
            else:
                self.in1_pipe.setActorPickable(0)
            if self.in1_visibilityCheckBox.isChecked():
                self.in1_pipe.setActorVisibility(1)
            else:
                self.in1_pipe.setActorVisibility(0)
            self.in1_pipe.setGaussStandardDeviation(
                self.in1_sigmaSpinBox.value())
            self.in1_pipe.setGaussRadius(self.in1_radiusSpinBox.value())
            self.in1_pipe.setIsosurface(self.in1_isosurfaceSpinBox.value())
        elif (pipeline == "in2"):
            if self.in2_pickableCheckBox.isChecked():
                self.in2_pipe.setActorPickable(1)
            else:
                self.in2_pipe.setActorPickable(0)
            if self.in2_visibilityCheckBox.isChecked():
                self.in2_pipe.setActorVisibility(1)
            else:
                self.in2_pipe.setActorVisibility(0)
            self.in2_pipe.setGaussStandardDeviation(
                self.in2_sigmaSpinBox.value())
            self.in2_pipe.setGaussRadius(self.in2_radiusSpinBox.value())
            self.in2_pipe.setIsosurface(self.in2_isosurfaceSpinBox.value())
        else:
            return

    def createPipeline(self, _filename, pipeline):
        # Remove any existing pipelines before creating a new one
        if (pipeline == "in1"):
            if (self.in1_pipe != None):
                self.renderer.RemoveActor(self.in1_pipe.getActor())
                self.pickerstyle.removePoints("in1_pipeline")
                del self.in1_pipe
                self.in1_pipe = None
            self.in1_pipe = Pipeline()
            self.in1_pipe.setActorColor(self.cp.getColour("bone1"))
            self.in1_filenameLabel.setText(os.path.basename(_filename))
            self.in1_filenameLabel.setToolTip(_filename)
            ptr = self.in1_pipe
        elif (pipeline == "in2"):
            if (self.in2_pipe != None):
                self.renderer.RemoveActor(self.in2_pipe.getActor())
                self.pickerstyle.removePoints("in2_pipeline")
                del self.in2_pipe
                self.in2_pipe = None
            self.in2_pipe = Pipeline()
            self.in2_pipe.setActorColor(self.cp.getColour("bone2"))
            self.in2_filenameLabel.setText(os.path.basename(_filename))
            self.in2_filenameLabel.setToolTip(_filename)
            ptr = self.in2_pipe
        else:
            return

        self.setPipelineAttributesFromGUI(pipeline)
        ptr.constructPipeline(_filename)
        ptr.addActor(self.renderer)
        self.pickerstyle.setMainActor(ptr.getActor(), (pipeline + "_pipeline"))
        self.log_window.append(ptr.getProcessingLog())
        self.log_window.append(ptr.getImageInfoLog())
        self.refreshRenderWindow()
        self.updateGUI()
        return

    def togglePickable(self, _state, pipeline):
        if (pipeline == "in1"):
            if (self.in1_pipe != None):
                ptr = self.in1_pipe
            else:
                return
        elif (pipeline == "in2"):
            if (self.in2_pipe != None):
                ptr = self.in2_pipe
            else:
                return

        # Only two states are possible
        if (qtc.Qt.Checked == _state):
            ptr.setActorPickable(1)
            self.statusBar().showMessage("Toggling actor pickability ON", 4000)
            self.refreshRenderWindow()
        else:
            ptr.setActorPickable(0)
            self.statusBar().showMessage("Toggling actor pickability OFF",
                                         4000)
            self.refreshRenderWindow()
        return

    def toggleVisibility(self, _state, pipeline):
        if (pipeline == "in1"):
            if (self.in1_pipe != None):
                name = "in1_pipeline"
                ptr = self.in1_pipe
            else:
                return
        elif (pipeline == "in2"):
            if (self.in2_pipe != None):
                name = "in2_pipeline"
                ptr = self.in2_pipe
            else:
                return

        # Only two states are possible
        if (qtc.Qt.Checked == _state):
            ptr.setActorVisibility(1)
            self.pickerstyle.setVisibilityOfPoints(name, 1)
            self.statusBar().showMessage("Toggling actor visibility ON", 4000)
            self.refreshRenderWindow()
        else:
            ptr.setActorVisibility(0)
            self.pickerstyle.setVisibilityOfPoints(name, 0)
            self.statusBar().showMessage("Toggling actor visibility OFF", 4000)
            self.refreshRenderWindow()
        return

    def changeSigma(self, _value, pipeline):
        if (pipeline == "in1"):
            if (self.in1_pipe != None):
                self.in1_pipe.setGaussStandardDeviation(_value)
                self.statusBar().showMessage(
                    f"Changing standard deviation to {_value}", 4000)
                self.refreshRenderWindow()
        elif (pipeline == "in2"):
            if (self.in2_pipe != None):
                self.in2_pipe.setGaussStandardDeviation(_value)
                self.statusBar().showMessage(
                    f"Changing standard deviation to {_value}", 4000)
                self.refreshRenderWindow()
        else:
            return
        return

    def changeRadius(self, _value, pipeline):
        if (pipeline == "in1"):
            if (self.in1_pipe != None):
                self.in1_pipe.setGaussRadius(_value)
                self.statusBar().showMessage(f"Changing radius to {_value}",
                                             4000)
                self.refreshRenderWindow()
        elif (pipeline == "in2"):
            if (self.in2_pipe != None):
                self.in2_pipe.setGaussRadius(_value)
                self.statusBar().showMessage(f"Changing radius to {_value}",
                                             4000)
                self.refreshRenderWindow()
        else:
            return
        return

    def changeIsosurface(self, _value, pipeline):
        if (pipeline == "in1"):
            if (self.in1_pipe != None):
                self.in1_pipe.setIsosurface(_value)
                self.statusBar().showMessage(
                    f"Changing isosurface to {_value}", 4000)
                self.refreshRenderWindow()
        elif (pipeline == "in2"):
            if (self.in2_pipe != None):
                self.in2_pipe.setIsosurface(_value)
                self.statusBar().showMessage(
                    f"Changing isosurface to {_value}", 4000)
                self.refreshRenderWindow()
        else:
            return
        return

    def updateCamera(self, type):
        camera = self.renderer.GetActiveCamera()
        inc = self.incrementCameraSpinBox.value()

        if (type == "roll"):
            camera.Roll(inc)
        if (type == "elevation"):
            camera.Elevation(inc)
        if (type == "azimuth"):
            camera.Azimuth(inc)

        camera.OrthogonalizeViewUp()
        self.refreshRenderWindow()
        return

    def applyLandmarkTransform(self):
        if (self.in2_pipe != None):
            in1_pts = self.pickerstyle.getNumberOfPoints("in1_pipeline")
            in2_pts = self.pickerstyle.getNumberOfPoints("in2_pipeline")
            if (in1_pts == in2_pts and in1_pts >= 3):
                lm = vtk.vtkLandmarkTransform()
                lm.SetTargetLandmarks(
                    self.pickerstyle.getPoints("in1_pipeline"))
                lm.SetSourceLandmarks(
                    self.pickerstyle.getPoints("in2_pipeline"))
                lm.SetModeToRigidBody()
                lm.Update()
                mat = lm.GetMatrix()
                self.in2_pipe.setRigidBodyTransformConcatenateMatrix(mat)
                self.pickerstyle.removePoints("in1_pipeline")
                self.pickerstyle.removePoints("in2_pipeline")
                self.refreshRenderWindow()
                self.updateGUI()
                self.statusBar().showMessage(
                    f"Landmark transform complete based on {in1_pts}", 4000)
            else:
                self.statusBar().showMessage(
                    "ERROR: Landmark transform could not be executed", 4000)

    def applyICPTransform(self):
        if (self.in2_pipe != None):
            icp = vtk.vtkIterativeClosestPointTransform()
            icp.SetTarget(self.in1_pipe.getPolyData())
            icp.SetSource(self.in2_pipe.getPolyData())
            icp.SetMaximumNumberOfIterations(10)
            icp.StartByMatchingCentroidsOn()
            icp.GetInverse()
            icp.Update()
            # Concatenate the transform
            mat = icp.GetMatrix()
            self.in2_pipe.setRigidBodyTransformConcatenateMatrix(mat)
            # Clean up
            self.pickerstyle.removePoints("in1_pipeline")
            self.pickerstyle.removePoints("in2_pipeline")
            self.refreshRenderWindow()
            self.updateGUI()
            self.statusBar().clearMessage()
            self.statusBar().showMessage("ICP transform complete", 4000)

    def resetTransform(self):
        if (self.in2_pipe != None):
            reply = qtw.QMessageBox.question(
                self, "Message",
                "Are you sure you want to reset the transform?",
                qtw.QMessageBox.Yes | qtw.QMessageBox.No, qtw.QMessageBox.Yes)
            if reply == qtw.QMessageBox.Yes:
                self.in2_pipe.setRigidBodyTransformToIdentity()
                self.pickerstyle.removePoints("in1_pipeline")
                self.pickerstyle.removePoints("in2_pipeline")
                self.refreshRenderWindow()
                self.updateGUI()
                self.statusBar().showMessage("Reset transform complete", 4000)

    def toggleTransformApplied(self, _state):
        if (self.in2_pipe != None):
            if (qtc.Qt.Checked == _state):
                self.in2_pipe.useTransform(True)
                self.statusBar().showMessage(
                    "Toggling transform visibility ON", 4000)
            else:
                self.in2_pipe.useTransform(False)
                self.statusBar().showMessage(
                    "Toggling transform visibility OFF", 4000)
            self.refreshRenderWindow()

    def updateMatrixGUI(self, _mat):
        precision = 3
        formatter = "{{:6.{}f}}".format(precision)
        self.mat11.setText(formatter.format(float(_mat.GetElement(0, 0))))
        self.mat12.setText(formatter.format(float(_mat.GetElement(0, 1))))
        self.mat13.setText(formatter.format(float(_mat.GetElement(0, 2))))
        self.mat14.setText(formatter.format(float(_mat.GetElement(0, 3))))
        self.mat21.setText(formatter.format(float(_mat.GetElement(1, 0))))
        self.mat22.setText(formatter.format(float(_mat.GetElement(1, 1))))
        self.mat23.setText(formatter.format(float(_mat.GetElement(1, 2))))
        self.mat24.setText(formatter.format(float(_mat.GetElement(1, 3))))
        self.mat31.setText(formatter.format(float(_mat.GetElement(2, 0))))
        self.mat32.setText(formatter.format(float(_mat.GetElement(2, 1))))
        self.mat33.setText(formatter.format(float(_mat.GetElement(2, 2))))
        self.mat34.setText(formatter.format(float(_mat.GetElement(2, 3))))
        self.mat41.setText(formatter.format(float(_mat.GetElement(3, 0))))
        self.mat42.setText(formatter.format(float(_mat.GetElement(3, 1))))
        self.mat43.setText(formatter.format(float(_mat.GetElement(3, 2))))
        self.mat44.setText(formatter.format(float(_mat.GetElement(3, 3))))
        return

    def updateVectorsGUI(self):
        if (self.in2_pipe == None):
            return
        converter = ScancoMatrixConverter()
        mat = self.in2_pipe.getMatrix()
        converter.setDimImage1(self.in1_pipe.getDimensions())
        converter.setDimImage2(self.in2_pipe.getDimensions())
        converter.setPosImage1(self.in1_pipe.getPosition())
        converter.setPosImage2(self.in2_pipe.getPosition())
        converter.setElSizeMMImage1(self.in1_pipe.getElementSize())
        converter.setElSizeMMImage2(self.in2_pipe.getElementSize())
        converter.setTransform(mat)
        converter.calculateVectors()

        rot = converter.getRotationVector()
        trans = converter.getTranslationVector()

        precision = 3
        formatter = "{{:6.{}f}}".format(precision)
        self.rot1.setText(formatter.format(float(rot[0])))
        self.rot2.setText(formatter.format(float(rot[1])))
        self.rot3.setText(formatter.format(float(rot[2])))
        self.trans1.setText(formatter.format(float(trans[0])))
        self.trans2.setText(formatter.format(float(trans[1])))
        self.trans3.setText(formatter.format(float(trans[2])))
        return

    def updateGUI(self):
        in1_pts = self.pickerstyle.getNumberOfPoints("in1_pipeline")
        in2_pts = self.pickerstyle.getNumberOfPoints("in2_pipeline")
        #print("There are " + str(in1_pts) + " in1 points.")
        #print("There are " + str(in2_pts) + " in2 points.")
        self.in1_points_count.display(in1_pts)
        self.in2_points_count.display(in2_pts)
        if (in1_pts == in2_pts and in1_pts >= 3):
            self.landmarkTransformPushButton.setEnabled(True)
        else:
            self.landmarkTransformPushButton.setEnabled(False)
        if (self.in1_pipe != None):
            self.in2_mainGroupBox.setEnabled(
                True)  # activate GUI for in2_pipeline
            self.in1_sigmaSpinBox.setValue(
                self.in1_pipe.getGaussStandardDeviation())
            self.in1_radiusSpinBox.setValue(self.in1_pipe.getGaussRadius())
            self.in1_isosurfaceSpinBox.setValue(self.in1_pipe.getIsosurface())
        if (self.in2_pipe != None):
            self.updateMatrixGUI(self.in2_pipe.getMatrix())
            self.updateVectorsGUI()
            self.in2_sigmaSpinBox.setValue(
                self.in2_pipe.getGaussStandardDeviation())
            self.in2_radiusSpinBox.setValue(self.in2_pipe.getGaussRadius())
            self.in2_isosurfaceSpinBox.setValue(self.in2_pipe.getIsosurface())
        return

    def keyEventDetected(self, obj, event):
        self.updateGUI()
        key = self.vtkWidget.GetKeySym()
        if (key in 'p') or (key in 'd'):  # pick or delete points
            self.log_window.append(self.pickerstyle.getPointActionString())
        # print("keypress – clicked "+key)
        return

    def mouseEventDetected(self, obj, event):
        print("mouserelease – click!")
        return

    def validExtension(self, extension):
        if (extension == ".aim" or \
            extension == ".nii" or \
            extension == ".dcm" or \
            extension == ".stl"):
            return True
        else:
            return False

    def openFile(self, pipeline_name):
        if (pipeline_name == "in2" and self.in1_pipe == None):
            qtw.QMessageBox.warning(
                self, "Warning", "Image 2 cannot be loaded before image 1.")
            return
        self.statusBar().showMessage("Load image types (.aim, .nii, .dcm)",
                                     4000)
        filename, _ = qtw.QFileDialog.getOpenFileName(
            self, "Select a 3D image file to open…", qtc.QDir.homePath(),
            "Aim Files (*.aim) ;;Nifti Files (*.nii) ;;DICOM Files (*.dcm) ;;STL Files (*.stl) ;;All Files (*)",
            "All Files (*)", qtw.QFileDialog.DontUseNativeDialog
            | qtw.QFileDialog.DontResolveSymlinks)

        if filename:
            _, ext = os.path.splitext(filename)
            if not (self.validExtension(ext.lower())):
                qtw.QMessageBox.warning(self, "Error", "Invalid file type.")
                return

            self.createPipeline(filename, pipeline_name)
            self.statusBar().showMessage("Loading file " + filename, 4000)
        return

    def saveLogFile(self):
        filename, _ = qtw.QFileDialog.getSaveFileName(
            self, "Select the file to save to…", qtc.QDir.homePath(),
            "Text Files (*.txt) ")
        if filename:
            try:
                with open(filename, 'w') as fh:
                    fh.write(self.log_window.toPlainText())
            except Exception as e:
                qtw.QMessageBox.critical(self, f"Could not save file: {e}")

    def savePointsFile(self):
        filename, _ = qtw.QFileDialog.getSaveFileName(
            self, "Select the file to save to…", qtc.QDir.homePath(),
            "Text Files (*.txt) ;;Python Files (*.py) ;;All Files (*)")
        if filename:
            try:
                with open(filename, 'w') as fh:
                    fh.write(self.pickerstyle.getAllPointsAsString())
            except Exception as e:
                qtw.QMessageBox.critical(self, f"Could not save file: {e}")

    def saveTransformFile(self, type):

        converter = ScancoMatrixConverter()

        if (self.in2_pipe == None):
            mat = vtk.vtkMatrix4x4()
            converter.setDimImage1(self.in1_pipe.getDimensions())
            converter.setPosImage1(self.in1_pipe.getPosition())
            converter.setElSizeMMImage1(self.in1_pipe.getElementSize())
            converter.setTransform(mat)
            converter.calculateVectors()
        else:
            mat = self.in2_pipe.getMatrix()
            converter.setDimImage1(self.in1_pipe.getDimensions())
            converter.setDimImage2(self.in2_pipe.getDimensions())
            converter.setPosImage1(self.in1_pipe.getPosition())
            converter.setPosImage2(self.in2_pipe.getPosition())
            converter.setElSizeMMImage1(self.in1_pipe.getElementSize())
            converter.setElSizeMMImage2(self.in2_pipe.getElementSize())
            converter.setTransform(mat)
            converter.calculateVectors()

        filename, _ = qtw.QFileDialog.getSaveFileName(
            self, "Select the file to save to…", qtc.QDir.homePath(),
            "Text Files (*.txt) ;;Python Files (*.py) ;;All Files (*)")
        if filename:
            try:
                with open(filename, 'w') as fh:
                    s = ""
                    if (type == "matrix"):
                        s += converter.getTransformAsString()
                    if (type == "vector"):
                        s += converter.getVectorsAsString()
                    fh.write(s)
            except Exception as e:
                qtw.QMessageBox.critical(self, f"Could not save file: {e}")

    def quit(self):
        reply = qtw.QMessageBox.question(
            self, "Message", "Are you sure you want to quit?",
            qtw.QMessageBox.Yes | qtw.QMessageBox.No, qtw.QMessageBox.Yes)
        if reply == qtw.QMessageBox.Yes:
            exit(0)

    def about(self):
        about = qtw.QMessageBox(self)
        about.setWindowIcon(qtg.QIcon('/bonelab/gui/src/icon.png'))
        about.setIcon(qtw.QMessageBox.Information)
        about.setText("blQtViewer 1.0")
        about.setInformativeText(
            "Copyright (C) 2020\nBone Imaging Laboratory\nAll rights reserved.\[email protected]"
        )
        about.setStandardButtons(qtw.QMessageBox.Ok | qtw.QMessageBox.Cancel)
        about.exec_()

    def extrudeFromPoints(self):

        pts = self.pickerstyle.getPoints("in1_pipeline")

        if (pts.GetNumberOfPoints() < 3):
            qtw.QMessageBox.warning(
                self, "Warning",
                "At least 3 points must be defined on image 1 to create extrusion."
            )
            return

        if (not self.in1_pipe.getIsValidForExtrusion()):
            qtw.QMessageBox.warning(
                self, "Warning",
                "Extrusion may not work properly when input file is not of type AIM."
            )

        # Spline
        spline = vtk.vtkParametricSpline()
        spline.SetPoints(pts)
        spline.ClosedOn()

        parametricFunction = vtk.vtkParametricFunctionSource()
        parametricFunction.SetParametricFunction(spline)
        parametricFunction.Update()

        # Extrude
        extrusionFactor = 100.0  # mm above and below surface
        # A large number will cause the extrusion to fill the extent of the input image

        positiveExtruder = vtk.vtkLinearExtrusionFilter()
        positiveExtruder.SetInputConnection(parametricFunction.GetOutputPort())
        positiveExtruder.SetExtrusionTypeToNormalExtrusion()
        positiveExtruder.SetVector(0, 0, 1)
        positiveExtruder.CappingOn()
        positiveExtruder.SetScaleFactor(extrusionFactor)

        posTriFilter = vtk.vtkTriangleFilter()
        posTriFilter.SetInputConnection(positiveExtruder.GetOutputPort())

        negativeExtruder = vtk.vtkLinearExtrusionFilter()
        negativeExtruder.SetInputConnection(parametricFunction.GetOutputPort())
        negativeExtruder.SetExtrusionTypeToNormalExtrusion()
        negativeExtruder.SetVector(0, 0, -1)
        negativeExtruder.CappingOn()
        negativeExtruder.SetScaleFactor(extrusionFactor)

        negTriFilter = vtk.vtkTriangleFilter()
        negTriFilter.SetInputConnection(negativeExtruder.GetOutputPort())

        # Combine data
        combiner = vtk.vtkAppendPolyData()
        combiner.AddInputConnection(posTriFilter.GetOutputPort())
        combiner.AddInputConnection(negTriFilter.GetOutputPort())

        cleaner = vtk.vtkCleanPolyData()
        cleaner.SetInputConnection(combiner.GetOutputPort())
        cleaner.Update()

        el_size_mm = self.in1_pipe.getElementSize()
        dim = self.in1_pipe.getDimensions()
        extent = self.in1_pipe.getExtent()
        origin = self.in1_pipe.getOrigin()
        foregroundValue = 127
        backgroundValue = 0

        # Stencil
        whiteImage = vtk.vtkImageData()
        whiteImage.SetSpacing(el_size_mm)
        whiteImage.SetDimensions(dim)
        whiteImage.SetExtent(extent)
        whiteImage.SetOrigin(origin)
        whiteImage.AllocateScalars(vtk.VTK_CHAR, 1)
        whiteImage.GetPointData().GetScalars().Fill(foregroundValue)

        # Use our extruded polydata to stencil the solid image
        poly2sten = vtk.vtkPolyDataToImageStencil()
        poly2sten.SetTolerance(0)
        #poly2sten.SetInputConnection(clipper.GetOutputPort())
        poly2sten.SetInputConnection(cleaner.GetOutputPort())
        poly2sten.SetOutputOrigin(origin)
        poly2sten.SetOutputSpacing(el_size_mm)
        poly2sten.SetOutputWholeExtent(whiteImage.GetExtent())

        stencil = vtk.vtkImageStencil()
        stencil.SetInputData(whiteImage)
        stencil.SetStencilConnection(poly2sten.GetOutputPort())
        #stencil.ReverseStencilOff()
        stencil.SetBackgroundValue(backgroundValue)
        stencil.Update()

        # Write image
        filename, _ = qtw.QFileDialog.getSaveFileName(
            self, "Select the file to save to…", qtc.QDir.homePath(),
            "AIM File (*.aim)")

        if (filename):
            writer = vtkbone.vtkboneAIMWriter()
            writer.SetInputConnection(stencil.GetOutputPort())
            writer.SetFileName(filename)
            writer.SetProcessingLog(
                '!-------------------------------------------------------------------------------\n'
                + 'Written by blQtViewer.')
            writer.Update()
            self.statusBar().showMessage("File " + filename + " written.",
                                         4000)
Esempio n. 8
0
class QLiverViewer(QtWidgets.QFrame):
    colors = vtk.vtkNamedColors()
    widgetMoved = Signal(object, object, object)  #float)

    widgetRegistered = Signal(object, object, object)

    def __init__(self, parent):
        super(QLiverViewer, self).__init__(parent)

        # Make the actual QtWidget a child so that it can be re_parented
        self.interactor = QVTKRenderWindowInteractor(self)
        self.layout = QtWidgets.QHBoxLayout()
        self.layout.addWidget(self.interactor)
        self.layout.setContentsMargins(0, 0, 0, 0)
        self.setLayout(self.layout)

        self.imgWidth: float = 50.0  # mm
        self.worldScale: float = 0.05
        self.worldScale: float = 1.0
        self.brighter25: bool = True
        self.opacity: float = 0.35

        self.showReferencePlane = False

        self.nReferences = 2

        self.refplanes = []  # Displayed reference planes
        self.planeWidgets = []  # Widgets for manipulation
        self.contours = []  # Contours clipped to reference plane
        self.fullcontours = []  # Full contour data
        self.planeSources = []
        self.userAttempts = []
        self.contourResults = self.nReferences * [None]
        self.vessels = None
        self.liver = None

        self.initScene()
        self.initLiver()
        self.initVessels()  # Must be before planeWidgets

        self.lastPositions = dict({
            'origin': self.nReferences * [None],
            'normal': self.nReferences * [None],
            'axis1': self.nReferences * [None],
            'reset': self.nReferences * [None]
        })

        self.lastIndex = None
        self.initPlaneWidgets(0)
        if not IOUSFAN:
            self.initPlaneWidgets(1)

        self.resetCamera()

        self.style = vtk.vtkInteractorStyleTrackballCamera()
        self.style.SetDefaultRenderer(self.renderer)
        self.interactor.SetInteractorStyle(self.style)

        self.interactor.AddObserver('KeyPressEvent', self.KeyPress, 1.0)

    def KeyPress(self, obj, ev):
        key = obj.GetKeySym()
        index = self.lastIndex
        if self.lastIndex is None:
            return
        userAttempt = self.userAttempts[index]
        planeWidget = self.planeWidgets[index]
        refActor = self.refplanes[index]
        resetPositions = self.lastPositions['reset'][index]
        planeSource = self.planeSources[index]
        contourActor = self.contours[index]
        if key == 'c':
            print('Reset')
            userAttempt.SetUserTransform(None)
            userAttempt.Modified()

            # Reset planeWidget
            planeWidget.SetEnabled(0)
            planeWidget.SetOrigin(resetPositions[0])
            planeWidget.SetPoint1(resetPositions[1])
            planeWidget.SetPoint2(resetPositions[2])
            planeWidget.Modified()
            planeWidget.SetEnabled(1)

            lastNormal = planeWidget.GetNormal()
            lastAxis1 = vtk.vtkVector3d()
            vtk.vtkMath.Subtract(planeWidget.GetPoint1(),
                                 planeWidget.GetOrigin(), lastAxis1)
            lastOrigin = planeWidget.GetCenter()
            self.lastPositions['origin'][index] = lastOrigin
            self.lastPositions['normal'][index] = lastNormal
            self.lastPositions['axis1'][index] = lastAxis1

            if self.contourResults[index] is not None:
                self.renderer.RemoveActor(self.contourResults[index])
                self.contourResults[index] = None
            self.render_window.Render()
        elif key == 's':
            print('Registration')
            # ============ run ICP ==============
            icp = vtk.vtkIterativeClosestPointTransform()

            # Transform contours
            tfpdf0 = vtk.vtkTransformPolyDataFilter()
            tfpdf0.SetInputData(self.fullcontours[index])
            tfpdf0.SetTransform(userAttempt.GetUserTransform())
            tfpdf0.Update()
            wrongContours = tfpdf0.GetOutput()
            icp.SetSource(wrongContours)
            icp.SetTarget(self.vesselPolyData)
            icp.GetLandmarkTransform().SetModeToRigidBody()
            icp.DebugOn()
            icp.SetMaximumNumberOfIterations(10)
            icp.StartByMatchingCentroidsOff()
            icp.SetCheckMeanDistance(1)  # Experiment
            #icp.SetMeanDistanceModeToAbsoluteValue() # Original
            icp.SetMeanDistanceModeToRMS()  # Default
            #print(icp.GetLandmarkTransform())
            #sys.stdout.write("Before: ")
            #print(icp.GetMeanDistance())
            from vtkUtils import CloudMeanDist
            dist0 = CloudMeanDist(wrongContours, self.vesselPolyData)
            print(dist0)

            icp.Modified()
            icp.Update()
            #print(icp) # Shows mean distance
            #print(icp.GetLandmarkTransform())

            icpTransformFilter = vtk.vtkTransformPolyDataFilter()
            icpTransformFilter.SetInputData(wrongContours)
            icpTransformFilter.SetTransform(icp)
            icpTransformFilter.Update()

            correctedContours = icpTransformFilter.GetOutput()
            dist1 = CloudMeanDist(correctedContours, self.vesselPolyData)
            print(dist1)
            #sys.stdout.write("After: ")
            #print(icp.GetMeanDistance())

            tubes = vtk.vtkTubeFilter()
            tubes.SetInputData(correctedContours)
            tubes.CappingOn()
            tubes.SidesShareVerticesOff()
            tubes.SetNumberOfSides(12)
            tubes.SetRadius(1.0)

            edgeMapper = vtk.vtkPolyDataMapper()
            edgeMapper.ScalarVisibilityOff()
            edgeMapper.SetInputConnection(tubes.GetOutputPort())

            if self.contourResults[index] is not None:
                self.renderer.RemoveActor(self.contourResults[index])
            testActor = vtk.vtkActor()
            testActor.SetMapper(edgeMapper)
            prop = testActor.GetProperty()
            prop.SetColor(yellow)
            prop.SetLineWidth(3)
            self.renderer.AddActor(testActor)
            self.contourResults[index] = testActor

            # Concatenate and get transform (w,x,y,z)
            userAttempt.GetUserTransform().Concatenate(icp.GetMatrix())

            userAttempt.GetUserTransform().Update()

            (deg, x, y,
             z) = userAttempt.GetUserTransform().GetOrientationWXYZ()

            resetPositions = self.lastPositions['reset'][index]

            # Not correct!!!
            positionError = np.array(
                userAttempt.GetUserTransform().GetPosition())

            positionError = np.array(
                userAttempt.GetUserTransform().TransformPoint(
                    resetPositions[3])) - np.array(resetPositions[3])

            # Reset afterwards
            userAttempt.SetUserTransform(None)
            userAttempt.Modified()

            # Reset planeWidget.
            planeWidget.SetEnabled(0)
            planeWidget.SetOrigin(resetPositions[0])
            planeWidget.SetPoint1(resetPositions[1])
            planeWidget.SetPoint2(resetPositions[2])
            planeWidget.Modified()
            planeWidget.SetEnabled(1)

            lastNormal = planeWidget.GetNormal()
            lastAxis1 = vtk.vtkVector3d()
            vtk.vtkMath.Subtract(planeWidget.GetPoint1(),
                                 planeWidget.GetOrigin(), lastAxis1)
            lastOrigin = planeWidget.GetCenter()

            self.lastPositions['origin'][index] = lastOrigin
            self.lastPositions['normal'][index] = lastNormal
            self.lastPositions['axis1'][index] = lastAxis1
            self.render_window.Render()

            originalNormal = resetPositions[4]
            originalNormal = originalNormal / np.sqrt(np.sum(originalNormal**
                                                             2))
            self.widgetRegistered.emit((deg, np.r_[x, y, z]), originalNormal,
                                       positionError)
        elif key == 'm':
            # Move a reference
            print('Moving reference')

            showReferencePlane = False
            # Remove old reference plane
            if refActor is not None:
                self.removeActor(refActor)
                refActor = None
                showReferencePlane = True

            # Move reference plane
            planeSource.SetOrigin(planeWidget.GetOrigin())
            planeSource.SetPoint1(planeWidget.GetPoint1())
            planeSource.SetPoint2(planeWidget.GetPoint2())

            # mapper
            mapper = vtk.vtkPolyDataMapper()
            mapper.SetInputConnection(planeSource.GetOutputPort())

            # New reference actor
            if showReferencePlane:
                refActor = vtk.vtkActor()
                refActor.SetMapper(mapper)
                prop = refActor.GetProperty()
                prop.SetColor(blue)
                prop.SetOpacity(self.opacity)
                self.refplanes[index] = refActor
                self.renderer.AddActor(refActor)

            self.removeActor(contourActor)

            tubes, oldContours = self.computeContoursAndTubes(planeSource)

            edgeMapper = vtk.vtkPolyDataMapper()
            edgeMapper.ScalarVisibilityOff()
            edgeMapper.SetInputConnection(tubes.GetOutputPort())

            planes = self.computeClippingPlanes(planeSource)

            edgeMapper.SetClippingPlanes(planes)

            actor = vtk.vtkActor()
            actor.SetMapper(edgeMapper)
            prop = actor.GetProperty()
            prop.SetColor(green)
            prop.SetLineWidth(3)
            self.contours[index] = actor
            self.renderer.AddActor(actor)

            # User attempt
            self.removeActor(userAttempt)

            attempt = vtk.vtkActor()
            transform = vtk.vtkTransform()
            trans = vtk.vtkMatrix4x4()
            trans.Identity()
            transform.SetMatrix(trans)
            transform.PostMultiply()
            attempt.SetUserTransform(transform)

            self.fullcontours[index] = oldContours
            mapper = vtk.vtkPolyDataMapper()
            mapper.ScalarVisibilityOff()
            mapper.SetInputData(oldContours)

            attempt.SetMapper(mapper)
            attempt.GetProperty().SetColor(red)
            self.userAttempts[index] = attempt
            self.renderer.AddActor(attempt)

            # Update reset positions of plane widget
            self.lastPositions['reset'][index] = [
                planeWidget.GetOrigin(),
                planeWidget.GetPoint1(),
                planeWidget.GetPoint2(),
                planeWidget.GetCenter(),
                np.array(planeWidget.GetNormal())
            ]  # Redundant
            sys.stdout.write('Plane origin: ')
            print(planeWidget.GetOrigin())
            sys.stdout.write('Plane point1: ')
            print(planeWidget.GetPoint1())
            sys.stdout.write('Plane point2: ')
            print(planeWidget.GetPoint2())
            print(
                "To make the app start at this location, insert these points near the tag\n 'FIXED PLANE POSITION AND ORIENTATION' in LiverView.py"
            )
            self.render_window.Render()

    def removeActor(self, actor):
        self.interactor.Disable()
        self.renderer.RemoveActor(actor)
        self.interactor.Enable()
        # self.render_window.Render()
    def scale(self, polyData):
        if self.worldScale == 1.0:
            return polyData
        else:
            transform = vtk.vtkTransform()
            transform.Scale(self.worldScale, self.worldScale, self.worldScale)
            transformFilter = vtk.vtkTransformPolyDataFilter()
            transformFilter.SetInputData(polyData)
            transformFilter.SetTransform(transform)
            transformFilter.Update()
            return transformFilter.GetOutput()

    def getReferencePosition(self, index):
        """
    Dumped from slicer program
    """
        refplanes = []
        centers = []

        if IOUSFAN:
            refplanes.append(
                np.array([[1, 0, 0, -191.357], [0, 0, -1, 17.9393],
                          [0, 1, 0, 192.499], [0, 0, 0, 1]]))
            centers.append(np.r_[117.54695990115218, 103.95861040356766,
                                 127.81653778974703])
        else:
            refplanes.append(
                np.array([[
                    0.99722, -0.00658366270883304, 0.0741938205770847, -321.64
                ], [0.07419, 0.17584, -0.98162, -227.46],
                          [-0.0065837, 0.98440, 0.17584, -563.67],
                          [0, 0, 0, 1.0]]))
            centers.append(np.r_[-75.696, -149.42, -231.76])

            refplanes.append(
                np.array([[
                    0.9824507428729312, -0.028608856565971154,
                    0.1843151408713164, -221.425151769367
                ],
                          [
                              0.18431514087131629, 0.3004711475787132,
                              -0.935812491003576, -325.6553959586223
                          ],
                          [
                              -0.028608856565971223, 0.9533617481306448,
                              0.3004711475787133, -547.1574253306663
                          ], [0, 0, 0, 1]]))

            centers.append(np.r_[-31.317285034663634, -174.62449255285645,
                                 -193.39018826551072])
        return centers[index], refplanes[index]

    def resetCamera(self):
        qDebug('resetCamera')
        self.renderer.ResetCamera()

    def initScene(self):
        qDebug('initScene()')
        self.renderer = vtk.vtkOpenGLRenderer()
        self.render_window = self.interactor.GetRenderWindow()
        self.render_window.AddRenderer(self.renderer)

        self.cellPicker = vtk.vtkCellPicker()
        self.cellPicker.SetTolerance(30.0)
        self.interactor.SetPicker(self.cellPicker)

        #* Top background color
        bg_t = np.ones(3) * 245.0 / 255.0

        #* Bottom background color
        bg_b = np.ones(3) * 170.0 / 255.0

        self.renderer.SetBackground(bg_t)
        self.renderer.SetBackground2(bg_b)
        self.renderer.GradientBackgroundOn()

    def computeClippingPlanes(self, source):
        # Clipping planes
        planes = vtk.vtkPlaneCollection()

        axis1 = np.array(source.GetPoint1()) - np.array(source.GetOrigin())
        axis2 = np.array(source.GetPoint2()) - np.array(source.GetOrigin())

        # Okay
        plane1 = vtk.vtkPlane()
        plane1.SetOrigin(source.GetOrigin())
        plane1.SetNormal(axis2)

        plane2 = vtk.vtkPlane()
        plane2.SetOrigin(source.GetOrigin())
        plane2.SetNormal(axis1)

        tmp = axis2 + np.array(source.GetPoint1())

        plane3 = vtk.vtkPlane()
        plane3.SetOrigin(tmp)
        plane3.SetNormal(-axis2)

        plane4 = vtk.vtkPlane()
        plane4.SetOrigin(tmp)
        plane4.SetNormal(-axis1)

        planes.AddItem(plane1)
        planes.AddItem(plane2)
        planes.AddItem(plane3)
        planes.AddItem(plane4)
        return planes

    def computeContoursAndTubes(self, source):
        # Plane for intersection
        plane = vtk.vtkPlane()
        plane.SetOrigin(source.GetOrigin())
        plane.SetNormal(source.GetNormal())

        cutEdges = vtk.vtkCutter()
        cutEdges.SetInputConnection(self.vesselNormals.GetOutputPort())
        cutEdges.SetCutFunction(plane)
        cutEdges.GenerateCutScalarsOff()  # Was on
        cutEdges.SetValue(0, 0.5)
        cutEdges.Update()

        cutStrips = vtk.vtkStripper()
        cutStrips.SetInputConnection(cutEdges.GetOutputPort())
        cutStrips.Update()
        oldContours = cutStrips.GetOutput()

        tubes = vtk.vtkTubeFilter()
        tubes.SetInputConnection(cutStrips.GetOutputPort())
        tubes.CappingOn()
        tubes.SidesShareVerticesOff()
        tubes.SetNumberOfSides(12)
        tubes.SetRadius(self.worldScale * 1.0)
        return tubes, oldContours

    def initPlaneWidgets(self, index):
        qDebug('initPlaneWidgets()')
        center, htrans = self.getReferencePosition(index)

        hw = self.imgWidth
        shw = self.worldScale * hw
        source = vtk.vtkPlaneSource()
        source.SetOrigin(0, 0, 0)
        source.SetPoint1(shw, 0, 0)
        source.SetPoint2(0, shw, 0)

        transform = vtk.vtkTransform()
        mat = vtk.vtkMatrix4x4()
        for i in range(4):
            for j in range(4):
                mat.SetElement(i, j, htrans[i, j])

        # Should transformation also be scaled??
        for i in range(3):
            mat.SetElement(i, 3, self.worldScale * mat.GetElement(i, 3))

        transform.SetMatrix(mat)
        transform.Update()

        origin = source.GetOrigin()
        origin = transform.TransformPoint(origin)
        source.SetOrigin(origin)

        p1 = source.GetPoint1()
        p1 = transform.TransformPoint(p1)
        source.SetPoint1(p1)

        p2 = source.GetPoint2()
        p2 = transform.TransformPoint(p2)
        source.SetPoint2(p2)

        source.Update()

        source.SetCenter(self.worldScale * center)
        source.Update()

        # Test position good for slice 17 (HACK)
        if (IOUSFAN):
            source.SetOrigin(-36.00039424299387, 58.447421532729656,
                             116.93018531955384)
            source.SetPoint1(13.731795848152041, 54.203001711976306,
                             119.87877296847647)
            source.SetPoint2(-40.18599847580337, 8.635225461941415,
                             115.82300881527104)
            source.Update()

        self.planeSources.append(source)

        #####################################
        # Blue reference plane
        #####################################

        # mapper
        if self.showReferencePlane:
            mapper0 = vtk.vtkPolyDataMapper()
            mapper0.SetInputConnection(source.GetOutputPort())

            # actor
            refActor = vtk.vtkActor()
            refActor.SetMapper(mapper0)
            prop = refActor.GetProperty()
            prop.SetColor(blue)
            prop.SetOpacity(self.opacity)
            self.refplanes.append(refActor)
            self.renderer.AddActor(refActor)
        else:
            self.refplanes.append(None)

        #####################################
        # Compute contours, tubes and clipping
        #####################################

        tubes, oldContours = self.computeContoursAndTubes(source)

        edgeMapper = vtk.vtkPolyDataMapper()
        edgeMapper.ScalarVisibilityOff()
        edgeMapper.SetInputConnection(tubes.GetOutputPort())

        planes = self.computeClippingPlanes(source)

        edgeMapper.SetClippingPlanes(planes)

        actor = vtk.vtkActor()
        actor.SetMapper(edgeMapper)
        prop = actor.GetProperty()
        prop.SetColor(green)
        prop.SetLineWidth(3)

        self.contours.append(actor)
        self.renderer.AddActor(actor)

        ###################################################
        # Plane widget for interaction
        ###################################################
        planeWidget = vtk.vtkPlaneWidget()
        planeWidget.SetInteractor(self.interactor)
        planeWidget.SetOrigin(source.GetOrigin())
        planeWidget.SetPoint1(source.GetPoint1())
        planeWidget.SetPoint2(source.GetPoint2())
        prop = planeWidget.GetHandleProperty()
        prop.SetColor(QLiverViewer.colors.GetColor3d("Red"))
        prop = planeWidget.GetPlaneProperty()
        prop.SetColor(QLiverViewer.colors.GetColor3d("Red"))

        # Original position and orientation of reference plane
        self.lastPositions['reset'][index] = [
            planeWidget.GetOrigin(),
            planeWidget.GetPoint1(),
            planeWidget.GetPoint2(),
            planeWidget.GetCenter(),
            np.array(planeWidget.GetNormal())
        ]  # Redundant
        print('normal')
        print(planeWidget.GetNormal())
        planeWidget.SetEnabled(1)
        planeWidget.AddObserver(vtk.vtkCommand.EndInteractionEvent,
                                self.onWidgetMoved, 1.0)

        attempt = vtk.vtkActor()

        self.fullcontours.append(oldContours)
        mapper = vtk.vtkPolyDataMapper()
        mapper.SetInputData(oldContours)
        attempt.SetMapper(mapper)
        attempt.GetProperty().SetColor(red)
        self.userAttempts.append(attempt)
        self.renderer.AddActor(attempt)

        lastNormal = planeWidget.GetNormal()
        lastAxis1 = vtk.vtkVector3d()
        vtk.vtkMath.Subtract(planeWidget.GetPoint1(), planeWidget.GetOrigin(),
                             lastAxis1)
        lastOrigin = planeWidget.GetCenter()
        self.lastPositions['origin'][index] = lastOrigin
        self.lastPositions['normal'][index] = lastNormal
        self.lastPositions['axis1'][index] = lastAxis1
        self.planeWidgets.append(planeWidget)
        self.render_window.Render()

    def onWidgetMoved(self, obj, ev):
        if (obj in self.planeWidgets):
            index = self.planeWidgets.index(obj)
            self.lastIndex = index

            normal0 = self.lastPositions['normal'][index]
            first0 = self.lastPositions['axis1'][index]
            origin0 = self.lastPositions['origin'][index]

            normal1 = np.array(obj.GetNormal())
            first1 = np.array(obj.GetPoint1()) - np.array(obj.GetOrigin())
            origin1 = obj.GetCenter()

            trans = AxesToTransform(normal0, first0, origin0, normal1, first1,
                                    origin1)
            if self.userAttempts[index].GetUserTransform() is not None:
                self.userAttempts[index].GetUserTransform().Concatenate(trans)
            else:
                transform = vtk.vtkTransform()
                transform.SetMatrix(trans)
                transform.PostMultiply()
                self.userAttempts[index].SetUserTransform(transform)

            (deg, x, y, z) = self.userAttempts[index].GetUserTransform(
            ).GetOrientationWXYZ()

            self.lastPositions['origin'][index] = obj.GetCenter()

            lastAxis1 = [first1[0], first1[1], first1[2]]
            lastNormal = (normal1[0], normal1[1], normal1[2])

            self.lastPositions['axis1'][index] = lastAxis1
            self.lastPositions['normal'][index] = lastNormal

            self.userAttempts[index].Modified()
            self.render_window.Render()

            print("center")

            print('wtf')

            originalNormal = self.lastPositions['reset'][index][4]
            originalNormal = originalNormal / np.sqrt(np.sum(originalNormal**
                                                             2))

            self.widgetMoved.emit(
                (deg, np.r_[x, y, z]), originalNormal,
                np.array(obj.GetCenter()) -
                np.array(self.lastPositions['reset'][index][3]))
        #print(ev)

    def initVessels(self):
        qDebug('initVessels()')
        if os.name == 'nt':
            filename = os.path.join(filedir,
                                    '../../data/Abdomen/Connected.vtp')
            if IOUSFAN:
                #filename = 'e:/analogic/TrialVTK/data/VesselMeshData.vtk'
                filename = 'e:/analogic/TrialVTK/data/LiverVesselMeshData.vtk'
        else:
            filename = '/home/jmh/bkmedical/data/CT/Connected.vtp'

        # read data
        if IOUSFAN:
            reader = vtk.vtkGenericDataObjectReader()
        else:
            reader = vtk.vtkXMLPolyDataReader()
        reader.SetFileName(filename)
        reader.Update()

        connectFilter = vtk.vtkPolyDataConnectivityFilter()
        connectFilter.SetInputConnection(reader.GetOutputPort())
        connectFilter.SetExtractionModeToLargestRegion()
        connectFilter.Update()

        self.vesselPolyData = self.scale(connectFilter.GetOutput())

        # compute normals
        self.vesselNormals = vtk.vtkPolyDataNormals()
        self.vesselNormals.SetInputData(self.vesselPolyData)

        # mapper
        mapper = vtk.vtkPolyDataMapper()
        mapper.SetInputConnection(self.vesselNormals.GetOutputPort())

        # actor for vessels
        self.vessels = vtk.vtkActor()
        self.vessels.SetMapper(mapper)
        prop = self.vessels.GetProperty()

        if self.brighter25:
            prop.SetColor(vtk.vtkColor3d(hexCol("#517487")))  # 25% lighter
        else:
            prop.SetColor(vtk.vtkColor3d(hexCol("#415d6c")))
        # assign actor to the renderer
        self.renderer.AddActor(self.vessels)

    def initLiver(self):
        qDebug('initLiver()')
        if os.name == 'nt':
            filename = os.path.join(
                filedir, '../../data/Abdomen/Liver_3D-interpolation.vtp')
            if IOUSFAN:
                filename = 'e:/analogic/TrialVTK/data/segmented_liver_ITK_snap.vtk'
        else:
            filename = '/home/jmh/bkmedical/data/CT/Liver_3D-interpolation.vtp'
        if IOUSFAN:
            reader = vtk.vtkGenericDataObjectReader()
        else:
            reader = vtk.vtkXMLPolyDataReader()

        reader.SetFileName(filename)
        reader.Update()

        connectFilter = vtk.vtkPolyDataConnectivityFilter()
        connectFilter.SetInputConnection(reader.GetOutputPort())
        connectFilter.SetExtractionModeToLargestRegion()
        connectFilter.Update()

        surfNormals = vtk.vtkPolyDataNormals()
        surfNormals.SetInputData(self.scale(connectFilter.GetOutput()))

        #Create a mapper and actor
        mapper = vtk.vtkPolyDataMapper()
        mapper.SetInputConnection(surfNormals.GetOutputPort())  # was reader
        self.liver = vtk.vtkActor()
        self.liver.SetMapper(mapper)
        prop = self.liver.GetProperty()

        if self.brighter25:
            prop.SetColor(vtk.vtkColor3d(hexCol("#873927")))
        else:
            prop.SetColor(vtk.vtkColor3d(hexCol("#6c2e1f")))
        prop.SetOpacity(self.opacity)
        self.renderer.AddActor(self.liver)

    def start(self):
        self.interactor.Initialize()