コード例 #1
0
ファイル: abstract_test.py プロジェクト: iprafols/SQUEzE
    def compare_json_spectra(self, orig_file, new_file):
        """Compares two sets of spectra saved in a json file"""
        orig_spectra = Spectra.from_json(load_json(orig_file))
        orig_spectra_list = orig_spectra.spectra_list()
        new_spectra = Spectra.from_json(load_json(new_file))
        new_spectra_list = new_spectra.spectra_list()

        self.assertTrue(orig_spectra.size(), new_spectra.size())
        for index in range(orig_spectra.size()):
            self.assertTrue(np.allclose(orig_spectra_list[index].wave(),
                                        new_spectra_list[index].wave()))
            self.assertTrue(np.allclose(orig_spectra_list[index].flux(),
                                        new_spectra_list[index].flux()))
            self.assertTrue(np.allclose(orig_spectra_list[index].ivar(),
                                        new_spectra_list[index].ivar()))
コード例 #2
0
ファイル: aps_squeze.py プロジェクト: iprafols/SQUEzE
def squeze_worker(infiles,
                  model,
                  aps_ids,
                  targsrvy,
                  targclass,
                  mask_aps_ids,
                  area,
                  mask_areas,
                  wlranges,
                  cache_Rcsr,
                  sens_corr,
                  mask_gaps,
                  vacuum,
                  tellurics,
                  fill_gap,
                  arms_ratio,
                  join_arms,
                  quiet=False,
                  save_file=None):
    """
    Function description:
        Run SQUEzE on the data from infiles
    """

    # manage verbosity
    userprint = verboseprint if not quiet else quietprint

    # load model
    userprint("================================================")
    userprint("")
    userprint("Loading model")
    if model.endswith(".json"):
        model = Model.from_json(load_json(model))
    else:
        model = Model.from_fits(model)

    # load spectra
    userprint("Loading spectra")
    weave_formatted_spectra = APSOB(infiles,
                                    aps_ids=aps_ids,
                                    targsrvy=targsrvy,
                                    targclass=targclass,
                                    mask_aps_ids=mask_aps_ids,
                                    area=area,
                                    mask_areas=mask_areas,
                                    wlranges=wlranges,
                                    sens_corr=sens_corr,
                                    mask_gaps=mask_gaps,
                                    vacuum=vacuum,
                                    tellurics=tellurics,
                                    fill_gap=fill_gap,
                                    arms_ratio=arms_ratio,
                                    join_arms=join_arms)

    userprint("Formatting spectra to be digested by SQUEzE")
    spectra = Spectra.from_weave(weave_formatted_spectra, userprint=userprint)

    # TODO: split spectra into several sublists so that we can parallelise

    # initialize candidates object
    userprint("Initialize candidates object")
    if save_file is None:
        candidates = Candidates(mode="operation", model=model)
    else:
        candidates = Candidates(mode="operation", model=model, name=save_file)

    # look for candidates
    userprint("Looking for candidates")
    candidates.find_candidates(spectra.spectra_list())

    columns_candidates = spectra.spectra_list()[0].metadata_names()
    if save_file is None:
        candidates.candidates_list_to_dataframe(columns_candidates, save=False)
    else:
        candidates.candidates_list_to_dataframe(columns_candidates, save=True)

    # compute probabilities
    userprint("Computing probabilities")
    candidates.classify_candidates()

    # TODO: if we paralelize, then use 'merging' mode to join the results

    userprint("SQUEzE run is completed, returning")
    userprint("================================================")
    userprint("")
    # return the candidates and the chosen probability threshold
    return candidates.candidates(), model.get_settings().get("Z_PRECISION")
