def train_model(self, model_fits): """ Create a model instance and train it. Save the resulting model Parameters ---------- model_fits : bool If True, save the model as a fits file. Otherwise, save it as a json file. """ # consistency checks if self.__mode != "training": raise Error("The function train_model is available in the " + f"training mode only. Detected mode is {self.__mode}") selected_cols = [ col.upper() for col in self.__candidates.columns if col.endswith("RATIO_SN") ] selected_cols += [ col.upper() for col in self.__candidates.columns if col.endswith("RATIO2") ] selected_cols += [ col.upper() for col in self.__candidates.columns if col.endswith("RATIO") ] selected_cols += ["PEAK_SIGNIFICANCE"] # add extra columns if len(self.__model_options ) == 3 and self.__model_options[2] is not None: selected_cols += [item.upper() for item in self.__model_options[2]] # add columns to compute the class in training selected_cols += ['CLASS_PERSON', 'CORRECT_REDSHIFT'] if self.__name.endswith(".fits"): if model_fits: model_name = self.__name.replace(".fits", "_model.fits.gz") else: model_name = self.__name.replace(".fits", "_model.json") elif self.__name.endswith(".fits.gz"): if model_fits: model_name = self.__name.replace(".fits.gz", "_model.fits.gz") else: model_name = self.__name.replace(".fits.gz", "_model.json") else: raise Error("Invalid model name") self.__model = Model(model_name, selected_cols, self.__get_settings(), model_options=self.__model_options) self.__model.train(self.__candidates) self.__model.save_model()
def classify_candidates(self, save=True): """ Create a model instance and train it. Save the resulting model""" # consistency checks if self.__mode not in ["test", "operation"]: raise Error( "The function classify_candidates is available in the " + f"test mode only. Detected mode is {self.__mode}") if self.__candidates is None: raise Error("Attempting to run the function classify_candidates " + "but no candidates were found/loaded. Check your " + "formatter") self.__candidates = self.__model.compute_probability(self.__candidates) if save: self.save_candidates()
def find_candidates(self, spectra, columns_candidates): """ Find candidates for a given set of spectra, then integrate them in the candidates catalogue and save the new version of the catalogue. Parameters ---------- spectra : list of Spectrum The spectra in which candidates will be looked for. columns_candidates : list of str The column names of the spectral metadata """ if self.__mode == "training" and "Z_TRUE" not in spectra[ 0].metadata_names(): raise Error("Mode is set to 'training', but spectra do not " + "have the property 'Z_TRUE'.") if self.__mode == "test" and "Z_TRUE" not in spectra[0].metadata_names( ): raise Error("Mode is set to 'test', but spectra do not " + "have the property 'Z_TRUE'.") if self.__mode == "merge": raise Error("The function find_candidates is not available in " + "merge mode.") for spectrum in spectra: # locate candidates in this spectrum # candidates are appended to self.__candidates_list self.__find_candidates(spectrum) if len(self.__candidates_list) > MAX_CANDIDATES_TO_CONVERT: self.__userprint("Converting candidates to dataframe") time0 = time.time() self.candidates_list_to_dataframe(columns_candidates, save=False) time1 = time.time() self.__userprint( "INFO: time elapsed to convert candidates to dataframe: " f"{(time0-time1)/60.0} minutes")
def set_mode(self, mode): """ Allow user to change the running mode Parameters ---------- mode : "training", "test", "candidates", "operation", or "merge" - Default: "operation" Running mode. "training" mode assumes that true redshifts are known and provide a series of functions to train the model. """ if mode in ["training", "test", "candidates", "operation", "merge"]: self.__mode = mode else: raise Error("Invalid mode")
def __init__(self, spectrum_dict, wave, metadata): """ Initialize class instance Parameters ---------- spectrum_dict : dict A dictionary with the flux, inverse variance, and wavelengths for the red CCD. Keys are "red_flux" and "red_ivar", and with the flux, inverse variance, and wavelengths for the blue CCD. Keys are "blue_flux" and "blue_ivar". wave : dict Dictionary with the wavelength information. Must contain the arrays "red_wave" and "blue_wave", and the floats "red_delta_wave" and "blue_deltas_wave". metadata : dict A dictionary with the metadata. Keys should be strings """ # check that "specid" is present in metadata if "SPECID" not in metadata.keys(): raise Error("""The property "SPECID" must be present in metadata""") flux, ivar, wave = get_spectra(spectrum_dict, wave) super().__init__(flux, ivar, wave, metadata)
def __init__(self, flux, ivar, wave, metadata): """ Initialize class instance. This function should be modified as required or removed if no specific initialization operations are required Parameters ---------- flux : np.array Array containing the flux ivar : np.array Array containing the inverse variance wave : np.array Array containing the wavelength metadata : dict A dictionary where the keys are the names of the properties and have type str. """ super().__init__(flux, ivar, wave, metadata) # TODO: fill function raise Error("Not implemented")
def merge(self, others_list, save=True): """ Merge self.__candidates with another candidates object Parameters ---------- others_list : pd.DataFrame The other candidates object to merge save : bool - Defaut: True If True, save candidates before exiting """ if self.__mode != "merge": raise Error("The function merge is available in the " + f"merge mode only. Detected mode is {self.__mode}") for index, candidates_filename in enumerate(others_list): self.__userprint(f"Merging... {index} of {len(others_list)}") try: # load candidates data = Table.read(candidates_filename, format='fits') other = data.to_pandas() del data # append to candidates list self.__candidates = self.__candidates.append(other, ignore_index=True) except TypeError: self.__userprint( f"Error occured when loading file {candidates_filename}.") self.__userprint("Ignoring file") if save: self.save_candidates()
def __init__(self, spectrum_file, metadata, sky_mask, mask_jpas=False, mask_jpas_alt=False, rebin_pixels_width=0, extend_pixels=0, noise_increase=1, forbidden_wavelenghts=None): """ Initialize class instance Parameters ---------- spectrum_file : str Name of the fits files containing the spectrum metadata : dict A dictionary with the metadata. Keys should be strings sky_mask : (np.array, float) A tuple containing the array of the wavelengths to mask and the margin used in the masking. Wavelengths separated to wavelength given in the array by less than the margin will be masked mask_jpas : bool - Default: False If set, mask pixels corresponding to filters in trays T3 and T4. Only works if the bin size is 100 Angstroms mask_jpas_alt : bool - Default: False If set, mask pixels corresponding to filters in trays T3* and T4. Only works if the bin size is 100 Angstroms rebin_pixels_width : float, >0 - Default: 0 Width of the new pixel (in Angstroms) extend_pixels : float, >0 - Default: 0 Pixel overlap region (in Angstroms) noise_increase : int, >0 - Default: 1 Adds noise to the spectrum by adding a gaussian random number of width equal to the (noise_amount-1) times the given variance. Then increase the variance by a factor of sqrt(noise_amount) forbidden_wavelengths : list of tuples or None - Default: None If not None, a list containing tuples specifying ranges of wavelengths that will be masked (both ends included). Each tuple must contain the initial and final range of wavelenghts. This is intended to be complementary to the sky mask to limit the wavelength coverage, and hard cuts will be applied """ # check that "specid" is present in metadata if "SPECID" not in metadata.keys(): raise Error("""The property "SPECID" must be present in metadata""") # open fits file spectrum_hdul = fitsio.FITS(spectrum_file) # intialize arrays # The 1.0 mulitplying is added to change type from >4f to np.float # this is required by numba later on wave = 10**spectrum_hdul[1]["LOGLAM"][:] flux = 1.0 * spectrum_hdul[1]["FLUX"][:] ivar = 1.0 * spectrum_hdul[1]["IVAR"][:] super().__init__(flux, ivar, wave, metadata) # compute sky mask masklambda = sky_mask[0] margin = sky_mask[1] self.__skymask = None self.__find_skymask(masklambda, margin) # mask forbidden lines if forbidden_wavelenghts is not None: self.__filter_wavelengths(forbidden_wavelenghts) # store the wavelength, flux and inverse variance as arrays # mask pixels self._ivar[self.__skymask] = 0.0 if noise_increase > 1: self.__add_noise(noise_increase) if rebin_pixels_width > 0: self._flux, self._ivar, self._wave = self.rebin( rebin_pixels_width, extend_pixels=extend_pixels) # JPAS mask if mask_jpas: pos = np.where(~((np.isin( self._wave, [3900, 4000, 4300, 4400, 4700, 4800, 5100, 5200])) | (self._wave >= 7300))) self._wave = self._wave[pos].copy() self._ivar = self._ivar[pos].copy() self._flux = self._flux[pos].copy() elif mask_jpas_alt: pos = np.where(~((np.isin( self._wave, [3800, 4000, 4200, 4400, 4600, 4800, 5000, 5200])) | (self._wave >= 7300))) self._wave = self._wave[pos].copy() self._ivar = self._ivar[pos].copy() self._flux = self._flux[pos].copy() spectrum_hdul.close()
def main(cmdargs): """ Run SQUEzE in operation mode """ # load options parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, parents=[OPERATION_PARSER]) args = parser.parse_args(cmdargs) # manage verbosity userprint = verboseprint if not args.quiet else quietprint t0 = time.time() # load model userprint("Loading model") if args.model.endswith(".json"): model = Model.from_json(load_json(args.model)) else: model = Model.from_fits(args.model) t1 = time.time() userprint(f"INFO: time elapsed to load model", (t1 - t0) / 60.0, 'minutes') # initialize candidates object userprint("Initializing candidates object") if args.output_candidates is None: candidates = Candidates(mode="operation", model=model, userprint=userprint) else: candidates = Candidates(mode="operation", name=args.output_candidates, model=model, userprint=userprint) # load candidates dataframe if they have previously looked for if args.load_candidates: userprint("Loading existing candidates") t2 = time.time() candidates.load_candidates(args.input_candidates) t3 = time.time() userprint( f"INFO: time elapsed to load candidates: {(t3-t2)/60.0} minutes") # load spectra if args.input_spectra is not None: userprint("Loading spectra") t4 = time.time() columns_candidates = [] userprint("There are {} files with spectra to be loaded".format( len(args.input_spectra))) for index, spectra_filename in enumerate(args.input_spectra): userprint("Loading spectra from {} ({}/{})".format( spectra_filename, index, len(args.input_spectra))) t40 = time.time() spectra = Spectra.from_json(load_json(spectra_filename)) if not isinstance(spectra, Spectra): raise Error("Invalid list of spectra") if index == 0: columns_candidates += spectra.spectra_list()[0].metadata_names( ) # look for candidates userprint("Looking for candidates") candidates.find_candidates(spectra.spectra_list(), columns_candidates) t41 = time.time() userprint( f"INFO: time elapsed to find candidates from {spectra_filename}:" f" {(t41-t40)/60.0} minutes") t5 = time.time() userprint( f"INFO: time elapsed to find candidates: {(t5-t4)/60.0} minutes") # convert to dataframe userprint("Converting candidates to dataframe") t6 = time.time() candidates.candidates_list_to_dataframe(columns_candidates) t7 = time.time() userprint( f"INFO: time elapsed to convert candidates to dataframe: {(t7-t6)/60.0} minutes" ) # compute probabilities userprint("Computing probabilities") t8 = time.time() candidates.classify_candidates() t9 = time.time() userprint( f"INFO: time elapsed to classify candidates: {(t9-t8)/60.0} minutes") # save the catalogue as a fits file if not args.no_save_catalogue: candidates.save_catalogue(args.output_catalogue, args.prob_cut) t10 = time.time() userprint(f"INFO: total elapsed time: {(t10-t0)/60.0} minutes") userprint("Done")
def main(cmdargs): """ Run SQUEzE in test mode """ # load options parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, parents=[TEST_PARSER]) args = parser.parse_args(cmdargs) if args.check_statistics: quasar_parser_check(parser, args) # manage verbosity userprint = verboseprint if not args.quiet else quietprint t0 = time.time() # load quasar catalogue (only if --check-statistics is passed) if args.check_statistics: userprint("Loading quasar catalogue") if args.qso_dataframe is not None: quasar_catalogue = deserialize(load_json(args.qso_dataframe)) quasar_catalogue["LOADED"] = True else: quasar_catalogue = QuasarCatalogue( args.qso_cat, args.qso_cols, args.qso_specid, args.qso_ztrue, args.qso_hdu).quasar_catalogue() quasar_catalogue["LOADED"] = False t1 = time.time() userprint( f"INFO: time elapsed to load quasar catalogue: {(t1-t0)/60.0} minutes" ) # load model userprint("Loading model") t2 = time.time() if args.model.endswith(".json"): model = Model.from_json(load_json(args.model)) else: model = Model.from_fits(args.model) t3 = time.time() userprint(f"INFO: time elapsed to load model: {(t3-t2)/60.0} minutes") # initialize candidates object userprint("Initializing candidates object") if args.output_candidates is None: candidates = Candidates(mode="test", model=model, userprint=userprint) else: candidates = Candidates(mode="test", name=args.output_candidates, model=model, userprint=userprint) # load candidates dataframe if they have previously looked for if args.load_candidates: userprint("Loading existing candidates") t4 = time.time() candidates.load_candidates(args.input_candidates) t5 = time.time() userprint( f"INFO: time elapsed to load candidates: {(t5-t4)/60.0} minutes") # load spectra if args.input_spectra is not None: userprint("Loading spectra") t6 = time.time() columns_candidates = [] userprint("There are {} files with spectra to be loaded".format( len(args.input_spectra))) for index, spectra_filename in enumerate(args.input_spectra): userprint("Loading spectra from {} ({}/{})".format( spectra_filename, index, len(args.input_spectra))) t60 = time.time() spectra = Spectra.from_json(load_json(spectra_filename)) if not isinstance(spectra, Spectra): raise Error("Invalid list of spectra") if index == 0: columns_candidates += spectra.spectra_list()[0].metadata_names( ) # flag loaded quasars as such if args.check_statistics: for spec in spectra.spectra_list(): if quasar_catalogue[quasar_catalogue["SPECID"] == spec.metadata_by_key( "SPECID")].shape[0] > 0: index2 = quasar_catalogue.index[ quasar_catalogue["SPECID"] == spec.metadata_by_key( "SPECID")].tolist()[0] quasar_catalogue.at[index2, "LOADED"] = True # look for candidates userprint("Looking for candidates") candidates.find_candidates(spectra.spectra_list(), columns_candidates) t61 = time.time() userprint( f"INFO: time elapsed to find candidates from {spectra_filename}:" f" {(t61-t60)/60.0} minutes") t7 = time.time() userprint( f"INFO: time elapsed to find candidates: {(t7-t6)/60.0} minutes") # convert to dataframe userprint("Converting candidates to dataframe") t8 = time.time() candidates.candidates_list_to_dataframe(columns_candidates) t9 = time.time() userprint( f"INFO: time elapsed to convert candidates to dataframe: {(t9-t8)/60.0} minutes" ) # compute probabilities userprint("Computing probabilities") t10 = time.time() candidates.classify_candidates() t11 = time.time() userprint( f"INFO: time elapsed to classify candidates: {(t11-t10)/60.0} minutes") # check completeness if args.check_statistics: probs = args.check_probs if args.check_probs is not None else np.arange( 0.9, 0.0, -0.05) userprint("Check statistics") data_frame = candidates.candidates() userprint("\n---------------") userprint("step 1") candidates.find_completeness_purity(quasar_catalogue.reset_index(), data_frame) for prob in probs: userprint("\n---------------") userprint("proba > {}".format(prob)) candidates.find_completeness_purity( quasar_catalogue.reset_index(), data_frame[(data_frame["PROB"] > prob) & ~(data_frame["DUPLICATED"]) & (data_frame["Z_CONF_PERSON"] == 3)], ) # save the catalogue as a fits file if not args.no_save_catalogue: candidates.save_catalogue(args.output_catalogue, args.prob_cut) t12 = time.time() userprint(f"INFO: total elapsed time: {(t12-t0)/60.0} minutes") userprint("Done")
def main(cmdargs): """ Run SQUEzE in training mode """ # load options parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, parents=[TRAINING_PARSER]) args = parser.parse_args(cmdargs) # manage verbosity userprint = verboseprint if not args.quiet else quietprint t0 = time.time() # load lines userprint("Loading lines") lines = LINES if args.lines is None else deserialize(load_json(args.lines)) # load try_line try_line = TRY_LINES if args.try_lines is None else args.try_lines # load redshift precision z_precision = Z_PRECISION if args.z_precision is None else args.z_precision # load peakfinder options peakfind_width = PEAKFIND_WIDTH if args.peakfind_width is None else args.peakfind_width peakfind_sig = PEAKFIND_SIG if args.peakfind_sig is None else args.peakfind_sig # load random forest options random_forest_options = RANDOM_FOREST_OPTIONS if args.random_forest_options is None else load_json( args.random_forest_options) random_state = RANDOM_STATE if args.random_state is None else args.random_state # initialize candidates object userprint("Initializing candidates object") if args.output_candidates is None: candidates = Candidates(lines_settings=(lines, try_line), z_precision=z_precision, mode="training", peakfind=(peakfind_width, peakfind_sig), model=None, userprint=userprint, model_options=(random_forest_options, random_state, args.pass_cols_to_rf)) else: candidates = Candidates(lines_settings=(lines, try_line), z_precision=z_precision, mode="training", name=args.output_candidates, peakfind=(peakfind_width, peakfind_sig), model=None, userprint=userprint, model_options=(random_forest_options, random_state, args.pass_cols_to_rf)) # load candidates dataframe if they have previously looked for if args.load_candidates: userprint("Loading existing candidates") t1 = time.time() candidates.load_candidates(args.input_candidates) t2 = time.time() userprint( f"INFO: time elapsed to load candidates: {(t2-t1)/60.0} minutes") # load spectra if args.input_spectra is not None: userprint("Loading spectra") t3 = time.time() columns_candidates = [] userprint("There are {} files with spectra to be loaded".format( len(args.input_spectra))) for index, spectra_filename in enumerate(args.input_spectra): userprint("Loading spectra from {} ({}/{})".format( spectra_filename, index, len(args.input_spectra))) t30 = time.time() spectra = Spectra.from_json(load_json(spectra_filename)) if not isinstance(spectra, Spectra): raise Error("Invalid list of spectra") if index == 0: columns_candidates += spectra.spectra_list()[0].metadata_names( ) # look for candidates userprint("Looking for candidates") candidates.find_candidates(spectra.spectra_list(), columns_candidates) t31 = time.time() userprint( f"INFO: time elapsed to find candidates from {spectra_filename}: " f"{(t31-t30)/60.0} minutes") t4 = time.time() userprint( f"INFO: time elapsed to find candidates: {(t4-t3)/60.0} minutes") # convert to dataframe userprint("Converting candidates to dataframe") t5 = time.time() candidates.candidates_list_to_dataframe(columns_candidates) t6 = time.time() userprint( f"INFO: time elapsed to convert candidates to dataframe: {(t6-t5)/60.0} minutes" ) # train model userprint("Training model") t7 = time.time() candidates.train_model(args.model_fits) t8 = time.time() userprint(f"INFO: time elapsed to train model: {(t8-t7)/60.0} minutes") userprint(f"INFO: total elapsed time: {(t8-t0)/60.0} minutes") userprint("Done")
def find_completeness_purity(self, quasars_data_frame, data_frame=None): """ Given a DataFrame with candidates and another one with the catalogued quasars, compute the completeness and the purity. Upon error, return np.nan Parameters ---------- quasars_data_frame : string DataFrame containing the quasar catalogue. The quasars must contain the column "specid" to identify the spectrum. data_frame : pd.DataFrame - Default: self.__candidates DataFrame where the percentile will be computed. Must contain the columns "is_correct" and "specid". Returns ------- purity : float The computed purity completeness: float The computed completeness found_quasars : int The total number of found quasars. """ # consistency checks if self.__mode not in ["training", "test"]: raise Error( "The function find_completeness_purity is available in the " + f"training and test modes only. Detected mode is {self.__mode}" ) if data_frame is None: data_frame = self.__candidates if "IS_CORRECT" not in data_frame.columns: raise Error( "find_completeness_purity: invalid DataFrame, the column " + "'IS_CORRECT' is missing") if "SPECID" not in data_frame.columns: raise Error( "find_completeness_purity: invalid DataFrame, the column " + "'SPECID' is missing") found_quasars = 0 found_quasars_zge1 = 0 found_quasars_zge2_1 = 0 num_quasars = quasars_data_frame.shape[0] num_quasars_zge1 = quasars_data_frame[ quasars_data_frame["Z_TRUE"] >= 1.0].shape[0] num_quasars_zge2_1 = quasars_data_frame[ quasars_data_frame["Z_TRUE"] >= 2.1].shape[0] for index in np.arange(num_quasars): specid = quasars_data_frame.iloc[ quasars_data_frame.index[index]]["SPECID"] if data_frame[(data_frame["SPECID"] == specid) & (data_frame["IS_CORRECT"])].shape[0] > 0: found_quasars += 1 if quasars_data_frame.iloc[ quasars_data_frame.index[index]]["Z_TRUE"] >= 2.1: found_quasars_zge2_1 += 1 found_quasars_zge1 += 1 elif quasars_data_frame.iloc[ quasars_data_frame.index[index]]["Z_TRUE"] >= 1: found_quasars_zge1 += 1 if float(num_quasars) > 0.0: completeness = float(found_quasars) / float(num_quasars) else: completeness = np.nan if float(num_quasars_zge1) > 0.0: completeness_zge1 = float(found_quasars_zge1) / float( num_quasars_zge1) else: completeness_zge1 = np.nan if float(num_quasars_zge2_1) > 0.0: completeness_zge2_1 = float(found_quasars_zge2_1) / float( num_quasars_zge2_1) else: completeness_zge2_1 = np.nan if float(data_frame.shape[0]) > 0.: purity = float(data_frame["IS_CORRECT"].sum()) / float( data_frame.shape[0]) purity_zge1 = (float( data_frame[data_frame["Z_TRUE"] >= 1]["IS_CORRECT"].sum()) / float( data_frame[data_frame["Z_TRUE"] >= 1].shape[0])) purity_zge2_1 = ( float(data_frame[data_frame["Z_TRUE"] >= 2.1] ["IS_CORRECT"].sum()) / float(data_frame[data_frame["Z_TRUE"] >= 2.1].shape[0])) line_purity = float(data_frame["IS_LINE"].sum()) / float( data_frame.shape[0]) else: purity = np.nan purity_zge1 = np.nan purity_zge2_1 = np.nan line_purity = np.nan self.__userprint(f"There are {data_frame.shape[0]} candidates ", ) self.__userprint(f"for {num_quasars} catalogued quasars") self.__userprint(f"number of quasars = {num_quasars}") self.__userprint(f"found quasars = {found_quasars}") self.__userprint(f"completeness = {completeness:.2%}") self.__userprint(f"completeness z>=1 = {completeness_zge1:.2%}") self.__userprint(f"completeness z>=2.1 = {completeness_zge2_1:.2%}") self.__userprint(f"purity = {purity:.2%}") self.__userprint(f"purity z >=1 = {purity_zge1:.2%}") self.__userprint(f"purity z >=2.1 = {purity_zge2_1:.2%}") self.__userprint(f"line purity = {line_purity:.2%}") return purity, completeness, found_quasars
def __init__(self, lines_settings=(LINES, TRY_LINES), z_precision=Z_PRECISION, mode="operation", name="SQUEzE_candidates.fits.gz", peakfind=(PEAKFIND_WIDTH, PEAKFIND_SIG), model=None, model_options=(RANDOM_FOREST_OPTIONS, RANDOM_STATE, PASS_COLS_TO_RF), userprint=verboseprint): """ Initialize class instance. Parameters ---------- lines_settings : (pandas.DataFrame, list) - Default: (LINES, TRY_LINES) A tuple with a DataFrame with the information of the lines to compute the ratios and the name of the lines to assume for each of the found peaks. This names must be included in the DataFrame. This will be overloaded if model is not None. z_precision : float - Default: z_precision A true candidate is defined as a candidate having an absolute value of Delta_z is lower or equal than z_precision. Ignored if mode is "operation". This will be overloaded if model is not None. mode : "training", "test", "operation", "candidates", or "merge" - Default: "operation" Running mode. "training" mode assumes that true redshifts are known and provide a series of functions to train the model. name : string - Default: "SQUEzE_candidates.fits.gz" Name of the candidates sample. The code will save an python-binary with the information of the database in a csv file with this name. If load is set to True, then the candidates sample will be loaded from this file. Recommended extension is fits.gz. model : Model or None - Default: None Instance of the Model class defined in squeze_model or None. In test and operation mode, it is supposed to be the quasar model to construct the catalogue. In training mode, it is supposed to be None initially, and the model will be trained and given as an output of the code. model_options : (dict, int, list or None) - Defaut: (RANDOM_FOREST_OPTIONS, RANDOM_STATE, None) The first dictionary sets the options to be passed to the random forest cosntructor. If high-low split of the training is desired, the dictionary must contain the entries "high" and "low", and the corresponding values must be dictionaries with the options for each of the classifiers. The second int is the random state passed to the random forest classifiers. The third list contains columns to be passed to the random forest classifiers (None for no columns). In training mode, they're passed to the model instance before training. Otherwise it's ignored. userprint : function - Default: verboseprint Print function to use """ if mode in ["training", "test", "operation", "candidates", "merge"]: self.__mode = mode else: raise Error("Invalid mode") if name.endswith(".fits.gz") or name.endswith(".fits"): self.__name = name else: message = ( "Candidates name should have .fits or .fits.gz extensions." f"Given name was {name}") raise Error(message) # printing function self.__userprint = userprint # initialize empty catalogue self.__candidates_list = [] self.__candidates = None # main settings self.__lines = lines_settings[0] self.__try_lines = lines_settings[1] self.__z_precision = z_precision # options to be passed to the peak finder self.__peakfind_width = peakfind[0] self.__peakfind_sig = peakfind[1] # model if model is None: self.__model = None else: self.__model = model self.__load_model_settings() self.__model_options = model_options # initialize peak finder self.__peak_finder = PeakFinder(self.__peakfind_width, self.__peakfind_sig) # make sure fields in self.__lines are properly sorted self.__lines = self.__lines[[ 'WAVE', 'START', 'END', 'BLUE_START', 'BLUE_END', 'RED_START', 'RED_END' ]] # compute convert try_lines strings to indexs in self.__lines array self.__try_lines_indexs = np.array([ np.where(self.__lines.index == item)[0][0] for item in self.__try_lines ]) self.__try_lines_dict = dict( zip(self.__try_lines, self.__try_lines_indexs)) self.__try_lines_dict["none"] = -1
def append(self, spectrum): """ Add a spectrum to the list """ if not isinstance(spectrum, Spectrum): raise Error("""Invalid spectrum""") self.__spectra_list.append(spectrum)