コード例 #1
0
ファイル: sdss_data_product.py プロジェクト: sdss/astra
    def query_data_model_identifiers_from_database(self, context):
        """
        Query the SDSS database for ApStar data model identifiers.

        :param context:
            The Airflow DAG execution context.
        """
        if self.release is not None:
            releases = [self.release]
        else:
            releases = self.infer_releases(context)

        mjd_start = parse_as_mjd(context["prev_ds"])
        mjd_end = parse_as_mjd(context["ds"])

        for release in releases:
            if release.upper() == "DR16":
                yield from self.query_sdss4_dr16_data_model_identifiers_from_database(
                    mjd_start, mjd_end)
            elif release.upper() == "SDSS5":
                yield from self.query_sdss5_data_model_identifiers_from_database(
                    mjd_start, mjd_end)
            else:
                log.warning(
                    f"Don't know how to query for data model identifiers for release '{release}'!"
                )
                continue
コード例 #2
0
def _query_task_instances_without_meta(sdss5_only=False):
    stmt = exists().where(
        astradb.TaskInstance.pk == astradb.TaskInstanceMeta.ti_pk)
    q = astra_session.query(astradb.TaskInstance.pk).filter(~stmt)
    if sdss5_only:
        log.warning("Only doing SDSS5 task instances at the moment")
        # Only do sdss5 things so far.
        q = q.filter(astradb.TaskInstanceParameter.ti_pk == astradb.TaskInstance.pk)\
            .filter(astradb.TaskInstanceParameter.parameter_pk.in_((438829, 494337, 493889)))
    return q
コード例 #3
0
ファイル: base.py プロジェクト: sdss/astra
    def on_kill(self) -> None:
        if self.slurm_kwargs:
            # Cancel the Slurm job.
            try:
                cancel_slurm_job_given_name(self._slurm_label)
            except AttributeError:
                log.warning(
                    f"Tried to cancel Slurm job but cannot find the Slurm label! Maybe the Slurm job wasn't submitted yet?"
                )

        self._unlink_primary_key_path()
        return None
コード例 #4
0
ファイル: utils.py プロジェクト: sdss/astra
def deserialize_pks(pk, flatten=False):
    """
    Recursively de-serialize input primary keys, which could be in the form of integers, or as
    paths to temporary files that contain integers.
    
    :param pks:
        the input primary keys
    
    :param flatten: [optional]
        return all primary keys as a single flattened list (default: False)
    
    :returns:
        a list of primary keys as integers
    """
    if isinstance(pk, int):
        v = pk
    elif isinstance(pk, float):
        log.warning(f"Forcing primary key input {pk} as integer")
        v = int(pk)
    elif isinstance(pk, (list, tuple)):
        v = list(map(deserialize_pks, pk))
    elif isinstance(pk, str):
        if os.path.exists(pk):
            with open(pk, "r") as fp:
                contents = json.load(fp)
        else:
            # Need to use double quotes.
            try:
                contents = json.loads(pk.replace("'", '"'))
            except:
                raise ValueError(
                    f"Cannot deserialize primary key of type {type(pk)}: {pk}")
        if isinstance(contents, int):
            v = contents
        else:
            v = list(map(deserialize_pks, contents))
    else:
        raise ValueError(
            f"Cannot deserialize primary key of type {type(pk)}: {pk}")

    return _flatten([v]) if flatten else v
コード例 #5
0
def slice_and_shape(spectrum, slice_args, shape, repeat=None, **kwargs):

    if repeat is not None:
        spectrum._data = np.repeat(spectrum._data, repeat)
        spectrum._uncertainty.array = np.repeat(spectrum._uncertainty.array, repeat)

    if slice_args is not None:
        slices = tuple([slice(*each) for each in slice_args])

        spectrum._data = spectrum._data[slices]
        spectrum._uncertainty.array = spectrum._uncertainty.array[slices]

        try:
            spectrum.meta["snr"] = spectrum.meta["snr"][slices[0]]
        except:
            log.warning(f"Unable to slice 'snr' metadata with {slice_args}")

    spectrum._data = spectrum._data.reshape(shape)
    spectrum._uncertainty.array = spectrum._uncertainty.array.reshape(shape)

    return spectrum
コード例 #6
0
ファイル: sdss_data_product.py プロジェクト: sdss/astra
    def data_model_identifiers(self, context):
        """ 
        Yield the data model identifiers found by this operator.
        
        :param context:
            The Airflow context dictionary.
        """

        if self._data_model_identifiers is not None:
            log.warning(
                "Using data model identifiers specified by _data_model_identifiers. "
                "Ignoring the DAG execution context! Use this at your own risk!"
            )

            sources = self._data_model_identifiers
            if callable(sources):
                sources = sources()

            yield from fulfil_defaults_for_data_model_identifiers(
                sources, context)
        else:
            yield from self.query_data_model_identifiers_from_database(context)
コード例 #7
0
database = SDSSDatabaseConnection(autoconnect=True)
# Ignore what the documentation says for sdssdb.
# Create a file called ~/.config/sdssdb/sdssdb.yml and put your connection info there.

try:
    database.set_profile("astra")

except AssertionError as e:
    from astra.utils import log
    log.exception(e)
    log.warning(
        """ No database profile named 'astra' found in ~/.config/sdssdb/sdssdb.yml -- it should look like:

        astra:
          user: [SDSSDB_USERNAME]
          host: [SDSSDB_HOSTNAME]
          port: 5432
          domain: [SDSSDB_DOMAIN]

        See https://sdssdb.readthedocs.io/en/stable/intro.html#supported-profiles for more details. 
        """)
    session = None

else:
    try:
        session = database.Session()
    except:
        print(f"Cannot load database session")