コード例 #3
0
ファイル: format_mini_jpas.py プロジェクト: iprafols/SQUEzE
def main(cmdargs):
    """ Load DESI spectra using the Spectra and DESISpectrum Classes
        defined in squeze_boss_spectra.py and squeze_desi_spectrum.py
        respectively.
        """

    # load options
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        parents=[PARENT_PARSER])

    parser.add_argument("--input-filename",
                        type=str,
                        required=True,
                        help="""Name of the fits file to be loaded.""")
    parser.add_argument("--filters-info",
                        type=str,
                        required=True,
                        help="""Name of the file containing the filters
                        information""")
    parser.add_argument("--filter-wave",
                        type=str,
                        required=False,
                        default="Filter.effective_wavelength",
                        help="""Name of the field containing the wavelength""")
    parser.add_argument(
        "--mag-col",
        type=str,
        required=True,
        help="""Name of the magnitude system to be used. For example,
                        type PSFCOR to use the column FLambdaDualObj.FLUX_PSFCOR for them
                        flux and FLambdaDualObj.MAG_RELERR_PSFCOR for the errors"""
    )
    parser.add_argument("--output-filename",
                        type=str,
                        required=True,
                        help="""Name of the output filename.""")
    parser.add_argument(
        "--keep-cols",
        nargs='+',
        default=None,
        required=False,
        help="""White-spaced list of the columns to be kept in the
                        formatted spectra""")
    parser.add_argument("--trays-t1-t2",
                        action="store_true",
                        required=False,
                        help="""If passed, load only filters in trays T1 and
                        T2.""")

    args = parser.parse_args(cmdargs)

    # manage verbosity
    userprint = verboseprint if not args.quiet else quietprint

    # load filters information
    userprint("Loading filters")
    hdu_filters = fits.open(args.filters_info)
    filter_names = hdu_filters[1].data["Filter.name"]

    if args.trays_t1_t2:
        select_filters = np.array([
            0, 3, 5, 7, 9, 11, 14, 16, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
            29, 30, 31, 32, 33, 34, 35, 36, 37, 38
        ])
    else:
        select_filters = np.array(
            [i for i, name in enumerate(filter_names) if name.startswith("J")])
    # The 1.0 mulitplying is added to change type from >4f to np.float
    # this is required by numba later on
    wave = 1.0 * hdu_filters[1].data[args.filter_wave][select_filters]

    hdu_filters.close()

    # initialize squeze Spectra class
    squeze_spectra = Spectra()

    # loop over spectra
    userprint("Loading spectra")
    hdu = fits.open(args.input_filename)
    for row in hdu[1].data:

        # load data
        # The 1.0 mulitplying is added to change type from >4f to np.float
        # this is required by numba later on
        mask = ((row["FLambdaDualObj.FLAGS"] > 0) |
                (row["FLambdaDualObj.MASK_FLAGS"] > 0))
        flux = 1.0 * row["FLambdaDualObj.FLUX_{}".format(args.mag_col)]
        relerr = 1.0 * row["FLambdaDualObj.FLUX_RELERR_{}".format(
            args.mag_col)]
        ivar = 1 / (flux * relerr)**2
        ivar[mask] = 0.0

        # select the filters in select_filters
        flux = flux[select_filters]
        ivar = ivar[select_filters]

        # prepare metadata
        metadata = {col.upper(): row[col] for col in args.keep_cols}
        metadata["SPECID"] = int("{}{}".format(row["FLambdaDualObj.TILE_ID"],
                                               row["FLambdaDualObj.NUMBER"]))

        # format spectrum
        spectrum = SimpleSpectrum(flux, ivar, wave, metadata)

        # append to list
        squeze_spectra.append(spectrum)

    # save formated spectra
    userprint(f"Saving to file: {args.output_filename}")
    save_json(args.output_filename, squeze_spectra)

    userprint("Done")
