Ejemplo n.º 1
0
    def check_img_sizes(self):
        """ Check if images have the same size """

        # - Return false if no images are stored
        if not self.img_data:
            return False

        # - Compare image sizes across different channels
        same_size = True
        nx_tmp = 0
        ny_tmp = 0
        for i in range(len(self.img_data)):
            imgsize = np.shape(self.img_data)
            nx = imgsize[1]
            ny = imgsize[0]
            if i == 0:
                nx_tmp = nx
                ny_tmp = ny
            else:
                if (nx != nx_tmp or ny != ny_tmp):
                    logger.debug(
                        "Image %s has different size (%d,%d) wrt to previous images (%d,%d)!"
                        % (self.filepaths[i], nx, ny, nx_tmp, ny_tmp))
                    same_size = False

        return same_size
Ejemplo n.º 2
0
	def __read_sdata(self, index):
		""" Read source data """

		# - Check index
		if index<0 or index>=self.datasize:
			logger.error("Invalid index %d given!" % index)
			return None

		# - Init sdata
		sdata= SData()
		sdata.refch= self.refch
		sdata.kernsize= self.kernsize
		sdata.draw= self.draw
		sdata.save_ssim_pars= self.save_ssim_pars
		sdata.negative_pix_fract_thr= self.negative_pix_fract_thr
		sdata.bad_pix_fract_thr= self.bad_pix_fract_thr
		
		sdata_mask= SData()
		sdata_mask.refch= self.refch
		sdata_mask.kernsize= self.kernsize
		sdata_mask.draw= self.draw
		sdata_mask.save_ssim_pars= self.save_ssim_pars
		sdata_mask.negative_pix_fract_thr= self.negative_pix_fract_thr
		sdata_mask.bad_pix_fract_thr= self.bad_pix_fract_thr

		# - Read source image data
		logger.debug("Reading source image data %d ..." % index)
		#d= self.datalist["data"][index]
		d= self.datalist[index]
		if sdata.set_from_dict(d)<0:
			logger.error("Failed to set source image data %d!" % index)
			return None

		if sdata.read_imgs()<0:
			logger.error("Failed to read source images %d!" % index)
			return None

		# - Read source masked image data
		logger.debug("Reading source masked image data %d ..." % index)
		#d= self.datalist_mask["data"][index]
		d= self.datalist_mask[index]
		
		if sdata_mask.set_from_dict(d)<0:
			logger.error("Failed to set source masked image data %d!" % index)
			return None

		if sdata_mask.read_imgs()<0:
			logger.error("Failed to read source masked images %d!" % index)
			return None

		return sdata, sdata_mask
Ejemplo n.º 3
0
    def run_from_datalist(self, datalist, img_group_1, img_group_2):
        """ Run spectral index calculation passing data dict lists as inputs """

        # - Check input data
        if not datalist:
            logger.error("Empty data dict list given!")
            return -1

        if not img_group_1 or not img_group_2:
            logger.error("Empty image group index given!")
            return -1

        if len(img_group_1) != len(img_group_2):
            logger.error(
                "Given image group index list have different lengths!")
            return -1

        # - Set data info
        logger.debug("Setting data info ...")
        self.datalist = datalist
        self.datasize = len(self.datalist)
        self.labels = [item["label"] for item in self.datalist]
        self.snames = [item["sname"] for item in self.datalist]
        self.classids = [item["id"] for item in self.datalist]

        # - Check number of channels per image
        nchannels_set = set([len(item["filepaths"]) for item in self.datalist])
        if len(nchannels_set) != 1:
            logger.warn(
                "Number of channels in each object instance is different (len(nchannels_set)=%d!=1)!"
                % (len(nchannels_set)))
            print(nchannels_set)
            return -1

        self.nchannels = list(nchannels_set)[0]

        # - Loop over data and extract params per each source
        logger.info("Loop over data and extract params per each source ...")
        for i in range(self.datasize):
            if self.__process_source(i, img_group_1, img_group_2) < 0:
                logger.warn("Failed to process source %d, skip to next ..." %
                            (i))
                continue

        # - Save data
        if self.save:
            logger.info("Saving data to file %s ..." % (self.outfile))
            self.__save_data()

        return 0
