def __compute_spectral_index(self, img_group_1, img_group_2):
        """ Compute spectral index alpha """

        # - Check first if frequency data are available
        print("self.img_freqs")
        print(self.img_freqs)
        print("len(self.img_freqs)")
        print(len(self.img_freqs))
        print("len(self.img_data)")
        print(len(self.img_data))
        print("img_group_1")
        print(img_group_1)
        print("img_group_2")
        print(img_group_2)

        freqs = []
        if self.img_freqs and len(self.img_freqs) == len(self.img_data):
            freqs = self.img_freqs
        else:
            if self.img_freqs_head and len(self.img_freqs_head) == len(
                    self.img_data):
                freqs = self.img_freqs_head
            else:
                logger.error("No frequency data given (user/header)!")
                return -1

        # - Check group indexes
        if len(img_group_1) != len(img_group_2):
            logger.error("Group indexes do not have the same length!")
            return -1

        # - Check group indices are within available channels
        for i in range(len(img_group_1)):
            index = img_group_1[i]
            if index < 0 or index >= self.nchannels:
                logger.error(
                    "Invalid index (%d) in group 1, must be in range [0,%d]!" %
                    (index, self.nchannels - 1))
                return -1

        for i in range(len(img_group_2)):
            index = img_group_2[i]
            if index < 0 or index >= self.nchannels:
                logger.error(
                    "Invalid index (%d) in group 2, must be in range [0,%d]!" %
                    (index, self.nchannels - 1))
                return -1

        # - Loop over img combinations and compute spectral indices
        logger.info("Computing spectral index (#%d combinations) ..." %
                    (len(img_group_1)))
        alphas = []
        rcoeffs = []

        smask = self.img_data_mask[self.refch]

        for i in range(len(img_group_1)):
            index_1 = img_group_1[i]
            index_2 = img_group_2[i]
            data_1 = self.img_data[index_1]
            data_2 = self.img_data[index_2]

            # - Find frequency from header
            nu1 = freqs[index_1]
            nu2 = freqs[index_2]
            #alpha12, alpha21= compute_alpha(data_1, data_2, nu1, nu2, smask, draw_plots)
            #alpha= 0.5*(alpha12+alpha21)
            outtuple = self.__compute_alpha(data_1, data_2, nu1, nu2, smask)
            if outtuple is None:
                logger.warn(
                    "alpha calculation failed for map combination %d-%d, skip to next ..."
                    % (index_1, index_2))
                continue

            alpha = outtuple[0]
            r = outtuple[1]
            alphas.append(alpha)
            rcoeffs.append(r)

        logger.info("Computing average spectral index ...")
        print(alphas)

        alphas = np.array(alphas)
        alphas_safe = alphas[np.isfinite(alphas)]
        alphas = alphas_safe
        if alphas.size == 0:
            logger.warn(
                "No alpha measurement left (all nans), will set alpha values to -999 ..."
            )
            alpha_mean = -999
            alpha_median = -999
            alpha_min = -999
            alpha_max = -999
        else:
            alpha_mean = np.mean(alphas)
            alpha_median = np.median(alphas)
            alpha_min = np.min(alphas)
            alpha_max = np.max(alphas)

        rcoeffs = np.array(rcoeffs)
        rcoeffs_safe = rcoeffs[np.isfinite(rcoeffs)]
        rcoeffs = rcoeffs_safe
        if rcoeffs.size == 0:
            logger.warn(
                "No rcoeffs measurement left (all nans), will set alpha values to -999 ..."
            )
            rcoeff_mean = -999
            rcoeff_median = -999
            rcoeff_min = -999
            rcoeff_max = -999
        else:
            rcoeff_mean = np.mean(rcoeffs)
            rcoeff_median = np.median(rcoeffs)
            rcoeff_min = np.min(rcoeffs)
            rcoeff_max = np.max(rcoeffs)

        # - Set spectral index
        self.alpha = alpha_mean
        self.rcoeff = rcoeff_mean
        if self.alpha != -999 and self.rcoeff >= self.rcoeff_thr:
            self.has_good_alpha = True
        else:
            self.has_good_alpha = False

        return 0
