예제 #1
0
파일: utils.py 프로젝트: sdss/astra
def get_or_create_parameter_pk(name, value):
    """
    Get or create the primary key for a parameter key/value pair in the database.

    :param name:
        the name of the parameter
    
    :param value:
        the value of the parameter, serialized or not
    
    :returns:
        A two-length tuple containing the integer of the primary key, and a boolean
        indicating whether the entry in the database was created by this function call.
    """

    kwds = dict(parameter_name=name, parameter_value=serialize(value))
    q = session.query(astradb.Parameter).filter_by(**kwds)
    instance = q.one_or_none()
    create = (instance is None)
    if create:
        instance = astradb.Parameter(**kwds)
        try:
            with session.begin(subtransactions=True):
                session.add(instance)

        except sqlalchemy.exc.IntegrityError:

            q = session.query(astradb.Parameter).filter_by(**kwds)
            instance = q.one_or_none()
            if instance is None:
                log.exception(
                    f"Cannot create or retrieve parameter with {kwds}")
                raise

    return (instance.pk, create)
예제 #2
0
def _query_task_instances_without_meta(sdss5_only=False):
    stmt = exists().where(
        astradb.TaskInstance.pk == astradb.TaskInstanceMeta.ti_pk)
    q = astra_session.query(astradb.TaskInstance.pk).filter(~stmt)
    if sdss5_only:
        log.warning("Only doing SDSS5 task instances at the moment")
        # Only do sdss5 things so far.
        q = q.filter(astradb.TaskInstanceParameter.ti_pk == astradb.TaskInstance.pk)\
            .filter(astradb.TaskInstanceParameter.parameter_pk.in_((438829, 494337, 493889)))
    return q
예제 #3
0
파일: utils.py 프로젝트: sdss/astra
def create_task_output(task_instance_or_pk, model, **kwargs):
    """
    Create a new entry in the database for the output of a task.

    :param task_instance_or_pk:
        the task instance (or its primary key) to reference this output to
        
    :param model:
        the database model to store the result (e.g., `astra.database.astradb.Ferre`)
    
    :param \**kwargs:
        the keyword arguments that will be stored in the database
    
    :returns:
        A two-length tuple containing the task instance, and the output instance
    """

    # Get the task instance.
    if not isinstance(task_instance_or_pk, astradb.TaskInstance):
        task_instance = session.query(astradb.TaskInstance)\
                               .filter(astradb.TaskInstance.pk == task_instance_or_pk)\
                               .one_or_none()

        if task_instance is None:
            raise ValueError(
                f"no task instance found matching primary key {task_instance_pk}"
            )
    else:
        task_instance = task_instance_or_pk

    # Create a new output interface entry.
    with session.begin():
        output_interface = astradb.OutputInterface()
        session.add(output_interface)

    assert output_interface.pk is not None

    # Include the task instance PK so that if the output for that task instance is
    # later updated, then we can still find historical outputs.
    kwds = dict(ti_pk=task_instance.pk, output_pk=output_interface.pk)
    kwds.update(kwargs)

    # Create the instance of the result.
    output_result = model(**kwds)
    with session.begin():
        session.add(output_result)

    # Reference the output to the task instance.
    with session.begin():
        task_instance.output_pk = output_interface.pk

    assert task_instance.output_pk is not None
    #log.info(f"Created output {output_result} for task instance {task_instance}")
    return output_result
예제 #4
0
def add_meta_to_new_task_instances_without_meta():
    """
    Add meta to new task instances without meta.
    """
    pk, = astra_session.query(astradb.TaskInstance.pk).filter(
        astradb.TaskInstanceMeta.ti_pk == astradb.TaskInstance.pk).order_by(
            astradb.TaskInstance.pk.desc()).first()

    stmt = exists().where(
        astradb.TaskInstance.pk == astradb.TaskInstanceMeta.ti_pk)
    q = astra_session.query(astradb.TaskInstance.pk).filter(
        astradb.TaskInstance.pk > pk).filter(~stmt)
    for pk in tqdm(q.yield_per(1),
                   total=q.count(),
                   desc="Adding metadata to task instances"):
        try:
            add_meta_to_task_instance(pk)
        except:
            log.exception(f"Unable to add meta to task instance with pk {pk}")
            continue
    return None
예제 #5
0
파일: utils.py 프로젝트: sdss/astra
def del_task_instance_parameter(task_instance, key):
    try:
        value = task_instance.parameters[key]
    except KeyError:
        # That key isn't in there!
        None
    else:
        # Get the PK.
        parameter_pk, _ = get_or_create_parameter_pk(key, value)

        # Get the TI/PK
        q = session.query(astradb.TaskInstanceParameter).filter(
            astradb.TaskInstanceParameter.ti_pk == task_instance.pk,
            astradb.TaskInstanceParameter.parameter_pk ==
            parameter_pk).one_or_none()

        session.delete(q)
        log.debug(
            f"Removed key/value pair {key}: {value} from task instance {task_instance}"
        )

    assert key not in task_instance.parameters
    return True
