def classify(pks, **kwargs): """ Classify sources given the primary keys of task instances. :param pks: the primary keys of the task instances in the database that need classification """ models = {} results = {} for instance, path, spectrum in prepare_data(pks): if spectrum is None: continue model_path = instance.parameters["model_path"] try: model, factory = models[model_path] except KeyError: network_factory = model_path.split("_")[-2] factory = getattr(networks, network_factory) log.info(f"Loading model from {model_path} using {factory}") model = utils.read_network(factory, model_path) model.eval() models[model_path] = (model, factory) flux = torch.from_numpy(spectrum.flux.value.astype(np.float32)) with torch.no_grad(): prediction = model.forward( flux) #Variable(torch.Tensor(spectrum.flux.value))) log_probs = prediction.cpu().numpy().flatten() results[instance.pk] = log_probs for pk, log_probs in tqdm(results.items(), desc="Writing results"): result = _prepare_log_prob_result(factory.class_names, log_probs) # Write the output to the database. create_task_output(pk, astradb.Classification, **result)
def estimate_stellar_labels(pks, default_num_uncertainty_draws=100, default_large_error=1e10): """ Estimate the stellar parameters for APOGEE ApStar observations, where task instances have been created with the given primary keys (`pks`). :param pks: The primary keys of task instances that include information about what ApStar observation to load. :param default_num_uncertainty_draws: [optional] The number of random draws to make of the flux uncertainties, which will be propagated into the estimate of the stellar parameter uncertainties (default: 100). :param default_large_error: [optional] An arbitrarily large error value to assign to bad pixels (default: 1e10). """ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") log.info(f"Running APOGEENet on device {device} with:") log.info(f"\tpks: {pks}") log.debug( f"CUDA_VISIBLE_DEVICES = '{os.environ.get('CUDA_VISIBLE_DEVICES')}'") log.debug(f"Using torch version {torch.__version__} in {torch.__path__}") models = {} pks = deserialize_pks(pks, flatten=True) total = len(pks) log.info(f"There are {total} primary keys to process: {pks}") for instance, path, spectrum in tqdm(prepare_data(pks), total=total): if spectrum is None: continue model_path = instance.parameters["model_path"] # Load the model. try: model = models[model_path] except KeyError: log.info(f"Loaded model from {model_path}") models[model_path] = model = Model(model_path, device) N, P = spectrum.flux.shape # Build metadata array. metadata_keys, metadata, metadata_norm = get_metadata(spectrum) flux = np.nan_to_num(spectrum.flux.value).astype(np.float32).reshape( (N, 1, P)) meta = np.tile(metadata_norm, N).reshape((N, -1)) flux = torch.from_numpy(flux).to(device) meta = torch.from_numpy(meta).to(device) with torch.set_grad_enabled(False): predictions = model.predict_spectra(flux, meta) if device != "cpu": predictions = predictions.cpu().data.numpy() # Replace infinites with non-finite. predictions[~np.isfinite(predictions)] = np.nan # Create results array. log_g, log_teff, fe_h = predictions.T teff = 10**log_teff result = dict( snr=spectrum.meta["snr"], teff=teff.tolist(), logg=log_g.tolist(), fe_h=fe_h.tolist(), ) num_uncertainty_draws = int( instance.parameters.get("num_uncertainty_draws", default_num_uncertainty_draws)) if num_uncertainty_draws > 0: large_error = float( instance.parameters.get("large_error", default_large_error)) flux_error = np.nan_to_num( spectrum.uncertainty.array**-0.5).astype(np.float32).reshape( (N, 1, P)) median_error = 5 * np.median(flux_error, axis=(1, 2)) for j, value in enumerate(median_error): bad_pixel = (flux_error[j] == large_error) | (flux_error[j] >= value) flux_error[j][bad_pixel] = value flux_error = torch.from_numpy(flux_error).to(device) inputs = torch.randn((num_uncertainty_draws, N, 1, P), device=device) * flux_error + flux inputs = inputs.reshape((num_uncertainty_draws * N, 1, P)) meta_error = meta.repeat(num_uncertainty_draws, 1) with torch.set_grad_enabled(False): draws = model.predict_spectra(inputs, meta_error) if device != "cpu": draws = draws.cpu().data.numpy() draws = draws.reshape((num_uncertainty_draws, N, -1)) # Need to put the log(teffs) to teffs before calculating std_dev draws[:, :, 1] = 10**draws[:, :, 1] median_draw_predictions = np.nanmedian(draws, axis=0) std_draw_predictions = np.nanstd(draws, axis=0) log_g_median, teff_median, fe_h_median = median_draw_predictions.T log_g_std, teff_std, fe_h_std = std_draw_predictions.T result.update(_teff_median=teff_median.tolist(), _logg_median=log_g_median.tolist(), _fe_h_median=fe_h_median.tolist(), u_teff=teff_std.tolist(), u_logg=log_g_std.tolist(), u_fe_h=fe_h_std.tolist()) else: median_draw_predictions, std_draw_predictions = (None, None) # Add the bitmask flag. bitmask_flag = create_bitmask( predictions, median_draw_predictions=median_draw_predictions, std_draw_predictions=std_draw_predictions) result.update(bitmask_flag=bitmask_flag.tolist()) # Write the result to database. create_task_output(instance, astradb.ApogeeNet, **result) log.info(f"Completed processing of {total} primary keys")
def estimate_radial_velocity(pks, verbose=True, mcmc=False, figfile=None, cornername=None, retpmodels=False, plot=False, tweak=True, usepeak=False, maxvel=[-1000, 1000]): """ Estimate radial velocities for the sources that are identified by the task instances of the given primary keys. :param pks: The primary keys of task instances to estimate radial velocities for, which includes parameters to identify the source SDSS data model product. See `doppler.rv.fit` for more information on other keyword arguments. """ # TODO: Move this to astra/contrib import doppler log.info(f"Estimating radial velocities for {len(pks)} task instances") failures = [] for instance, path, spectrum in prepare_data(pks): if spectrum is None: continue log.debug(f"Running Doppler on {instance} from {path}") try: spectrum = doppler.read(path) summary, model_spectrum, modified_input_spectrum = doppler.rv.fit( spectrum, verbose=verbose, mcmc=mcmc, figfile=figfile, cornername=cornername, retpmodels=retpmodels, plot=plot, tweak=tweak, usepeak=usepeak, maxvel=maxvel) except: log.exception( f"Exception occurred on Doppler on {path} with task instance {instance}" ) failures.append(instance.pk) continue else: # Write the output to the database. results = prepare_results(summary) create_task_output(instance, astradb.Doppler, **results) if len(failures) > 0: log.warning( f"There were {len(failures)} Doppler failures out of a total {len(pks)} executions." ) log.warning(f"Failed primary keys include: {failures}") log.warning(f"Raising last exception to indicate failure in pipeline.") raise
def estimate_stellar_labels(pks, **kwargs): """ Estimate stellar labels given a single-layer neural network. :param pks: The primary keys of task instances to estimate stellar labels for. The task instances include information to identify the source SDSS data product. """ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") log.info(f"Running ThePayne on device {device} with:") log.info( f"CUDA_VISIBLE_DEVICES = '{os.environ.get('CUDA_VISIBLE_DEVICES')}'") log.info(f"Using torch version {torch.__version__} in {torch.__path__}") states = {} log.info(f"Estimating stellar labels for task instances") results = {} for instance, path, spectrum in prepare_data(pks): if spectrum is None: continue model_path = instance.parameters["model_path"] try: state = states[model_path] except KeyError: log.info(f"Loading model from {model_path}") state = states[model_path] = test.load_state(model_path) label_names = state["label_names"] L = len(label_names) log.info(f"Estimating these {L} label names: {label_names}") # Run optimization. t_init = time() p_opt, p_cov, model_flux, meta = test.test(spectrum.wavelength.value, spectrum.flux.value, spectrum.uncertainty.array, **state) t_opt = time() - t_init #log.debug(f"spectrum shape: {spectrum.flux.shape}") #log.debug(f"p_opt shape: {p_opt.shape}") #log.debug(f"spectrum meta: {spectrum.meta['snr']}") # Prepare outputs. result = dict(zip(label_names, p_opt.T)) result.update(snr=spectrum.meta["snr"]) # Include uncertainties. result.update( dict( zip((f"u_{ln}" for ln in label_names), np.sqrt(p_cov[:, np.arange(p_opt.shape[1]), np.arange(p_opt.shape[1])].T)))) results[instance.pk] = result log.info(f"Result for {instance} took {t_opt} seconds") # Write database outputs. for pk, result in tqdm(results.items(), desc="Writing database outputs"): # Write database outputs. create_task_output(pk, astradb.ThePayne, **result) return None
def _estimate_stellar_labels(pk): # TODO: It would be great if these were stored with the network, # instead of being hard-coded. label_names = ["teff", "logg", "vsini", "v_micro", "m_h"] # Translate: _t = { "teff": "T_eff", "logg": "log(g)", "m_h": "[M/H]", "vsini": "v*sin(i)", } # TODO: This implicitly assumes that the same constraints and network path are used by all the # primary keys given. This is the usual case, but we should check this, and code around it. # TODO: This implementation requires knowing the observed spectrum before loading data. # This is fine for ApStar objects since they all have the same dispersion sampling, # but will not be fine for dispersion sampling that differs in each observation. # Let's peak ahead at the first valid spectrum we can find. instance, _, spectrum = next(prepare_data([pk])) if spectrum is None: # No valid spectrum. log.warning( f"Cannot build LSF for fitter because no spectrum found for primary key {pk}" ) return None network = Network() network.read_in(instance.parameters["network_path"]) constraints = json.loads(instance.parameters.get("constraints", "{}")) fitted_label_names = [ ln for ln in label_names \ if network.grid[_t.get(ln, ln)][0] != network.grid[_t.get(ln, ln)][1] ] L = len(fitted_label_names) bounds_unscaled = np.zeros((2, L)) for i, ln in enumerate(fitted_label_names): bounds_unscaled[:, i] = constraints.get(ln, network.grid[_t.get(ln, ln)][:2]) fit = Fit(network, int(instance.parameters["N_chebyshev"])) fit.bounds_unscaled = bounds_unscaled spectral_resolution = int(instance.parameters["spectral_resolution"]) fit.lsf = LSF_Fixed_R(spectral_resolution, spectrum.wavelength.value, network.wave) # Note the Stramut code uses inconsistent naming for "presearch", but in the operator interface we use # 'pre_search' in all situations. That's why there is some funny naming translation here. fit.N_presearch_iter = int(instance.parameters["N_pre_search_iter"]) fit.N_pre_search = int(instance.parameters["N_pre_search"]) fitter = UncertFit(fit, spectral_resolution) N, P = spectrum.flux.shape keys = [] keys.extend(fitted_label_names) keys.extend([f"u_{ln}" for ln in fitted_label_names]) keys.extend(["v_rad", "u_v_rad", "chi2", "theta"]) result = {key: [] for key in keys} result["snr"] = spectrum.meta["snr"] model_fluxes = [] log.info(f"Running ThePayne-Che on {N} spectra for {instance}") for i in range(N): flux = spectrum.flux.value[i] error = spectrum.uncertainty.array[0]**-0.5 # TODO: No NaNs/infs are allowed, but it doesn't seem like that was an issue for Stramut's code. # Possibly due to different versions of scipy. In any case, raise this as a potential bug, # since the errors do not always seem to be believed by ThePayne-Che. bad = (~np.isfinite(flux)) | (error <= 0) flux[bad] = 0 error[bad] = 1e10 fit_result = fitter.run( spectrum.wavelength.value, flux, error, ) # The `popt` attribute is length: len(label_names) + 1 (for radial velocity) + N_chebyshev # Relevent attributes are: # - fit_result.popt # - fit_result.uncert # - fit_result.RV_uncert # - fit_result.model for j, label_name in enumerate(fitted_label_names): result[label_name].append(fit_result.popt[j]) result[f"u_{label_name}"].append(fit_result.uncert[j]) result["theta"].append(fit_result.popt[L + 1:].tolist()) result["chi2"].append(fit_result.chi2_func(fit_result.popt)) result["v_rad"].append(fit_result.popt[L]) result["u_v_rad"].append(fit_result.RV_uncert) model_fluxes.append(fit_result.model) # Write database result. create_task_output(instance, astradb.ThePayneChe, **result) # TODO: Write AstraSource object here. return None
def classify_apstar(pks, dag, task, run_id, **kwargs): """ Classify observations of APOGEE (ApStar) sources, given the existing classifications of the individual visits. :param pks: The primary keys of task instances where visits have been classified. These primary keys will be used to work out which stars need classifying, before tasks """ pks = deserialize_pks(pks, flatten=True) # For each unique apStar object, we need to find all the visits that have been classified. distinct_apogee_drp_star_pk = session.query( distinct(astradb.TaskInstanceMeta.apogee_drp_star_pk)).filter( astradb.TaskInstance.pk.in_(pks), astradb.TaskInstanceMeta.ti_pk == astradb.TaskInstance.pk).all() # We need to make sure that we will only retrieve results on apVisit objects, and not on apStar objects. parameter_pk, = session.query(astradb.Parameter.pk).filter( astradb.Parameter.parameter_name == "filetype", astradb.Parameter.parameter_value == "apVisit").one_or_none() for star_pk in distinct_apogee_drp_star_pk: results = session.query( astradb.TaskInstance, astradb.TaskInstanceMeta, astradb.Classification ).filter( astradb.Classification.output_pk == astradb.TaskInstance.output_pk, astradb.TaskInstance.pk == astradb.TaskInstanceMeta.ti_pk, astradb.TaskInstanceMeta.apogee_drp_star_pk == star_pk, astradb.TaskInstanceParameter.ti_pk == astradb.TaskInstance.pk, astradb.TaskInstanceParameter.parameter_pk == parameter_pk).all() column_func = lambda column_name: column_name.startswith("lp_") lps = {} for j, (ti, meta, classification) in enumerate(results): if j == 0: for column_name in classification.__table__.columns.keys(): if column_func(column_name): lps[column_name] = [] for column_name in lps.keys(): values = getattr(classification, column_name) if values is None: continue assert len( values ) == 1, "We are getting results from apStars and re-adding to apStars!" lps[column_name].append(values[0]) # Calculate total log probabilities. joint_lps = np.array( [np.sum(lp) for lp in lps.values() if len(lp) > 0]) keys = [key for key, lp in lps.items() if len(lp) > 0] # Calculate normalized probabilities. with np.errstate(under="ignore"): relative_log_probs = joint_lps - logsumexp(joint_lps) # Round for PostgreSQL 'real' type. # https://www.postgresql.org/docs/9.1/datatype-numeric.html # and # https://stackoverflow.com/questions/9556586/floating-point-numbers-of-python-float-and-postgresql-double-precision decimals = 3 probs = np.round(np.exp(relative_log_probs), decimals) joint_result = {k: [float(lp)] for k, lp in zip(keys, joint_lps)} joint_result.update({k[1:]: [float(v)] for k, v in zip(keys, probs)}) # Create a task for this classification. # To do that we need to construct the parameters for the task. columns = ( apogee_drpdb.Star.apred_vers.label( "apred"), # TODO: Raise with Nidever apogee_drpdb.Star.healpix, apogee_drpdb.Star.telescope, apogee_drpdb.Star.apogee_id.label( "obj"), # TODO: Raise with Nidever ) apred, healpix, telescope, obj = sdss_session.query(*columns).filter( apogee_drpdb.Star.pk == star_pk).one() parameters = dict(apred=apred, healpix=healpix, telescope=telescope, obj=obj, release="sdss5", filetype="apStar", apstar="stars") args = (dag.dag_id, task.task_id, run_id) # Get a string representation of the python callable to store in the database. instance = create_task_instance(*args, parameters) output = create_task_output(instance.pk, astradb.Classification, **joint_result) raise a
def execute(self, context): """ Execute the operator. :param context: The Airflow DAG context. """ # Load spectra. instances, Ns = ([], []) wavelength, flux, sigma, spectrum_meta = ([], [], [], []) for instance, path, spectrum in self.prepare_data(): if spectrum is None: continue N, P = spectrum.flux.shape wavelength.append( np.tile(spectrum.wavelength.value, N).reshape((N, -1))) flux.append(spectrum.flux.value) sigma.append(spectrum.uncertainty.array**-0.5) spectrum_meta.append(dict(snr=spectrum.meta["snr"])) Ns.append(N) instances.append(instance) Ns = np.array(Ns, dtype=int) wavelength, flux, sigma = tuple( map(np.vstack, (wavelength, flux, sigma))) # Create names for easy debugging in FERRE outputs. names = create_names( instances, Ns, "{star_index}_{telescope}_{obj}_{spectrum_index}") # Load initial parameters, taking account initial_parameters = create_initial_parameters(instances, Ns) # Directory. directory = os.path.join( get_base_output_path(), "ferre", "tasks", f"{context['ds']}-{context['dag'].dag_id}-{context['task'].task_id}-{context['run_id']}" ) os.makedirs(directory, exist_ok=True) log.info(f"Working directory for task is {directory}") # Prepare FERRE. args = prepare_ferre( directory, dict(wavelength=wavelength, flux=flux, sigma=sigma, header_path=self.header_path, names=names, initial_parameters=initial_parameters, frozen_parameters=self.frozen_parameters, interpolation_order=self.interpolation_order, input_weights_path=self.input_weights_path, input_lsf_shape_path=self.input_lsf_shape_path, lsf_shape_flag=self.lsf_shape_flag, error_algorithm_flag=self.error_algorithm_flag, wavelength_interpolation_flag=self. wavelength_interpolation_flag, optimization_algorithm_flag=self.optimization_algorithm_flag, continuum_flag=self.continuum_flag, continuum_order=self.continuum_order, continuum_segment=self.continuum_segment, continuum_reject=self.continuum_reject, continuum_observations_flag=self.continuum_observations_flag, full_covariance=self.full_covariance, pca_project=self.pca_project, pca_chi=self.pca_chi, n_threads=self.n_threads, f_access=self.f_access, f_format=self.f_format, ferre_kwargs=self.ferre_kwargs)) # Execute, either by slurm or whatever. log.debug(f"FERRE ready to roll in {directory}") assert self.slurm_kwargs self.execute_by_slurm( context, bash_command= "/uufs/chpc.utah.edu/common/home/sdss09/software/apogee/Linux/apogee/trunk/bin/ferre.x", directory=directory, ) # Unbelievably, FERRE sends a '1' exit code every time it is executed. Even if it succeeds. # TODO: Ask Carlos or Jon to remove this insanity. # Parse outputs. # TODO: clean up this function param, param_err, output_meta = parse_ferre_outputs( directory, self.header_path, *args) results = group_results_by_instance(param, param_err, output_meta, spectrum_meta, Ns) for instance, (result, data) in zip(instances, results): if result is None: continue create_task_output(instance, astradb.Ferre, **result) log.debug(f"{instance}") log.debug(f"{result}") log.debug(f"{data}") # TODO: Write a data model product for this intermediate output! output_path = utils.output_data_product_path(instance.pk) os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, "wb") as fp: pickle.dump((result, data), fp) log.info( f"Wrote outputs of task instance {instance} to {output_path}") # Always return the primary keys that were worked on! return self.pks
def estimate_stellar_labels(pks, model_path, dwave_slam=10., p_slam=(1E-8, 1E-7), q_slam=0.7, ivar_block_slam=None, eps_slam=1E-19, rsv_frac_slam=2., n_jobs_slam=1, verbose_slam=5): """ Estimate the stellar parameters for APOGEE ApStar observations, where task instances have been created with the given primary keys (`pks`). :param pks: The primary keys of task instances that include information about what ApStar observation to load. :param model_path: The disk path of the pre-trained model. :param dwave_slam: float binning width :param p_slam: tuple of 2 ps [optional] smoothing parameter between 0 and 1: (default: 1E-8, 1E-7) 0 -> LS-straight line 1 -> cubic spline interpolant :param q_slam: float in range of [0, 100] [optional] percentile, between 0 and 1 (default: 0.7) :param ivar_block_slam: ndarray (n_pix, ) | None [optional] ivar array (default: None) :param eps_slam: float [optional] the ivar threshold (default: 1E-19) :param rsv_frac_slam: float [optional] the fraction of pixels reserved in terms of std. default is 3. :param n_jobs_slam: int [optional] number of processes launched by joblib (default: 1) :param verbose_slam: int / bool [optional] verbose level (default: 5) """ ''' device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") log.info(f"Running APOGEENet on device {device} with:") log.info(f"\tmodel_path: {model_path}") log.info(f"\tpks: {pks}") log.debug(f"CUDA_VISIBLE_DEVICES = '{os.environ.get('CUDA_VISIBLE_DEVICES')}'") log.debug(f"Using torch version {torch.__version__} in {torch.__path__}") # Load the model. ### model = Model(model_path, device) ''' # Load the model. model = Slam.load_dump(model_path) ### ("./models/btsettl.dump") ### wave_interp = np.load("./models/wave_interp_R1800.npz")['wave'] ### ??? how to load properly wave_interp = model.wave log.info(f"Loaded model from {model_path}") pks = deserialize_pks(pks, flatten=True) total = len(pks) log.info(f"There are {total} primary keys to process: {pks}") for instance, path, spectrum in tqdm(prepare_data(pks), total=total): if spectrum is None: continue N, P = spectrum.flux.shape ''' ### original code in apogeenet flux = np.nan_to_num(spectrum.flux.value).astype(np.float32).reshape((N, 1, P)) ### original code in MDwarfMachine fluxes, invars = [], [] for i in tqdm(range(len(obs_spec))): fluxes += [obs_spec[i]['flux_resamp']] invars += [obs_spec[i]['invar_resamp']] fluxes, invars = np.array(fluxes), np.array(invars) ''' ### wave = np.nan_to_num(spectrum.spectral_axis.value).astype(np.float32).reshape((N, 1, P)) ### fluxes = np.nan_to_num(spectrum.flux.value).astype(np.float32).reshape((N, 1, P)) ### ??? reshape to what format ### invars = np.nan_to_num(spectrum.uncertainty.array).astype(np.float32).reshape((N, 1, P)) ### ??? spectrum.uncertainity format wave = spectrum.spectral_axis fluxes = spectrum.flux invars = specrrum.uncertainty fluxes_resamp, invars_resamp = [], [] for i in tqdm(range(N)): fluxes_temp, invars_temp = resample(wave[i], fluxes[i], invars[i], wave_interp) fluxes_resamp += [fluxes_temp] invars_resamp += [invars_temp] fluxes_resamp, invars_resamp = np.array(fluxes_resamp), np.array( invars_resamp) ### normalization of each spetra ### fluxes_norm, fluxes_cont = normalize_spectra_block(wave_interp, fluxes_resamp, ### (6147., 8910.), 10., p=(1E-8, 1E-7), q=0.7, ### eps=1E-19, rsv_frac=2., n_jobs=1, verbose=5) ### ??? inputs fluxes_norm, fluxes_cont = normalize_spectra_block( wave_interp, fluxes_resamp, (6147., 8910.), dwave_slam, p=p_slam, q=q_slam, ivar_block=ivar_block_slam, eps=eps_slam, rsv_frac=rsv_frac_slam, n_jobs=n_jobs_slam, verbose=verbose_slam) invars_norm = fluxes_cont**2 * invars_resamp ### Initial estimation: get initial estimate of parameters by chi2 best match label_init = model.predict_labels_quick(fluxes_norm, invars_norm, n_jobs=1) ### SLAM prediction: optimize parameters results_pred = model.predict_labels_multi(label_init, fluxes_norm, invars_norm) label_pred = np.array([label['x'] for label in results_pred]) std_pred = np.array([label['pstd'] for label in results_pred]) ### modify the following block for SLAM style # Create results array. ### log_g, log_teff, fe_h = predictions.T ### teff = 10**log_teff teff = label_pred[:, 0] m_h = label_pred[:, 1] log_g = label_pred[:, 2] alpha_m = label_pred[:, 3] u_teff = std_pred[:, 0] u_m_h = std_pred[:, 1] u_log_g = std_pred[:, 2] u_alpha_m = std_pred[:, 3] result = dict( snr=spectrum.meta["snr"], teff=teff.tolist(), m_h=m_h.tolist(), logg=log_g.tolist(), alpha_m=alpha_m.tolist(), u_teff=u_teff.tolist(), u_m_h=u_m_h.tolist(), u_logg=u_log_g.tolist(), u_alpha_m=u_alpha_m.tolist(), ) # Write the result to database. ### create_task_output(instance, astradb.ApogeeNet, **result) create_task_output(instance, astradb.SLAM, **result) log.info(f"Completed processing of {total} primary keys")
def write_database_outputs( task, ti, run_id, element_from_task_id_callable=None, **kwargs ): """ Collate outputs from upstream FERRE executions and write them to an ASPCAP database table. :param task: This task, as given by the Airflow context dictionary. :param ti: This task instance, as given by the Airflow context dictionary. :param run_id: This run ID, as given by the Airflow context dictionary. :param element_from_task_id_callable: [optional] A Python callable that returns the chemical element, given a task ID. """ log.debug(f"Writing ASPCAP database outputs") pks = [] for upstream_task in task.upstream_list: pks.append(ti.xcom_pull(task_ids=upstream_task.task_id)) log.debug(f"Upstream primary keys: {pks}") # Group them together by source. instance_pks = [] for source_pks in list(zip(*pks)): # The one with the lowest primary key will be the stellar parameters. sp_pk, *abundance_pks = sorted(source_pks) sp_instance = session.query(astradb.TaskInstance).filter(astradb.TaskInstance.pk == sp_pk).one_or_none() abundance_instances = session.query(astradb.TaskInstance).filter(astradb.TaskInstance.pk.in_(abundance_pks)).all() # Get parameters that are in common to all instances. keep = {} for key, value in sp_instance.parameters.items(): for instance in abundance_instances: if instance.parameters[key] != value: break else: keep[key] = value # Create a task instance. instance = create_task_instance( dag_id=task.dag_id, task_id=task.task_id, run_id=run_id, parameters=keep ) # Create a partial results table. keys = ["snr"] label_names = ("teff", "logg", "metals", "log10vdop", "o_mg_si_s_ca_ti", "lgvsini", "c", "n") for key in label_names: keys.extend([key, f"u_{key}"]) results = dict([(key, getattr(sp_instance.output, key)) for key in keys]) # Now update with elemental abundance instances. for el_instance in abundance_instances: if element_from_task_id_callable is not None: element = element_from_task_id_callable(el_instance.task_id).lower() else: element = el_instance.task_id.split(".")[-1].lower() # Check what is not frozen. thawed_label_names = [] ignore = ("lgvsini", ) # Ignore situations where lgvsini was missing from grid and it screws up the task for key in label_names: if key not in ignore and not getattr(el_instance.output, f"frozen_{key}"): thawed_label_names.append(key) if len(thawed_label_names) > 1: log.warning(f"Multiple thawed label names for {element} {el_instance}: {thawed_label_names}") values = np.hstack([getattr(el_instance.output, ln) for ln in thawed_label_names]).tolist() u_values = np.hstack([getattr(el_instance.output, f"u_{ln}") for ln in thawed_label_names]).tolist() results.update({ f"{element}_h": values, f"u_{element}_h": u_values, }) # Include associated primary keys so we can reference back to original parameters, etc. results["associated_ti_pks"] = [sp_pk, *abundance_pks] log.debug(f"Results entry: {results}") # Create an entry in the output interface table. # (We will update this later with any elemental abundance results). # TODO: Should we link back to the original FERRE primary keys? output = create_task_output( instance, astradb.Aspcap, **results ) log.debug(f"Created output {output} for instance {instance}") instance_pks.append(instance.pk) return instance_pks