コード例 #4
0
def main(cmdargs):
    """ Run SQUEzE in operation mode """
    # load options
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        parents=[OPERATION_PARSER])
    args = parser.parse_args(cmdargs)

    # manage verbosity
    userprint = verboseprint if not args.quiet else quietprint

    t0 = time.time()
    # load model
    userprint("Loading model")
    if args.model.endswith(".json"):
        model = Model.from_json(load_json(args.model))
    else:
        model = Model.from_fits(args.model)
    t1 = time.time()
    userprint(f"INFO: time elapsed to load model", (t1 - t0) / 60.0, 'minutes')

    # initialize candidates object
    userprint("Initializing candidates object")
    if args.output_candidates is None:
        candidates = Candidates(mode="operation",
                                model=model,
                                userprint=userprint)
    else:
        candidates = Candidates(mode="operation",
                                name=args.output_candidates,
                                model=model,
                                userprint=userprint)

    # load candidates dataframe if they have previously looked for
    if args.load_candidates:
        userprint("Loading existing candidates")
        t2 = time.time()
        candidates.load_candidates(args.input_candidates)
        t3 = time.time()
        userprint(
            f"INFO: time elapsed to load candidates: {(t3-t2)/60.0} minutes")

    # load spectra
    if args.input_spectra is not None:
        userprint("Loading spectra")
        t4 = time.time()
        columns_candidates = []
        userprint("There are {} files with spectra to be loaded".format(
            len(args.input_spectra)))
        for index, spectra_filename in enumerate(args.input_spectra):
            userprint("Loading spectra from {} ({}/{})".format(
                spectra_filename, index, len(args.input_spectra)))
            t40 = time.time()
            spectra = Spectra.from_json(load_json(spectra_filename))
            if not isinstance(spectra, Spectra):
                raise Error("Invalid list of spectra")

            if index == 0:
                columns_candidates += spectra.spectra_list()[0].metadata_names(
                )

            # look for candidates
            userprint("Looking for candidates")
            candidates.find_candidates(spectra.spectra_list(),
                                       columns_candidates)

            t41 = time.time()
            userprint(
                f"INFO: time elapsed to find candidates from {spectra_filename}:"
                f" {(t41-t40)/60.0} minutes")

        t5 = time.time()
        userprint(
            f"INFO: time elapsed to find candidates: {(t5-t4)/60.0} minutes")

        # convert to dataframe
        userprint("Converting candidates to dataframe")
        t6 = time.time()
        candidates.candidates_list_to_dataframe(columns_candidates)
        t7 = time.time()
        userprint(
            f"INFO: time elapsed to convert candidates to dataframe: {(t7-t6)/60.0} minutes"
        )

    # compute probabilities
    userprint("Computing probabilities")
    t8 = time.time()
    candidates.classify_candidates()
    t9 = time.time()
    userprint(
        f"INFO: time elapsed to classify candidates: {(t9-t8)/60.0} minutes")

    # save the catalogue as a fits file
    if not args.no_save_catalogue:
        candidates.save_catalogue(args.output_catalogue, args.prob_cut)

    t10 = time.time()
    userprint(f"INFO: total elapsed time: {(t10-t0)/60.0} minutes")
    userprint("Done")