예제 #6
0
def classify_apstar(pks, dag, task, run_id, **kwargs):
    """
    Classify observations of APOGEE (ApStar) sources, given the existing classifications of the
    individual visits.

    :param pks:
        The primary keys of task instances where visits have been classified. These primary keys will
        be used to work out which stars need classifying, before tasks
    """

    pks = deserialize_pks(pks, flatten=True)

    # For each unique apStar object, we need to find all the visits that have been classified.
    distinct_apogee_drp_star_pk = session.query(
        distinct(astradb.TaskInstanceMeta.apogee_drp_star_pk)).filter(
            astradb.TaskInstance.pk.in_(pks),
            astradb.TaskInstanceMeta.ti_pk == astradb.TaskInstance.pk).all()

    # We need to make sure that we will only retrieve results on apVisit objects, and not on apStar objects.
    parameter_pk, = session.query(astradb.Parameter.pk).filter(
        astradb.Parameter.parameter_name == "filetype",
        astradb.Parameter.parameter_value == "apVisit").one_or_none()

    for star_pk in distinct_apogee_drp_star_pk:

        results = session.query(
            astradb.TaskInstance, astradb.TaskInstanceMeta,
            astradb.Classification
        ).filter(
            astradb.Classification.output_pk == astradb.TaskInstance.output_pk,
            astradb.TaskInstance.pk == astradb.TaskInstanceMeta.ti_pk,
            astradb.TaskInstanceMeta.apogee_drp_star_pk == star_pk,
            astradb.TaskInstanceParameter.ti_pk == astradb.TaskInstance.pk,
            astradb.TaskInstanceParameter.parameter_pk == parameter_pk).all()

        column_func = lambda column_name: column_name.startswith("lp_")

        lps = {}
        for j, (ti, meta, classification) in enumerate(results):
            if j == 0:
                for column_name in classification.__table__.columns.keys():
                    if column_func(column_name):
                        lps[column_name] = []

            for column_name in lps.keys():
                values = getattr(classification, column_name)
                if values is None: continue
                assert len(
                    values
                ) == 1, "We are getting results from apStars and re-adding to apStars!"
                lps[column_name].append(values[0])

        # Calculate total log probabilities.
        joint_lps = np.array(
            [np.sum(lp) for lp in lps.values() if len(lp) > 0])
        keys = [key for key, lp in lps.items() if len(lp) > 0]

        # Calculate normalized probabilities.
        with np.errstate(under="ignore"):
            relative_log_probs = joint_lps - logsumexp(joint_lps)

        # Round for PostgreSQL 'real' type.
        # https://www.postgresql.org/docs/9.1/datatype-numeric.html
        # and
        # https://stackoverflow.com/questions/9556586/floating-point-numbers-of-python-float-and-postgresql-double-precision
        decimals = 3
        probs = np.round(np.exp(relative_log_probs), decimals)

        joint_result = {k: [float(lp)] for k, lp in zip(keys, joint_lps)}
        joint_result.update({k[1:]: [float(v)] for k, v in zip(keys, probs)})

        # Create a task for this classification.
        # To do that we need to construct the parameters for the task.
        columns = (
            apogee_drpdb.Star.apred_vers.label(
                "apred"),  # TODO: Raise with Nidever
            apogee_drpdb.Star.healpix,
            apogee_drpdb.Star.telescope,
            apogee_drpdb.Star.apogee_id.label(
                "obj"),  # TODO: Raise with Nidever
        )
        apred, healpix, telescope, obj = sdss_session.query(*columns).filter(
            apogee_drpdb.Star.pk == star_pk).one()
        parameters = dict(apred=apred,
                          healpix=healpix,
                          telescope=telescope,
                          obj=obj,
                          release="sdss5",
                          filetype="apStar",
                          apstar="stars")

        args = (dag.dag_id, task.task_id, run_id)

        # Get a string representation of the python callable to store in the database.

        instance = create_task_instance(*args, parameters)
        output = create_task_output(instance.pk, astradb.Classification,
                                    **joint_result)

        raise a
예제 #7
0
파일: utils.py 프로젝트: sdss/astra
def get_task_instance(
    dag_id: str,
    task_id: str,
    run_id,
    parameters: Dict,
):
    """
    Get a task instance exactly matching the given DAG and task identifiers, and the given parameters.

    :param dag_id:
        The identifier of the directed acyclic graph (DAG).
    
    :param task_id:
        The identifier of the task.

    :param run_id:
        The identifier of the Apache Airflow execution run.
    
    :param parameters:
        The parameters of the task, as a dictionary
    """

    # TODO: Profile this and consider whether it should be used.
    if False:
        # Quick check for things matching dag_id or task_id, which is cheaper than checking all parameters.
        q_ti = session.query(astradb.TaskInstance).filter(
            astradb.TaskInstance.dag_id == dag_id,
            astradb.TaskInstance.task_id == task_id,
            astradb.TaskInstance.run_id == run_id)
        if q_ti.count() == 0:
            return None

    # Get primary keys of the individual parameters, and then check by task.
    q_p = session.query(astradb.Parameter.pk).filter(
        or_(*(and_(astradb.Parameter.parameter_name == k,
                   astradb.Parameter.parameter_value == serialize(v))
              for k, v in parameters.items())))
    N_p = q_p.count()
    if N_p < len(parameters):
        # No task with all of these parameters.
        return None

    # Perform subquery to get primary keys of task instances that have all of these parameters.
    sq = session.query(astradb.TaskInstanceParameter.ti_pk)\
                .filter(astradb.TaskInstanceParameter.parameter_pk.in_(pk for pk, in q_p.all()))\
                .group_by(astradb.TaskInstanceParameter.ti_pk)\
                .having(func.count(distinct(astradb.TaskInstanceParameter.parameter_pk)) == N_p).subquery()

    # If an exact match is required, combine multiple sub-queries.
    if True:
        sq = session.query(
            astradb.TaskInstanceParameter.ti_pk).join(
                sq,
                astradb.TaskInstanceParameter.ti_pk == sq.c.ti_pk
            )\
            .group_by(astradb.TaskInstanceParameter.ti_pk)\
            .having(func.count(distinct(astradb.TaskInstanceParameter.parameter_pk)) == len(parameters)).subquery()

    # Query for task instances that match the subquery and match our additional constraints.
    q = session.query(astradb.TaskInstance).join(
        sq, astradb.TaskInstance.pk == sq.c.ti_pk)
    q = q.filter(astradb.TaskInstance.dag_id == dag_id)\
         .filter(astradb.TaskInstance.task_id == task_id)
    if run_id is not None:
        q = q.filter(astradb.TaskInstance.run_id == run_id)

    return q.one_or_none()
예제 #8
0
파일: astradb.py 프로젝트: sdss/astra
 def get_task_instances(self):
     return session.query(TaskInstance).filter_by(
         output_pk=self.output_pk).all()
