示例#1
0
    def initialize(self):
        # Optimizer is exposed so that calling scripts can reference with custom observers
        self.optimizer = itk.LBFGSBOptimizerv4.New(
            MaximumNumberOfFunctionEvaluations=self.MAX_FUNCTION_EVALUATIONS,
            MaximumNumberOfCorrections=self.MAX_CORRECTIONS)

        # Monitor optimization via observer
        if (self.verbose):

            def print_iteration():
                print(
                    f'Iteration: {self.optimizer.GetCurrentIteration()}'
                    f' Metric: {self.optimizer.GetCurrentMetricValue()}'
                    f' Infinity Norm: {self.optimizer.GetInfinityNormOfProjectedGradient()}'
                )

            self.optimizer.AddObserver(itk.IterationEvent(), print_iteration)
    def execute(self, inputs, update=0, last=0):
        """
		Initializes and executes the registration process. Does
		the result translation to input image and returns translated image.
		"""
        if not lib.ProcessingFilter.ProcessingFilter.execute(self, inputs):
            return None

        backgroundValue = self.parameters["BackgroundPixelValue"]
        usePrevious = self.parameters["UsePreviousAsFixed"]
        transX = self.parameters["X"]
        transY = self.parameters["Y"]
        transZ = self.parameters["Z"]

        self.minSteps = []
        self.maxSteps = []
        self.maxIters = []
        levelsX = []
        levelsY = []
        levelsZ = []
        for i in range(4):
            minStepLength = eval("self.parameters['MinStepLength%d']" % i)
            maxStepLength = eval("self.parameters['MaxStepLength%d']" % i)
            maxIterations = eval("self.parameters['MaxIterations%d']" % i)
            x = eval("self.parameters['LevelX%d']" % i)
            y = eval("self.parameters['LevelY%d']" % i)
            z = eval("self.parameters['LevelZ%d']" % i)
            if minStepLength and maxStepLength and maxIterations and x and y and z:
                self.minSteps.append(minStepLength)
                self.maxSteps.append(maxStepLength)
                self.maxIters.append(maxIterations)
                levelsX.append(x)
                levelsY.append(y)
                levelsZ.append(z)
        usedLevels = len(self.minSteps)

        if usePrevious:
            if scripting.processingTimepoint > 0:
                fixedtp = scripting.processingTimepoint - 1
            else:
                fixedtp = 0
        else:
            fixedtp = self.parameters["FixedTimepoint"] - 1

        movingImage = self.getInput(1)
        movingImage.SetUpdateExtent(movingImage.GetWholeExtent())
        movingImage.Update()

        # Create copy of data, otherwise movingImage will point to same image
        # as fixedImage
        mi = vtk.vtkImageData()
        mi.DeepCopy(movingImage)
        movingImage = mi
        movingImage.Update()
        #movingImage = self.convertVTKtoITK(movingImage, cast = types.FloatType)
        # Following is dirty but currently has to be done this way since
        # convertVTKtoITK doesn't work with two dataset unless
        # itkConfig.ProgressCallback is set and that eats all memory.
        vtkToItk = itk.VTKImageToImageFilter.IF3.New()
        icast = vtk.vtkImageCast()
        icast.SetOutputScalarTypeToFloat()
        icast.SetInput(movingImage)
        vtkToItk.SetInput(icast.GetOutput())
        movingImage = vtkToItk.GetOutput()
        movingImage.Update()

        fixedImage = self.dataUnit.getSourceDataUnits()[0].getTimepoint(
            fixedtp)
        fixedImage.SetUpdateExtent(fixedImage.GetWholeExtent())
        fixedImage.Update()

        #fixedImage = self.convertVTKtoITK(fixedImage, cast = types.FloatType)
        vtkToItk2 = itk.VTKImageToImageFilter.IF3.New()
        icast2 = vtk.vtkImageCast()
        icast2.SetOutputScalarTypeToFloat()
        icast2.SetInput(fixedImage)
        vtkToItk2.SetInput(icast2.GetOutput())
        fixedImage = vtkToItk2.GetOutput()
        fixedImage.Update()

        # Use last transform parameters as initialization to this registration
        if self.transform and not usePrevious and False:
            initialParameters = self.transform.GetParameters()
        else:
            initialParameters = None

        # Create registration framework's components
        fixedPyramid = itk.MultiResolutionPyramidImageFilter.IF3IF3.New()
        movingPyramid = itk.MultiResolutionPyramidImageFilter.IF3IF3.New()
        self.registration = itk.MultiResolutionImageRegistrationMethod.IF3IF3.New(
        )
        self.metric = itk.MeanSquaresImageToImageMetric.IF3IF3.New()
        self.transform = itk.TranslationTransform.D3.New()
        self.interpolator = itk.LinearInterpolateImageFunction.IF3D.New()
        self.optimizer = itk.RegularStepGradientDescentOptimizer.New()

        # Initialize registration framework's components
        self.registration.SetOptimizer(self.optimizer)
        self.registration.SetTransform(self.transform)
        self.registration.SetInterpolator(self.interpolator)
        self.registration.SetMetric(self.metric)
        self.registration.SetFixedImagePyramid(fixedPyramid)
        self.registration.SetMovingImagePyramid(movingPyramid)
        self.registration.SetFixedImage(fixedImage)
        self.registration.SetMovingImage(movingImage)
        self.registration.SetFixedImageRegion(
            fixedImage.GetLargestPossibleRegion())
        #fixedSize = fixedImage.GetLargestPossibleRegion().GetSize()
        #region = itk.ImageRegion._3()
        #size = itk.Size._3()
        #index = itk.Index._3()
        #index.SetElement(0, -fixedSize.GetElement(0))
        #index.SetElement(1, -fixedSize.GetElement(1))
        #index.SetElement(2, -fixedSize.GetElement(2))
        #size.SetElement(0, 3*fixedSize.GetElement(0))
        #size.SetElement(1, 3*fixedSize.GetElement(1))
        #size.SetElement(2, 3*fixedSize.GetElement(2))
        #region.SetIndex(index)
        #region.SetSize(size)

        # Set schedules
        fixedMatr = itk.vnl_matrix.UI(usedLevels, 3)
        movingMatr = itk.vnl_matrix.UI(usedLevels, 3)
        for i in range(usedLevels):
            fixedMatr.put(i, 0, levelsX[i])
            fixedMatr.put(i, 1, levelsY[i])
            fixedMatr.put(i, 2, levelsZ[i])
            movingMatr.put(i, 0, levelsX[i])
            movingMatr.put(i, 1, levelsY[i])
            movingMatr.put(i, 2, levelsZ[i])

        fixedArray = itk.Array2D.UI(fixedMatr)
        movingArray = itk.Array2D.UI(movingMatr)
        self.registration.SetSchedules(fixedArray, movingArray)

        # Use last transform parameters as initialization to this registration
        if initialParameters is None:
            self.transform.SetIdentity()
            initialParameters = self.transform.GetParameters()

        self.registration.SetInitialTransformParameters(
            tuple([initialParameters[i] for i in range(3)]))
        self.optimizer.SetMaximumStepLength(self.maxSteps[usedLevels - 1])
        self.optimizer.SetMinimumStepLength(self.minSteps[usedLevels - 1])
        self.optimizer.SetNumberOfIterations(self.maxIters[usedLevels - 1])

        iterationCommand = itk.PyCommand.New()
        iterationCommand.SetCommandCallable(self.updateProgress)
        self.optimizer.AddObserver(itk.IterationEvent(), iterationCommand)
        levelCommand = itk.PyCommand.New()
        levelCommand.SetCommandCallable(self.updateParameters)
        self.registration.AddObserver(itk.IterationEvent(), levelCommand)

        # Execute registration
        Logging.info("Starting registration")
        startTime = time.time()
        self.registration.StartRegistration()
        finalParameters = self.registration.GetLastTransformParameters()
        Logging.info("Registration took %s seconds" %
                     (time.time() - startTime))
        Logging.info("Final Registration parameters")
        Logging.info("Translation X = %f" % (finalParameters.GetElement(0)))
        Logging.info("Translation Y = %f" % (finalParameters.GetElement(1)))
        Logging.info("Translation Z = %f" % (finalParameters.GetElement(2)))

        if usePrevious:
            for i in range(3):
                self.totalTranslation.SetElement(
                    i,
                    self.totalTranslation.GetElement(i) +
                    finalParameters.GetElement(i))
            finalParameters = self.totalTranslation

        if not self.parameters["X"]:
            finalParameters.SetElement(0, 0.0)
        if not self.parameters["Y"]:
            finalParameters.SetElement(1, 0.0)
        if not self.parameters["Z"]:
            finalParameters.SetElement(2, 0.0)

        # del filters to free memory
        #del self.metric
        #del self.optimizer
        #del self.interpolator
        #del self.registration

        Logging.info("Use transform parameters")
        Logging.info("Translation X = %f" % (finalParameters.GetElement(0)))
        Logging.info("Translation Y = %f" % (finalParameters.GetElement(1)))
        Logging.info("Translation Z = %f" % (finalParameters.GetElement(2)))

        # Translate input image using results from the registration
        resampler = itk.ResampleImageFilter.IF3IF3.New()
        self.transform.SetParameters(finalParameters)
        resampler.SetTransform(self.transform)
        resampler.SetInput(movingImage)
        region = movingImage.GetLargestPossibleRegion()
        resampler.SetSize(region.GetSize())
        resampler.SetOutputSpacing(movingImage.GetSpacing())
        resampler.SetOutputOrigin(movingImage.GetOrigin())
        resampler.SetDefaultPixelValue(backgroundValue)
        data = resampler.GetOutput()
        data.Update()

        data = self.convertITKtoVTK(data, cast="UC3")
        return data