Ejemplo n.º 4
0
	def read_imgs(self):
		""" Read image data from paths """

		# - Check data filelists
		if not self.filepaths:
			logger.error("Empty filelists given!")
			return -1

		# - Read images
		nimgs= len(self.filepaths)
		self.nchannels= nimgs

		for filename in self.filepaths:
			# - Read image
			logger.debug("Reading file %s ..." % filename) 
			data= None
			try:
				data, header= self.__read_fits(filename)
			except Exception as e:
				logger.error("Failed to read image data from file %s (err=%s)!" % (filename,str(e)))
				return -1

			# - Compute data mask
			#   NB: =1 good values, =0 bad (pix=0 or pix=inf or pix=nan)
			data_mask= np.logical_and(data!=0, np.isfinite(data)).astype(np.uint8)
		
			# - Check image integrity
			#has_bad_pixs= self.__has_bad_pixels(data, check_fract=False, thr=0)
			#if has_bad_pixs:
			#	logger.warn("Image %s has too many bad pixels (f=%f>%f)!" % (filename,f_badpix,self.f_badpix_thr) )	
			#	return -1

			# - Append image channel data to list
			self.img_data.append(data)
			self.img_heads.append(header)
			self.img_data_mask.append(data_mask)
		
		# - Check image sizes
		if not self.check_img_sizes():
			logger.error("Image channels for source %s do not have the same size, check your dataset!" % self.sname)
			return -1

		# - Set data shapes
		self.nx= self.img_data[0].shape[1]
		self.ny= self.img_data[0].shape[0]
		self.nchannels= len(self.img_data)

		return 0
Ejemplo n.º 5
0
	def run_from_datalist(self, datalist, datalist_mask):
		""" Run moment calculation passing data dict lists as inputs """

		# - Set data info
		logger.debug("Setting data info ...")
		self.datalist= datalist
		self.datalist_mask= datalist_mask

		self.datasize= len(self.datalist)
		self.labels= [item["label"] for item in self.datalist]
		self.snames= [item["sname"] for item in self.datalist]
		self.classids= 	[item["id"] for item in self.datalist]
		self.classfract_map= dict(Counter(self.classids).items())
		datasize_mask= len(self.datalist_mask)

		# - Check number of channels per image
		nchannels_set= set([len(item["filepaths"]) for item in self.datalist])
		if len(nchannels_set)!=1:
			logger.warn("Number of channels in each object instance is different (len(nchannels_set)=%d!=1)!" % (len(nchannels_set)))
			print(nchannels_set)
			return -1

		self.nchannels= list(nchannels_set)[0]
		
		# - Check data size for imgs and masks
		if self.datasize!=datasize_mask:
			logger.error("Img and mask datalist have different size!")
			return -1

		# - Loop over data and extract params per each source
		logger.info("Loop over data and extract params per each source ...")
		for i in range(self.datasize):
			if self.__process_sdata(i)<0:
				logger.error("Failed to read and process source data %d, skip to next..." % (i))
				continue
			
		# - Save data
		if self.save:
			logger.info("Saving data to file %s ..." % (self.outfile))
			self.__save_data()

		return 0