예제 #9
0
def _get_results(database_model_name,
                 spectrum_index=None,
                 filter_by_kwargs=None,
                 limit=None):

    # Get the database model
    database_model = getattr(astradb, database_model_name)
    if filter_by_kwargs is None:
        filter_by_kwargs = dict()

    q = astra_session.query(
        astradb.TaskInstance,
        func.json_object_agg(astradb.Parameter.parameter_name,
                             astradb.Parameter.parameter_value),
        astradb.TaskInstanceMeta,
        database_model,
    ).filter(
        astradb.TaskInstance.output_pk == database_model.output_pk,
        astradb.TaskInstance.pk == astradb.TaskInstanceMeta.ti_pk,
        astradb.TaskInstance.pk == astradb.TaskInstanceParameter.ti_pk,
        astradb.TaskInstanceParameter.parameter_pk == astradb.Parameter.pk,
    ).filter_by(**filter_by_kwargs).group_by(astradb.TaskInstance,
                                             astradb.TaskInstanceMeta,
                                             database_model)

    if limit is not None:
        q = q.limit(limit)

    rows = []
    for ti, parameters, meta, result in tqdm(q.yield_per(1), total=q.count()):

        row = OrderedDict([
            # Source information.
            ("catalogid", meta.catalogid),
            ("ra", meta.ra or np.nan),
            ("dec", meta.dec or np.nan),
            ("pmra", meta.pmra or np.nan),
            ("pmdec", meta.pmdec or np.nan),
            ("parallax", meta.parallax or np.nan),
            ("gaia_dr2_source_id", meta.gaia_dr2_source_id or -1),
            # Task information.
            ("ti_pk", ti.pk),
            # Parameters (minimal)
            ("release", parameters.get("release", "")),
            ("obj", parameters.get("obj", "")),
            ("healpix", parameters.get("healpix", -1)),
            ("telescope", parameters.get("telescope", "")),
        ])

        # Add the result information.
        ignore_keys = ("output_pk", "ti_pk", "associated_ti_pks")
        if spectrum_index is None:
            # Get all results.
            N = len(result.snr)
            for i in range(N):
                this_row = row.copy()
                this_row["spectrum_index"] = i
                for key in result.__table__.columns.keys():
                    if key in ignore_keys or key.startswith("_"): continue

                    value = getattr(result, key)
                    if isinstance(value, (tuple, list)):
                        value = value[i]
                    this_row[key] = value or np.nan

                rows.append(this_row)

        else:
            # Only single result.
            for key in result.__table__.columns.keys():
                if key in ignore_keys or key.startswith("_"): continue

                value = getattr(result, key)

                if isinstance(value, (tuple, list)):
                    value = value[spectrum_index]

                row[key] = value or np.nan

            rows.append(row)

    return rows
예제 #10
0
파일: operators.py 프로젝트: sdss/astra
def _select_training_set_data_from_database(label_columns,
                                            filter_args=None,
                                            filter_func=None,
                                            limit=None,
                                            **kwargs):
    label_columns = list(label_columns)
    label_names = [column.key for column in label_columns]
    L = len(label_names)

    if filter_func is None:
        filter_func = lambda *_, **__: True

    # Get the label names.
    log.info(f"Querying for label names {label_names} from {label_columns}")

    # Figure out what other columns we will need to identify the input file.
    for column in label_columns:
        try:
            primary_parent = column.class_
        except AttributeError:
            continue
        else:
            break
    else:
        raise ValueError(
            "Can't get primary parent. are you labelling every column?")

    log.debug(f"Identified primary parent table as {primary_parent}")

    if primary_parent == catalogdb.SDSSApogeeAllStarMergeR13:

        log.debug(
            f"Adding columns and setting data_model_func for {primary_parent}")
        additional_columns = [
            catalogdb.SDSSDR16ApogeeStar.apstar_version.label("apstar"),
            catalogdb.SDSSDR16ApogeeStar.field,
            catalogdb.SDSSDR16ApogeeStar.apogee_id.label("obj"),
            catalogdb.SDSSDR16ApogeeStar.file,
            catalogdb.SDSSDR16ApogeeStar.telescope,

            # Things that we might want for filtering on.
            catalogdb.SDSSDR16ApogeeStar.snr
        ]

        columns = label_columns + additional_columns

        q = session.query(*columns).join(
            catalogdb.SDSSApogeeAllStarMergeR13,
            func.trim(catalogdb.SDSSApogeeAllStarMergeR13.apstar_ids) ==
            catalogdb.SDSSDR16ApogeeStar.apstar_id)

        data_model_func = lambda apstar, field, obj, filename, telescope, *_, : {
            "release": "DR16",
            "filetype": "apStar",
            "apstar": apstar,
            "field": field,
            "obj": obj,
            "prefix": filename[:2],
            "telescope": telescope,
            "apred": filename.split("-")[1]
        }

    else:
        raise NotImplementedError(
            f"Cannot intelligently figure out what data model keywords will be necessary."
        )

    if filter_args is not None:
        q = q.filter(*filter_args)

    if limit is not None:
        q = q.limit(limit)

    log.debug(f"Querying {q}")

    data_model_identifiers = []
    labels = {label_name: [] for label_name in label_names}
    for i, row in enumerate(tqdm(q.yield_per(1), total=q.count())):
        if not filter_func(*row): continue

        for label_name, value in zip(label_names, row[:L]):
            if not np.isfinite(value) or value is None:
                log.warning(
                    f"Label {label_name} in {i} row is not finite: {value}!")
            labels[label_name].append(value)
        data_model_identifiers.append(data_model_func(*row[L:]))

    return (labels, data_model_identifiers)
예제 #11
0
파일: astradb.py 프로젝트: sdss/astra
 def meta(self):
     return session.query(TaskInstanceMeta).filter(
         TaskInstanceMeta.ti_pk == self.pk).one_or_none()
예제 #12
0
def individual_visit_data():
    """
    Return a dictionary of results where the differences between labels derived from 
    individual visits and those derived from the stacked spectra are pre-calculated.
    """
    sq = session.query(
            astradb.ApogeeNet.output_pk.label("output_pk"),
            func.json_object_agg(
                astradb.Parameter.parameter_name,
                astradb.Parameter.parameter_value
            ).label("parameters")
        )\
        .filter(astradb.ApogeeNet.output_pk == astradb.TaskInstance.output_pk)\
        .filter(astradb.TaskInstance.pk == astradb.TaskInstanceParameter.ti_pk)\
        .filter(astradb.TaskInstanceParameter.parameter_pk == astradb.Parameter.pk)\
        .group_by(astradb.ApogeeNet)\
        .subquery(with_labels=True)

    q = session.query(
            astradb.TaskInstance,
            astradb.ApogeeNet, 
            func.cardinality(astradb.ApogeeNet.snr),
            sq.c.parameters
        )\
        .filter(sq.c.output_pk == astradb.ApogeeNet.output_pk)\
        .filter(sq.c.output_pk == astradb.TaskInstance.output_pk)

    total, = session.query(func.sum(func.cardinality(astradb.ApogeeNet.snr))).first()

    keys = (
        "ti_pk", "snr_stacked", "snr_visit", 
        "teff_stacked", "logg_stacked", "fe_h_stacked", 
        "delta_teff", "delta_logg", "delta_fe_h", 
        "bitmask_stacked", "bitmask_visit", "release", "date"
    )
    data = OrderedDict([(key, []) for key in keys])
    
    with tqdm(total=total, unit="spectra") as pb:
        for task_instance, result, N, parameters in q.yield_per(1):
            date = datetime.strptime(task_instance.run_id.split("T")[0].split("_")[-1], "%Y-%m-%d")
            for i in range(2, N):
                data["ti_pk"].append(task_instance.pk)
                data["snr_stacked"].append(result.snr[0])
                data["snr_visit"].append(result.snr[i])
                data["teff_stacked"].append(result.teff[0])
                data["logg_stacked"].append(result.logg[0])
                data["fe_h_stacked"].append(result.fe_h[0])
                data["delta_teff"].append(result.teff[i] - result.teff[0])
                data["delta_logg"].append(result.logg[i] - result.logg[0])
                data["delta_fe_h"].append(result.fe_h[i] - result.fe_h[0])
                data["bitmask_stacked"].append(result.bitmask_flag[0])
                data["bitmask_visit"].append(result.bitmask_flag[i])
                data["release"].append(parameters["release"])
                data["date"].append(date.year + (int(date.strftime("%j")) / 366))

            pb.update(N)

    for key in keys:
        data[key] = np.array(data[key])

    return data