コード例 #5
0
def main(cmdargs):
    """ Load WEAVE spectra using the WeaveSpectrum Class defined in
        squeze_weave_spectrum.py.
        """

    # load options
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        parents=[PARENT_PARSER, QUASAR_CATALOGUE_PARSER])

    parser.add_argument(
        "--red-spectra",
        nargs='+',
        type=str,
        default=None,
        required=True,
        help="""Name of the fits file containig the red CCD data.
                            Size should be the same as the list in --blue-spectra."""
    )

    parser.add_argument(
        "--blue-spectra",
        nargs='+',
        type=str,
        default=None,
        required=True,
        help="""Name of the fits file containig the blue CCD data.
                            Size should be the same as the list in --blue-spectra."""
    )

    parser.add_argument("--out",
                        type=str,
                        default="spectra.json",
                        required=False,
                        help="""Name of the json file where the list of spectra
                            will be saved""")

    parser.add_argument(
        "--mag-cat",
        type=str,
        default=None,
        required=False,
        help="""Name of the text file containing the measured magnitudes
                            for the observed spectra""")

    args = parser.parse_args(cmdargs)

    # manage verbosity
    userprint = verboseprint if not args.quiet else quietprint

    # load quasar catalogue
    userprint("loading catalogue from {}".format(args.qso_cat))
    quasar_catalogue = QuasarCatalogue(args.qso_cat, args.qso_cols,
                                       args.qso_specid, args.qso_hdu)
    quasar_catalogue = quasar_catalogue.quasar_catalogue()

    # load magnitudes catalogue
    if args.mag_cat is not None:
        mags_catalogue = pd.read_csv(args.mag_cat, delim_whitespace=True)

    # intialize spectra variable
    spectra = Spectra()

    for red_filename, blue_filename in zip(args.red_spectra,
                                           args.blue_spectra):
        # load red spectra
        userprint("loading red spectra from {}".format(red_filename))
        observed_red = fits.open(red_filename)
        wave = {
            "red_delta_wave":
            observed_red["RED_DATA"].header["CD1_1"],
            "red_wave":
            np.zeros(observed_red["RED_DATA"].header["NAXIS1"], dtype=float)
        }
        wave.get("red_wave")[0] = observed_red["RED_DATA"].header["CRVAL1"]
        for index in range(1, wave.get("red_wave").size):
            wave["red_wave"][index] = wave["red_wave"][index - 1] + wave.get(
                "red_delta_wave")
        targid = observed_red["FIBTABLE"].data["TARGID"].astype(str)

        # load blue spectra
        userprint("loading blue spectra from {}".format(blue_filename))
        observed_blue = fits.open(blue_filename)
        wave["blue_delta_wave"] = observed_blue["BLUE_DATA"].header["CD1_1"]
        wave["blue_wave"] = np.zeros(
            observed_blue["BLUE_DATA"].header["NAXIS1"], dtype=float)
        wave.get("blue_wave")[0] = observed_blue["BLUE_DATA"].header["CRVAL1"]
        for index in range(1, wave.get("blue_wave").size):
            wave.get("blue_wave")[index] = wave.get("blue_wave")[index - 1] + \
                wave.get("blue_delta_wave")

        # format spectra
        userprint("formatting red and blue data into a single spectra")
        for index in tqdm.tqdm(range(
                observed_red["RED_DATA"].header["NAXIS2"])):
            spectrum_dict = {"red_flux" : observed_red["RED_DATA"].data[index]*\
                observed_red["RED_SENSFUNC"].data[index],
                             "red_ivar" : observed_red["RED_IVAR"].data[index]*\
                observed_red["RED_SENSFUNC"].data[index],
                             "blue_flux" : observed_blue["BLUE_DATA"].data[index]*\
                observed_blue["BLUE_SENSFUNC"].data[index],
                             "blue_ivar" : observed_blue["BLUE_IVAR"].data[index]*\
                observed_blue["BLUE_SENSFUNC"].data[index]}

            # check that we don't have an empty spectrum
            if np.unique(spectrum_dict.get("red_flux")).size == 1:
                continue

            # add targid to metadata
            metadata = {"TARGID": targid[index], "SPECID": targid[index]}

            # add true redshift to metadata
            if quasar_catalogue[quasar_catalogue["TARGID"] ==
                                targid[index]]["Z"].shape[0] > 0:
                metadata["Z_TRUE"] = quasar_catalogue[
                    quasar_catalogue["TARGID"] == targid[index]]["Z"].values[0]
            else:
                metadata["Z_TRUE"] = np.nan

            # add magnitudes to metadata
            if mags_catalogue[mags_catalogue["TARGID"] ==
                              targid[index]]["GMAG"].shape[0] > 0:
                metadata["GMAG"] = mags_catalogue[mags_catalogue["TARGID"] == \
                                                  targid[index]]["GMAG"].values[0]
                metadata["RMAG"] = mags_catalogue[mags_catalogue["TARGID"] == \
                                                  targid[index]]["RMAG"].values[0]
            else:
                metadata["GMAG"] = np.nan
                metadata["RMAG"] = np.nan

            # add spectrum to list
            spectra.append(WeaveSpectrum(spectrum_dict, wave, metadata))

    # save them as a json file to be used by SQUEzE
    save_json(args.out, spectra)