def init_process(database):
コード例 #8
0
def estimate_radial_velocity(pks,
                             verbose=True,
                             mcmc=False,
                             figfile=None,
                             cornername=None,
                             retpmodels=False,
                             plot=False,
                             tweak=True,
                             usepeak=False,
                             maxvel=[-1000, 1000]):
    """
    Estimate radial velocities for the sources that are identified by the task instances
    of the given primary keys.

    :param pks:
        The primary keys of task instances to estimate radial velocities for, which includes
        parameters to identify the source SDSS data model product.

    See `doppler.rv.fit` for more information on other keyword arguments.
    """

    # TODO: Move this to astra/contrib
    import doppler

    log.info(f"Estimating radial velocities for {len(pks)} task instances")

    failures = []
    for instance, path, spectrum in prepare_data(pks):
        if spectrum is None: continue

        log.debug(f"Running Doppler on {instance} from {path}")

        try:
            spectrum = doppler.read(path)
            summary, model_spectrum, modified_input_spectrum = doppler.rv.fit(
                spectrum,
                verbose=verbose,
                mcmc=mcmc,
                figfile=figfile,
                cornername=cornername,
                retpmodels=retpmodels,
                plot=plot,
                tweak=tweak,
                usepeak=usepeak,
                maxvel=maxvel)

        except:
            log.exception(
                f"Exception occurred on Doppler on {path} with task instance {instance}"
            )
            failures.append(instance.pk)
            continue

        else:
            # Write the output to the database.
            results = prepare_results(summary)

            create_task_output(instance, astradb.Doppler, **results)

    if len(failures) > 0:
        log.warning(
            f"There were {len(failures)} Doppler failures out of a total {len(pks)} executions."
        )
        log.warning(f"Failed primary keys include: {failures}")

        log.warning(f"Raising last exception to indicate failure in pipeline.")
        raise
コード例 #9
0
def write_database_outputs(
        task, 
        ti, 
        run_id, 
        element_from_task_id_callable=None,
        **kwargs
    ):
    """
    Collate outputs from upstream FERRE executions and write them to an ASPCAP database table.
    
    :param task:
        This task, as given by the Airflow context dictionary.
    
    :param ti:
        This task instance, as given by the Airflow context dictionary.
    
    :param run_id:
        This run ID, as given by the Airflow context dictionary.
    
    :param element_from_task_id_callable: [optional]
        A Python callable that returns the chemical element, given a task ID.
    """

    
    log.debug(f"Writing ASPCAP database outputs")

    pks = []
    for upstream_task in task.upstream_list:
        pks.append(ti.xcom_pull(task_ids=upstream_task.task_id))

    log.debug(f"Upstream primary keys: {pks}")

    # Group them together by source.
    instance_pks = []
    for source_pks in list(zip(*pks)):

        # The one with the lowest primary key will be the stellar parameters.
        sp_pk, *abundance_pks = sorted(source_pks)
        
        sp_instance = session.query(astradb.TaskInstance).filter(astradb.TaskInstance.pk == sp_pk).one_or_none()
        abundance_instances = session.query(astradb.TaskInstance).filter(astradb.TaskInstance.pk.in_(abundance_pks)).all()

        # Get parameters that are in common to all instances.
        keep = {}
        for key, value in sp_instance.parameters.items():
            for instance in abundance_instances:
                if instance.parameters[key] != value:
                    break
            else:
                keep[key] = value

        # Create a task instance.
        instance = create_task_instance(
            dag_id=task.dag_id, 
            task_id=task.task_id, 
            run_id=run_id,
            parameters=keep
        )

        # Create a partial results table.
        keys = ["snr"]
        label_names = ("teff", "logg", "metals", "log10vdop", "o_mg_si_s_ca_ti", "lgvsini", "c", "n")
        for key in label_names:
            keys.extend([key, f"u_{key}"])
        
        results = dict([(key, getattr(sp_instance.output, key)) for key in keys])

        # Now update with elemental abundance instances.
        for el_instance in abundance_instances:
            
            if element_from_task_id_callable is not None:
                element = element_from_task_id_callable(el_instance.task_id).lower()
            else:
                element = el_instance.task_id.split(".")[-1].lower()
            
            # Check what is not frozen.
            thawed_label_names = []
            ignore = ("lgvsini", ) # Ignore situations where lgvsini was missing from grid and it screws up the task
            for key in label_names:
                if key not in ignore and not getattr(el_instance.output, f"frozen_{key}"):
                    thawed_label_names.append(key)

            if len(thawed_label_names) > 1:
                log.warning(f"Multiple thawed label names for {element} {el_instance}: {thawed_label_names}")

            values = np.hstack([getattr(el_instance.output, ln) for ln in thawed_label_names]).tolist()
            u_values = np.hstack([getattr(el_instance.output, f"u_{ln}") for ln in thawed_label_names]).tolist()

            results.update({
                f"{element}_h": values,
                f"u_{element}_h": u_values,
            })

        # Include associated primary keys so we can reference back to original parameters, etc.
        results["associated_ti_pks"] = [sp_pk, *abundance_pks]

        log.debug(f"Results entry: {results}")

        # Create an entry in the output interface table.
        # (We will update this later with any elemental abundance results).
        # TODO: Should we link back to the original FERRE primary keys?
        output = create_task_output(
            instance,
            astradb.Aspcap,
            **results
        )
        log.debug(f"Created output {output} for instance {instance}")
        instance_pks.append(instance.pk)
        
    return instance_pks