예제 #13
0
def branch(task_id_callable, task, ti, **kwargs):
    """
    A function to branch specific downstream tasks, given the primary keys
    returned by the upstream tasks.
    
    :param task_id_callable:
        A Python callable that takes in as input the `header_path` and
        returns a task ID.
    
    :param task:    
        The task being executed. This is supplied by the DAG context.
    
    :param ti:
        The task instance. This is supplied by the DAG context.
    
    :returns:
        A list of task IDs that should execute next.
    """

    # Get primary keys from upstream tasks.
    pks = []
    for upstream_task in task.upstream_list:
        pks.append(ti.xcom_pull(task_ids=upstream_task.task_id))

    pks = flatten(pks)
    log.debug(f"Upstream primary keys: {pks}")
    log.debug(f"Downstream task IDs: {task.downstream_list}")

    # Get unique header paths for the primary keys given.
    # TODO: This query could fail if the number of primary keys provided
    #       is yuuge. May consider changing this query.
    q = session.query(astradb.TaskInstanceParameter.ti_pk, astradb.Parameter.parameter_value)\
               .join(astradb.TaskInstanceParameter, 
                     astradb.TaskInstanceParameter.parameter_pk == astradb.Parameter.pk)\
               .filter(astradb.Parameter.parameter_name == "header_path")\
               .filter(astradb.TaskInstanceParameter.ti_pk.in_(pks))    

    log.debug(f"Found:")
    downstream_task_ids = []
    for pk, header_path in q.all():
        log.debug(f"\t{pk}: {header_path}")

        telescope, lsf, spectral_type_desc = utils.task_id_parts(header_path)
        if telescope is None and lsf is None:
            # Special hack for BA grids, where telescope/lsf information cannot be found from header path.
            # TODO: Consider removing this hack entirely. This could be fixed by symbolicly linking the BA grids to locations
            #       for each telescope/fibre combination.

            instance = session.query(astradb.TaskInstance)\
                              .filter(astradb.TaskInstance.pk == pk).one_or_none()
            
            tree = SDSSPath(release=instance.parameters["release"])
            path = tree.full(**instance.parameters)

            header = getheader(path)
            downstream_task_ids.append(
                task_id_callable(
                    header_path,
                    # TODO: This is matching the telescope styling in utils.task_id_parts, but these should have a common place.
                    telescope=instance.parameters["telescope"].upper()[:3],
                    lsf=utils.get_lsf_grid_name(header["MEANFIB"])
                )
            )
        else:
            downstream_task_ids.append(task_id_callable(header_path))
        log.debug(f"\t\tadded {downstream_task_ids[-1]}")


    downstream_task_ids = sorted(set(downstream_task_ids))

    log.debug(f"Downstream tasks to execute:")
    for task_id in downstream_task_ids:
        log.debug(f"\t{task_id}")

    return downstream_task_ids
예제 #14
0
def write_database_outputs(
        task, 
        ti, 
        run_id, 
        element_from_task_id_callable=None,
        **kwargs
    ):
    """
    Collate outputs from upstream FERRE executions and write them to an ASPCAP database table.
    
    :param task:
        This task, as given by the Airflow context dictionary.
    
    :param ti:
        This task instance, as given by the Airflow context dictionary.
    
    :param run_id:
        This run ID, as given by the Airflow context dictionary.
    
    :param element_from_task_id_callable: [optional]
        A Python callable that returns the chemical element, given a task ID.
    """

    
    log.debug(f"Writing ASPCAP database outputs")

    pks = []
    for upstream_task in task.upstream_list:
        pks.append(ti.xcom_pull(task_ids=upstream_task.task_id))

    log.debug(f"Upstream primary keys: {pks}")

    # Group them together by source.
    instance_pks = []
    for source_pks in list(zip(*pks)):

        # The one with the lowest primary key will be the stellar parameters.
        sp_pk, *abundance_pks = sorted(source_pks)
        
        sp_instance = session.query(astradb.TaskInstance).filter(astradb.TaskInstance.pk == sp_pk).one_or_none()
        abundance_instances = session.query(astradb.TaskInstance).filter(astradb.TaskInstance.pk.in_(abundance_pks)).all()

        # Get parameters that are in common to all instances.
        keep = {}
        for key, value in sp_instance.parameters.items():
            for instance in abundance_instances:
                if instance.parameters[key] != value:
                    break
            else:
                keep[key] = value

        # Create a task instance.
        instance = create_task_instance(
            dag_id=task.dag_id, 
            task_id=task.task_id, 
            run_id=run_id,
            parameters=keep
        )

        # Create a partial results table.
        keys = ["snr"]
        label_names = ("teff", "logg", "metals", "log10vdop", "o_mg_si_s_ca_ti", "lgvsini", "c", "n")
        for key in label_names:
            keys.extend([key, f"u_{key}"])
        
        results = dict([(key, getattr(sp_instance.output, key)) for key in keys])

        # Now update with elemental abundance instances.
        for el_instance in abundance_instances:
            
            if element_from_task_id_callable is not None:
                element = element_from_task_id_callable(el_instance.task_id).lower()
            else:
                element = el_instance.task_id.split(".")[-1].lower()
            
            # Check what is not frozen.
            thawed_label_names = []
            ignore = ("lgvsini", ) # Ignore situations where lgvsini was missing from grid and it screws up the task
            for key in label_names:
                if key not in ignore and not getattr(el_instance.output, f"frozen_{key}"):
                    thawed_label_names.append(key)

            if len(thawed_label_names) > 1:
                log.warning(f"Multiple thawed label names for {element} {el_instance}: {thawed_label_names}")

            values = np.hstack([getattr(el_instance.output, ln) for ln in thawed_label_names]).tolist()
            u_values = np.hstack([getattr(el_instance.output, f"u_{ln}") for ln in thawed_label_names]).tolist()

            results.update({
                f"{element}_h": values,
                f"u_{element}_h": u_values,
            })

        # Include associated primary keys so we can reference back to original parameters, etc.
        results["associated_ti_pks"] = [sp_pk, *abundance_pks]

        log.debug(f"Results entry: {results}")

        # Create an entry in the output interface table.
        # (We will update this later with any elemental abundance results).
        # TODO: Should we link back to the original FERRE primary keys?
        output = create_task_output(
            instance,
            astradb.Aspcap,
            **results
        )
        log.debug(f"Created output {output} for instance {instance}")
        instance_pks.append(instance.pk)
        
    return instance_pks