コード例 #6
0
ファイル: squeze_test.py プロジェクト: iprafols/SQUEzE
def main(cmdargs):
    """ Run SQUEzE in test mode """
    # load options
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        parents=[TEST_PARSER])
    args = parser.parse_args(cmdargs)
    if args.check_statistics:
        quasar_parser_check(parser, args)

    # manage verbosity
    userprint = verboseprint if not args.quiet else quietprint

    t0 = time.time()
    # load quasar catalogue (only if --check-statistics is passed)
    if args.check_statistics:
        userprint("Loading quasar catalogue")
        if args.qso_dataframe is not None:
            quasar_catalogue = deserialize(load_json(args.qso_dataframe))
            quasar_catalogue["LOADED"] = True
        else:
            quasar_catalogue = QuasarCatalogue(
                args.qso_cat, args.qso_cols, args.qso_specid, args.qso_ztrue,
                args.qso_hdu).quasar_catalogue()
            quasar_catalogue["LOADED"] = False
        t1 = time.time()
        userprint(
            f"INFO: time elapsed to load quasar catalogue: {(t1-t0)/60.0} minutes"
        )

    # load model
    userprint("Loading model")
    t2 = time.time()
    if args.model.endswith(".json"):
        model = Model.from_json(load_json(args.model))
    else:
        model = Model.from_fits(args.model)
    t3 = time.time()
    userprint(f"INFO: time elapsed to load model: {(t3-t2)/60.0} minutes")

    # initialize candidates object
    userprint("Initializing candidates object")
    if args.output_candidates is None:
        candidates = Candidates(mode="test", model=model, userprint=userprint)
    else:
        candidates = Candidates(mode="test",
                                name=args.output_candidates,
                                model=model,
                                userprint=userprint)

    # load candidates dataframe if they have previously looked for
    if args.load_candidates:
        userprint("Loading existing candidates")
        t4 = time.time()
        candidates.load_candidates(args.input_candidates)
        t5 = time.time()
        userprint(
            f"INFO: time elapsed to load candidates: {(t5-t4)/60.0} minutes")

    # load spectra
    if args.input_spectra is not None:
        userprint("Loading spectra")
        t6 = time.time()
        columns_candidates = []
        userprint("There are {} files with spectra to be loaded".format(
            len(args.input_spectra)))
        for index, spectra_filename in enumerate(args.input_spectra):
            userprint("Loading spectra from {} ({}/{})".format(
                spectra_filename, index, len(args.input_spectra)))
            t60 = time.time()
            spectra = Spectra.from_json(load_json(spectra_filename))
            if not isinstance(spectra, Spectra):
                raise Error("Invalid list of spectra")

            if index == 0:
                columns_candidates += spectra.spectra_list()[0].metadata_names(
                )

            # flag loaded quasars as such
            if args.check_statistics:
                for spec in spectra.spectra_list():
                    if quasar_catalogue[quasar_catalogue["SPECID"] ==
                                        spec.metadata_by_key(
                                            "SPECID")].shape[0] > 0:
                        index2 = quasar_catalogue.index[
                            quasar_catalogue["SPECID"] == spec.metadata_by_key(
                                "SPECID")].tolist()[0]
                        quasar_catalogue.at[index2, "LOADED"] = True

            # look for candidates
            userprint("Looking for candidates")
            candidates.find_candidates(spectra.spectra_list(),
                                       columns_candidates)

            t61 = time.time()
            userprint(
                f"INFO: time elapsed to find candidates from {spectra_filename}:"
                f" {(t61-t60)/60.0} minutes")

        t7 = time.time()
        userprint(
            f"INFO: time elapsed to find candidates: {(t7-t6)/60.0} minutes")

        # convert to dataframe
        userprint("Converting candidates to dataframe")
        t8 = time.time()
        candidates.candidates_list_to_dataframe(columns_candidates)
        t9 = time.time()
        userprint(
            f"INFO: time elapsed to convert candidates to dataframe: {(t9-t8)/60.0} minutes"
        )

    # compute probabilities
    userprint("Computing probabilities")
    t10 = time.time()
    candidates.classify_candidates()
    t11 = time.time()
    userprint(
        f"INFO: time elapsed to classify candidates: {(t11-t10)/60.0} minutes")

    # check completeness
    if args.check_statistics:
        probs = args.check_probs if args.check_probs is not None else np.arange(
            0.9, 0.0, -0.05)
        userprint("Check statistics")
        data_frame = candidates.candidates()
        userprint("\n---------------")
        userprint("step 1")
        candidates.find_completeness_purity(quasar_catalogue.reset_index(),
                                            data_frame)
        for prob in probs:
            userprint("\n---------------")
            userprint("proba > {}".format(prob))
            candidates.find_completeness_purity(
                quasar_catalogue.reset_index(),
                data_frame[(data_frame["PROB"] > prob)
                           & ~(data_frame["DUPLICATED"]) &
                           (data_frame["Z_CONF_PERSON"] == 3)],
            )

    # save the catalogue as a fits file
    if not args.no_save_catalogue:
        candidates.save_catalogue(args.output_catalogue, args.prob_cut)

    t12 = time.time()
    userprint(f"INFO: total elapsed time: {(t12-t0)/60.0} minutes")
    userprint("Done")