コード例 #10
0
ファイル: sdss_data_product.py プロジェクト: sdss/astra
def fulfil_defaults_for_data_model_identifiers(data_model_identifiers,
                                               context):
    """
    Intelligently set default entries for partially specified data model identifiers.
    
    :param data_model_identifiers:
        An list (or iterable) of dictionaries, where each dictionary contains keyword arguments
        to specify a data model product.

    :param context:
        The Airflow context dictionary. This is only used to infer the 'release' context,
        if it is not given, based on the execution date.

    :returns:
        A list of data model identifiers, where all required parameters are provided.
    
    :raises RuntimeError:
        If all data model identifiers could not be fulfilled.
    """

    try:
        releases = infer_releases(context["ds"], context["next_ds"])
    except:
        log.exception(f"Could not infer release from context {context}")
        default_release = None
    else:
        # Take the 'most recent' release.
        default_release = releases[-1]

    trees = {}

    defaults = {
        "sdss5": {
            "apStar": {
                "apstar": "stars",
                "apred": "daily",
                "telescope": lambda obj, **_: "apo25m"
                if "+" in obj else "lco25m",
                "healpix": lambda obj, **_: str(healpix(obj)),
            }
        }
    }

    for dmi in data_model_identifiers:

        try:
            filetype = dmi["filetype"]
        except KeyError:
            raise KeyError(
                f"no filetype given for data model identifiers {dmi} "
                f"-- set 'filetype': 'full' and use 'full': <PATH> to set explicit path"
            )
        except:
            raise TypeError(
                f"data model identifiers must be dict-like object (not {type(dmi)}: {dmi}"
            )

        source = dmi.copy()
        release = source.setdefault("release", default_release)

        try:
            tree = trees[release]
        except KeyError:
            trees[release] = tree = SDSSPath(release=release)

        missing_keys = set(tree.lookup_keys(filetype)).difference(dmi)
        for missing_key in missing_keys:
            try:
                default = defaults[release][filetype][missing_key]
            except KeyError:
                raise RuntimeError(
                    f"no default function found for {missing_key} for {release} / {filetype}"
                )

            if callable(default):
                default = default(**source)

            log.warning(
                f"Filling '{missing_key}' with default value '{default}' for {source}"
            )
            source[missing_key] = default

        yield source
コード例 #11
0
ファイル: operators.py プロジェクト: sdss/astra
def train_polynomial_model(labels, data, order=2, regularization=0, threads=1):

    log.debug(f'Inputs are: ({type(labels)}) {labels}')
    log.debug(f'Data are: {data}')
    # labels could be in JSON format.
    if isinstance(labels, str):
        labels = json.loads(labels.replace("'", '"'))
        # TODO: use a general deserializer that fixes the single quote issues with json loading

    if isinstance(data, str) and os.path.exists(data):
        with open(data, "rb") as fp:
            data = pickle.load(fp)

    for key in ("dispersion", "wavelength"):
        try:
            dispersion = data[key]
        except KeyError:
            continue
        else:
            break
    else:
        raise ValueError(f"unable to find {key} in data")

    training_set_flux = data["normalized_flux"]
    training_set_ivar = data["normalized_ivar"]

    try:
        num_spectra = data["num_spectra"]
    except:
        log.debug(
            f"Keeping all items in training set; not checking for missing spectra."
        )
    else:
        keep = (num_spectra == 1)
        if not all(keep):
            log.warning(
                f"Excluding {sum(~keep)} objects from the training set that had missing spectra"
            )

            labels = {k: np.array(v)[keep] for k, v in labels.items()}
            training_set_flux = training_set_flux[keep]
            training_set_ivar = training_set_ivar[keep]

    # Set the vectorizer.
    vectorizer = tc.vectorizer.PolynomialVectorizer(
        labels.keys(),
        order=order,
    )

    # Initiate model.
    model = tc.model.CannonModel(labels,
                                 training_set_flux,
                                 training_set_ivar,
                                 vectorizer=vectorizer,
                                 dispersion=dispersion,
                                 regularization=regularization)

    model.train(threads=threads)

    output_path = os.path.join(get_base_output_path(), "thecannon",
                               "model.pkl")
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    log.info(f"Writing The Cannon model {model} to disk {output_path}")
    model.write(output_path, include_training_set_spectra=True, overwrite=True)
    return output_path
コード例 #12
0
ファイル: operators.py プロジェクト: sdss/astra
def _select_training_set_data_from_database(label_columns,
                                            filter_args=None,
                                            filter_func=None,
                                            limit=None,
                                            **kwargs):
    label_columns = list(label_columns)
    label_names = [column.key for column in label_columns]
    L = len(label_names)

    if filter_func is None:
        filter_func = lambda *_, **__: True

    # Get the label names.
    log.info(f"Querying for label names {label_names} from {label_columns}")

    # Figure out what other columns we will need to identify the input file.
    for column in label_columns:
        try:
            primary_parent = column.class_
        except AttributeError:
            continue
        else:
            break
    else:
        raise ValueError(
            "Can't get primary parent. are you labelling every column?")

    log.debug(f"Identified primary parent table as {primary_parent}")

    if primary_parent == catalogdb.SDSSApogeeAllStarMergeR13:

        log.debug(
            f"Adding columns and setting data_model_func for {primary_parent}")
        additional_columns = [
            catalogdb.SDSSDR16ApogeeStar.apstar_version.label("apstar"),
            catalogdb.SDSSDR16ApogeeStar.field,
            catalogdb.SDSSDR16ApogeeStar.apogee_id.label("obj"),
            catalogdb.SDSSDR16ApogeeStar.file,
            catalogdb.SDSSDR16ApogeeStar.telescope,

            # Things that we might want for filtering on.
            catalogdb.SDSSDR16ApogeeStar.snr
        ]

        columns = label_columns + additional_columns

        q = session.query(*columns).join(
            catalogdb.SDSSApogeeAllStarMergeR13,
            func.trim(catalogdb.SDSSApogeeAllStarMergeR13.apstar_ids) ==
            catalogdb.SDSSDR16ApogeeStar.apstar_id)

        data_model_func = lambda apstar, field, obj, filename, telescope, *_, : {
            "release": "DR16",
            "filetype": "apStar",
            "apstar": apstar,
            "field": field,
            "obj": obj,
            "prefix": filename[:2],
            "telescope": telescope,
            "apred": filename.split("-")[1]
        }

    else:
        raise NotImplementedError(
            f"Cannot intelligently figure out what data model keywords will be necessary."
        )

    if filter_args is not None:
        q = q.filter(*filter_args)

    if limit is not None:
        q = q.limit(limit)

    log.debug(f"Querying {q}")

    data_model_identifiers = []
    labels = {label_name: [] for label_name in label_names}
    for i, row in enumerate(tqdm(q.yield_per(1), total=q.count())):
        if not filter_func(*row): continue

        for label_name, value in zip(label_names, row[:L]):
            if not np.isfinite(value) or value is None:
                log.warning(
                    f"Label {label_name} in {i} row is not finite: {value}!")
            labels[label_name].append(value)
        data_model_identifiers.append(data_model_func(*row[L:]))

    return (labels, data_model_identifiers)
