예제 #1
0
def extractBlockMatches(filepath1, filepath2, params, paramsSIFT, properties,
                        csvDir, exeload, load):
    """
  filepath1: the file path to an image of a section.
  filepath2: the file path to an image of another section.
  params: dictionary of parameters necessary for BlockMatching.
  exeload: an ExecutorService for parallel loading of image files.
  load: a function that knows how to load the image from the filepath.

  return False if the CSV file already exists, True if it has to be computed.
  """

    # Skip if pointmatches CSV file exists already:
    csvpath = os.path.join(
        csvDir,
        basename(filepath1) + '.' + basename(filepath2) + ".pointmatches.csv")
    if os.path.exists(csvpath):
        return False

    try:

        # Load files in parallel
        futures = [
            exeload.submit(Task(load, filepath1)),
            exeload.submit(Task(load, filepath2))
        ]

        fp1 = futures[0].get(
        )  # FloatProcessor, already Gaussian-blurred, contrast-corrected and scaled!
        fp2 = futures[1].get()  # FloatProcessor, idem

        # Define points from the mesh
        sourcePoints = ArrayList()
        # List to fill
        sourceMatches = ArrayList(
        )  # of PointMatch from filepath1 to filepath2

        # Don't use blockmatching if the dimensions are different
        use_blockmatching = fp1.getWidth() == fp2.getWidth() and fp1.getHeight(
        ) == fp2.getHeight()

        # Fill the sourcePoints
        mesh = TransformMesh(params["meshResolution"], fp1.width, fp1.height)
        PointMatch.sourcePoints(mesh.getVA().keySet(), sourcePoints)
        syncPrintQ("Extracting block matches for \n S: " + filepath1 +
                   "\n T: " + filepath2 + "\n  with " +
                   str(sourcePoints.size()) + " mesh sourcePoints.")
        # Run
        BlockMatching.matchByMaximalPMCCFromPreScaledImages(
            fp1,
            fp2,
            params["scale"],  # float
            params["blockRadius"],  # X
            params["blockRadius"],  # Y
            params["searchRadius"],  # X
            params["searchRadius"],  # Y
            params["minR"],  # float
            params["rod"],  # float
            params["maxCurvature"],  # float
            sourcePoints,
            sourceMatches)

        # At least some should match to accept the translation
        if len(sourceMatches) < max(20, len(sourcePoints) / 5) / 2:
            syncPrintQ(
                "Found only %i blockmatching pointmatches (from %i source points)"
                % (len(sourceMatches), len(sourcePoints)))
            syncPrintQ(
                "... therefore invoking SIFT pointmatching for:\n  S: " +
                basename(filepath1) + "\n  T: " + basename(filepath2))
            # Can fail if there is a shift larger than the searchRadius
            # Try SIFT features, which are location independent
            #
            # Images are now scaled: load originals
            futures = [
                exeload.submit(
                    Task(loadFloatProcessor,
                         filepath1,
                         params,
                         paramsSIFT,
                         scale=False)),
                exeload.submit(
                    Task(loadFloatProcessor,
                         filepath2,
                         params,
                         paramsSIFT,
                         scale=False))
            ]

            fp1 = futures[0].get()  # FloatProcessor, original
            fp2 = futures[1].get()  # FloatProcessor, original

            # Images can be of different size: scale them the same way
            area1 = fp1.width * fp1.height
            area2 = fp2.width * fp2.height

            if area1 == area2:
                paramsSIFT1 = paramsSIFT.clone()
                paramsSIFT1.maxOctaveSize = int(
                    max(properties.get("SIFT_max_size", 2048),
                        fp1.width * params["scale"]))
                paramsSIFT1.minOctaveSize = int(paramsSIFT1.maxOctaveSize /
                                                pow(2, paramsSIFT1.steps))
                paramsSIFT2 = paramsSIFT1
            else:
                bigger, smaller = (fp1, fp2) if area1 > area2 else (fp2, fp1)
                target_width_bigger = int(
                    max(1024, bigger.width * params["scale"]))
                if 1024 == target_width_bigger:
                    target_width_smaller = int(1024 * float(smaller.width) /
                                               bigger.width)
                else:
                    target_width_smaller = smaller.width * params["scale"]
                #
                paramsSIFT1 = paramsSIFT.clone()
                paramsSIFT1.maxOctaveSize = target_width_bigger
                paramsSIFT1.minOctaveSize = int(paramsSIFT1.maxOctaveSize /
                                                pow(2, paramsSIFT1.steps))
                paramsSIFT2 = paramsSIFT.clone()
                paramsSIFT2.maxOctaveSize = target_width_smaller
                paramsSIFT2.minOctaveSize = int(paramsSIFT2.maxOctaveSize /
                                                pow(2, paramsSIFT2.steps))

            ijSIFT1 = SIFT(FloatArray2DSIFT(paramsSIFT1))
            features1 = ArrayList()  # of Point instances
            ijSIFT1.extractFeatures(fp1, features1)

            ijSIFT2 = SIFT(FloatArray2DSIFT(paramsSIFT2))
            features2 = ArrayList()  # of Point instances
            ijSIFT2.extractFeatures(fp2, features2)
            # Vector of PointMatch instances
            sourceMatches = FloatArray2DSIFT.createMatches(
                features1,
                features2,
                params.get(
                    "max_sd",
                    1.5),  # max_sd: maximal difference in size (ratio max/min)
                TranslationModel2D(),
                params.get("max_id", Double.MAX_VALUE
                           ),  # max_id: maximal distance in image space
                params.get("rod", 0.9))  # rod: ratio of best vs second best

        # Store pointmatches
        savePointMatches(os.path.basename(filepath1),
                         os.path.basename(filepath2), sourceMatches, csvDir,
                         params)

        return True
    except:
        printException()