Ejemplo n.º 6
0
    def __read_imgs(self):
        """ Read image data from paths """

        # - Check data filelists
        if not self.filepaths:
            logger.error("Empty filelists given!")
            return -1

        # - Read images
        nimgs = len(self.filepaths)
        self.nchannels = nimgs
        has_freq_data = True

        for filename in self.filepaths:
            # - Read image
            logger.debug("Reading file %s ..." % (filename))
            data = None
            try:
                data, header, wcs = Utils.read_fits(filename)
            except Exception as e:
                logger.error(
                    "Failed to read image data from file %s (err=%s)!" %
                    (filename, str(e)))
                return -1

            # - Compute data mask
            #   NB: =1 good values, =0 bad (pix=0 or pix=inf or pix=nan)
            data_mask = np.logical_and(data != 0,
                                       np.isfinite(data)).astype(np.uint8)

            # - Extract frequency information from header
            has_freq_in_header = False
            freq = -999
            if 'CRVAL3' in header and 'CTYPE3' in header:
                axis_type = header['CTYPE']
                if axis_type == "FREQ":
                    freq = header['CRVAL3']
                    has_freq_in_header = True
                else:
                    has_freq_data = False
            else:
                has_freq_data = False

            # - Append image channel data to list
            self.img_data.append(data)
            self.img_heads.append(header)
            self.img_data_mask.append(data_mask)
            self.img_freqs_head.append(freq)

        # - Reset freq data if one of the channel has no data
        if not has_freq_data:
            self.img_freqs_head = []

        # - Check image sizes
        if not self.check_img_sizes():
            logger.error(
                "Image channels for source %s do not have the same size, check your dataset!"
                % self.sname)
            return -1

        # - Set data shapes
        self.nx = self.img_data[0].shape[1]
        self.ny = self.img_data[0].shape[0]
        self.nchannels = len(self.img_data)

        return 0