示例#3
0
    def execute(self, inputs, update=0, last=0):
        """
		Initializes and executes the registration process. Does
		the result transform to the input image and returns transformed image.
		"""
        if not lib.ProcessingFilter.ProcessingFilter.execute(self, inputs):
            return None

        backgroundValue = self.parameters["BackgroundPixelValue"]
        minStepLength = self.parameters["MinStepLength"]
        maxStepLength = self.parameters["MaxStepLength"]
        maxIterations = self.parameters["MaxIterations"]
        caching = self.parameters["UseCaching"]
        gridSizeX = self.parameters["GridSizeX"]
        gridSizeY = self.parameters["GridSizeY"]
        gridSizeZ = self.parameters["GridSizeZ"]
        fixedtp = self.parameters["FixedTimepoint"] - 1

        movingImage = self.getInput(1)
        movingImage.SetUpdateExtent(movingImage.GetWholeExtent())
        movingImage.Update()

        # Create copy of data, otherwise movingImage will point to same image
        # as fixedImage
        mi = vtk.vtkImageData()
        mi.DeepCopy(movingImage)
        movingImage = mi
        movingImage.Update()
        #movingImage = self.convertVTKtoITK(movingImage, cast = types.FloatType)
        # Following is dirty but currently has to be done this way since
        # convertVTKtoITK doesn't work with two dataset unless
        # itkConfig.ProgressCallback is set and that eats all memory.
        vtkToItk = itk.VTKImageToImageFilter.IF3.New()
        icast = vtk.vtkImageCast()
        icast.SetOutputScalarTypeToFloat()
        icast.SetInput(movingImage)
        vtkToItk.SetInput(icast.GetOutput())
        vtkToItk.Update()
        movingImage = vtkToItk.GetOutput()

        fixedImage = self.dataUnit.getSourceDataUnits()[0].getTimepoint(
            fixedtp)
        fixedImage.SetUpdateExtent(fixedImage.GetWholeExtent())
        fixedImage.Update()
        #fixedImage = self.convertVTKtoITK(fixedImage, cast = types.FloatType)
        vtkToItk2 = itk.VTKImageToImageFilter.IF3.New()
        icast2 = vtk.vtkImageCast()
        icast2.SetOutputScalarTypeToFloat()
        icast2.SetInput(fixedImage)
        vtkToItk2.SetInput(icast2.GetOutput())
        vtkToItk2.Update()
        fixedImage = vtkToItk2.GetOutput()

        fixedRegion = fixedImage.GetLargestPossibleRegion()
        fixedSize = fixedRegion.GetSize()
        dim = fixedRegion.GetImageDimension()
        if dim == 3 and fixedSize.GetElement(2) < 2:
            dim = 2
            extractFilter = itk.ExtractImageFilter.IF3IF2.New()
            extractRegion = itk.ImageRegion._3()
            extractSize = itk.Size._3()
            extractIndex = itk.Index._3()
            for i in range(0, 3):
                extractIndex.SetElement(i, 0)
            extractRegion.SetIndex(extractIndex)
            for i in range(0, 2):
                extractSize.SetElement(i, fixedSize.GetElement(i))
            extractRegion.SetSize(extractSize)
            extractFilter.SetExtractionRegion(extractRegion)
            extractFilter.SetInput(fixedImage)
            fixedImage = extractFilter.GetOutput()
            fixedImage.Update()

            extractFilter2 = itk.ExtractImageFilter.IF3IF2.New()
            extractFilter2.SetExtractionRegion(extractRegion)
            extractFilter2.SetInput(movingImage)
            movingImage = extractFilter2.GetOutput()
            movingImage.Update()

        # Create registration framework's components
        self.registration = eval("itk.ImageRegistrationMethod.IF%dIF%d.New()" %
                                 (dim, dim))
        self.optimizer = itk.RegularStepGradientDescentOptimizer.New()
        #self.optimizer = itk.LBFGSBOptimizer.New()
        self.metric = eval("itk.MeanSquaresImageToImageMetric.IF%dIF%d.New()" %
                           (dim, dim))
        self.interpolator = eval(
            "itk.LinearInterpolateImageFunction.IF%dD.New()" % dim)
        self.transform = eval("itk.BSplineDeformableTransform.D%d%d.New()" %
                              (dim, dim))

        # Initialize registration framework's components
        self.registration.SetOptimizer(self.optimizer.GetPointer())
        self.registration.SetTransform(self.transform.GetPointer())
        self.registration.SetInterpolator(self.interpolator.GetPointer())
        self.registration.SetMetric(self.metric.GetPointer())
        self.registration.SetFixedImage(fixedImage)
        self.registration.SetMovingImage(movingImage)
        self.registration.SetFixedImageRegion(fixedImage.GetBufferedRegion())

        self.optimizer.SetMaximumStepLength(maxStepLength)
        self.optimizer.SetMinimumStepLength(minStepLength)
        self.optimizer.SetNumberOfIterations(maxIterations)
        self.metric.SetUseCachingOfBSplineWeights(caching)

        # Initialize transform
        gridRegion = eval("itk.ImageRegion._%d()" % dim)
        gridSizeOnImage = eval("itk.Size._%d()" % dim)
        gridBorderSize = eval("itk.Size._%d()" % dim)
        totalGridSize = eval("itk.Size._%d()" % dim)

        gridSizeOnImage.SetElement(0, gridSizeX)
        gridSizeOnImage.SetElement(1, gridSizeY)
        if dim == 3:
            gridSizeOnImage.SetElement(2, gridSizeZ)
        gridBorderSize.Fill(3)

        for i in range(dim):
            totalGridSize.SetElement(
                i,
                gridSizeOnImage.GetElement(i) + gridBorderSize.GetElement(i))
        gridRegion.SetSize(totalGridSize)

        gridSpacing = fixedImage.GetSpacing()
        gridOrigin = fixedImage.GetOrigin()
        fixedSize = fixedRegion.GetSize()

        for i in range(dim):
            spacingElement = gridSpacing.GetElement(i)
            spacingElement *= int(
                math.floor(
                    float(fixedSize.GetElement(i) - 1) /
                    float(gridSizeOnImage.GetElement(i) - 1)))
            originElement = gridOrigin.GetElement(i)
            originElement -= spacingElement
            gridSpacing.SetElement(i, spacingElement)
            gridOrigin.SetElement(i, originElement)

        self.transform.SetGridSpacing(gridSpacing)
        self.transform.SetGridOrigin(gridOrigin)
        self.transform.SetGridRegion(gridRegion)
        self.transform.SetGridDirection(fixedImage.GetDirection())
        numOfParameters = self.transform.GetNumberOfParameters()
        parameters = itk.Array.D(numOfParameters)
        parameters.Fill(0.0)
        self.transform.SetParameters(parameters)
        self.registration.SetInitialTransformParameters(
            self.transform.GetParameters())

        #boundSelect = itk.Array.SL(numOfParameters)
        #upperBound = itk.Array.D(numOfParameters)
        #lowerBound = itk.Array.D(numOfParameters)
        #boundSelect.Fill(0)
        #upperBound.Fill(0.0)
        #lowerBound.Fill(0.0)
        #self.optimizer.SetBoundSelection(boundSelect)
        #self.optimizer.SetUpperBound(upperBound)
        #self.optimizer.SetLowerBound(lowerBound)
        #self.optimizer.SetCostFunctionConvergenceFactor(10**12)
        #self.optimizer.SetProjectedGradientTolerance(1.0)
        #self.optimizer.SetMaximumNumberOfIterations(500)
        #self.optimizer.SetMaximumNumberOfEvaluations(500)
        #self.optimizer.SetMaximumNumberOfCorrections(5)

        iterationCommand = itk.PyCommand.New()
        iterationCommand.SetCommandCallable(self.updateProgress)
        self.optimizer.AddObserver(itk.IterationEvent(),
                                   iterationCommand.GetPointer())

        Logging.info("Starting registration")
        startTime = time.time()
        self.registration.StartRegistration()
        finalParameters = self.registration.GetLastTransformParameters()
        self.transform.SetParameters(finalParameters)
        Logging.info("Registration took %s seconds" %
                     (time.time() - startTime))

        self.resampler = eval("itk.ResampleImageFilter.IF%dIF%d.New()" %
                              (dim, dim))
        resampleInterpolator = eval(
            "itk.NearestNeighborInterpolateImageFunction.IF%dD.New()" % dim)
        self.transform.SetParameters(finalParameters)
        self.resampler.SetTransform(self.transform.GetPointer())
        self.resampler.SetInput(movingImage)
        self.resampler.SetSize(
            movingImage.GetLargestPossibleRegion().GetSize())
        self.resampler.SetOutputSpacing(movingImage.GetSpacing())
        self.resampler.SetOutputOrigin(movingImage.GetOrigin())
        self.resampler.SetDefaultPixelValue(backgroundValue)
        self.resampler.SetInterpolator(resampleInterpolator.GetPointer())
        data = self.resampler.GetOutput()
        data.Update()

        data = self.convertITKtoVTK(data, cast="UC3")
        return data