예제 #2
0
	def call(self, save_images=False):
		self.thread_used = threading.currentThread().getName()
		self.started = time.time()
		try:
			IJ.log(time.asctime())
			IJ.log(str(self.patchA))
			IJ.log(str(self.patchB))			
			print str(self.patchA)
			print str(self.patchB)

			# Adapted from ElasticMontage.java
			# https://github.com/trakem2/TrakEM2/blob/master/TrakEM2_/src/main/java/mpicbg/trakem2/align/ElasticMontage.java
			pi1 = self.patchA.createTransformedImage()
			pi2 = self.patchB.createTransformedImage()

			fp1 = pi1.target.convertToFloat()
			mask1 = pi1.getMask()
			if mask1 is None:
				fpMask1 = None
			else:
				fpMask1 = scaleByte(mask1)

			fp2 = pi2.target.convertToFloat()
			mask2 = pi2.getMask()
			if mask2 is None:
				fpMask2 = None
			else:
				fpMask2 = scaleByte(mask1)

			w = self.tileA.getWidth()
			h = self.tileA.getHeight()
			num_x = Math.max(2, int(Math.ceil(w / self.params.spring_length) + 1))
			num_y = Math.max(2, int(Math.ceil(h / self.params.spring_triangle_height_twice) + 1))
			w_mesh = (num_x - 1) * self.params.spring_length
			h_mesh = (num_y - 1) * self.params.spring_triangle_height_twice

			mesh = SpringMesh(num_x,
								num_y,
								w_mesh,
								h_mesh,
								self.params.stiffness,
								self.params.max_stretch * self.params.scale,
								self.params.damp)

			vertices = mesh.getVertices()
			maskSamples = RealPointSampleList(2)
			for vertex in vertices:
				maskSamples.add(RealPoint(vertex.getL()), ARGBType(-1)) # equivalent of 0xffffffff
			pm12 = ArrayList()
			v1 = mesh.getVertices()

			t = self.tileB.getModel().createInverse()
			t.concatenate(self.tileA.getModel())

			BlockMatching.matchByMaximalPMCC(
							fp1,
							fp2,
							fpMask1,
							fpMask2,
							self.params.scale,
							t,
							self.params.block_radius,
							self.params.block_radius,
							self.params.search_radius,
							self.params.search_radius,
							self.params.min_R,
							self.params.rod_R,
							self.params.max_curvature_R,
							v1,
							pm12,
							ErrorStatistic(1))

			pre_smooth_block_matches = len(pm12)
			if self.params.save_data:
				pre_smooth_filename = self.params.output_folder + str(self.patchB) + "_" + str(self.patchA) + "_pre_smooth_pts.txt"
				self.writePointsFile(pm12, pre_smooth_filename)

			if self.params.use_local_smoothness_filter:
				model = Util.createModel(self.params.local_model_index)
				try:
					model.localSmoothnessFilter(pm12, pm12, self.params.local_region_sigma, self.params.max_local_epsilon, self.params.max_local_trust)
					if self.params.save_data:
						post_smooth_filename = self.params.output_folder + str(self.patchB) + "_" + str(self.patchA) + "_post_smooth_pts.txt"
						self.writePointsFile(pm12, post_smooth_filename)
				except:
					pm12.clear()

			color_samples, max_displacement = self.matches2ColorSamples(pm12)
			post_smooth_block_matches = len(pm12)
				
			print time.asctime()
			print str(self.patchB) + "_" + str(self.patchA) + "\tblock_matches\t" + str(pre_smooth_block_matches) + "\tsmooth_filtered\t" + str(pre_smooth_block_matches - post_smooth_block_matches) + "\tmax_displacement\t" + str(max_displacement) + "\trelaxed_length\t" + str(self.params.spring_length) + "\tsigma\t" + str(self.params.local_region_sigma)
			IJ.log(time.asctime())
			IJ.log(str(self.patchB) + "_" + str(self.patchA) + ": block_matches " + str(pre_smooth_block_matches) + ", smooth_filtered " + str(pre_smooth_block_matches - post_smooth_block_matches) + ", max_displacement " + str(max_displacement) + ", relaxed_length " + str(self.params.spring_length) + ", sigma " + str(self.params.local_region_sigma))
			if self.params.save_data and self.wf:
				self.wf.write(str(self.patchB) + 
					"\t" + str(self.patchA) + 
					"\t" + str(pre_smooth_block_matches) + 
					"\t" + str(pre_smooth_block_matches - post_smooth_block_matches) + 
					"\t" + str(max_displacement) + 
					"\t" + str(self.params.spring_length) + 
					"\t" + str(self.params.local_region_sigma) + 
					"\t" + str(num_x) + "\n")

			if self.params.export_point_roi:
				pm12Sources = ArrayList()
				pm12Targets = ArrayList()

				PointMatch.sourcePoints(pm12, pm12Sources)
				PointMatch.targetPoints(pm12, pm12Targets)

				roi1 = pointsToPointRoi(pm12Sources)
				roi2 = pointsToPointRoi(pm12Targets)

				# # Adapted from BlockMatching.java
				# # https://github.com/axtimwalde/mpicbg/blob/master/mpicbg/src/main/java/mpicbg/ij/blockmatching/BlockMatching.java
				# tTarget = TranslationModel2D()
				# sTarget = SimilarityModel2D()
				# tTarget.set(-self.params.search_radius, -self.params.search_radius)
				# sTarget.set(1.0/self.params.scale, 0, 0, 0)
				# lTarget = CoordinateTransformList()
				# lTarget.add(sTarget)
				# lTarget.add(tTarget)
				# lTarget.add(t)
				# targetMapping = TransformMapping(lTarget)

				# mappedScaledTarget = FloatProcessor(fp1.getWidth() + 2*search_radius, fp1.getHeight() + 2*search_radius)

				# targetMapping.mapInverseInterpolated(fp2, mappedScaledTarget)
				# imp1 = tileImagePlus("imp1", fp1)
				# imp1.show()				
				# imp2 = ImagePlus("imp2", mappedScaledTarget)
				# imp2.show()

				imp1 = ImagePlus("imp1", fp1)
				imp1.show()				
				imp2 = ImagePlus("imp2", fp2)
				imp2.show()				

				imp1.setRoi(roi1)
				imp2.setRoi(roi2)

			if self.params.export_displacement_vectors:
				pm12Targets = ArrayList()
				PointMatch.targetPoints(pm12, pm12Targets)

				maskSamples2 = RealPointSampleList(2)
				for point in pm12Targets:
					maskSamples2.add(RealPoint(point.getW()), ARGBType(-1))
				factory = ImagePlusImgFactory()
				kdtreeMatches = KDTree(color_samples)
				kdtreeMask = KDTree(maskSamples)
				
				img = factory.create([fp1.getWidth(), fp1.getHeight()], ARGBType())
				self.drawNearestNeighbor(
							img, 
							NearestNeighborSearchOnKDTree(kdtreeMatches),
							NearestNeighborSearchOnKDTree(kdtreeMask))
				scaled_img = self.scaleIntImagePlus(img, 0.03)
				if self.params.save_data:
					fs = FileSaver(scaled_img)
					fs.saveAsTiff(self.params.output_folder + str(self.patchB) + "_" + str(self.patchA) + ".tif")
				else:
					scaled_img.show()
				print time.asctime()
				print str(self.patchB) + "_" + str(self.patchA) + "\tsaved"
				IJ.log(time.asctime())
				IJ.log(str(self.patchB) + "_" + str(self.patchA) + ": saved")
		except Exception, ex:
			self.exception = ex
			print str(ex)
			IJ.log(str(ex))
			if self.params.save_data and self.wf:
				self.wf.write(str(ex) + "\n")