예제 #15
0
            spectrum.meta["snr"] = spectrum.meta["snr"][slices[0]]
        except:
            log.warning(f"Unable to slice 'snr' metadata with {slice_args}")
    

    if median_filter_correction_from_task_id_like is not None:

        upstream_pk = instance.parameters.get("upstream_pk", None)
        if upstream_pk is None:
            raise ValueError(f"cannot do median filter correction because no upstream_pk parameter for {instance}")

        upstream_pk = literal_eval(upstream_pk)

        # There could be many upstream tasks listed, so we should get the matching one.
        q = session.query(astradb.TaskInstance)\
                   .filter(astradb.TaskInstance.pk.in_(upstream_pk))\
                   .filter(astradb.TaskInstance.task_id.like(median_filter_correction_from_task_id_like))

        upstream_instance = q.one_or_none()
        if upstream_instance is None:
            raise RuntimeError(f"cannot find upstream instance in {upstream_pk} matching {median_filter_correction_from_task_id_like}")

        log.info(f"Applying median filtered correction\n\tto {instance}\n\tfrom {upstream_instance}")

        upstream_path = utils.output_data_product_path(upstream_instance.pk)
        with open(upstream_path, "rb") as fp:
            result, data = pickle.load(fp)
        
        # Need number of pixels from header
        n_pixels = [header["NPIX"] for header in utils.read_ferre_headers(utils.expand_path(instance.parameters["header_path"]))][1:]
예제 #16
0
    def data_model_identifiers(self, context):
        """
        Yield data model identifiers from upstream that match this operator's
        header path.
        """

        pks, task, ti = ([], context["task"], context["ti"])
        while True:
            for upstream_task in task.upstream_list:
                log.debug(f"Considering {upstream_task}")
                if isinstance(upstream_task, BranchPythonOperator):
                    # Jump over branch operators
                    log.debug(
                        f"Jumping over BranchPythonOperator {upstream_task}")
                    task = upstream_task
                    break

                log.debug(
                    f"Using upstream results from {upstream_task} ({upstream_task.task_id}) and {ti}"
                )
                these_pks = ti.xcom_pull(task_ids=upstream_task.task_id)
                if these_pks is not None:
                    pks.extend(these_pks)
            else:
                break

        pks = flatten(pks)
        if not pks:
            # This can happen if the BA stellar parameters is executed (because we all all branches to be skipped),
            # but everything else was skipped.
            raise AirflowSkipException(f"No upstream primary keys identified.")

        log.debug(f"From pks: {pks}")
        log.debug(f"That also match {self.header_path}")

        # Restrict to primary keys that have the same header path.
        q = session.query(astradb.TaskInstanceParameter.ti_pk)\
                   .distinct(astradb.TaskInstanceParameter.ti_pk)\
                   .join(astradb.Parameter,
                         astradb.TaskInstanceParameter.parameter_pk == astradb.Parameter.pk)\
                   .filter(astradb.Parameter.parameter_name == "header_path")\
                   .filter(astradb.Parameter.parameter_value == self.header_path)\
                   .filter(astradb.TaskInstanceParameter.ti_pk.in_(pks))

        log.debug(f"Restricting to primary keys:")

        first_or_none = lambda item: None if item is None else item[0]
        callables = [
            ("initial_teff", lambda i: first_or_none(i.output.teff)),
            ("initial_logg", lambda i: first_or_none(i.output.logg)),
            ("initial_metals", lambda i: first_or_none(i.output.metals)),
            ("initial_log10vdop", lambda i: first_or_none(i.output.log10vdop)),
            ("initial_o_mg_si_s_ca_ti",
             lambda i: first_or_none(i.output.o_mg_si_s_ca_ti)),
            ("initial_lgvsini", lambda i: first_or_none(i.output.lgvsini)),
            ("initial_c", lambda i: first_or_none(i.output.c)),
            ("initial_n", lambda i: first_or_none(i.output.n)),
        ]

        trees = {}
        for pk, in q.all():
            q = session.query(
                astradb.TaskInstance).filter(astradb.TaskInstance.pk == pk)
            instance = q.one_or_none()
            log.debug(f"{instance} with {instance.output}")

            release = instance.parameters["release"]
            filetype = instance.parameters["filetype"]

            parameters = dict(release=release, filetype=filetype)

            tree = trees.get(release, None)
            if tree is None:
                tree = trees[release] = SDSSPath(release=release)

            for key in tree.lookup_keys(filetype):
                parameters[key] = instance.parameters[key]

            # What other information should we pass on?
            if instance.output is None:
                # Only pass on the data model identifiers, and any initial values.
                # Let everything else be specified in this operator
                for key, callable in callables:
                    parameters[key] = instance.parameters[key]

            else:
                # There is an upstream FerreOperator.
                log.debug(
                    f"Taking previous result in {pk} as initial result here")

                # Take final teff/logg/etc as the initial values for this task.
                # TODO: Query whether we should be taking first or none, because if
                #       we are running all visits we may want to use individual visit
                #       results from the previous iteration
                for key, callable in callables:
                    parameters[key] = callable(instance)

            # Store upstream primary key as a parameter, too.
            # We could decide not to do this, but it makes it much easier to find
            # upstream tasks.
            parameters.setdefault("upstream_pk", [])
            if "upstream_pk" in instance.parameters:
                try:
                    upstream_pk = literal_eval(
                        instance.parameters["upstream_pk"])
                    parameters["upstream_pk"].extend(upstream_pk)
                except:
                    log.exception(
                        f"Cannot add upstream primary keys from {instance}: {instance.parameters['upstream_pk']}"
                    )

            parameters["upstream_pk"].append(pk)

            yield parameters