コード例 #7
0
ファイル: format_desi_minisv.py プロジェクト: iprafols/SQUEzE
def main(cmdargs):
    """ Load DESI spectra using the Spectra and DESISpectrum Classes
        defined in squeze_boss_spectra.py and squeze_desi_spectrum.py
        respectively.
        """

    # load options
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument(
        "--input-filename",
        type=str,
        required=True,
        help="""Name of the filename to be loaded to be loaded.""")
    parser.add_argument("--output-filename",
                        type=str,
                        required=True,
                        help="""Name of the output filename.""")
    parser.add_argument(
        "--single-exp",
        action="store_true",
        help="""Load only the first reobservation for each spectrum""")
    parser.add_argument(
        "--metadata",
        nargs='+',
        required=False,
        default=["TARGETID", "TARGET_RA", "TARGET_DEC", "CMX_TARGET"],
        help="""White-spaced list of the list of columns to keep as metadata"""
    )
    args = parser.parse_args(cmdargs)

    # read desi spectra
    desi_spectra = read_spectra(args.input_filename)

    # initialize squeze Spectra class
    squeze_spectra = Spectra()

    # get targetids
    targetid = np.unique(desi_spectra.fibermap["TARGETID"])

    # loop over targeid
    for targid in targetid:

        # select objects
        pos = np.where(desi_spectra.fibermap["TARGETID"] == targid)

        # prepare metadata
        metadata = {
            col.upper(): desi_spectra.fibermap[col][pos[0][0]]
            for col in args.metadata
        }

        # add specid
        metadata["SPECID"] = targid

        # Extract-2 data
        flux = {}
        wave = {}
        ivar = {}
        mask = {}
        for band in desi_spectra.bands:
            flux[band] = desi_spectra.flux[band][pos]
            wave[band] = desi_spectra.wave[band]
            ivar[band] = desi_spectra.ivar[band][pos]
            mask[band] = desi_spectra.mask[band][pos]

        # format spectrum
        spectrum = DesiSpectrum(flux, wave, ivar, mask, metadata,
                                args.single_exp)

        # append to list
        squeze_spectra.append(spectrum)

    # save formated spectra
    save_json(args.output_filename, squeze_spectra)
