def initialize(self): # Optimizer is exposed so that calling scripts can reference with custom observers self.optimizer = itk.LBFGSBOptimizerv4.New( MaximumNumberOfFunctionEvaluations=self.MAX_FUNCTION_EVALUATIONS, MaximumNumberOfCorrections=self.MAX_CORRECTIONS) # Monitor optimization via observer if (self.verbose): def print_iteration(): print( f'Iteration: {self.optimizer.GetCurrentIteration()}' f' Metric: {self.optimizer.GetCurrentMetricValue()}' f' Infinity Norm: {self.optimizer.GetInfinityNormOfProjectedGradient()}' ) self.optimizer.AddObserver(itk.IterationEvent(), print_iteration)
def execute(self, inputs, update=0, last=0): """ Initializes and executes the registration process. Does the result translation to input image and returns translated image. """ if not lib.ProcessingFilter.ProcessingFilter.execute(self, inputs): return None backgroundValue = self.parameters["BackgroundPixelValue"] usePrevious = self.parameters["UsePreviousAsFixed"] transX = self.parameters["X"] transY = self.parameters["Y"] transZ = self.parameters["Z"] self.minSteps = [] self.maxSteps = [] self.maxIters = [] levelsX = [] levelsY = [] levelsZ = [] for i in range(4): minStepLength = eval("self.parameters['MinStepLength%d']" % i) maxStepLength = eval("self.parameters['MaxStepLength%d']" % i) maxIterations = eval("self.parameters['MaxIterations%d']" % i) x = eval("self.parameters['LevelX%d']" % i) y = eval("self.parameters['LevelY%d']" % i) z = eval("self.parameters['LevelZ%d']" % i) if minStepLength and maxStepLength and maxIterations and x and y and z: self.minSteps.append(minStepLength) self.maxSteps.append(maxStepLength) self.maxIters.append(maxIterations) levelsX.append(x) levelsY.append(y) levelsZ.append(z) usedLevels = len(self.minSteps) if usePrevious: if scripting.processingTimepoint > 0: fixedtp = scripting.processingTimepoint - 1 else: fixedtp = 0 else: fixedtp = self.parameters["FixedTimepoint"] - 1 movingImage = self.getInput(1) movingImage.SetUpdateExtent(movingImage.GetWholeExtent()) movingImage.Update() # Create copy of data, otherwise movingImage will point to same image # as fixedImage mi = vtk.vtkImageData() mi.DeepCopy(movingImage) movingImage = mi movingImage.Update() #movingImage = self.convertVTKtoITK(movingImage, cast = types.FloatType) # Following is dirty but currently has to be done this way since # convertVTKtoITK doesn't work with two dataset unless # itkConfig.ProgressCallback is set and that eats all memory. vtkToItk = itk.VTKImageToImageFilter.IF3.New() icast = vtk.vtkImageCast() icast.SetOutputScalarTypeToFloat() icast.SetInput(movingImage) vtkToItk.SetInput(icast.GetOutput()) movingImage = vtkToItk.GetOutput() movingImage.Update() fixedImage = self.dataUnit.getSourceDataUnits()[0].getTimepoint( fixedtp) fixedImage.SetUpdateExtent(fixedImage.GetWholeExtent()) fixedImage.Update() #fixedImage = self.convertVTKtoITK(fixedImage, cast = types.FloatType) vtkToItk2 = itk.VTKImageToImageFilter.IF3.New() icast2 = vtk.vtkImageCast() icast2.SetOutputScalarTypeToFloat() icast2.SetInput(fixedImage) vtkToItk2.SetInput(icast2.GetOutput()) fixedImage = vtkToItk2.GetOutput() fixedImage.Update() # Use last transform parameters as initialization to this registration if self.transform and not usePrevious and False: initialParameters = self.transform.GetParameters() else: initialParameters = None # Create registration framework's components fixedPyramid = itk.MultiResolutionPyramidImageFilter.IF3IF3.New() movingPyramid = itk.MultiResolutionPyramidImageFilter.IF3IF3.New() self.registration = itk.MultiResolutionImageRegistrationMethod.IF3IF3.New( ) self.metric = itk.MeanSquaresImageToImageMetric.IF3IF3.New() self.transform = itk.TranslationTransform.D3.New() self.interpolator = itk.LinearInterpolateImageFunction.IF3D.New() self.optimizer = itk.RegularStepGradientDescentOptimizer.New() # Initialize registration framework's components self.registration.SetOptimizer(self.optimizer) self.registration.SetTransform(self.transform) self.registration.SetInterpolator(self.interpolator) self.registration.SetMetric(self.metric) self.registration.SetFixedImagePyramid(fixedPyramid) self.registration.SetMovingImagePyramid(movingPyramid) self.registration.SetFixedImage(fixedImage) self.registration.SetMovingImage(movingImage) self.registration.SetFixedImageRegion( fixedImage.GetLargestPossibleRegion()) #fixedSize = fixedImage.GetLargestPossibleRegion().GetSize() #region = itk.ImageRegion._3() #size = itk.Size._3() #index = itk.Index._3() #index.SetElement(0, -fixedSize.GetElement(0)) #index.SetElement(1, -fixedSize.GetElement(1)) #index.SetElement(2, -fixedSize.GetElement(2)) #size.SetElement(0, 3*fixedSize.GetElement(0)) #size.SetElement(1, 3*fixedSize.GetElement(1)) #size.SetElement(2, 3*fixedSize.GetElement(2)) #region.SetIndex(index) #region.SetSize(size) # Set schedules fixedMatr = itk.vnl_matrix.UI(usedLevels, 3) movingMatr = itk.vnl_matrix.UI(usedLevels, 3) for i in range(usedLevels): fixedMatr.put(i, 0, levelsX[i]) fixedMatr.put(i, 1, levelsY[i]) fixedMatr.put(i, 2, levelsZ[i]) movingMatr.put(i, 0, levelsX[i]) movingMatr.put(i, 1, levelsY[i]) movingMatr.put(i, 2, levelsZ[i]) fixedArray = itk.Array2D.UI(fixedMatr) movingArray = itk.Array2D.UI(movingMatr) self.registration.SetSchedules(fixedArray, movingArray) # Use last transform parameters as initialization to this registration if initialParameters is None: self.transform.SetIdentity() initialParameters = self.transform.GetParameters() self.registration.SetInitialTransformParameters( tuple([initialParameters[i] for i in range(3)])) self.optimizer.SetMaximumStepLength(self.maxSteps[usedLevels - 1]) self.optimizer.SetMinimumStepLength(self.minSteps[usedLevels - 1]) self.optimizer.SetNumberOfIterations(self.maxIters[usedLevels - 1]) iterationCommand = itk.PyCommand.New() iterationCommand.SetCommandCallable(self.updateProgress) self.optimizer.AddObserver(itk.IterationEvent(), iterationCommand) levelCommand = itk.PyCommand.New() levelCommand.SetCommandCallable(self.updateParameters) self.registration.AddObserver(itk.IterationEvent(), levelCommand) # Execute registration Logging.info("Starting registration") startTime = time.time() self.registration.StartRegistration() finalParameters = self.registration.GetLastTransformParameters() Logging.info("Registration took %s seconds" % (time.time() - startTime)) Logging.info("Final Registration parameters") Logging.info("Translation X = %f" % (finalParameters.GetElement(0))) Logging.info("Translation Y = %f" % (finalParameters.GetElement(1))) Logging.info("Translation Z = %f" % (finalParameters.GetElement(2))) if usePrevious: for i in range(3): self.totalTranslation.SetElement( i, self.totalTranslation.GetElement(i) + finalParameters.GetElement(i)) finalParameters = self.totalTranslation if not self.parameters["X"]: finalParameters.SetElement(0, 0.0) if not self.parameters["Y"]: finalParameters.SetElement(1, 0.0) if not self.parameters["Z"]: finalParameters.SetElement(2, 0.0) # del filters to free memory #del self.metric #del self.optimizer #del self.interpolator #del self.registration Logging.info("Use transform parameters") Logging.info("Translation X = %f" % (finalParameters.GetElement(0))) Logging.info("Translation Y = %f" % (finalParameters.GetElement(1))) Logging.info("Translation Z = %f" % (finalParameters.GetElement(2))) # Translate input image using results from the registration resampler = itk.ResampleImageFilter.IF3IF3.New() self.transform.SetParameters(finalParameters) resampler.SetTransform(self.transform) resampler.SetInput(movingImage) region = movingImage.GetLargestPossibleRegion() resampler.SetSize(region.GetSize()) resampler.SetOutputSpacing(movingImage.GetSpacing()) resampler.SetOutputOrigin(movingImage.GetOrigin()) resampler.SetDefaultPixelValue(backgroundValue) data = resampler.GetOutput() data.Update() data = self.convertITKtoVTK(data, cast="UC3") return data
def execute(self, inputs, update=0, last=0): """ Initializes and executes the registration process. Does the result transform to the input image and returns transformed image. """ if not lib.ProcessingFilter.ProcessingFilter.execute(self, inputs): return None backgroundValue = self.parameters["BackgroundPixelValue"] minStepLength = self.parameters["MinStepLength"] maxStepLength = self.parameters["MaxStepLength"] maxIterations = self.parameters["MaxIterations"] caching = self.parameters["UseCaching"] gridSizeX = self.parameters["GridSizeX"] gridSizeY = self.parameters["GridSizeY"] gridSizeZ = self.parameters["GridSizeZ"] fixedtp = self.parameters["FixedTimepoint"] - 1 movingImage = self.getInput(1) movingImage.SetUpdateExtent(movingImage.GetWholeExtent()) movingImage.Update() # Create copy of data, otherwise movingImage will point to same image # as fixedImage mi = vtk.vtkImageData() mi.DeepCopy(movingImage) movingImage = mi movingImage.Update() #movingImage = self.convertVTKtoITK(movingImage, cast = types.FloatType) # Following is dirty but currently has to be done this way since # convertVTKtoITK doesn't work with two dataset unless # itkConfig.ProgressCallback is set and that eats all memory. vtkToItk = itk.VTKImageToImageFilter.IF3.New() icast = vtk.vtkImageCast() icast.SetOutputScalarTypeToFloat() icast.SetInput(movingImage) vtkToItk.SetInput(icast.GetOutput()) vtkToItk.Update() movingImage = vtkToItk.GetOutput() fixedImage = self.dataUnit.getSourceDataUnits()[0].getTimepoint( fixedtp) fixedImage.SetUpdateExtent(fixedImage.GetWholeExtent()) fixedImage.Update() #fixedImage = self.convertVTKtoITK(fixedImage, cast = types.FloatType) vtkToItk2 = itk.VTKImageToImageFilter.IF3.New() icast2 = vtk.vtkImageCast() icast2.SetOutputScalarTypeToFloat() icast2.SetInput(fixedImage) vtkToItk2.SetInput(icast2.GetOutput()) vtkToItk2.Update() fixedImage = vtkToItk2.GetOutput() fixedRegion = fixedImage.GetLargestPossibleRegion() fixedSize = fixedRegion.GetSize() dim = fixedRegion.GetImageDimension() if dim == 3 and fixedSize.GetElement(2) < 2: dim = 2 extractFilter = itk.ExtractImageFilter.IF3IF2.New() extractRegion = itk.ImageRegion._3() extractSize = itk.Size._3() extractIndex = itk.Index._3() for i in range(0, 3): extractIndex.SetElement(i, 0) extractRegion.SetIndex(extractIndex) for i in range(0, 2): extractSize.SetElement(i, fixedSize.GetElement(i)) extractRegion.SetSize(extractSize) extractFilter.SetExtractionRegion(extractRegion) extractFilter.SetInput(fixedImage) fixedImage = extractFilter.GetOutput() fixedImage.Update() extractFilter2 = itk.ExtractImageFilter.IF3IF2.New() extractFilter2.SetExtractionRegion(extractRegion) extractFilter2.SetInput(movingImage) movingImage = extractFilter2.GetOutput() movingImage.Update() # Create registration framework's components self.registration = eval("itk.ImageRegistrationMethod.IF%dIF%d.New()" % (dim, dim)) self.optimizer = itk.RegularStepGradientDescentOptimizer.New() #self.optimizer = itk.LBFGSBOptimizer.New() self.metric = eval("itk.MeanSquaresImageToImageMetric.IF%dIF%d.New()" % (dim, dim)) self.interpolator = eval( "itk.LinearInterpolateImageFunction.IF%dD.New()" % dim) self.transform = eval("itk.BSplineDeformableTransform.D%d%d.New()" % (dim, dim)) # Initialize registration framework's components self.registration.SetOptimizer(self.optimizer.GetPointer()) self.registration.SetTransform(self.transform.GetPointer()) self.registration.SetInterpolator(self.interpolator.GetPointer()) self.registration.SetMetric(self.metric.GetPointer()) self.registration.SetFixedImage(fixedImage) self.registration.SetMovingImage(movingImage) self.registration.SetFixedImageRegion(fixedImage.GetBufferedRegion()) self.optimizer.SetMaximumStepLength(maxStepLength) self.optimizer.SetMinimumStepLength(minStepLength) self.optimizer.SetNumberOfIterations(maxIterations) self.metric.SetUseCachingOfBSplineWeights(caching) # Initialize transform gridRegion = eval("itk.ImageRegion._%d()" % dim) gridSizeOnImage = eval("itk.Size._%d()" % dim) gridBorderSize = eval("itk.Size._%d()" % dim) totalGridSize = eval("itk.Size._%d()" % dim) gridSizeOnImage.SetElement(0, gridSizeX) gridSizeOnImage.SetElement(1, gridSizeY) if dim == 3: gridSizeOnImage.SetElement(2, gridSizeZ) gridBorderSize.Fill(3) for i in range(dim): totalGridSize.SetElement( i, gridSizeOnImage.GetElement(i) + gridBorderSize.GetElement(i)) gridRegion.SetSize(totalGridSize) gridSpacing = fixedImage.GetSpacing() gridOrigin = fixedImage.GetOrigin() fixedSize = fixedRegion.GetSize() for i in range(dim): spacingElement = gridSpacing.GetElement(i) spacingElement *= int( math.floor( float(fixedSize.GetElement(i) - 1) / float(gridSizeOnImage.GetElement(i) - 1))) originElement = gridOrigin.GetElement(i) originElement -= spacingElement gridSpacing.SetElement(i, spacingElement) gridOrigin.SetElement(i, originElement) self.transform.SetGridSpacing(gridSpacing) self.transform.SetGridOrigin(gridOrigin) self.transform.SetGridRegion(gridRegion) self.transform.SetGridDirection(fixedImage.GetDirection()) numOfParameters = self.transform.GetNumberOfParameters() parameters = itk.Array.D(numOfParameters) parameters.Fill(0.0) self.transform.SetParameters(parameters) self.registration.SetInitialTransformParameters( self.transform.GetParameters()) #boundSelect = itk.Array.SL(numOfParameters) #upperBound = itk.Array.D(numOfParameters) #lowerBound = itk.Array.D(numOfParameters) #boundSelect.Fill(0) #upperBound.Fill(0.0) #lowerBound.Fill(0.0) #self.optimizer.SetBoundSelection(boundSelect) #self.optimizer.SetUpperBound(upperBound) #self.optimizer.SetLowerBound(lowerBound) #self.optimizer.SetCostFunctionConvergenceFactor(10**12) #self.optimizer.SetProjectedGradientTolerance(1.0) #self.optimizer.SetMaximumNumberOfIterations(500) #self.optimizer.SetMaximumNumberOfEvaluations(500) #self.optimizer.SetMaximumNumberOfCorrections(5) iterationCommand = itk.PyCommand.New() iterationCommand.SetCommandCallable(self.updateProgress) self.optimizer.AddObserver(itk.IterationEvent(), iterationCommand.GetPointer()) Logging.info("Starting registration") startTime = time.time() self.registration.StartRegistration() finalParameters = self.registration.GetLastTransformParameters() self.transform.SetParameters(finalParameters) Logging.info("Registration took %s seconds" % (time.time() - startTime)) self.resampler = eval("itk.ResampleImageFilter.IF%dIF%d.New()" % (dim, dim)) resampleInterpolator = eval( "itk.NearestNeighborInterpolateImageFunction.IF%dD.New()" % dim) self.transform.SetParameters(finalParameters) self.resampler.SetTransform(self.transform.GetPointer()) self.resampler.SetInput(movingImage) self.resampler.SetSize( movingImage.GetLargestPossibleRegion().GetSize()) self.resampler.SetOutputSpacing(movingImage.GetSpacing()) self.resampler.SetOutputOrigin(movingImage.GetOrigin()) self.resampler.SetDefaultPixelValue(backgroundValue) self.resampler.SetInterpolator(resampleInterpolator.GetPointer()) data = self.resampler.GetOutput() data.Update() data = self.convertITKtoVTK(data, cast="UC3") return data
def register(self, fixedData, movingData): clip1 = movingData.info.getData('clip') clip2 = fixedData.info.getData('clip') fixed_res = fixedData.getResolution().tolist() moving_res = movingData.getResolution().tolist() def iterationUpdate(): currentParameter = transform.GetParameters() print "M: %f P: %f %f %f" % ( optimizer.GetValue(), currentParameter.GetElement(0), currentParameter.GetElement(1), currentParameter.GetElement(2)) image_type = fixedData.getITKImageType() fixedImage = fixedData.getITKImage() movingImage = movingData.getITKImage() rescale_filter_fixed = itk.RescaleIntensityImageFilter[ image_type, image_type].New() rescale_filter_fixed.SetInput(fixedImage) rescale_filter_fixed.SetOutputMinimum(0) rescale_filter_fixed.SetOutputMaximum(255) rescale_filter_moving = itk.RescaleIntensityImageFilter[ image_type, image_type].New() rescale_filter_moving.SetInput(movingImage) rescale_filter_moving.SetOutputMinimum(0) rescale_filter_moving.SetOutputMaximum(255) registration = itk.ImageRegistrationMethod[image_type, image_type].New() imageMetric = itk.MattesMutualInformationImageToImageMetric[ image_type, image_type].New() transform = itk.TranslationTransform.New() optimizer = itk.RegularStepGradientDescentOptimizer.New() interpolator = itk.LinearInterpolateImageFunction[image_type, itk.D].New() registration.SetOptimizer(optimizer) registration.SetTransform(transform) registration.SetInterpolator(interpolator) registration.SetMetric(imageMetric) registration.SetFixedImage(rescale_filter_fixed.GetOutput()) registration.SetMovingImage(rescale_filter_moving.GetOutput()) registration.SetFixedImageRegion(fixedImage.GetBufferedRegion()) para = [ -clip1[4] * moving_res[0] + clip2[4] * fixed_res[0], -clip1[2] * moving_res[1] + clip2[2] * fixed_res[1], 0 ] transform.SetParameters(para) initialParameters = transform.GetParameters() print "Initial Registration Parameters " print initialParameters.GetElement(0) print initialParameters.GetElement(1) print initialParameters.GetElement(2) registration.SetInitialTransformParameters(initialParameters) # optimizer scale optimizerScales = itk.Array[itk.D](transform.GetNumberOfParameters()) optimizerScales.SetElement(0, 1.0) optimizerScales.SetElement(1, 1.0) optimizerScales.SetElement(2, 1.0) #imageMetric.UseAllPixelsOn() imageMetric.SetNumberOfHistogramBins(64) imageMetric.SetNumberOfSpatialSamples(800000) optimizer.MinimizeOn() optimizer.SetMaximumStepLength(2.00) optimizer.SetMinimumStepLength(0.001) optimizer.SetRelaxationFactor(0.8) optimizer.SetNumberOfIterations(200) iterationCommand = itk.PyCommand.New() iterationCommand.SetCommandCallable(iterationUpdate) optimizer.AddObserver(itk.IterationEvent(), iterationCommand) # Start the registration process try: registration.Update() except Exception: print "error" transform.SetParameters([0.0, 0.0, 0.0]) # Get the final parameters of the transformation finalParameters = registration.GetLastTransformParameters() print "Final Registration Parameters " print finalParameters.GetElement(0) print finalParameters.GetElement(1) print finalParameters.GetElement(2) # Use the final transform for resampling the moving image. parameters = transform.GetParameters() # Fail to use ResampleImageFilter x = parameters.GetElement(0) y = parameters.GetElement(1) z = parameters.GetElement(2) T = ml.mat([x, y, z]).T transform = sitk.Transform(3, sitk.sitkAffine) para = [1, 0, 0, 0, 1, 0, 0, 0, 1] + T.T.tolist()[0] transform.SetParameters(para) movingImage = movingData.getSimpleITKImage() fixedImage = fixedData.getSimpleITKImage() resultImage = sitk.Resample(movingImage, fixedImage, transform, sitk.sitkLinear, 0, sitk.sitkFloat32) return sitk.GetArrayFromImage(resultImage), {}, para + [0, 0, 0]