예제 #17
0
def get_best_result(task, ti, **kwargs):
    """
    When there are numerous FERRE tasks that are upstream, this
    function will return the primary keys of the task instances that gave
    the best result on a per-observation basis.
    """

    # Get the PKs from upstream.
    pks = []
    log.debug(f"Upstream tasks: {task.upstream_list}")
    for upstream_task in task.upstream_list:
        pks.append(ti.xcom_pull(task_ids=upstream_task.task_id))

    pks = flatten(pks)
    log.debug(f"Getting best initial guess among primary keys {pks}")

    # Need to uniquely identify observations.
    param_bit_mask = bitmask.ParamBitMask()
    bad_grid_edge = (param_bit_mask.get_value("GRIDEDGE_WARN") | param_bit_mask.get_value("GRIDEDGE_BAD"))

    trees = {}
    best_tasks = {}
    for i, pk in enumerate(pks):
        q = session.query(astradb.TaskInstance).filter(astradb.TaskInstance.pk==pk)
        instance = q.one_or_none()

        if instance.output is None:
            log.warning(f"No output found for task instance {instance}")
            continue

        p = instance.parameters

        # Check that the telescope is the same as what we expect from this task ID.
        # This is a bit of a hack. Let us explain.

        # The "BA" grid does not have a telescope/fiber model, so you can run LCO and APO
        # data through the initial-BA grid. And those outputs go to the "get_best_results"
        # for each of the APO and LCO tasks (e.g., this function).
        # If there is only APO data, then the LCO "get_best_result" will only have one
        # input: the BA results. Then it will erroneously think that's the best result
        # for that source.

        # It's hacky to put this logic in here. It should be in the DAG instead. Same
        # thing for parsing 'telescope' name in the DAG (eg 'APO') from 'apo25m'.
        this_telescope_short_name = p["telescope"][:3].upper()
        expected_telescope_short_name = task.task_id.split(".")[1]
        log.info(f"For instance {instance} we have {this_telescope_short_name} and {expected_telescope_short_name}")
        if this_telescope_short_name != expected_telescope_short_name:
            continue

        try:
            tree = trees[p["release"]]                
        except KeyError:
            tree = trees[p["release"]] = SDSSPath(release=p["release"])
        
        key = "_".join([
            p['release'],
            p['filetype'],
            *[p[k] for k in tree.lookup_keys(p['filetype'])]
        ])
        
        best_tasks.setdefault(key, (np.inf, None))
        
        # TODO: Confirm that this is base10 log. This should also be 'log_reduced_chisq_fit',
        #       according to the documentation.
        log_chisq_fit, *_ = instance.output.log_chisq_fit
        previous_teff, *_ = instance.output.teff
        bitmask_flag, *_ = instance.output.bitmask_flag
        
        log.debug(f"Result {instance} {instance.output} with log_chisq_fit = {log_chisq_fit} and {previous_teff} and {bitmask_flag}")
        
        # Note: If FERRE totally fails then it will assign -999 values to the log_chisq_fit. So we have to
        #       check that the log_chisq_fit is actually sensible!
        #       (Or we should only query task instances where the output is sensible!)
        if log_chisq_fit < 0: # TODO: This is a f*****g hack.
            log.debug(f"Skipping result for {instance} {instance.output} as log_chisq_fit = {log_chisq_fit}")
            continue
            
        parsed_header = utils.parse_header_path(p["header_path"])
        
        # Penalise chi-sq in the same way they did for DR17.
        # See github.com/sdss/apogee/python/apogee/aspcap/aspcap.py#L658
        if parsed_header["spectral_type"] == "GK" and previous_teff < 3900:
            log.debug(f"Increasing \chisq because spectral type GK")
            log_chisq_fit += np.log10(10)

        bitmask_flag_logg, bitmask_flag_teff = bitmask_flag[-2:]
        if bitmask_flag_logg & bad_grid_edge:
            log.debug(f"Increasing \chisq because logg flag is bad edge")
            log_chisq_fit += np.log10(5)
            
        if bitmask_flag_teff & bad_grid_edge:
            log.debug(f"Increasing \chisq because teff flag is bad edge")
            log_chisq_fit += np.log10(5)
        
        # Is this the best so far?
        if log_chisq_fit < best_tasks[key][0]:
            log.debug(f"Assigning this output to best task as {log_chisq_fit} < {best_tasks[key][0]}: {pk}")
            best_tasks[key] = (log_chisq_fit, pk)
    
    for key, (log_chisq_fit, pk) in best_tasks.items():
        if pk is None:
            log.warning(f"No good task found for key {key}: ({log_chisq_fit}, {pk})")
        else:
            log.info(f"Best task for key {key} with \chi^2 of {log_chisq_fit:.2f} is primary key {pk}")

    if best_tasks:
        return [pk for (log_chisq_fit, pk) in best_tasks.values() if pk is not None]
    else:
        raise AirflowSkipException(f"no task outputs found from {len(pks)} primary keys")
예제 #18
0
파일: utils.py 프로젝트: sdss/astra
def prepare_data(pks):
    """
    Return the task instance, data model path, and spectrum for each given primary key,
    and apply any spectrum callbacks to the spectrum as it is loaded.

    :param pks:
        Primary keys of task instances to load data products for.

    :returns:
        Yields a four length tuple containing the task instance, the spectrum path, the
        original spectrum, and the modified spectrum after any spectrum callbacks have been
        executed. If no spectrum callback is executed, then the modified spectrum will be
        `None`.
    """

    trees = {}

    for pk in deserialize_pks(pks, flatten=True):
        q = session.query(
            astradb.TaskInstance).filter(astradb.TaskInstance.pk == pk)
        instance = q.one_or_none()

        if instance is None:
            log.warning(f"No task instance found for primary key {pk}")
            path = spectrum = None

        else:
            release = instance.parameters["release"]
            tree = trees.get(release, None)
            if tree is None:
                trees[release] = tree = SDSSPath(release=release)

            # Monkey-patch BOSS Spec paths.
            try:
                path = tree.full(**instance.parameters)
            except:
                if instance.parameters["filetype"] == "spec":
                    from astra.utils import monkey_patch_get_boss_spec_path
                    path = monkey_patch_get_boss_spec_path(
                        **instance.parameters)
                else:
                    raise

            try:
                spectrum = Spectrum1D.read(path)
            except:
                log.exception(
                    f"Unable to load Spectrum1D from path {path} on task instance {instance}"
                )
                spectrum = None
            else:
                # Are there any spectrum callbacks?
                spectrum_callback = instance.parameters.get(
                    "spectrum_callback", None)
                if spectrum_callback is not None:
                    spectrum_callback_kwargs = instance.parameters.get(
                        "spectrum_callback_kwargs", "{}")
                    try:
                        spectrum_callback_kwargs = literal_eval(
                            spectrum_callback_kwargs)
                    except:
                        log.exception(
                            f"Unable to literally evalute spectrum callback kwargs for {instance}: {spectrum_callback_kwargs}"
                        )
                        raise

                    try:
                        func = string_to_callable(spectrum_callback)

                        spectrum = func(spectrum=spectrum,
                                        path=path,
                                        instance=instance,
                                        **spectrum_callback_kwargs)

                    except:
                        log.exception(
                            f"Unable to execute spectrum callback '{spectrum_callback}' on {instance}"
                        )
                        raise

        yield (instance, path, spectrum)