예제 #3
0
	def call(self, save_images=False):
		self.thread_used = threading.currentThread().getName()
		self.started = time.time()
		try:
			filenames = os.listdir(self.params.input_folder)
			filenames.sort()

			imp1 = IJ.openImage(os.path.join(self.params.input_folder, filenames[self.imgA]))
			imp2 = IJ.openImage(os.path.join(self.params.input_folder, filenames[self.imgB]))
			IJ.log(time.asctime())
			IJ.log(str(self.imgA) + ": " + str(imp1))
			IJ.log(str(self.imgB) + ": " + str(imp2))			
			print str(self.imgA) + ": " + str(imp1)
			print str(self.imgB) + ": " + str(imp2)

			mesh_resolution = int(imp1.getWidth() / self.params.point_distance)
			effective_point_distance = imp1.getWidth() / mesh_resolution

			mesh = SpringMesh(mesh_resolution, imp1.getWidth(), imp1.getHeight(), 1, 1000, 0.9)
			vertices = mesh.getVertices()
			maskSamples = RealPointSampleList(2)
			for vertex in vertices:
				maskSamples.add(RealPoint(vertex.getL()), ARGBType(-1)) # equivalent of 0xffffffff
			pm12 = ArrayList()
			v1 = mesh.getVertices()

			ip1 = imp1.getProcessor().convertToFloat().duplicate()
			ip2 = imp2.getProcessor().convertToFloat().duplicate()

			ip1Mask = self.createMask(imp1.getProcessor().convertToRGB())
			ip2Mask = self.createMask(imp2.getProcessor().convertToRGB())

			ct = TranslationModel2D()

			BlockMatching.matchByMaximalPMCC(
							ip1,
							ip2,
							ip1Mask,
							ip2Mask,
							self.params.scale,
							ct,
							self.params.block_radius,
							self.params.block_radius,
							self.params.search_radius,
							self.params.search_radius,
							self.params.min_R,
							self.params.rod_R,
							self.params.max_curvature_R,
							v1,
							pm12,
							ErrorStatistic(1))

			pre_smooth_block_matches = len(pm12)
			if self.params.save_data:
				pre_smooth_filename = self.params.output_folder + str(imp2.getTitle())[:-4] + "_" + str(imp1.getTitle())[:-4] + "_pre_smooth_pts.txt"
				self.writePointsFile(pm12, pre_smooth_filename)

			if self.params.use_local_smoothness_filter:
				model = Util.createModel(self.params.local_model_index)
				try:
					model.localSmoothnessFilter(pm12, pm12, self.params.local_region_sigma, self.params.max_local_epsilon, self.params.max_local_trust)
					if self.params.save_data:
						post_smooth_filename = self.params.output_folder + str(imp2.getTitle())[:-4] + "_" + str(imp1.getTitle())[:-4] + "_post_smooth_pts.txt"
						self.writePointsFile(pm12, post_smooth_filename)
				except:
					pm12.clear()

			color_samples, max_displacement = self.matches2ColorSamples(pm12)
			post_smooth_block_matches = len(pm12)
				
			print time.asctime()
			print str(self.imgB) + "_" + str(self.imgA) + "\tblock_matches\t" + str(pre_smooth_block_matches) + "\tsmooth_filtered\t" + str(pre_smooth_block_matches - post_smooth_block_matches) + "\tmax_displacement\t" + str(max_displacement) + "\trelaxed_length\t" + str(effective_point_distance) + "\tsigma\t" + str(self.params.local_region_sigma)
			IJ.log(time.asctime())
			IJ.log(str(self.imgB) + "_" + str(self.imgA) + ": block_matches " + str(pre_smooth_block_matches) + ", smooth_filtered " + str(pre_smooth_block_matches - post_smooth_block_matches) + ", max_displacement " + str(max_displacement) + ", relaxed_length " + str(effective_point_distance) + ", sigma " + str(self.params.local_region_sigma))
			if self.params.save_data and self.wf:
				self.wf.write(str(self.imgB) + 
					"\t" + str(self.imgA) + 
					"\t" + str(imp2.getTitle())[:-4] + 
					"\t" + str(imp1.getTitle())[:-4] + 
					"\t" + str(pre_smooth_block_matches) + 
					"\t" + str(pre_smooth_block_matches - post_smooth_block_matches) + 
					"\t" + str(max_displacement) + 
					"\t" + str(effective_point_distance) + 
					"\t" + str(self.params.local_region_sigma) + 
					"\t" + str(mesh_resolution) + "\n")

			if self.params.export_point_roi:
				pm12Sources = ArrayList()
				pm12Targets = ArrayList()

				PointMatch.sourcePoints(pm12, pm12Sources)
				PointMatch.targetPoints(pm12, pm12Targets)

				roi1 = pointsToPointRoi(pm12Sources)
				roi2 = pointsToPointRoi(pm12Targets)

				imp1.setRoi(roi1)
				imp2.setRoi(roi2)

			if self.params.export_displacement_vectors:
				pm12Targets = ArrayList()
				PointMatch.targetPoints(pm12, pm12Targets)

				maskSamples2 = RealPointSampleList(2)
				for point in pm12Targets:
					maskSamples2.add(RealPoint(point.getW()), ARGBType(-1))
				factory = ImagePlusImgFactory()
				kdtreeMatches = KDTree(color_samples)
				kdtreeMask = KDTree(maskSamples)
				
				img = factory.create([imp1.getWidth(), imp1.getHeight()], ARGBType())
				self.drawNearestNeighbor(
							img, 
							NearestNeighborSearchOnKDTree(kdtreeMatches),
							NearestNeighborSearchOnKDTree(kdtreeMask))
				scaled_img = self.scaleIntImagePlus(img, 0.1)
				if self.params.save_data:
					fs = FileSaver(scaled_img)
					fs.saveAsTiff(self.params.output_folder + str(imp2.getTitle())[:-4] + "_" + str(imp1.getTitle()))
				else:
					scaled_img.show()
				print time.asctime()
				print str(self.imgB) + "_" + str(self.imgA) + "\tsaved\t" + filenames[self.imgB]
				IJ.log(time.asctime())
				IJ.log(str(self.imgB) + "_" + str(self.imgA) + ": saved " + filenames[self.imgB])
		except Exception, ex:
			self.exception = ex
			print str(ex)
			IJ.log(str(ex))
			if self.params.save_data and self.wf:
				self.wf.write(str(ex) + "\n")