Esempio n. 2
0
    def make_masked_cutouts(self,
                            region_sky,
                            dilatemask=False,
                            kernsize=5,
                            maskval=0):
        """ Produce masked cutouts """

        # - Find cutout files produced
        logger.info("Searching for produced cutouts for source %s ..." %
                    (self.sname))
        cutout_dir = os.path.join(self.datadir, self.sname)
        file_pattern = os.path.join(cutout_dir, "*.fits")
        files = glob.glob(file_pattern)

        nfiles = len(files)
        if nfiles == 0 or nfiles != self.nsurveys:
            logger.warn(
                "Number of cutout files produced (%d) different wrt expected (%d)!"
                % (nfiles, self.nsurveys))
            return -1

        # - Create directory for masked cutouts
        masked_cutout_dir = os.path.join(self.datadir_mask, self.sname)
        if not os.path.exists(masked_cutout_dir):
            logger.info("Creating cutout masked data dir %s ..." %
                        (masked_cutout_dir))
            Utils.mkdir(masked_cutout_dir, delete_if_exists=False)

        # - Retrieve FITS header & wcs
        logger.info("Retrieving cutout FITS header & WCS for source %s ..." %
                    (self.sname))
        try:
            header = fits.getheader(files[0])
            data_shape = fits.getdata(files[0]).shape
            wcs = WCS(header)
        except Exception as e:
            logger.error(
                "Failed to retrieve file %s header/WCS for source %s (err=%s)!"
                % (files[0], self.sname, str(e)))
            return -1

        # - Convert region to pixel coords
        logger.info(
            "Converting sky region for source %s to pixel coordinates ..." %
            (self.sname))
        try:
            region = region_sky.to_pixel(wcs)
        except Exception as e:
            logger.error(
                "Failed to convert sky region for source %s to pixel coordinates (err=%s)!"
                % (self.sname, str(e)))
            return -1

        # - Compute mask
        logger.info("Computing mask for source %s ..." % (self.sname))
        try:
            mask = region.to_mask(mode='center')
        except Exception as e:
            logger.error(
                "Failed to get mask from region for source %s (err=%s)!" %
                (self.sname, str(e)))
            return -1

        if mask is None:
            logger.warn("mask obtained from region for source %s is None!" %
                        (self.sname))
            return -1

        # - Compute image mask
        logger.info("Computing image mask for source %s ..." % (self.sname))
        maskimg = mask.to_image(data_shape)
        if maskimg is None:
            logger.error(
                "maskimg is None for source %s, this shoudn't occur at this stage!"
                % (self.sname))
            return -1

        maskimg[maskimg != 0] = 1
        maskimg = maskimg.astype(np.uint8)

        # - Dilate image mask to enlarge area around source
        if dilatemask:
            logger.info(
                "Dilating image mask to enlarge area around source %s ..." %
                (self.sname))
            structel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
                                                 (kernsize, kernsize))
            maskimg_dil = cv2.dilate(maskimg, structel, iterations=1)
            maskimg = maskimg_dil

        # - Loop over files and create masked cutouts
        for i in range(nfiles):
            filename = files[i]
            filename_base = os.path.basename(filename)
            filename_base_noext = os.path.splitext(filename_base)[0]
            filename_mask = os.path.join(masked_cutout_dir,
                                         filename_base_noext + '_masked.fits')

            logger.info("Creating masked cutout file %s from file %s ..." %
                        (filename_mask, filename_base))
            try:
                header = fits.getheader(filename)
                data = fits.getdata(filename)
                data[maskimg == 0] = maskval

                hdu_out = fits.PrimaryHDU(data, header)
                hdul = fits.HDUList([hdu_out])
                hdul.writeto(filename_mask, overwrite=True)

            except Exception as e:
                logger.error("Failed to create masked file %s for source %s!" %
                             (filename_mask, self.sname))
                return -1

        return 0
    def __compute_alpha(self, data_1, data_2, nu1, nu2, smask):
        """ Compute alpha """

        # - Get array of pixels !=0 & finite in both maps
        cond_img1 = np.logical_and(data_1 != 0, np.isfinite(data_1))
        cond_img2 = np.logical_and(data_2 != 0, np.isfinite(data_2))
        cond_img12 = np.logical_and(cond_img1, cond_img2)
        cond_final = np.logical_and(cond_img12, smask == 1)

        indexes = np.where(cond_final)
        img_1d_1 = data_1[indexes]
        img_1d_2 = data_2[indexes]

        logger.info("#%d pixels in image 1 ..." % (len(img_1d_1)))
        logger.info("#%d pixels in image 2 ..." % (len(img_1d_2)))

        if len(img_1d_1) <= 0 or len(img_1d_2) < 0:
            logger.warn(
                "No pixels left for T-T analysis after applying conditions (finite+mask) (hint: check if source is outside one or more channels)"
            )
            return None

        # - Perform fit 1-2
        logger.info("Compute spectral index from T-T fit  ...")
        res_12 = linregress(img_1d_1, img_1d_2)
        slope_12 = res_12.slope
        intercept_12 = res_12.intercept
        alpha_12 = self.__slope2alpha(slope_12, nu1, nu2)
        r_12 = res_12.rvalue

        print("== FIT RES 1-2 ==")
        print(res_12)
        print("alpha_12=%f" % (alpha_12))

        # - Perform fit 2-1
        res_21 = linregress(img_1d_2, img_1d_1)
        slope_21 = res_21.slope
        intercept_21 = res_21.intercept
        alpha_21 = self.__slope2alpha(slope_21, nu2, nu1)
        r_21 = res_21.rvalue

        print("== FIT RES 2-1 ==")
        print(res_21)
        print("alpha_21=%f" % (alpha_21))

        # - Reject fits if any of them is nan
        goodvalues_12 = np.isfinite(slope_12) and slope_12 > 0
        goodvalues_21 = np.isfinite(slope_21) and slope_21 > 0

        # - Add some goodness of fit criteria
        obs_12 = img_1d_2
        pred_12 = slope_12 * img_1d_1 + intercept_12
        residuals_12 = obs_12 - pred_12
        residuals_mean_12 = np.mean(residuals_12)
        residuals_std_12 = np.std(residuals_12)
        residuals_min_12 = np.min(residuals_12)
        residuals_max_12 = np.max(residuals_12)

        obs_21 = img_1d_1
        pred_21 = slope_21 * img_1d_2 + intercept_21
        residuals_21 = obs_21 - pred_21
        residuals_mean_21 = np.mean(residuals_21)
        residuals_std_21 = np.std(residuals_21)
        residuals_min_21 = np.min(residuals_21)
        residuals_max_21 = np.max(residuals_21)

        # - Set return tuple
        outtuple = ()
        if goodvalues_12 and not goodvalues_21:
            outtuple = (alpha_12, r_12, residuals_mean_12, residuals_std_12,
                        residuals_min_12, residuals_max_12)
        elif goodvalues_21 and not goodvalues_12:
            outtuple = (alpha_21, r_21, residuals_mean_21, residuals_std_21,
                        residuals_min_21, residuals_max_21)
        else:
            # - Select best model
            best_resbias_id = 1
            best_resstd_id = 1
            best_rcoeff_id = 1
            if np.abs(residuals_mean_21) < np.abs(
                    residuals_mean_12):  # check smallest residual bias
                best_resbias_id = 2
            if np.abs(residuals_std_21) < np.abs(
                    residuals_std_12):  # check smallest residual std dev
                best_resstd_id = 2
            if np.abs(r_21) > np.abs(
                    r_12):  # check larger (closer to 1) correlation coeff
                best_rcoeff_id = 2

            if best_rcoeff_id == 1:
                outtuple = (alpha_12, r_12, residuals_mean_12,
                            residuals_std_12, residuals_min_12,
                            residuals_max_12)
            else:
                outtuple = (alpha_21, r_21, residuals_mean_21,
                            residuals_std_21, residuals_min_21,
                            residuals_max_21)

        return outtuple