コード例 #8
0
def main(cmdargs):
    """ Run SQUEzE in training mode """
    # load options
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        parents=[TRAINING_PARSER])
    args = parser.parse_args(cmdargs)

    # manage verbosity
    userprint = verboseprint if not args.quiet else quietprint

    t0 = time.time()
    # load lines
    userprint("Loading lines")
    lines = LINES if args.lines is None else deserialize(load_json(args.lines))

    # load try_line
    try_line = TRY_LINES if args.try_lines is None else args.try_lines

    # load redshift precision
    z_precision = Z_PRECISION if args.z_precision is None else args.z_precision

    # load peakfinder options
    peakfind_width = PEAKFIND_WIDTH if args.peakfind_width is None else args.peakfind_width
    peakfind_sig = PEAKFIND_SIG if args.peakfind_sig is None else args.peakfind_sig

    # load random forest options
    random_forest_options = RANDOM_FOREST_OPTIONS if args.random_forest_options is None else load_json(
        args.random_forest_options)
    random_state = RANDOM_STATE if args.random_state is None else args.random_state

    # initialize candidates object
    userprint("Initializing candidates object")
    if args.output_candidates is None:
        candidates = Candidates(lines_settings=(lines, try_line),
                                z_precision=z_precision,
                                mode="training",
                                peakfind=(peakfind_width, peakfind_sig),
                                model=None,
                                userprint=userprint,
                                model_options=(random_forest_options,
                                               random_state,
                                               args.pass_cols_to_rf))
    else:
        candidates = Candidates(lines_settings=(lines, try_line),
                                z_precision=z_precision,
                                mode="training",
                                name=args.output_candidates,
                                peakfind=(peakfind_width, peakfind_sig),
                                model=None,
                                userprint=userprint,
                                model_options=(random_forest_options,
                                               random_state,
                                               args.pass_cols_to_rf))

    # load candidates dataframe if they have previously looked for
    if args.load_candidates:
        userprint("Loading existing candidates")
        t1 = time.time()
        candidates.load_candidates(args.input_candidates)
        t2 = time.time()
        userprint(
            f"INFO: time elapsed to load candidates: {(t2-t1)/60.0} minutes")

    # load spectra
    if args.input_spectra is not None:
        userprint("Loading spectra")
        t3 = time.time()
        columns_candidates = []
        userprint("There are {} files with spectra to be loaded".format(
            len(args.input_spectra)))
        for index, spectra_filename in enumerate(args.input_spectra):
            userprint("Loading spectra from {} ({}/{})".format(
                spectra_filename, index, len(args.input_spectra)))
            t30 = time.time()
            spectra = Spectra.from_json(load_json(spectra_filename))
            if not isinstance(spectra, Spectra):
                raise Error("Invalid list of spectra")

            if index == 0:
                columns_candidates += spectra.spectra_list()[0].metadata_names(
                )

            # look for candidates
            userprint("Looking for candidates")
            candidates.find_candidates(spectra.spectra_list(),
                                       columns_candidates)

            t31 = time.time()
            userprint(
                f"INFO: time elapsed to find candidates from {spectra_filename}: "
                f"{(t31-t30)/60.0} minutes")

        t4 = time.time()
        userprint(
            f"INFO: time elapsed to find candidates: {(t4-t3)/60.0} minutes")

        # convert to dataframe
        userprint("Converting candidates to dataframe")
        t5 = time.time()
        candidates.candidates_list_to_dataframe(columns_candidates)
        t6 = time.time()
        userprint(
            f"INFO: time elapsed to convert candidates to dataframe: {(t6-t5)/60.0} minutes"
        )

    # train model
    userprint("Training model")
    t7 = time.time()
    candidates.train_model(args.model_fits)
    t8 = time.time()
    userprint(f"INFO: time elapsed to train model: {(t8-t7)/60.0} minutes")

    userprint(f"INFO: total elapsed time: {(t8-t0)/60.0} minutes")
    userprint("Done")