示例#4
0
    def register(self, fixedData, movingData):
        clip1 = movingData.info.getData('clip')
        clip2 = fixedData.info.getData('clip')

        fixed_res = fixedData.getResolution().tolist()
        moving_res = movingData.getResolution().tolist()

        def iterationUpdate():
            currentParameter = transform.GetParameters()
            print "M: %f   P: %f %f %f" % (
                optimizer.GetValue(), currentParameter.GetElement(0),
                currentParameter.GetElement(1), currentParameter.GetElement(2))

        image_type = fixedData.getITKImageType()
        fixedImage = fixedData.getITKImage()
        movingImage = movingData.getITKImage()
        rescale_filter_fixed = itk.RescaleIntensityImageFilter[
            image_type, image_type].New()
        rescale_filter_fixed.SetInput(fixedImage)
        rescale_filter_fixed.SetOutputMinimum(0)
        rescale_filter_fixed.SetOutputMaximum(255)
        rescale_filter_moving = itk.RescaleIntensityImageFilter[
            image_type, image_type].New()
        rescale_filter_moving.SetInput(movingImage)
        rescale_filter_moving.SetOutputMinimum(0)
        rescale_filter_moving.SetOutputMaximum(255)

        registration = itk.ImageRegistrationMethod[image_type,
                                                   image_type].New()
        imageMetric = itk.MattesMutualInformationImageToImageMetric[
            image_type, image_type].New()
        transform = itk.TranslationTransform.New()
        optimizer = itk.RegularStepGradientDescentOptimizer.New()
        interpolator = itk.LinearInterpolateImageFunction[image_type,
                                                          itk.D].New()

        registration.SetOptimizer(optimizer)
        registration.SetTransform(transform)
        registration.SetInterpolator(interpolator)
        registration.SetMetric(imageMetric)
        registration.SetFixedImage(rescale_filter_fixed.GetOutput())
        registration.SetMovingImage(rescale_filter_moving.GetOutput())
        registration.SetFixedImageRegion(fixedImage.GetBufferedRegion())

        para = [
            -clip1[4] * moving_res[0] + clip2[4] * fixed_res[0],
            -clip1[2] * moving_res[1] + clip2[2] * fixed_res[1], 0
        ]
        transform.SetParameters(para)

        initialParameters = transform.GetParameters()
        print "Initial Registration Parameters "
        print initialParameters.GetElement(0)
        print initialParameters.GetElement(1)
        print initialParameters.GetElement(2)
        registration.SetInitialTransformParameters(initialParameters)

        # optimizer scale
        optimizerScales = itk.Array[itk.D](transform.GetNumberOfParameters())
        optimizerScales.SetElement(0, 1.0)
        optimizerScales.SetElement(1, 1.0)
        optimizerScales.SetElement(2, 1.0)

        #imageMetric.UseAllPixelsOn()
        imageMetric.SetNumberOfHistogramBins(64)
        imageMetric.SetNumberOfSpatialSamples(800000)

        optimizer.MinimizeOn()
        optimizer.SetMaximumStepLength(2.00)
        optimizer.SetMinimumStepLength(0.001)
        optimizer.SetRelaxationFactor(0.8)
        optimizer.SetNumberOfIterations(200)

        iterationCommand = itk.PyCommand.New()
        iterationCommand.SetCommandCallable(iterationUpdate)
        optimizer.AddObserver(itk.IterationEvent(), iterationCommand)

        # Start the registration process
        try:
            registration.Update()
        except Exception:
            print "error"
            transform.SetParameters([0.0, 0.0, 0.0])

        # Get the final parameters of the transformation
        finalParameters = registration.GetLastTransformParameters()

        print "Final Registration Parameters "
        print finalParameters.GetElement(0)
        print finalParameters.GetElement(1)
        print finalParameters.GetElement(2)

        # Use the final transform for resampling the moving image.
        parameters = transform.GetParameters()

        # Fail to use ResampleImageFilter
        x = parameters.GetElement(0)
        y = parameters.GetElement(1)
        z = parameters.GetElement(2)
        T = ml.mat([x, y, z]).T
        transform = sitk.Transform(3, sitk.sitkAffine)
        para = [1, 0, 0, 0, 1, 0, 0, 0, 1] + T.T.tolist()[0]
        transform.SetParameters(para)

        movingImage = movingData.getSimpleITKImage()
        fixedImage = fixedData.getSimpleITKImage()
        resultImage = sitk.Resample(movingImage, fixedImage, transform,
                                    sitk.sitkLinear, 0, sitk.sitkFloat32)

        return sitk.GetArrayFromImage(resultImage), {}, para + [0, 0, 0]