コード例 #13
0
ファイル: base.py プロジェクト: sdss/astra
    def execute_by_slurm(self,
                         context,
                         bash_command,
                         directory=None,
                         poke_interval=60):

        uid = str(uuid.uuid4())[:8]
        label = ".".join([
            context["dag"].dag_id,
            context["task"].task_id,
            context["execution_date"].strftime('%Y-%m-%d'),
            # run_id is None if triggered by command line
            uid
        ])
        if len(label) > 64:
            log.warning(
                f"Truncating Slurm label ({label}) to 64 characters: {label[:64]}"
            )
            label = label[:64]

        self._slurm_label = label

        # It's bad practice to import here, but the slurm package is
        # not easily installable outside of Utah, and is not a "must-have"
        # requirement.
        from slurm import queue

        # TODO: HACK to be able to use local astra installation while in development
        if bash_command.startswith("astra "):
            bash_command = f"/uufs/chpc.utah.edu/common/home/u6020307/.local/bin/astra {bash_command[6:]}"

        slurm_kwargs = (self.slurm_kwargs or dict())

        log.info(
            f"Submitting Slurm job {label} with command:\n\t{bash_command}\nAnd Slurm keyword arguments: {slurm_kwargs}"
        )
        q = queue(verbose=True)
        q.create(label=label, dir=directory, **slurm_kwargs)
        q.append(bash_command)
        try:
            q.commit(hard=True, submit=True)
        except CalledProcessError as e:
            log.exception(
                f"Exception occurred when committing Slurm job with output:\n{e.output}"
            )
            raise

        log.info(
            f"Slurm job submitted with {q.key} and keywords {slurm_kwargs}")
        log.info(f"\tJob directory: {directory or q.job_dir}")

        stdout_path = os.path.join(directory or q.job_dir, f"{label}_01.o")
        stderr_path = os.path.join(directory or q.job_dir, f"{label}_01.e")

        # Now we wait until the Slurm job is complete.
        t_submitted, t_started = (time(), None)
        while 100 > q.get_percent_complete():

            sleep(poke_interval)

            t = time() - t_submitted

            if not os.path.exists(stderr_path) and not os.path.exists(
                    stdout_path):
                log.info(
                    f"Waiting on job {q.key} to start (elapsed: {t / 60:.0f} min)"
                )

            else:
                # Check if this is the first time it has started.
                if t_started is None:
                    t_started = time()
                    log.debug(
                        f"Recording job {q.key} as starting at {t_started} (took {t / 60:.0f} min to start)"
                    )

                log.info(
                    f"Waiting on job {q.key} to finish (elapsed: {t / 60:.0f} min)"
                )
                # Open last line of stdout path?

                # If this has been going much longer than the walltime, then something went wrong.
                # TODO: Check on the status of the job from Slurm.

        log.info(
            f"Job {q.key} in {q.job_dir} is complete after {(time() - t_submitted)/60:.0f} minutes."
        )

        with open(stderr_path, "r", newline="\n") as fp:
            stderr = fp.read()
        log.info(f"Contents of {stderr_path}:\n{stderr}")

        with open(stdout_path, "r", newline="\n") as fp:
            stdout = fp.read()
        log.info(f"Contents of {stdout_path}:\n{stdout}")

        # TODO: Better parsing for critical errors.
        if "Error" in stdout.rstrip().split("\n")[-1] \
        or "Error" in stderr.rstrip().split("\n")[-1]:
            raise RuntimeError(f"detected exception at task end-point")

        # TODO: Get exit codes from squeue

        return None