예제 #19
0
def export_to_table(output_path, overwrite=True):
    """
    Export the APOGEENet database results to a table.

    :param output_path:
        The disk location where to write the table to.
    
    :param overwrite: [optional]
        Overwrite any
    """

    output_path = os.path.expandvars(os.path.expanduser(output_path))
    if not overwrite and os.path.exists(output_path):
        raise OSError(f"path '{output_path}' already exists and asked not to overwrite it")

    sq = session.query(
            astradb.ApogeeNet.output_pk.label("output_pk"),
            func.json_object_agg(
                astradb.Parameter.parameter_name,
                astradb.Parameter.parameter_value
            ).label("parameters")
        )\
        .filter(astradb.ApogeeNet.output_pk == astradb.TaskInstance.output_pk)\
        .filter(astradb.TaskInstance.pk == astradb.TaskInstanceParameter.ti_pk)\
        .filter(astradb.TaskInstanceParameter.parameter_pk == astradb.Parameter.pk)\
        .group_by(astradb.ApogeeNet)\
        .subquery(with_labels=True)

    q = session.query(
            astradb.TaskInstance,
            astradb.ApogeeNet, 
            func.cardinality(astradb.ApogeeNet.snr),
            sq.c.parameters
        )\
        .filter(sq.c.output_pk == astradb.ApogeeNet.output_pk)\
        .filter(sq.c.output_pk == astradb.TaskInstance.output_pk)

    total, = session.query(func.sum(func.cardinality(astradb.ApogeeNet.snr))).first()

    table_columns = OrderedDict([
        ("ti_pk", []),
        ("run_id", []),
        ("release", []),
        ("apred", []),
        ("field", []),
        ("healpix", []),
        ("telescope", []),
        ("obj", []),
        ("spectrum_index", []),
    ])
    column_names = ("snr", "teff", "u_teff", "logg", "u_logg", "fe_h", "u_fe_h", "bitmask_flag")
    for cn in column_names:
        table_columns[cn] = []

    with tqdm(total=total, unit="spectra") as pb:
    
        for task_instance, result, N, parameters in q.yield_per(1):
            for i in range(N):
                table_columns["ti_pk"].append(result.ti_pk)
                table_columns["run_id"].append(task_instance.run_id)
                table_columns["release"].append(parameters["release"])
                table_columns["apred"].append(parameters["apred"])
                table_columns["field"].append(parameters.get("field", ""))
                table_columns["healpix"].append(parameters.get("healpix", ""))
                table_columns["telescope"].append(parameters["telescope"])
                table_columns["obj"].append(parameters["obj"])
                table_columns["spectrum_index"].append(i)

                for column_name in column_names:
                    table_columns[column_name].append(getattr(result, column_name)[i])
                
                pb.update(1)
    
    log.info(f"Creating table with {total} rows")
    table = Table(data=table_columns)
    log.info(f"Table created.")

    log.info(f"Writing to {output_path}")
    table.write(output_path, overwrite=overwrite)
    log.info("Done")

    return table_columns