コード例 #9
0
def main(cmdargs):
    """ Load DESI spectra using the Spectra and DESISpectrum Classes
        defined in squeze_boss_spectra.py and squeze_desi_spectrum.py
        respectively.
        """

    # load options
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument(
        "-i",
        "--input-filename",
        type=str,
        required=True,
        help="""Name of the filename to be loaded to be loaded.""")
    parser.add_argument(
        "-m",
        "--model",
        type=str,
        required=True,
        help="""Name of the file containing the trained model.""")
    parser.add_argument("-o",
                        "--output-filename",
                        type=str,
                        required=True,
                        help="""Name of the output fits file.""")
    parser.add_argument(
        "-e",
        "--single-exp",
        action="store_true",
        help="""Load only the first reobservation for each spectrum""")
    parser.add_argument(
        "--metadata",
        nargs='+',
        required=False,
        default=["TARGETID"],
        help="""White-spaced list of the list of columns to keep as metadata"""
    )
    parser.add_argument("-v",
                        "--verbose",
                        action="store_true",
                        help="""Print messages""")
    args = parser.parse_args(cmdargs)

    # prepare variables
    assert args.output_filename.endswith(
        "fits") or args.output_filename.endswith("fits.gz")
    if args.verbose:
        userprint = verboseprint
    else:
        userprint = quietprint
    args.keep_cols = [col.upper() for col in args.keep_cols]

    # read desi spectra
    userprint("Reading spectra")
    desi_spectra = read_spectra(args.input_filename)

    # initialize squeze Spectra class
    squeze_spectra = Spectra()

    # get targetids
    targetid = np.unique(desi_spectra.fibermap["TARGETID"])

    # loop over targeid
    for targid in targetid:

        # select objects
        pos = np.where(desi_spectra.fibermap["TARGETID"] == targid)

        # prepare metadata
        metadata = {
            col.upper(): desi_spectra.fibermap[col][pos[0][0]]
            for col in args.metadata
        }

        # add specid
        metadata["SPECID"] = targid

        # Extract-2 data
        flux = {}
        wave = {}
        ivar = {}
        mask = {}
        for band in desi_spectra.bands:
            flux[band] = desi_spectra.flux[band][pos]
            wave[band] = desi_spectra.wave[band]
            ivar[band] = desi_spectra.ivar[band][pos]
            mask[band] = desi_spectra.mask[band][pos]

        # format spectrum
        spectrum = DesiSpectrum(flux, wave, ivar, mask, metadata,
                                args.single_exp)

        # append to list
        squeze_spectra.append(spectrum)

    # load model
    userprint("Reading model")
    if args.model.endswith(".json"):
        model = Model.from_json(load_json(args.model))
    else:
        model = Model.from_fits(args.model)

    # initialize candidates object
    userprint("Initialize candidates object")
    candidates = Candidates(mode="operation",
                            model=model,
                            name=args.output_filename)

    # look for candidates
    userprint('Looking for candidates')
    candidates.find_candidates(squeze_spectra.spectra_list())
    columns_candidates = squeze_spectra.spectra_list()[0].metadata_names()
    candidates.candidates_list_to_dataframe(columns_candidates, save=False)

    # compute probabilities
    userprint("Computing probabilities")
    candidates.classify_candidates(save=False)

    # filter results
    data_frame = candidates.candidates()
    data_frame = data_frame[~data_frame["DUPLICATED"]]

    # save results
    data_out = np.zeros(len(data_frame),
                        dtype=[('TARGETID', 'int64'), ('Z_SQ', 'float64'),
                               ('Z_SQ_CONF', 'float64')])
    data_out['TARGETID'] = data_frame['TARGETID'].values
    data_out['Z_SQ'] = data_frame['Z_TRY'].values
    data_out['Z_SQ_CONF'] = data_frame['PROB'].values

    data_hdu = fits.BinTableHDU.from_columns(data_out, name='SQZ_CAT')
    data_hdu.writeto(args.output_filename)