コード例 #14
0
def get_best_result(task, ti, **kwargs):
    """
    When there are numerous FERRE tasks that are upstream, this
    function will return the primary keys of the task instances that gave
    the best result on a per-observation basis.
    """

    # Get the PKs from upstream.
    pks = []
    log.debug(f"Upstream tasks: {task.upstream_list}")
    for upstream_task in task.upstream_list:
        pks.append(ti.xcom_pull(task_ids=upstream_task.task_id))

    pks = flatten(pks)
    log.debug(f"Getting best initial guess among primary keys {pks}")

    # Need to uniquely identify observations.
    param_bit_mask = bitmask.ParamBitMask()
    bad_grid_edge = (param_bit_mask.get_value("GRIDEDGE_WARN") | param_bit_mask.get_value("GRIDEDGE_BAD"))

    trees = {}
    best_tasks = {}
    for i, pk in enumerate(pks):
        q = session.query(astradb.TaskInstance).filter(astradb.TaskInstance.pk==pk)
        instance = q.one_or_none()

        if instance.output is None:
            log.warning(f"No output found for task instance {instance}")
            continue

        p = instance.parameters

        # Check that the telescope is the same as what we expect from this task ID.
        # This is a bit of a hack. Let us explain.

        # The "BA" grid does not have a telescope/fiber model, so you can run LCO and APO
        # data through the initial-BA grid. And those outputs go to the "get_best_results"
        # for each of the APO and LCO tasks (e.g., this function).
        # If there is only APO data, then the LCO "get_best_result" will only have one
        # input: the BA results. Then it will erroneously think that's the best result
        # for that source.

        # It's hacky to put this logic in here. It should be in the DAG instead. Same
        # thing for parsing 'telescope' name in the DAG (eg 'APO') from 'apo25m'.
        this_telescope_short_name = p["telescope"][:3].upper()
        expected_telescope_short_name = task.task_id.split(".")[1]
        log.info(f"For instance {instance} we have {this_telescope_short_name} and {expected_telescope_short_name}")
        if this_telescope_short_name != expected_telescope_short_name:
            continue

        try:
            tree = trees[p["release"]]                
        except KeyError:
            tree = trees[p["release"]] = SDSSPath(release=p["release"])
        
        key = "_".join([
            p['release'],
            p['filetype'],
            *[p[k] for k in tree.lookup_keys(p['filetype'])]
        ])
        
        best_tasks.setdefault(key, (np.inf, None))
        
        # TODO: Confirm that this is base10 log. This should also be 'log_reduced_chisq_fit',
        #       according to the documentation.
        log_chisq_fit, *_ = instance.output.log_chisq_fit
        previous_teff, *_ = instance.output.teff
        bitmask_flag, *_ = instance.output.bitmask_flag
        
        log.debug(f"Result {instance} {instance.output} with log_chisq_fit = {log_chisq_fit} and {previous_teff} and {bitmask_flag}")
        
        # Note: If FERRE totally fails then it will assign -999 values to the log_chisq_fit. So we have to
        #       check that the log_chisq_fit is actually sensible!
        #       (Or we should only query task instances where the output is sensible!)
        if log_chisq_fit < 0: # TODO: This is a f*****g hack.
            log.debug(f"Skipping result for {instance} {instance.output} as log_chisq_fit = {log_chisq_fit}")
            continue
            
        parsed_header = utils.parse_header_path(p["header_path"])
        
        # Penalise chi-sq in the same way they did for DR17.
        # See github.com/sdss/apogee/python/apogee/aspcap/aspcap.py#L658
        if parsed_header["spectral_type"] == "GK" and previous_teff < 3900:
            log.debug(f"Increasing \chisq because spectral type GK")
            log_chisq_fit += np.log10(10)

        bitmask_flag_logg, bitmask_flag_teff = bitmask_flag[-2:]
        if bitmask_flag_logg & bad_grid_edge:
            log.debug(f"Increasing \chisq because logg flag is bad edge")
            log_chisq_fit += np.log10(5)
            
        if bitmask_flag_teff & bad_grid_edge:
            log.debug(f"Increasing \chisq because teff flag is bad edge")
            log_chisq_fit += np.log10(5)
        
        # Is this the best so far?
        if log_chisq_fit < best_tasks[key][0]:
            log.debug(f"Assigning this output to best task as {log_chisq_fit} < {best_tasks[key][0]}: {pk}")
            best_tasks[key] = (log_chisq_fit, pk)
    
    for key, (log_chisq_fit, pk) in best_tasks.items():
        if pk is None:
            log.warning(f"No good task found for key {key}: ({log_chisq_fit}, {pk})")
        else:
            log.info(f"Best task for key {key} with \chi^2 of {log_chisq_fit:.2f} is primary key {pk}")

    if best_tasks:
        return [pk for (log_chisq_fit, pk) in best_tasks.values() if pk is not None]
    else:
        raise AirflowSkipException(f"no task outputs found from {len(pks)} primary keys")