예제 #20
0
def add_meta_to_task_instance(ti_pk):
    """
    Update the task instance meta table for a given task instance.
    """

    ti = astra_session.query(astradb.TaskInstance).filter_by(pk=ti_pk).first()
    parameters = ti.parameters

    try:
        release = parameters["release"].lower()
        filetype = parameters["filetype"]
    except KeyError:
        raise KeyError(
            f"Either missing `release` or `filetype` parameter for task instance {ti}"
        )

    if release in ("sdss5", None, "null"):
        is_sdss5 = True
    elif release in ("dr17", "dr16"):
        is_sdss5 = False
    else:
        raise ValueError(
            f"Cannot figure out if {ti} is SDSS V or not based on release value '{release}'"
        )

    # Is it a SDSS-V object? If so we can get information from apogee_drpdb / boss tables first.
    is_apogee = filetype in ("apVisit", "apStar")
    is_boss = filetype in ("spec", )

    if not is_apogee and not is_boss:
        raise ValueError(
            f"Don't know what to do with filetype of '{filetype}'")

    tree = SDSSPath(release=release)

    meta = dict(ti_pk=ti_pk)
    if is_sdss5:
        if is_apogee:
            # Need the apogee_drp_star_pk and apogee_drp_visit_pks.
            star_pk, catalogid, gaia_dr2_source_id = sdssdb_session.query(
                apogee_drpdb.Star.pk, apogee_drpdb.Star.catalogid,
                apogee_drpdb.Star.gaiadr2_sourceid).filter_by(
                    apogee_id=parameters["obj"]).first()

            if filetype == "apVisit":
                # Match on a single visit.
                visit_pks = sdssdb_session.query(
                    apogee_drpdb.Visit.pk
                ).filter_by(
                    apogee_id=parameters["obj"],  # TODO: raise with Nidever
                    telescope=parameters["telescope"],
                    fiberid=parameters["fiber"],  # TODO: raise with Nidever
                    plate=parameters["plate"],
                    field=parameters["field"],
                    mjd=parameters["mjd"],
                    apred_vers=parameters["apred"]  # TODO: raise with Nidever
                ).one_or_none()

                visit_paths = [tree.full(**parameters)]

            elif filetype == "apStar":
                # Get all visits.
                visits = sdssdb_session.query(
                    apogee_drpdb.Visit.pk,
                    apogee_drpdb.Visit.mjd,
                    apogee_drpdb.Visit.field,
                    # We have apred.
                    apogee_drpdb.Visit.fiberid.label("fiber"),
                    # We have telescope
                    apogee_drpdb.Visit.plate).filter_by(
                        apogee_id=parameters["obj"]).all()

                # We will need the paths too.
                visit_pks = []
                visit_paths = []
                for visit_pk, mjd, field, fiber, plate in visits:
                    visit_pks.append(visit_pk)
                    visit_paths.append(
                        tree.full(filetype="apVisit",
                                  apred=parameters["apred"],
                                  telescope=parameters["telescope"],
                                  mjd=mjd,
                                  field=field,
                                  fiber=fiber,
                                  plate=plate))

            else:
                raise ValueError(
                    f"Don't know what to do with SDSS-V APOGEE filetype of '{filetype}'"
                )

            visit_pks = flatten(visit_pks)

            # For the visit files we have to open them to get the number of pixels.
            apogee_drp_visit_naxis1 = []
            for visit_path in visit_paths:
                try:
                    apogee_drp_visit_naxis1.append(
                        fits.getval(visit_path, "NAXIS1", ext=1))
                except:
                    log.exception(
                        f"Could not get NAXIS1 from path {visit_path}")
                    apogee_drp_visit_naxis1.append(-1)

            meta.update(
                catalogid=catalogid,
                gaia_dr2_source_id=gaia_dr2_source_id,
                apogee_drp_star_pk=star_pk,
                apogee_drp_visit_pks=visit_pks,
                apogee_drp_visit_naxis1=apogee_drp_visit_naxis1,
            )

        elif is_boss:
            # We need catalogid, gaia_dr2_source_id, and catalogdb_sdssv_boss_spall_pkey.
            catalogid = parameters["catalogid"]
            if catalogid > 0:
                gaia_dr2_source_id, = sdssdb_session.query(catalogdb.TICV8.gaia_int)\
                                                    .filter(catalogdb.TICV8.id == catalogdb.CatalogToTICV8.target_id)\
                                                    .filter(catalogdb.CatalogToTICV8.catalogid == catalogid).one_or_none()
            else:
                gaia_dr2_source_id = None


            pkey, = sdssdb_session.query(catalogdb.SDSSVBossSpall.pkey)\
                                  .filter(
                                      catalogdb.SDSSVBossSpall.catalogid == catalogid,
                                      catalogdb.SDSSVBossSpall.run2d == parameters["run2d"],
                                      catalogdb.SDSSVBossSpall.plate == parameters["plate"],
                                      catalogdb.SDSSVBossSpall.mjd == parameters["mjd"],
                                      catalogdb.SDSSVBossSpall.fiberid == parameters["fiberid"]
                                  ).one_or_none()
            meta.update(catalogid=catalogid,
                        gaia_dr2_source_id=gaia_dr2_source_id,
                        catalogdb_sdssv_boss_spall_pkey=pkey)

    else:
        # Need to get information from elsewhere.
        if is_apogee:
            if filetype == "apVisit":
                catalogid = sdssdb_session.query(
                    catalogdb.CatalogToSDSSDR16ApogeeStar.catalogid).filter(
                        catalogdb.SDSSDR16ApogeeStar.target_id ==
                        catalogdb.CatalogToSDSSDR16ApogeeStar.target_id,
                        catalogdb.SDSSDR16ApogeeStar.apstar_id ==
                        catalogdb.SDSSDR16ApogeeStarVisit.apstar_id,
                        catalogdb.SDSSDR16ApogeeStarVisit.visit_id ==
                        catalogdb.SDSSDR16ApogeeVisit.visit_id,
                        catalogdb.SDSSDR16ApogeeVisit.plate ==
                        parameters["plate"],
                        catalogdb.SDSSDR16ApogeeVisit.mjd == parameters["mjd"],
                        catalogdb.SDSSDR16ApogeeVisit.fiberid ==
                        parameters["fiber"],  # Raise with Nidever
                        catalogdb.SDSSDR16ApogeeVisit.location_id ==
                        parameters["field"],
                        catalogdb.SDSSDR16ApogeeVisit.apred_version ==
                        parameters["apred"]  # Raise with Nidever
                    ).first_or_none()

            elif filetype == "apStar":
                catalogid = sdssdb_session.query(
                    catalogdb.CatalogToSDSSDR16ApogeeStar.catalogid).filter(
                        catalogdb.SDSSDR16ApogeeStar.apogee_id.label("obj") ==
                        parameters["obj"], catalogdb.SDSSDR16ApogeeStar.field
                        == parameters["field"],
                        catalogdb.SDSSDR16ApogeeStar.telescope ==
                        parameters["telescope"],
                        catalogdb.SDSSDR16ApogeeStar.apstar_id == catalogdb.
                        CatalogToSDSSDR16ApogeeStar.target_id).one_or_none()
                if catalogid is not None:
                    catalogid, = catalogid

            else:
                raise ValueError(
                    f"Don't know what to do with SDSS-IV APOGEE filetype of '{filetype}'"
                )

            if catalogid > 0:
                gaia_dr2_source_id = sdssdb_session.query(catalogdb.TICV8.gaia_int)\
                                                   .filter(catalogdb.TICV8.id == catalogdb.CatalogToTICV8.target_id)\
                                                   .filter(catalogdb.CatalogToTICV8.catalogid == catalogid).one_or_none()
                if gaia_dr2_source_id is not None:
                    gaia_dr2_source_id, = gaia_dr2_source_id
            else:
                gaia_dr2_source_id = None

            meta.update(catalogid=catalogid,
                        gaia_dr2_source_id=gaia_dr2_source_id)

        else:
            raise NotImplementedError(
                f"Can only retrieve metadata for APOGEE products in SDSS-IV.")

    # Add generic information from the catalog, if we have a valid catalogid.
    if meta["catalogid"] > 0:
        row = sdssdb_session.query(catalogdb.Catalog).filter(
            catalogdb.Catalog.catalogid == meta["catalogid"]).one_or_none()

        meta.update(
            iauname=row.iauname,
            ra=row.ra,
            dec=row.dec,
            pmra=row.pmra,
            pmdec=row.pmdec,
            parallax=row.parallax,
            catalog_lead=row.lead,
            catalog_version_id=row.version_id,
        )

    # Create or update an entry in the database.
    instance = astradb.TaskInstanceMeta(**meta)
    with astra_session.begin(subtransactions=True):
        astra_session.merge(instance)

    return instance
예제 #21
0
파일: astradb.py 프로젝트: sdss/astra
 def parameters(self):
     q = session.query(Parameter).join(TaskInstanceParameter).filter(
         TaskInstanceParameter.ti_pk == self.pk)
     return dict(((p.parameter_name, p.parameter_value) for p in q.all()))