def check_img_sizes(self): """ Check if images have the same size """ # - Return false if no images are stored if not self.img_data: return False # - Compare image sizes across different channels same_size = True nx_tmp = 0 ny_tmp = 0 for i in range(len(self.img_data)): imgsize = np.shape(self.img_data) nx = imgsize[1] ny = imgsize[0] if i == 0: nx_tmp = nx ny_tmp = ny else: if (nx != nx_tmp or ny != ny_tmp): logger.debug( "Image %s has different size (%d,%d) wrt to previous images (%d,%d)!" % (self.filepaths[i], nx, ny, nx_tmp, ny_tmp)) same_size = False return same_size
def __read_sdata(self, index): """ Read source data """ # - Check index if index<0 or index>=self.datasize: logger.error("Invalid index %d given!" % index) return None # - Init sdata sdata= SData() sdata.refch= self.refch sdata.kernsize= self.kernsize sdata.draw= self.draw sdata.save_ssim_pars= self.save_ssim_pars sdata.negative_pix_fract_thr= self.negative_pix_fract_thr sdata.bad_pix_fract_thr= self.bad_pix_fract_thr sdata_mask= SData() sdata_mask.refch= self.refch sdata_mask.kernsize= self.kernsize sdata_mask.draw= self.draw sdata_mask.save_ssim_pars= self.save_ssim_pars sdata_mask.negative_pix_fract_thr= self.negative_pix_fract_thr sdata_mask.bad_pix_fract_thr= self.bad_pix_fract_thr # - Read source image data logger.debug("Reading source image data %d ..." % index) #d= self.datalist["data"][index] d= self.datalist[index] if sdata.set_from_dict(d)<0: logger.error("Failed to set source image data %d!" % index) return None if sdata.read_imgs()<0: logger.error("Failed to read source images %d!" % index) return None # - Read source masked image data logger.debug("Reading source masked image data %d ..." % index) #d= self.datalist_mask["data"][index] d= self.datalist_mask[index] if sdata_mask.set_from_dict(d)<0: logger.error("Failed to set source masked image data %d!" % index) return None if sdata_mask.read_imgs()<0: logger.error("Failed to read source masked images %d!" % index) return None return sdata, sdata_mask
def run_from_datalist(self, datalist, img_group_1, img_group_2): """ Run spectral index calculation passing data dict lists as inputs """ # - Check input data if not datalist: logger.error("Empty data dict list given!") return -1 if not img_group_1 or not img_group_2: logger.error("Empty image group index given!") return -1 if len(img_group_1) != len(img_group_2): logger.error( "Given image group index list have different lengths!") return -1 # - Set data info logger.debug("Setting data info ...") self.datalist = datalist self.datasize = len(self.datalist) self.labels = [item["label"] for item in self.datalist] self.snames = [item["sname"] for item in self.datalist] self.classids = [item["id"] for item in self.datalist] # - Check number of channels per image nchannels_set = set([len(item["filepaths"]) for item in self.datalist]) if len(nchannels_set) != 1: logger.warn( "Number of channels in each object instance is different (len(nchannels_set)=%d!=1)!" % (len(nchannels_set))) print(nchannels_set) return -1 self.nchannels = list(nchannels_set)[0] # - Loop over data and extract params per each source logger.info("Loop over data and extract params per each source ...") for i in range(self.datasize): if self.__process_source(i, img_group_1, img_group_2) < 0: logger.warn("Failed to process source %d, skip to next ..." % (i)) continue # - Save data if self.save: logger.info("Saving data to file %s ..." % (self.outfile)) self.__save_data() return 0
def read_imgs(self): """ Read image data from paths """ # - Check data filelists if not self.filepaths: logger.error("Empty filelists given!") return -1 # - Read images nimgs= len(self.filepaths) self.nchannels= nimgs for filename in self.filepaths: # - Read image logger.debug("Reading file %s ..." % filename) data= None try: data, header= self.__read_fits(filename) except Exception as e: logger.error("Failed to read image data from file %s (err=%s)!" % (filename,str(e))) return -1 # - Compute data mask # NB: =1 good values, =0 bad (pix=0 or pix=inf or pix=nan) data_mask= np.logical_and(data!=0, np.isfinite(data)).astype(np.uint8) # - Check image integrity #has_bad_pixs= self.__has_bad_pixels(data, check_fract=False, thr=0) #if has_bad_pixs: # logger.warn("Image %s has too many bad pixels (f=%f>%f)!" % (filename,f_badpix,self.f_badpix_thr) ) # return -1 # - Append image channel data to list self.img_data.append(data) self.img_heads.append(header) self.img_data_mask.append(data_mask) # - Check image sizes if not self.check_img_sizes(): logger.error("Image channels for source %s do not have the same size, check your dataset!" % self.sname) return -1 # - Set data shapes self.nx= self.img_data[0].shape[1] self.ny= self.img_data[0].shape[0] self.nchannels= len(self.img_data) return 0
def run_from_datalist(self, datalist, datalist_mask): """ Run moment calculation passing data dict lists as inputs """ # - Set data info logger.debug("Setting data info ...") self.datalist= datalist self.datalist_mask= datalist_mask self.datasize= len(self.datalist) self.labels= [item["label"] for item in self.datalist] self.snames= [item["sname"] for item in self.datalist] self.classids= [item["id"] for item in self.datalist] self.classfract_map= dict(Counter(self.classids).items()) datasize_mask= len(self.datalist_mask) # - Check number of channels per image nchannels_set= set([len(item["filepaths"]) for item in self.datalist]) if len(nchannels_set)!=1: logger.warn("Number of channels in each object instance is different (len(nchannels_set)=%d!=1)!" % (len(nchannels_set))) print(nchannels_set) return -1 self.nchannels= list(nchannels_set)[0] # - Check data size for imgs and masks if self.datasize!=datasize_mask: logger.error("Img and mask datalist have different size!") return -1 # - Loop over data and extract params per each source logger.info("Loop over data and extract params per each source ...") for i in range(self.datasize): if self.__process_sdata(i)<0: logger.error("Failed to read and process source data %d, skip to next..." % (i)) continue # - Save data if self.save: logger.info("Saving data to file %s ..." % (self.outfile)) self.__save_data() return 0
def __read_imgs(self): """ Read image data from paths """ # - Check data filelists if not self.filepaths: logger.error("Empty filelists given!") return -1 # - Read images nimgs = len(self.filepaths) self.nchannels = nimgs has_freq_data = True for filename in self.filepaths: # - Read image logger.debug("Reading file %s ..." % (filename)) data = None try: data, header, wcs = Utils.read_fits(filename) except Exception as e: logger.error( "Failed to read image data from file %s (err=%s)!" % (filename, str(e))) return -1 # - Compute data mask # NB: =1 good values, =0 bad (pix=0 or pix=inf or pix=nan) data_mask = np.logical_and(data != 0, np.isfinite(data)).astype(np.uint8) # - Extract frequency information from header has_freq_in_header = False freq = -999 if 'CRVAL3' in header and 'CTYPE3' in header: axis_type = header['CTYPE'] if axis_type == "FREQ": freq = header['CRVAL3'] has_freq_in_header = True else: has_freq_data = False else: has_freq_data = False # - Append image channel data to list self.img_data.append(data) self.img_heads.append(header) self.img_data_mask.append(data_mask) self.img_freqs_head.append(freq) # - Reset freq data if one of the channel has no data if not has_freq_data: self.img_freqs_head = [] # - Check image sizes if not self.check_img_sizes(): logger.error( "Image channels for source %s do not have the same size, check your dataset!" % self.sname) return -1 # - Set data shapes self.nx = self.img_data[0].shape[1] self.ny = self.img_data[0].shape[0] self.nchannels = len(self.img_data) return 0
def __extract_sources(self, data, bkg, rms, mask=None, seed_thr=4, merge_thr=3, dist_thr=-1): """ Find sources in channel data """ # - Compute image center data_shape= data.shape y_c= data_shape[0]/2.; x_c= data_shape[1]/2.; # - Compute mask if mask is None: logger.info("Computing image mask ...") mask= np.logical_and(data!=0, np.isfinite(data)) data_1d= data[mask] # - Threshold image at seed_thr zmap= (data-bkg)/rms binary_map= (zmap>merge_thr).astype(np.int32) binary_map[~mask]= 0 zmap[~mask]= 0 # - Extract source logger.info("Extracting sources ...") label_map= skimage.measure.label(binary_map) regprops= skimage.measure.regionprops(label_map, data) nsources= len(regprops) logger.info("#%d sources found ..." % nsources) # - Extract peaks kernsize= 3 footprint = np.ones((kernsize, ) * data.ndim, dtype=bool) peaks= peak_local_max(np.copy(zmap), footprint=footprint, threshold_abs=seed_thr, min_distance=2, exclude_border=True) #print(peaks) if peaks.shape[0]<=0: logger.info("No peaks detected in this image, return None ...") return None # - Select best source regprops_sel= [] peaks_sel= [] binary_maps_sel= [] polygons_sel= [] contours_sel= [] #binary_maps_sel= [] #binary_map_sel= np.zeros_like(binary_map) for regprop in regprops: # - Check if region max is >=seed_thr sslice= regprop.slice zmask= zmap[sslice] zmask_1d= zmask[np.logical_and(zmask!=0, np.isfinite(zmask))] zmax= zmask_1d.max() if zmax<seed_thr: logger.info("Skip source as zmax=%f<thr=%f" % (zmax, seed_thr)) continue # - Set binary map with this source logger.debug("Get source binary mask ...") bmap= np.zeros_like(binary_map) bmap[sslice]= binary_map[sslice] # - Extract contour and polygon from binary mask logger.info("Extracting contour and polygon from binary mask ...") contours= [] polygon= None try: bmap_uint8= bmap.copy() # copy as OpenCV internally modify origin mask bmap_uint8= bmap_uint8.astype(np.uint8) contours= cv2.findContours(bmap_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) contours= imutils.grab_contours(contours) if len(contours)>0: contour= np.squeeze(contours[0]) polygon = Polygon(contour) except Exception as e: logger.warn("Failed to compute mask contour (err=%s)!" % (str(e))) if polygon is None: logger.warn("Skip extracted blob as polygon failed to be computed...") continue # - Check if source has a local peak in the mask # NB: Check if polygon is computed has_peak= False peak_sel= None if polygon is not None: for peak in peaks: point = Point(peak[1], peak[0]) has_peak= polygon.contains(point) if has_peak: peak_sel= peak break if not has_peak: logger.info("Skip extracted blob as no peak was found inside source contour polygon!") continue # - Check for source peak distance wrt image center if dist_thr>0: dist= np.sqrt( (peak_sel[1]-x_c)**2 + (peak_sel[0]-y_c)**2 ) if dist>dist_thr: logger.info("Skip extracted source as peak-imcenter dist=%f<thr=%f" % (dist, dist_thr)) continue # - Update global binary mask and regprops #binary_map_sel[sslice]= binary_map[sslice] regprops_sel.append(regprop) peaks_sel.append(peak_sel) binary_maps_sel.append(bmap) polygons_sel.append(polygon) contours_sel.append(contours[0]) # - Return None if no source is selected nsources_sel= len(regprops_sel) if nsources_sel<=0: logger.info("No sources selected for this image ...") return None # - If more than 1 source is selected, take the one with peak closer to image center peak_final= peaks_sel[0] bmap_final= binary_maps_sel[0] regprop_final= regprops_sel[0] polygon_final= polygons_sel[0] contour_final= contours_sel[0] if nsources_sel>1: logger.info("#%d sources selected, going to select the closest to image center ..." % (nsources_sel)) dist_best= 1.e+99 index_best= -1 for j in range(len(peaks_sel)): peak= peaks_sel[j] bmap= binary_maps_sel[j] regprop= regprops_sel[j] polygon= polygons_sel[j] contour= contours_sel[j] dist= np.sqrt( (peak[1]-x_c)**2 + (peak[0]-y_c)**2 ) if dist<dist_best: dist_best= dist peak_final= peak bmap_final= bmap regprop_final= regprop polygon_final= polygon contour_final= contour index_best= j logger.info("Selected source no. %d as the closest one to image center ..." % (index_best)) else: logger.info("#%d sources selected..." % (nsources_sel)) # - Compute enclosing circle radius try: (xc, yc), radius= cv2.minEnclosingCircle(contour_final) enclosing_circle= (xc,yc,radius) except Exception as e: logger.warn("Failed to compute min enclosing circle (err=%s)!" % (str(e))) enclosing_circle= None # - Draw figure if self.draw: fig, ax = plt.subplots() # - Draw map #plt.imshow(label_map) #plt.imshow(data) plt.imshow(zmap) #plt.imshow(bmap_final) plt.colorbar() # - Draw bbox rectangle bbox= regprop_final.bbox ymin= bbox[0] ymax= bbox[2] xmin= bbox[1] xmax= bbox[3] dx= xmax-xmin-1 dy= ymax-ymin-1 rect = patches.Rectangle((xmin,ymin), dx, dy, linewidth=1, edgecolor='r', facecolor='none') ax.add_patch(rect) # - Draw selected peak if peak_final is not None: plt.scatter(peak_final[1], peak_final[0], s=10) # - Draw contour polygon if polygon_final is not None: plt.plot(*polygon_final.exterior.xy) # - Draw enclosing circle circle = plt.Circle((xc, yc), radius, color='g', clip_on=False, fill=False) ax.add_patch(circle) plt.show() return (peak_final, bmap_final, regprop_final, enclosing_circle)