Esempio n. 4
0
	def __extract_sources(self, data, bkg, rms, mask=None, seed_thr=4, merge_thr=3, dist_thr=-1):
		""" Find sources in channel data """
	
		# - Compute image center
		data_shape= data.shape
		y_c= data_shape[0]/2.;
		x_c= data_shape[1]/2.;

		# - Compute mask
		if mask is None:
			logger.info("Computing image mask ...")
			mask= np.logical_and(data!=0, np.isfinite(data))	

		data_1d= data[mask]
	
		# - Threshold image at seed_thr
		zmap= (data-bkg)/rms
		binary_map= (zmap>merge_thr).astype(np.int32)
		binary_map[~mask]= 0
		zmap[~mask]= 0
	
		# - Extract source
		logger.info("Extracting sources ...")
		label_map= skimage.measure.label(binary_map)
		regprops= skimage.measure.regionprops(label_map, data)

		nsources= len(regprops)
		logger.info("#%d sources found ..." % nsources)

		# - Extract peaks
		kernsize= 3
		footprint = np.ones((kernsize, ) * data.ndim, dtype=bool)
		peaks= peak_local_max(np.copy(zmap), footprint=footprint, threshold_abs=seed_thr, min_distance=2, exclude_border=True)
		#print(peaks)
		
		if peaks.shape[0]<=0:
			logger.info("No peaks detected in this image, return None ...")
			return None
		
		# - Select best source
		regprops_sel= []
		peaks_sel= []
		binary_maps_sel= []
		polygons_sel= []
		contours_sel= []
		#binary_maps_sel= []
		#binary_map_sel= np.zeros_like(binary_map)

		for regprop in regprops:
			# - Check if region max is >=seed_thr
			sslice= regprop.slice
			zmask= zmap[sslice]
			zmask_1d= zmask[np.logical_and(zmask!=0, np.isfinite(zmask))]	
			zmax= zmask_1d.max()
			if zmax<seed_thr:
				logger.info("Skip source as zmax=%f<thr=%f" % (zmax, seed_thr))
				continue

			# - Set binary map with this source
			logger.debug("Get source binary mask  ...")
			bmap= np.zeros_like(binary_map)
			bmap[sslice]= binary_map[sslice]

			# - Extract contour and polygon from binary mask
			logger.info("Extracting contour and polygon from binary mask ...")
			contours= []
			polygon= None
			try:
				bmap_uint8= bmap.copy() # copy as OpenCV internally modify origin mask
				bmap_uint8= bmap_uint8.astype(np.uint8)
				contours= cv2.findContours(bmap_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
				contours= imutils.grab_contours(contours)
				if len(contours)>0:
					contour= np.squeeze(contours[0])
					polygon = Polygon(contour)
			except Exception as e:
				logger.warn("Failed to compute mask contour (err=%s)!" % (str(e)))
			
			if polygon is None:
				logger.warn("Skip extracted blob as polygon failed to be computed...")
				continue

			# - Check if source has a local peak in the mask
			#   NB: Check if polygon is computed
			has_peak= False
			peak_sel= None
			if polygon is not None:
				for peak in peaks:
					point = Point(peak[1], peak[0])
					has_peak= polygon.contains(point)
					if has_peak:
						peak_sel= peak
						break
				
			if not has_peak: 
				logger.info("Skip extracted blob as no peak was found inside source contour polygon!")
				continue

			# - Check for source peak distance wrt image center
			if dist_thr>0:
				dist= np.sqrt( (peak_sel[1]-x_c)**2 + (peak_sel[0]-y_c)**2 )
				if dist>dist_thr:
					logger.info("Skip extracted source as peak-imcenter dist=%f<thr=%f" % (dist, dist_thr))
					continue

			# - Update global binary mask and regprops
			#binary_map_sel[sslice]= binary_map[sslice]
			regprops_sel.append(regprop)
			peaks_sel.append(peak_sel)
			binary_maps_sel.append(bmap)	
			polygons_sel.append(polygon)
			contours_sel.append(contours[0])
			
		# - Return None if no source is selected
		nsources_sel= len(regprops_sel)
		if nsources_sel<=0:
			logger.info("No sources selected for this image ...")
			return None

		# - If more than 1 source is selected, take the one with peak closer to image center
		peak_final= peaks_sel[0]
		bmap_final= binary_maps_sel[0]
		regprop_final= regprops_sel[0]
		polygon_final= polygons_sel[0]
		contour_final= contours_sel[0]

		if nsources_sel>1:
			logger.info("#%d sources selected, going to select the closest to image center ..." % (nsources_sel))
			
			dist_best= 1.e+99
			index_best= -1
			for j in range(len(peaks_sel)):
				peak= peaks_sel[j]
				bmap= binary_maps_sel[j]
				regprop= regprops_sel[j]
				polygon= polygons_sel[j]
				contour= contours_sel[j]
				dist= np.sqrt( (peak[1]-x_c)**2 + (peak[0]-y_c)**2 )
				if dist<dist_best:
					dist_best= dist
					peak_final= peak
					bmap_final= bmap
					regprop_final= regprop
					polygon_final= polygon
					contour_final= contour
					index_best= j
			
			logger.info("Selected source no. %d as the closest one to image center ..." % (index_best))							
		else:
			logger.info("#%d sources selected..." % (nsources_sel))
			
		# - Compute enclosing circle radius 
		try:
			(xc, yc), radius= cv2.minEnclosingCircle(contour_final)
			enclosing_circle= (xc,yc,radius)
		except Exception as e:
			logger.warn("Failed to compute min enclosing circle (err=%s)!" % (str(e)))
			enclosing_circle= None

		# - Draw figure
		if self.draw:
			fig, ax = plt.subplots()

			# - Draw map
			#plt.imshow(label_map)
			#plt.imshow(data)
			plt.imshow(zmap)
			#plt.imshow(bmap_final)
			plt.colorbar()

			# - Draw bbox rectangle
			bbox= regprop_final.bbox
			ymin= bbox[0]
			ymax= bbox[2]
			xmin= bbox[1]
			xmax= bbox[3]
			dx= xmax-xmin-1
			dy= ymax-ymin-1
			rect = patches.Rectangle((xmin,ymin), dx, dy, linewidth=1, edgecolor='r', facecolor='none')
			ax.add_patch(rect)

			# - Draw selected peak
			if peak_final is not None:
				plt.scatter(peak_final[1], peak_final[0], s=10)

			# - Draw contour polygon
			if polygon_final is not None:
				plt.plot(*polygon_final.exterior.xy)

			# - Draw enclosing circle
			circle = plt.Circle((xc, yc), radius, color='g', clip_on=False, fill=False)
			ax.add_patch(circle)

			plt.show()


		return (peak_final, bmap_final, regprop_final, enclosing_circle)
Esempio n. 5
0
	def compute_ssim_pars(self, winsize=3):
		""" Compute SSIM params """

		# - Loop over images and compute params
		index= 0
		for i in range(self.nchannels-1):
			
			img_i= self.img_data[i]
			cond_i= np.logical_and(img_i!=0, np.isfinite(img_i))

			img_max_i= np.nanmax(img_i[cond_i])
			img_min_i= np.nanmin(img_i[cond_i])
			
			img_norm_i= (img_i-img_min_i)/(img_max_i-img_min_i)
			img_norm_i[~cond_i]= 0

			# - Compute SSIM maps
			for j in range(i+1,self.nchannels):
				img_j= self.img_data[j]
				cond_j= np.logical_and(img_j!=0, np.isfinite(img_j))
				img_max_j= np.nanmax(img_j[cond_j])
				img_min_j= np.nanmin(img_j[cond_j])
				
				img_norm_j= (img_j-img_min_j)/(img_max_j-img_min_j)
				img_norm_j[~cond_j]= 0

				cond= np.logical_and(cond_i, cond_j)
				
				# - Compute SSIM moments
				#   NB: Need to normalize images to max otherwise the returned values are always ~1.
				logger.info("Computing SSIM for image %s (id=%s, ch=%d-%d) ..." % (self.sname, self.label, i+1, j+1))
				_, ssim_2d= structural_similarity(img_norm_i, img_norm_j, full=True, win_size=winsize, data_range=1)

				ssim_2d[ssim_2d<0]= 0
				ssim_2d[~cond]= 0
				self.ssim_maps.append(ssim_2d)

				ssim_1d= ssim_2d[cond]

				#if self.draw:
				#	plt.subplot(1, 3, 1)
				#	plt.imshow(img_norm_i, origin='lower')
				#	plt.colorbar()

				#	plt.subplot(1, 3, 2)
				#	plt.imshow(img_norm_j, origin='lower')
				#	plt.colorbar()
					
				#	plt.subplot(1, 3, 3)
				#	plt.imshow(ssim_2d, origin='lower')
				#	plt.colorbar()

				#	plt.show()

				if ssim_1d.size>0:
					ssim_mean= np.nanmean(ssim_1d)
					ssim_median= np.nanmedian(ssim_1d)
					ssim_avg= ssim_median
					self.ssim_avg.append(ssim_avg)
					
					logger.info("Image %s (chan=%d-%d): <SSIM>=%f" % (self.sname, i+1, j+1, ssim_avg))

				else:
					logger.warn("Image %s (chan=%d-%d): SSIM array is empty, setting estimators to -999..." % (self.sname, i+1, j+1))
					self.ssim_avg.append(-999)
					

		return 0
Esempio n. 6
0
	def __process_sdata(self, index):
		""" Process source data """

		#===========================
		#==    READ DATA
		#===========================
		# - Read source data
		logger.info("Reading source and source masked data %d ..." % (index))
		ret= self.__read_sdata(index)
		if ret is None:
			logger.error("Failed to read source data %d!" % (index))
			return -1

		sdata= ret[0]
		sdata_mask= ret[1]

		#===========================
		#==    MODIFY MASKS
		#===========================
		# - Shrink img & mask in masked sdata?		
		if self.shrink_masks: 
			logger.info("Shrinking img+mask on source masked data %d ..." % (index))
			if sdata_mask.shrink_masks(self.erode_kernels)<0:
				logger.warn("Failed to shrink mask for source masked data %d!" % (index))
				return -1

		# - Expand img & mask in masked sdata?
		if self.grow_masks:
			logger.info("Expanding img+mask on source masked data %d ..." % (index))
			if sdata_mask.grow_masks(self.dilate_kernels)<0:
				logger.warn("Failed to expand mask for source masked data %d!" % (index))
				return -1

		masks= sdata_mask.img_data_mask
		#mask_ref= masks[self.refch]

		#===========================
		#==  CHECK DATA INTEGRITY
		#===========================
		# - Check non-masked data
		has_good_data= sdata.has_good_data(check_mask=False, check_bad=True, check_neg=False, check_same=True)
		if not has_good_data:
			logger.warn("Source data %d are bad (too may NANs or equal pixel values)!" % (index))
			return -1

		# - Check masked data
		has_good_mask_data= sdata_mask.has_good_data(check_mask=False, check_bad=True, check_neg=True, check_same=True)
		if not has_good_mask_data:
			logger.warn("Source mask data %d are bad (too may NANs/negative or equal pixel values)!" % (index))
			return -1

		#===========================
		#==  CHECK AE RECO ACCURACY
		#===========================
		# ...
		# ...

		#===========================
		#==    COMPUTE BKG/FLUX
		#===========================
		# - Compute bkg on img over non-masked pixels
		logger.info("Computing bkg on source data %d ..." % (index))
		if sdata.compute_bkg(masks)<0:
			logger.warn("Failed to compute bkg for source data %d!" % (index))
			return -1

		bkg_levels= sdata.bkg_levels

		#print("--> bkg levels")
		#print(bkg_levels)

		# - Apply masks to sdata
		#   NB: Do this after bkg calculation (otherwise all non-masked pixels are set to 0, so bkg will be 0) and before subtract bkg
		logger.info("Applying masks to source data %d ..." % (index))
		sdata.apply_masks(masks)

		# - Subtract bkg on img
		#if self.subtract_bkg:
		#	logger.info("Subtracting bkg on source data %d ..." % (index))
		#	if sdata.subtract_bkg(bkg_levels, self.subtract_bkg_only_refch)<0:
		#		logger.warn("Failed to subtract bkg for source data %d!" % (index))
		#		return -1

		# - Compute integrated flux (no source extraction here, only sum of pixel fluxes in mask)
		logger.info("Computing flux on source data %d ..." % (index))
		sdata.compute_fluxes(subtract_bkg=self.subtract_bkg, subtract_only_refch=self.subtract_bkg_only_refch)

		# - Extract sources and compute pars 
		#   NB: source extraction may fail or not be accurate (e.g. miss source, contour not accurate, etc)
		logger.info("Extracting source blobs on source data %d ..." % (index))
		sdata.find_sources(
			seed_thr=self.seed_thr, merge_thr=self.merge_thr, dist_thr=self.dist_thr, 
			subtract_bkg=self.subtract_bkg, subtract_only_refch=self.subtract_bkg_only_refch
		)		

		#===========================
		#==    COMPUTE MOMENTS
		#===========================
		# - Compute centroids and moments on images (NB: masked before)
		logger.info("Computing moments on source data %d ..." % (index))	
		if sdata.compute_img_moments()<0:
			logger.warn("Failed to compute moments for source data %d!" % (index))
			return -1

		#===========================
		#==    COMPUTE SSIM
		#===========================
		if self.save_ssim_pars:
			logger.info("Computing ssim pars on source data %d ..." % (index))	
			if sdata.compute_ssim_pars(self.ssim_winsize)<0:
				logger.warn("Failed to compute SSIM pars for source data %d!" % (index))
				return -1

		#===========================
		#==   FILL SOURCE OUT DATA
		#===========================
		# - Fill and append features
		logger.info("Filling feature dict for source data %d ..." % (index))	
		sdata.fill_features()

		par_dict= sdata.param_dict
		if par_dict is None or not par_dict:
			logger.warn("Feature dict for source data %d is empty or None, skip it ..." % (index))
			
		else:
			# - Select features?
			if self.select_feat and self.selfeatids:
				ret= sdata.select_features(self.selfeatids)
				par_dict= sdata.param_dict

				if ret==0:
					self.par_dict_list.append(par_dict)
				else:
					logger.warn("Failed to select features for source data %d, skip it ..." % (index))

			else:
				self.par_dict_list.append(par_dict)
		
		return 0
Esempio n. 7
0
	def fill_features(self):

		# - Save name
		self.param_dict["sname"]= self.sname

		# - Save source flux
		flux_ref= self.fluxes[self.refch]
		
		for j in range(len(self.fluxes)):
			flux= self.fluxes[j]
			parname= "flux_ch" + str(j+1)
			self.param_dict[parname]= flux

		# - Save source flux log ratios Fj/F_radio (i.e. colors)
		lgFluxRatio_safe= 0
		is_good_flux_ref= (flux_ref>0) and (np.isfinite(flux_ref))
		if not is_good_flux_ref:
			logger.warn("Flux for ref chan (%d) is <=0 or nan for image %s (id=%s),  will set all color index to %d..." % (self.refch, self.sname, self.label, lgFluxRatio_safe))

		for j in range(len(self.fluxes)):
			if j==self.refch:
				continue
			flux= self.fluxes[j] # if source is not detected this is the background level
			is_good_flux= (flux>0) and (np.isfinite(flux))
			
			lgFluxRatio= 0
			if is_good_flux_ref:
				if is_good_flux:
					lgFluxRatio= np.log10(flux/flux_ref)
				else:
					logger.warn("Flux for chan %d is <=0 or nan for image %s (id=%s),  will set this color index to %d..." % (self.refch, self.sname, self.label, lgFluxRatio_safe))
					lgFluxRatio= lgFluxRatio_safe
			else:
				lgFluxRatio= lgFluxRatio_safe
			 
			parname= "lgFratio_ch" + str(self.refch+1) + "_" + str(j+1)
			self.param_dict[parname]= lgFluxRatio

		
		# - Save source flux log ratios Fj/F_radio (i.e. colors)
		cind_safe= 0
		sflux_ref= self.sfluxes[self.refch]
		is_good_flux_ref= (sflux_ref is not None) and (sflux_ref>0) and (np.isfinite(sflux_ref))
		if not is_good_flux_ref:
			logger.warn("Flux for ref chan (%d) is <=0 or nan for image %s (id=%s),  will set all color index to %d..." % (self.refch, self.sname, self.label, cind_safe))

		for j in range(len(self.sfluxes)):
			if j==self.refch:
				continue
			sflux= self.sfluxes[j] 
			flux= self.fluxes[j]
			if sflux is None: # source is not detected, take sum of pixel fluxes inside ref source aperture (e.g. the background)
				logger.info("Source is not detected in chan %d, taking pixel sum over ref source aperture %f ..." % (j+1, flux))
				sflux= flux
				
			is_good_flux= (sflux>0) and (np.isfinite(sflux))
			
			cind= 0
			if is_good_flux_ref:
				if is_good_flux:
					cind= np.log10(sflux/sflux_ref)
				else:
					logger.warn("Flux for chan %d is <=0 or nan for image %s (id=%s),  will set this color index to %d..." % (self.refch, self.sname, self.label, cind_safe))
					cind= cind_safe
			else:
				cind= cind_safe
			
			parname= "color_ch" + str(self.refch+1) + "_" + str(j+1)
			self.param_dict[parname]= cind


		# - Save source IOU
		for j in range(len(self.sious)):
			ch_i, ch_j= self.__get_triu_indices(j, self.nchannels)
			iou= self.sious[j]
			parname= "iou_ch" + str(ch_i) + "_" + str(ch_j)
			self.param_dict[parname]= iou
			
		# - Save source peak dist
		for j in range(len(self.speaks_dists)):
			ch_i, ch_j= self.__get_triu_indices(j, self.nchannels)
			peak_dist= self.speaks_dists[j]
			parname= "dpeak_ch" + str(ch_i) + "_" + str(ch_j)
			self.param_dict[parname]= peak_dist


		# - Save img moments
		for i in range(len(self.moments_zern)):
			for j in range(len(self.moments_zern[i])):
				if j==0:
					continue # Skip as mom0 is always the same
				m= self.moments_zern[i][j]
				parname= "zernmom" + str(j+1) + "_ch" + str(i+1)
				self.param_dict[parname]= m

		# - Save ssim parameters
		if self.save_ssim_pars:
			for j in range(len(self.ssim_avg)):
				ch_i, ch_j= self.__get_triu_indices(j, self.nchannels)
				parname= "ssim_avg_ch{}_{}".format(ch_i,ch_j)
				self.param_dict[parname]= self.ssim_avg[j]
				
		# - Save class id
		self.param_dict["id"]= self.id