コード例 #15
0
def sines_and_cosines(
        dispersion, 
        flux, 
        ivar,
        continuum_pixels,
        L=1400,
        order=3,
        regions=None,
        fill_value=1.0,
        **kwargs
    ):
    """
    Fit the flux values of pre-defined continuum pixels using a sum of sine and
    cosine functions.

    :param dispersion:
        The dispersion values.

    :param flux:
        The flux values for all pixels, as they correspond to the `dispersion`
        array.

    :param ivar:
        The inverse variances for all pixels, as they correspond to the
        `dispersion` array.

    :param continuum_pixels:
        A mask that selects pixels that should be considered as 'continuum'.

    :param L: [optional]
        The length scale for the sines and cosines.

    :param order: [optional]
        The number of sine/cosine functions to use in the fit.

    :param regions: [optional]
        Specify sections of the spectra that should be fitted separately in each
        star. This may be due to gaps between CCDs, or some other physically-
        motivated reason. These values should be specified in the same units as
        the `dispersion`, and should be given as a list of `[(start, end), ...]`
        values. For example, APOGEE spectra have gaps near the following
        wavelengths which could be used as `regions`:

        >> regions = ([15090, 15822], [15823, 16451], [16452, 16971])

    :param fill_value: [optional]
        The continuum value to use for when no continuum was calculated for that
        particular pixel (e.g., the pixel is outside of the `regions`).

    :param full_output: [optional]
        If set as True, then a metadata dictionary will also be returned.

    :returns:
        The continuum values for all pixels, and a dictionary that contains 
        metadata about the fit.
    """

    scalar = kwargs.pop("__magic_scalar", 1e-6) # MAGIC
    flux, ivar = np.atleast_2d(flux), np.atleast_2d(ivar)

    bad = ~np.isfinite(ivar) + ~np.isfinite(flux) + (ivar == 0)
    ivar[bad] = 0
    flux[bad] = 1

    if regions is None:
        regions = [(dispersion[0], dispersion[-1])]

    region_masks = []
    region_matrices = []
    continuum_masks = []
    continuum_matrices = []
    pixel_included_in_regions = np.zeros_like(flux).astype(int)
    for i, (start, end) in enumerate(regions):
        # Build the masks for this region.
        si, ei = np.searchsorted(dispersion, (start, end))

        if si == ei:
            # No pixels. Not a valid region.
            continue
        
        region_mask = (end >= dispersion) * (dispersion >= start)
        region_masks.append(region_mask)
        pixel_included_in_regions[:, region_mask] += 1

        continuum_masks.append(continuum_pixels[
            (ei >= continuum_pixels) * (continuum_pixels >= si)])

        # Build the design matrices for this region.
        region_matrices.append(
            _continuum_design_matrix(dispersion[region_masks[-1]], L, order))
        continuum_matrices.append(
            _continuum_design_matrix(dispersion[continuum_masks[-1]], L, order))

        # TODO: ISSUE: Check for overlapping regions and raise an warning.

    # Check for non-zero pixels (e.g. ivar > 0) that are not included in a
    # region. We should warn about this very loudly!
    warn_on_pixels = (pixel_included_in_regions == 0) * (ivar > 0)

    metadata = []
    continuum = np.ones_like(flux) * fill_value
    for i in range(flux.shape[0]):

        warn_indices = np.where(warn_on_pixels[i])[0]
        if any(warn_indices):
            # Split by deltas so that we give useful warning messages.
            segment_indices = np.where(np.diff(warn_indices) > 1)[0]
            segment_indices = np.sort(np.hstack(
                [0, segment_indices, segment_indices + 1, len(warn_indices)]))
            segment_indices = segment_indices.reshape(-1, 2)

            segments = ", ".join(["{:.1f} to {:.1f}".format(
                dispersion[s], dispersion[e], e-s) for s, e in segment_indices])

            log.warning(f"Some pixels in have measured flux values (e.g., ivar > 0) but "
                        f"are not included in any specified region ({segments}).")

        # Get the flux and inverse variance for this object.
        object_metadata = []
        object_flux, object_ivar = (flux[i], ivar[i])

        # Normalize each region.
        for region_mask, region_matrix, continuum_mask, continuum_matrix in \
        zip(region_masks, region_matrices, continuum_masks, continuum_matrices):
            if continuum_mask.size == 0:
                # Skipping..
                object_metadata.append([order, L, fill_value, scalar, [], None])
                continue

            # We will fit to continuum pixels only.   
            continuum_disp = dispersion[continuum_mask] 
            continuum_flux, continuum_ivar \
                = (object_flux[continuum_mask], object_ivar[continuum_mask])

            # Solve for the amplitudes.
            M = continuum_matrix
            MTM = np.dot(M, continuum_ivar[:, None] * M.T)
            MTy = np.dot(M, (continuum_ivar * continuum_flux).T)

            eigenvalues = np.linalg.eigvalsh(MTM)
            MTM[np.diag_indices(len(MTM))] += scalar * np.max(eigenvalues)
            eigenvalues = np.linalg.eigvalsh(MTM)
            condition_number = max(eigenvalues)/min(eigenvalues)

            amplitudes = np.linalg.solve(MTM, MTy)
            continuum[i, region_mask] = np.dot(region_matrix.T, amplitudes)
            object_metadata.append(
                (order, L, fill_value, scalar, amplitudes, condition_number))

        metadata.append(object_metadata)

    return (continuum, metadata) 
コード例 #16
0
ファイル: utils.py プロジェクト: sdss/astra
def prepare_data(pks):
    """
    Return the task instance, data model path, and spectrum for each given primary key,
    and apply any spectrum callbacks to the spectrum as it is loaded.

    :param pks:
        Primary keys of task instances to load data products for.

    :returns:
        Yields a four length tuple containing the task instance, the spectrum path, the
        original spectrum, and the modified spectrum after any spectrum callbacks have been
        executed. If no spectrum callback is executed, then the modified spectrum will be
        `None`.
    """

    trees = {}

    for pk in deserialize_pks(pks, flatten=True):
        q = session.query(
            astradb.TaskInstance).filter(astradb.TaskInstance.pk == pk)
        instance = q.one_or_none()

        if instance is None:
            log.warning(f"No task instance found for primary key {pk}")
            path = spectrum = None

        else:
            release = instance.parameters["release"]
            tree = trees.get(release, None)
            if tree is None:
                trees[release] = tree = SDSSPath(release=release)

            # Monkey-patch BOSS Spec paths.
            try:
                path = tree.full(**instance.parameters)
            except:
                if instance.parameters["filetype"] == "spec":
                    from astra.utils import monkey_patch_get_boss_spec_path
                    path = monkey_patch_get_boss_spec_path(
                        **instance.parameters)
                else:
                    raise

            try:
                spectrum = Spectrum1D.read(path)
            except:
                log.exception(
                    f"Unable to load Spectrum1D from path {path} on task instance {instance}"
                )
                spectrum = None
            else:
                # Are there any spectrum callbacks?
                spectrum_callback = instance.parameters.get(
                    "spectrum_callback", None)
                if spectrum_callback is not None:
                    spectrum_callback_kwargs = instance.parameters.get(
                        "spectrum_callback_kwargs", "{}")
                    try:
                        spectrum_callback_kwargs = literal_eval(
                            spectrum_callback_kwargs)
                    except:
                        log.exception(
                            f"Unable to literally evalute spectrum callback kwargs for {instance}: {spectrum_callback_kwargs}"
                        )
                        raise

                    try:
                        func = string_to_callable(spectrum_callback)

                        spectrum = func(spectrum=spectrum,
                                        path=path,
                                        instance=instance,
                                        **spectrum_callback_kwargs)

                    except:
                        log.exception(
                            f"Unable to execute spectrum callback '{spectrum_callback}' on {instance}"
                        )
                        raise

        yield (instance, path, spectrum)