def extractBlockMatches(filepath1,
                        filepath2,
                        params,
                        csvDir,
                        exeload,
                        load=loadFPMem):
    """
  filepath1: the file path to an image of a section.
  filepath2: the file path to an image of another section.
  params: dictionary of parameters necessary for BlockMatching.
  exeload: an ExecutorService for parallel loading of image files.
  load: a function that knows how to load the image from the filepath.

  return False if the CSV file already exists, True if it has to be computed.
  """

    # Skip if pointmatches CSV file exists already:
    csvpath = os.path.join(
        csvDir,
        basename(filepath1) + '.' + basename(filepath2) + ".pointmatches.csv")
    if os.path.exists(csvpath):
        return False

    try:

        # Load files in parallel
        futures = [
            exeload.submit(Task(load, filepath1)),
            exeload.submit(Task(load, filepath2))
        ]

        # Define points from the mesh
        sourcePoints = ArrayList()
        mesh = TransformMesh(params["meshResolution"], dimensions[0],
                             dimensions[1])
        PointMatch.sourcePoints(mesh.getVA().keySet(), sourcePoints)
        # List to fill
        sourceMatches = ArrayList(
        )  # of PointMatch from filepath1 to filepath2

        syncPrint("Extracting block matches for \n S: " + filepath1 +
                  "\n T: " + filepath2 + "\n  with " +
                  str(sourcePoints.size()) + " mesh sourcePoints.")

        BlockMatching.matchByMaximalPMCCFromPreScaledImages(
            futures[0].get(),  # FloatProcessor
            futures[1].get(),  # FloatProcessor
            params["scale"],  # float
            params["blockRadius"],  # X
            params["blockRadius"],  # Y
            params["searchRadius"],  # X
            params["searchRadius"],  # Y
            params["minR"],  # float
            params["rod"],  # float
            params["maxCurvature"],  # float
            sourcePoints,
            sourceMatches)

        # At least some should match to accept the translation
        if len(sourceMatches) < max(20, len(sourcePoints) / 5) / 2:
            syncPrint(
                "Found only %i blockmatching pointmatches (from %i source points)"
                % (len(sourceMatches), len(sourcePoints)))
            syncPrint("... therefore invoking SIFT pointmatching for:\n  S: " +
                      basename(filepath1) + "\n  T: " + basename(filepath2))
            # Can fail if there is a shift larger than the searchRadius
            # Try SIFT features, which are location independent
            #
            # Images are now scaled: load originals
            futures = [
                exeload.submit(Task(loadFloatProcessor, filepath1,
                                    scale=False)),
                exeload.submit(Task(loadFloatProcessor, filepath2,
                                    scale=False))
            ]
            ijSIFT = SIFT(FloatArray2DSIFT(paramsSIFT))
            features1 = ArrayList()  # of Point instances
            ijSIFT.extractFeatures(futures[0].get(), features1)
            features2 = ArrayList()  # of Point instances
            ijSIFT.extractFeatures(futures[1].get(), features2)
            # Vector of PointMatch instances
            sourceMatches = FloatArray2DSIFT.createMatches(
                features1,
                features2,
                1.5,  # max_sd
                TranslationModel2D(),
                Double.MAX_VALUE,
                params["rod"])  # rod: ratio of best vs second best

        # Store pointmatches
        savePointMatches(os.path.basename(filepath1),
                         os.path.basename(filepath2), sourceMatches, csvDir,
                         params)

        return True
    except:
        syncPrint(sys.exc_info())
        syncPrint("".join(traceback.format_exception()), out="stderr")