Ejemplo n.º 7
0
	def __extract_sources(self, data, bkg, rms, mask=None, seed_thr=4, merge_thr=3, dist_thr=-1):
		""" Find sources in channel data """
	
		# - Compute image center
		data_shape= data.shape
		y_c= data_shape[0]/2.;
		x_c= data_shape[1]/2.;

		# - Compute mask
		if mask is None:
			logger.info("Computing image mask ...")
			mask= np.logical_and(data!=0, np.isfinite(data))	

		data_1d= data[mask]
	
		# - Threshold image at seed_thr
		zmap= (data-bkg)/rms
		binary_map= (zmap>merge_thr).astype(np.int32)
		binary_map[~mask]= 0
		zmap[~mask]= 0
	
		# - Extract source
		logger.info("Extracting sources ...")
		label_map= skimage.measure.label(binary_map)
		regprops= skimage.measure.regionprops(label_map, data)

		nsources= len(regprops)
		logger.info("#%d sources found ..." % nsources)

		# - Extract peaks
		kernsize= 3
		footprint = np.ones((kernsize, ) * data.ndim, dtype=bool)
		peaks= peak_local_max(np.copy(zmap), footprint=footprint, threshold_abs=seed_thr, min_distance=2, exclude_border=True)
		#print(peaks)
		
		if peaks.shape[0]<=0:
			logger.info("No peaks detected in this image, return None ...")
			return None
		
		# - Select best source
		regprops_sel= []
		peaks_sel= []
		binary_maps_sel= []
		polygons_sel= []
		contours_sel= []
		#binary_maps_sel= []
		#binary_map_sel= np.zeros_like(binary_map)

		for regprop in regprops:
			# - Check if region max is >=seed_thr
			sslice= regprop.slice
			zmask= zmap[sslice]
			zmask_1d= zmask[np.logical_and(zmask!=0, np.isfinite(zmask))]	
			zmax= zmask_1d.max()
			if zmax<seed_thr:
				logger.info("Skip source as zmax=%f<thr=%f" % (zmax, seed_thr))
				continue

			# - Set binary map with this source
			logger.debug("Get source binary mask  ...")
			bmap= np.zeros_like(binary_map)
			bmap[sslice]= binary_map[sslice]

			# - Extract contour and polygon from binary mask
			logger.info("Extracting contour and polygon from binary mask ...")
			contours= []
			polygon= None
			try:
				bmap_uint8= bmap.copy() # copy as OpenCV internally modify origin mask
				bmap_uint8= bmap_uint8.astype(np.uint8)
				contours= cv2.findContours(bmap_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
				contours= imutils.grab_contours(contours)
				if len(contours)>0:
					contour= np.squeeze(contours[0])
					polygon = Polygon(contour)
			except Exception as e:
				logger.warn("Failed to compute mask contour (err=%s)!" % (str(e)))
			
			if polygon is None:
				logger.warn("Skip extracted blob as polygon failed to be computed...")
				continue

			# - Check if source has a local peak in the mask
			#   NB: Check if polygon is computed
			has_peak= False
			peak_sel= None
			if polygon is not None:
				for peak in peaks:
					point = Point(peak[1], peak[0])
					has_peak= polygon.contains(point)
					if has_peak:
						peak_sel= peak
						break
				
			if not has_peak: 
				logger.info("Skip extracted blob as no peak was found inside source contour polygon!")
				continue

			# - Check for source peak distance wrt image center
			if dist_thr>0:
				dist= np.sqrt( (peak_sel[1]-x_c)**2 + (peak_sel[0]-y_c)**2 )
				if dist>dist_thr:
					logger.info("Skip extracted source as peak-imcenter dist=%f<thr=%f" % (dist, dist_thr))
					continue

			# - Update global binary mask and regprops
			#binary_map_sel[sslice]= binary_map[sslice]
			regprops_sel.append(regprop)
			peaks_sel.append(peak_sel)
			binary_maps_sel.append(bmap)	
			polygons_sel.append(polygon)
			contours_sel.append(contours[0])
			
		# - Return None if no source is selected
		nsources_sel= len(regprops_sel)
		if nsources_sel<=0:
			logger.info("No sources selected for this image ...")
			return None

		# - If more than 1 source is selected, take the one with peak closer to image center
		peak_final= peaks_sel[0]
		bmap_final= binary_maps_sel[0]
		regprop_final= regprops_sel[0]
		polygon_final= polygons_sel[0]
		contour_final= contours_sel[0]

		if nsources_sel>1:
			logger.info("#%d sources selected, going to select the closest to image center ..." % (nsources_sel))
			
			dist_best= 1.e+99
			index_best= -1
			for j in range(len(peaks_sel)):
				peak= peaks_sel[j]
				bmap= binary_maps_sel[j]
				regprop= regprops_sel[j]
				polygon= polygons_sel[j]
				contour= contours_sel[j]
				dist= np.sqrt( (peak[1]-x_c)**2 + (peak[0]-y_c)**2 )
				if dist<dist_best:
					dist_best= dist
					peak_final= peak
					bmap_final= bmap
					regprop_final= regprop
					polygon_final= polygon
					contour_final= contour
					index_best= j
			
			logger.info("Selected source no. %d as the closest one to image center ..." % (index_best))							
		else:
			logger.info("#%d sources selected..." % (nsources_sel))
			
		# - Compute enclosing circle radius 
		try:
			(xc, yc), radius= cv2.minEnclosingCircle(contour_final)
			enclosing_circle= (xc,yc,radius)
		except Exception as e:
			logger.warn("Failed to compute min enclosing circle (err=%s)!" % (str(e)))
			enclosing_circle= None

		# - Draw figure
		if self.draw:
			fig, ax = plt.subplots()

			# - Draw map
			#plt.imshow(label_map)
			#plt.imshow(data)
			plt.imshow(zmap)
			#plt.imshow(bmap_final)
			plt.colorbar()

			# - Draw bbox rectangle
			bbox= regprop_final.bbox
			ymin= bbox[0]
			ymax= bbox[2]
			xmin= bbox[1]
			xmax= bbox[3]
			dx= xmax-xmin-1
			dy= ymax-ymin-1
			rect = patches.Rectangle((xmin,ymin), dx, dy, linewidth=1, edgecolor='r', facecolor='none')
			ax.add_patch(rect)

			# - Draw selected peak
			if peak_final is not None:
				plt.scatter(peak_final[1], peak_final[0], s=10)

			# - Draw contour polygon
			if polygon_final is not None:
				plt.plot(*polygon_final.exterior.xy)

			# - Draw enclosing circle
			circle = plt.Circle((xc, yc), radius, color='g', clip_on=False, fill=False)
			ax.add_patch(circle)

			plt.show()


		return (peak_final, bmap_final, regprop_final, enclosing_circle)