コード例 #17
0
    spectrum._uncertainty.array[bad] = ivar_bad_pixel
    
    # Ensure a minimum error.
    # TODO: This seems like a pretty bad idea!
    spectrum._uncertainty.array = np.clip(spectrum._uncertainty.array, ivar_min, ivar_max) # sigma = 5e-3

    if slice_args is not None:
        slices = tuple([slice(*each) for each in slice_args])

        spectrum._data = spectrum._data[slices]
        spectrum._uncertainty.array = spectrum._uncertainty.array[slices]

        try:
            spectrum.meta["snr"] = spectrum.meta["snr"][slices[0]]
        except:
            log.warning(f"Unable to slice 'snr' metadata with {slice_args}")
    

    if median_filter_correction_from_task_id_like is not None:

        upstream_pk = instance.parameters.get("upstream_pk", None)
        if upstream_pk is None:
            raise ValueError(f"cannot do median filter correction because no upstream_pk parameter for {instance}")

        upstream_pk = literal_eval(upstream_pk)

        # There could be many upstream tasks listed, so we should get the matching one.
        q = session.query(astradb.TaskInstance)\
                   .filter(astradb.TaskInstance.pk.in_(upstream_pk))\
                   .filter(astradb.TaskInstance.task_id.like(median_filter_correction_from_task_id_like))
コード例 #18
0
def _parse_names_and_initial_and_frozen_parameters(
        names,
        initial_parameters,
        frozen_parameters,
        headers,
        flux,
        clip_initial_parameters_to_boundary_edges=True,
        clip_epsilon_percent=1,
        **kwargs):

    # Read the labels from the first header path
    parameter_names = headers["LABEL"]

    # Need the number of spectra, which we will take from the flux array.
    N = len(flux)
    mid_point = _grid_mid_point(headers)
    parsed_initial_parameters = np.tile(mid_point, N).reshape((N, -1))

    log.debug(f"parsed initial parameters before {parsed_initial_parameters}")

    compare_parameter_names = list(
        map(sanitise_parameter_name, parameter_names))

    log.debug(f"Initial parameters passed for parsing {initial_parameters}")

    if initial_parameters is not None:
        log.debug(f"Comparison names {compare_parameter_names}")
        for i, (parameter_name,
                values) in enumerate(initial_parameters.items()):
            spn = sanitise_parameter_name(parameter_name)
            log.debug(f"{parameter_name} {values} {spn}")

            try:
                index = compare_parameter_names.index(spn)
            except ValueError:
                log.warning(
                    f"Ignoring initial parameters for {parameter_name} as they are not in {parameter_names}"
                )
                log.debug(
                    f"Nothing matched for {spn} {parameter_name} {compare_parameter_names}"
                )
            else:
                log.debug(f"Matched to index {index}")
                # Replace non-finite values with the mid point.
                finite = np.isfinite(values)
                if not np.all(finite):
                    log.warning(
                        f"Missing or non-finite initial values given for {parameter_name}. Defaulting to the grid mid-point."
                    )

                values = np.array(values)
                values[~finite] = mid_point[index]

                log.debug(f"values are {values} {type(values[0])} {finite}")
                parsed_initial_parameters[:, index] = values

    log.debug(f"parsed initial parameters after {parsed_initial_parameters}")

    kwds = dict()
    frozen_parameters = (frozen_parameters or dict())
    if frozen_parameters:
        # Ensure we have a dict-like thing.
        if isinstance(frozen_parameters, (list, tuple, np.ndarray)):
            frozen_parameters = {
                sanitise_parameter_name(k): True
                for k in frozen_parameters
            }
        elif isinstance(frozen_parameters, dict):
            # Exclude things that have boolean False.
            frozen_parameters = {
                sanitise_parameter_name(k): v for k, v in frozen_parameters.items() \
                if not (isinstance(v, bool) and not v)
            }
        else:
            raise TypeError(
                f"frozen_parameters must be list-like or dict-like")

        unknown_parameters = set(frozen_parameters).difference(
            compare_parameter_names)
        if unknown_parameters:
            raise ValueError(
                f"unknown parameter(s): {unknown_parameters} (available: {parameter_names})"
            )

        indices = [
            i for i, pn in enumerate(compare_parameter_names, start=1)
            if pn not in frozen_parameters
        ]

        if len(indices) == 0:
            raise ValueError(f"all parameters frozen?!")

        # Over-ride initial values with the frozen ones if given.
        for parameter_name, value in frozen_parameters.items():
            if not isinstance(value, bool):
                log.debug(
                    f"Over-writing initial values for {parameter_name} with frozen value of {value}"
                )
                zero_index = compare_parameter_names.index(parameter_name)
                parsed_initial_parameters[:, zero_index] = value
    else:
        # No frozen parameters.
        indices = 1 + np.arange(len(parameter_names), dtype=int)

    # Build a frozen parameters dict for result metadata.
    parsed_frozen_parameters = {
        pn: (pn in frozen_parameters)
        for pn in compare_parameter_names
    }

    L = len(indices)
    kwds.update(
        ndim=headers["N_OF_DIM"],
        nov=L,
        indv=" ".join([f"{i:.0f}" for i in indices]),
        # We will always provide an initial guess, even if it is the grid mid point.
        init=0,
        indini=" ".join(["1"] * L))

    # Now deal with names.
    if names is None:
        names = [f"{i:.0f}" for i in range(len(parsed_initial_parameters))]
    else:
        if len(names) != len(parsed_initial_parameters):
            raise ValueError(
                f"names and initial parameters does not match ({len(names)} != {len(parsed_initial_parameters)})"
            )

    # Let's check the initial values are all within the grid boundaries.
    lower_limit, upper_limit = _get_grid_limits(headers)
    try:
        _check_initial_parameters_within_grid_limits(parsed_initial_parameters,
                                                     lower_limit, upper_limit,
                                                     parameter_names)
    except ValueError as e:
        log.exception(
            f"Exception when checking initial parameters within grid boundaries:"
        )
        log.critical(e, exc_info=True)

        if clip_initial_parameters_to_boundary_edges:
            log.info(
                f"Clipping initial parameters to boundary edges (use clip_initial_parameters_to_boundary_edges=False to raise exception instead)"
            )

            clip = clip_epsilon_percent * (upper_limit - lower_limit) / 100.
            parsed_initial_parameters = np.round(
                np.clip(parsed_initial_parameters, lower_limit + clip,
                        upper_limit - clip), 3)
        else:
            raise

    return (kwds, names, parsed_initial_parameters, parsed_frozen_parameters)
コード例 #19
0
def _estimate_stellar_labels(pk):

    # TODO: It would be great if these were stored with the network,
    #       instead of being hard-coded.
    label_names = ["teff", "logg", "vsini", "v_micro", "m_h"]
    # Translate:
    _t = {
        "teff": "T_eff",
        "logg": "log(g)",
        "m_h": "[M/H]",
        "vsini": "v*sin(i)",
    }

    # TODO: This implicitly assumes that the same constraints and network path are used by all the
    #       primary keys given. This is the usual case, but we should check this, and code around it.

    # TODO: This implementation requires knowing the observed spectrum before loading data.
    #       This is fine for ApStar objects since they all have the same dispersion sampling,
    #       but will not be fine for dispersion sampling that differs in each observation.

    # Let's peak ahead at the first valid spectrum we can find.
    instance, _, spectrum = next(prepare_data([pk]))
    if spectrum is None:
        # No valid spectrum.
        log.warning(
            f"Cannot build LSF for fitter because no spectrum found for primary key {pk}"
        )
        return None

    network = Network()
    network.read_in(instance.parameters["network_path"])

    constraints = json.loads(instance.parameters.get("constraints", "{}"))
    fitted_label_names = [
        ln for ln in label_names \
            if network.grid[_t.get(ln, ln)][0] != network.grid[_t.get(ln, ln)][1]
    ]
    L = len(fitted_label_names)

    bounds_unscaled = np.zeros((2, L))
    for i, ln in enumerate(fitted_label_names):
        bounds_unscaled[:,
                        i] = constraints.get(ln, network.grid[_t.get(ln,
                                                                     ln)][:2])

    fit = Fit(network, int(instance.parameters["N_chebyshev"]))
    fit.bounds_unscaled = bounds_unscaled

    spectral_resolution = int(instance.parameters["spectral_resolution"])
    fit.lsf = LSF_Fixed_R(spectral_resolution, spectrum.wavelength.value,
                          network.wave)

    # Note the Stramut code uses inconsistent naming for "presearch", but in the operator interface we use
    # 'pre_search' in all situations. That's why there is some funny naming translation here.
    fit.N_presearch_iter = int(instance.parameters["N_pre_search_iter"])
    fit.N_pre_search = int(instance.parameters["N_pre_search"])

    fitter = UncertFit(fit, spectral_resolution)
    N, P = spectrum.flux.shape

    keys = []
    keys.extend(fitted_label_names)
    keys.extend([f"u_{ln}" for ln in fitted_label_names])
    keys.extend(["v_rad", "u_v_rad", "chi2", "theta"])

    result = {key: [] for key in keys}
    result["snr"] = spectrum.meta["snr"]

    model_fluxes = []
    log.info(f"Running ThePayne-Che on {N} spectra for {instance}")

    for i in range(N):

        flux = spectrum.flux.value[i]
        error = spectrum.uncertainty.array[0]**-0.5

        # TODO: No NaNs/infs are allowed, but it doesn't seem like that was an issue for Stramut's code.
        #       Possibly due to different versions of scipy. In any case, raise this as a potential bug,
        #       since the errors do not always seem to be believed by ThePayne-Che.
        bad = (~np.isfinite(flux)) | (error <= 0)
        flux[bad] = 0
        error[bad] = 1e10

        fit_result = fitter.run(
            spectrum.wavelength.value,
            flux,
            error,
        )

        # The `popt` attribute is length: len(label_names) + 1 (for radial velocity) + N_chebyshev

        # Relevent attributes are:
        # - fit_result.popt
        # - fit_result.uncert
        # - fit_result.RV_uncert
        # - fit_result.model

        for j, label_name in enumerate(fitted_label_names):
            result[label_name].append(fit_result.popt[j])
            result[f"u_{label_name}"].append(fit_result.uncert[j])

        result["theta"].append(fit_result.popt[L + 1:].tolist())
        result["chi2"].append(fit_result.chi2_func(fit_result.popt))
        result["v_rad"].append(fit_result.popt[L])
        result["u_v_rad"].append(fit_result.RV_uncert)

        model_fluxes.append(fit_result.model)

    # Write database result.
    create_task_output(instance, astradb.ThePayneChe, **result)

    # TODO: Write AstraSource object here.
    return None