Beispiel #1
0
    def __init__(self, spectrum_library_to_analyse, workspace, pipeline):
        """
        Open the spectrum library containing the spectra that we are to analyse.

        :param spectrum_library_to_analyse:
            The name of the spectrum library we are to analyse
        :type spectrum_library_to_analyse:
            str
        :param workspace:
            Directory where we expect to find spectrum libraries.
        :type workspace:
            str
        :param pipeline:
            The Pipeline we are to run spectra through.
        :type pipeline:
            Pipeline
        """

        # Initialise pipeline manager
        super(PipelineManagerReadFromSpectrumLibrary,
              self).__init__(pipeline=pipeline)

        # Open the spectrum library we are reading from
        self.spectrum_library_to_analyse = spectrum_library_to_analyse

        spectra = SpectrumLibrarySqlite.open_and_search(
            library_spec=spectrum_library_to_analyse,
            workspace=workspace,
            extra_constraints={})
        input_library, input_library_items = [
            spectra[i] for i in ("library", "items")
        ]

        input_library_ids = [i["specId"] for i in input_library_items]

        self.input_library = input_library
        self.input_library_items = input_library_items
        self.input_library_ids = input_library_ids
        self.spectrum_counter = 0
                    default="/tmp/gaussian_convolution_{}.log".format(pid),
                    dest="log_to",
                    help="Specify a log file where we log our progress.")
args = parser.parse_args()

logger.info("Calculating SNR ratios of spectra in <{}>".format(
    args.input_library))

# Set path to workspace where we create libraries of spectra
workspace = args.workspace if args.workspace else os_path.join(
    our_path, "../../../workspace")
os.system("mkdir -p {}".format(workspace))

# Open input SpectrumLibrary, and search for flux normalised spectra meeting our filtering constraints
spectra = SpectrumLibrarySqlite.open_and_search(
    library_spec=args.input_library,
    workspace=workspace,
    extra_constraints={"continuum_normalised": 0})

# Get a list of the spectrum IDs which we were returned
input_library, input_spectra_ids, input_spectra_constraints = [
    spectra[i] for i in ("library", "items", "constraints")
]

# Parse any definitions of SNR we were supplied on the command line
if (args.snr_definitions is None) or (len(args.snr_definitions) < 1):
    snr_definitions = None
else:
    snr_definitions = []
    for snr_definition in args.snr_definitions:
        words = snr_definition.split(",")
        snr_definitions.append([words[0], float(words[1]), float(words[2])])
# Set path to workspace where we expect to find libraries of spectra
our_path = os_path.split(os_path.abspath(__file__))[0]
workspace = os_path.join(our_path, "../../../../workspace")

# Create directory to store output files in
os.system("mkdir -p {}".format(args.output_stub))

# Fetch title for this Cannon run
cannon_output = json.loads(
    gzip.open(args.cannon + ".summary.json.gz", "rt").read())
description = cannon_output['description']

# Open spectrum library we originally trained the Cannon on
training_spectra_info = SpectrumLibrarySqlite.open_and_search(
    library_spec=cannon_output["train_library"],
    workspace=workspace,
    extra_constraints={"continuum_normalised": 1})

training_library, training_library_items = [
    training_spectra_info[i] for i in ("library", "items")
]

# Load training set
training_library_ids = [i["specId"] for i in training_library_items]
training_spectra = training_library.open(ids=training_library_ids)

# Recreate a Cannon instance, using the saved state
censoring_masks = cannon_output["censoring_mask"]
if censoring_masks is not None:
    for key, value in censoring_masks.items():
        censoring_masks[key] = np.asarray(value)
from fourgp_speclib import SpectrumLibrarySqlite

# Start logging our progress
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s:%(filename)s:%(message)s',
                    datefmt='%d/%m/%Y %H:%M:%S')
logger = logging.getLogger(__name__)
logger.info("Synthesizing spectra of pepsi")

# Instantiate base synthesizer
synthesizer = Synthesizer(library_name="turbo_pepsi_replica_3label",
                          logger=logger,
                          docstring=__doc__)

spectra = SpectrumLibrarySqlite.open_and_search(
        library_spec='pepsi_4fs_hrs/',
        workspace='/home/travegre/Projects/4GP/4most-4gp-scripts/workspace/',
        extra_constraints={"continuum_normalised": True}
    )
pepsi_library, pepsi_library_items = [spectra[i] for i in ("library", "items")]
# Load test set
pepsi_library_ids = [i["specId"] for i in pepsi_library_items]

spectra = [pepsi_library.open(ids=i).extract_item(0) for i in pepsi_library_ids]
print(spectra)

star_list = []
for spectrum in spectra:
  try:

    #star_list.append({'name': spectrum.metadata['Starname'], 'Teff': spectrum.metadata['Teff'], 'logg': spectrum.metadata['logg'], '[Fe/H]': spectrum.metadata['[Fe/H]'], 'microturbulence': spectrum.metadata['vmic_GES'], 'extra_metadata': {'set_id': 1}})
    star_list.append({'name': spectrum.metadata['Starname'], 'Teff': spectrum.metadata['Teff'], 'logg': spectrum.metadata['logg'], '[Fe/H]': spectrum.metadata['[Fe/H]'], 'extra_metadata': {'set_id': 1}})
def main():
    """
    Main entry point for running the Payne.
    """
    global logger

    logging.basicConfig(
        level=logging.INFO,
        format='[%(asctime)s] %(levelname)s:%(filename)s:%(message)s',
        datefmt='%d/%m/%Y %H:%M:%S')
    logger = logging.getLogger(__name__)

    # Read input parameters
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        '--test',
        required=True,
        dest='test_library',
        help=
        "Library of spectra to test the trained Payne on. Stars may be filtered by parameters by "
        "placing a comma-separated list of constraints in [] brackets after the name of the "
        "library. Use the syntax [Teff=3000] to demand equality, or [0<[Fe/H]<0.2] to specify a "
        "range.")
    parser.add_argument(
        '--train',
        required=True,
        dest='train_library',
        help=
        "Library of labelled spectra to train the Payne on. Stars may be filtered by parameters "
        "by placing a comma-separated list of constraints in [] brackets after the name of the "
        "library. Use the syntax [Teff=3000] to demand equality, or [0<[Fe/H]<0.2] to specify a "
        "range.")
    parser.add_argument(
        '--workspace',
        dest='workspace',
        default="",
        help="Directory where we expect to find spectrum libraries.")
    parser.add_argument(
        '--train-batch-number',
        required=False,
        dest='train_batch_number',
        type=int,
        default=0,
        help=
        "If training pixels in multiple batches on different machines, then this is the number of "
        "the batch of pixels we are to train. It should be in the range 0 .. batch_count-1 "
        "inclusive. If it is -1, then we skip training to move straight to testing."
    )
    parser.add_argument(
        '--test-batch-number',
        required=False,
        dest='test_batch_number',
        type=int,
        default=0,
        help=
        "If testing spectra in multiple batches on different machines, then this is the number of "
        "the batch of spectra we are to test. It should be in the range 0 .. test_batch_count-1 "
        "inclusive.")
    parser.add_argument(
        '--num-training-workers',
        required=False,
        dest='train_batch_count',
        type=int,
        default=1,
        help=
        "If training pixels in multiple batches on different machines, then this is the number "
        "of nodes/workers/batches.")
    parser.add_argument(
        '--num-testing-workers',
        required=False,
        dest='test_batch_count',
        type=int,
        default=1,
        help=
        "If testing spectra in multiple batches on different machines, then this is the number "
        "of nodes/workers/batches.")
    parser.add_argument(
        '--reload-payne',
        required=False,
        dest='reload_payne',
        default=None,
        help=
        "Skip training step, and reload a Payne that we've previously trained."
    )
    parser.add_argument('--description',
                        dest='description',
                        help="A description of this fitting run.")
    parser.add_argument(
        '--labels',
        dest='labels',
        default="Teff,logg,[Fe/H]",
        help="List of the labels the Payne is to learn to estimate.")
    parser.add_argument(
        '--label-expressions',
        dest='label_expressions',
        default="",
        help="List of the algebraic labels the Payne is to learn to estimate "
        "(e.g. photometry_B - photometry_V).")
    parser.add_argument(
        '--labels-individual',
        dest='labels_individual',
        default="",
        help="List of the labels the Payne is to fit in separate fitting runs."
    )
    parser.add_argument('--censor-scheme',
                        default="1",
                        dest='censor_scheme',
                        help="Censoring scheme version to use (1, 2 or 3).")
    parser.add_argument(
        '--censor',
        default="",
        dest='censor_line_list',
        help=
        "Optional list of line positions for the Payne to fit, ignoring continuum between."
    )
    parser.add_argument('--output-file',
                        default="./test_cannon.out",
                        dest='output_file',
                        help="Data file to write output to.")
    parser.add_argument(
        '--assume-scaled-solar',
        action='store_true',
        dest="assume_scaled_solar",
        help=
        "Assume scaled solar abundances for any elements which don't have abundances individually "
        "specified. Useful for working with incomplete data sets.")
    parser.add_argument(
        '--no-assume-scaled-solar',
        action='store_false',
        dest="assume_scaled_solar",
        help=
        "Do not assume scaled solar abundances; throw an error if training set is has missing "
        "labels.")
    parser.set_defaults(assume_scaled_solar=False)
    parser.add_argument('--multithread',
                        action='store_true',
                        dest="multithread",
                        help="Use multiple thread to speed Payne up.")
    parser.add_argument(
        '--nothread',
        action='store_false',
        dest="multithread",
        help="Do not use multiple threads - use only one CPU core.")
    parser.set_defaults(multithread=True)
    parser.add_argument(
        '--interpolate',
        action='store_true',
        dest="interpolate",
        help=
        "Interpolate the test spectra on the training spectra's wavelength raster. DANGEROUS!"
    )
    parser.add_argument(
        '--nointerpolate',
        action='store_false',
        dest="interpolate",
        help="Do not interpolate the test spectra onto a different raster.")
    parser.set_defaults(interpolate=False)
    parser.add_argument(
        '--train-wavelength-window',
        dest="train_wavelength_window",
        help="Use only the selected wavelength region for the training")
    parser.set_defaults(train_wavelength_window=False)
    parser.add_argument(
        '--neuron-count',
        dest="neuron_count",
        help="Number of neurons in each of the Payne NN layers")
    parser.set_defaults(neuron_count=10)
    args = parser.parse_args()

    logger.info("Testing Payne with arguments <{}> <{}> <{}> <{}>".format(
        args.test_library, args.train_library, args.censor_line_list,
        args.output_file))

    # List of labels over which we are going to test the performance of the Payne
    test_label_fields = args.labels.split(",")

    # List of labels we're going to fit individually
    if args.labels_individual:
        test_labels_individual = [
            i.split("+") for i in args.labels_individual.split(",")
        ]
    else:
        test_labels_individual = [[]]

    # Set path to workspace where we expect to find libraries of spectra
    our_path = os_path.split(os_path.abspath(__file__))[0]
    workspace = args.workspace if args.workspace else os_path.join(
        our_path, "../../../workspace")

    # Open training set
    spectra = SpectrumLibrarySqlite.open_and_search(
        library_spec=args.train_library,
        workspace=workspace,
        extra_constraints={"continuum_normalised": True})
    training_library, training_library_items = [
        spectra[i] for i in ("library", "items")
    ]

    # Open test set
    spectra = SpectrumLibrarySqlite.open_and_search(
        library_spec=args.test_library,
        workspace=workspace,
        extra_constraints={"continuum_normalised": True})
    test_library, test_library_items = [
        spectra[i] for i in ("library", "items")
    ]

    # Load training set
    training_library_ids_all = [i["specId"] for i in training_library_items]
    training_spectra_all = training_library.open(ids=training_library_ids_all)

    raster = training_spectra_all.wavelengths

    # Load test set
    test_library_ids = [i["specId"] for i in test_library_items]

    # Fit each set of labels we're fitting individually, one by one
    for labels_individual_batch_count, test_labels_individual_batch in enumerate(
            test_labels_individual):

        # Create filename for the output from this Payne run
        output_filename = args.output_file
        # If we're fitting elements individually, individually number the runs to fit each element
        if len(test_labels_individual) > 1:
            output_filename += "-{:03d}".format(labels_individual_batch_count)

        # If requested, fill in any missing labels on the training set by assuming scaled-solar abundances
        if args.assume_scaled_solar:
            training_spectra = autocomplete_scaled_solar_abundances(
                input_spectra=training_spectra_all,
                label_list=test_label_fields + test_labels_individual_batch)
        else:
            training_spectra = filter_training_spectra(
                input_spectra=training_spectra_all,
                label_list=test_label_fields + test_labels_individual_batch,
                input_library=training_library,
                input_spectrum_ids=training_library_ids_all)

        # Evaluate labels which are calculated via metadata expressions
        test_labels_expressions = []
        if args.label_expressions.strip():
            test_labels_expressions = args.label_expressions.split(",")
            evaluate_computed_labels(label_expressions=test_labels_expressions,
                                     spectra=training_spectra)

        # Make combined list of all labels the Payne is going to fit
        test_labels = test_label_fields + test_labels_individual_batch + test_labels_expressions
        logger.info("Beginning fit of labels <{}>.".format(
            ",".join(test_labels)))

        # If required, generate the censoring masks
        censoring_masks = create_censoring_masks(
            censoring_scheme=int(args.censor_scheme),
            raster=raster,
            censoring_line_list=args.censor_line_list,
            label_fields=test_label_fields + test_labels_individual_batch,
            label_expressions=test_labels_expressions,
            logger=logger)

        # Construct and train a model
        time_training_start = time.time()

        if args.train_wavelength_window:
            train_window_mask = (training_spectra.wavelengths > float(args.train_wavelength_window.split('-')[0])) \
                                & (training_spectra.wavelengths < float(args.train_wavelength_window.split('-')[1]))
            training_spectra.wavelengths = training_spectra.wavelengths[
                train_window_mask]
            training_spectra.values = training_spectra.values[:,
                                                              train_window_mask]
            training_spectra.value_errors = training_spectra.value_errors[:,
                                                                          train_window_mask]

        if args.reload_payne == 'true':
            model = PayneInstanceTing(
                training_set=training_spectra,
                label_names=test_labels,
                neuron_count=int(args.neuron_count),
                batch_number=args.train_batch_number,
                batch_count=args.train_batch_count,
                censors=censoring_masks,
                threads=None if args.multithread else 1,
                training_data_archive=output_filename,
                load_from_archive=True,
            )
        else:
            model = PayneInstanceTing(training_set=training_spectra,
                                      label_names=test_labels,
                                      neuron_count=int(args.neuron_count),
                                      batch_number=args.train_batch_number,
                                      batch_count=args.train_batch_count,
                                      censors=censoring_masks,
                                      threads=None if args.multithread else 1,
                                      training_data_archive=output_filename)

        time_training_end = time.time()

        # Plot some characteristic spectra from the Payne generative model
        if False:

            def sigmoid_def(z):
                return 1.0 / (1.0 + np.exp(-z))

            payne_status = model._payne_status
            w_array_0 = payne_status["w_array_0"]
            w_array_1 = payne_status["w_array_1"]
            w_array_2 = payne_status["w_array_2"]
            b_array_0 = payne_status["b_array_0"]
            b_array_1 = payne_status["b_array_1"]
            b_array_2 = payne_status["b_array_2"]
            x_min = payne_status["x_min"]
            x_max = payne_status["x_max"]

            # logging.info(w_array_0.shape) # dim1 - N pixels, dim2 - N neurons, dim3 - N labels
            labels = np.array([[5000, 4, 0], [5500, 4, 0], [5000, 4, -1.],
                               [5000, 5, 0], [4000, 4, 0]])
            labels = (labels - x_min) / (x_max - x_min) - 0.5

            import matplotlib
            #matplotlib.use('Agg')
            import matplotlib.pyplot as plt
            fig = plt.figure(figsize=(12, 8), dpi=200)
            for i in training_spectra.values:
                plt.plot(training_spectra.wavelengths, i, lw=0.5, alpha=0.2)
            colors = ['yellow', 'orange', 'red', 'green', 'blue']
            for j, i in enumerate(labels):
                predict_flux = w_array_2 * sigmoid_def(
                    np.sum(w_array_1 *
                           (sigmoid_def(np.dot(w_array_0, i) + b_array_0)),
                           axis=1) + b_array_1) + b_array_2
                plt.plot(training_spectra.wavelengths,
                         predict_flux,
                         lw=2,
                         c=colors[j])
            plt.show()
            fig.savefig("{:s}.characteristic_gen_model_plot.png".format(
                output_filename),
                        format='png')

        # Test the model
        if not os.path.exists(
                os.path.join(
                    output_filename,
                    "batch_{:04d}_of_{:04d}.full.json.gz".format(
                        args.test_batch_number, args.test_batch_count))):
            N = len(test_library_ids)
            time_taken = np.zeros(N)
            results = []

            spec_start = (N // args.test_batch_count +
                          1) * args.test_batch_number
            spec_end = min(N, (N // args.test_batch_count + 1) *
                           (args.test_batch_number + 1))

            threads = 1  #cpu_count()
            #srng = split_seq(range(spec_start, spec_end), threads)
            if not args.train_wavelength_window:
                train_window_mask = False

            params = [
                spec_start, spec_end, args, training_spectra, censoring_masks,
                model, test_labels, train_window_mask
            ]
            manager = multiprocessing.Manager()

            for batch in range(spec_start, spec_end)[::threads]:
                batch = [[
                    i,
                    test_library.open(ids=test_library_ids[i]).extract_item(0)
                ] + params for i in range(batch, batch + threads)
                         if i < len(test_library_ids)]

                dicti = manager.list()

                ps = []
                for i in batch:
                    spectrum = i[1]
                    ref_labels = []
                    for j in test_labels:
                        try:
                            print(j, ' ', spectrum.metadata[j])
                            ref_labels.append(spectrum.metadata[j])
                        except:
                            continue

                    #if len(ref_labels) != 3:
                    #    continue

                    #for j in test_labels[3:]:
                    #    ref_labels.append(0)

                    p = multiprocessing.Process(target=parallel_fit,
                                                args=(i, dicti,
                                                      np.array(ref_labels)))
                    ps.append(p)
                    p.start()

                for p in ps:
                    p.join()

                #with Pool(threads) as pool:
                #    batch_results = pool.map(parallel_fit, [[i, test_library.open(ids=test_library_ids[i]).extract_item(0)]+params for i in range(batch, batch+threads) if i<len(test_library_ids)])

                for result in dicti:
                    results.append(result)

            # Report time taken
            logger.info(
                "Fitting of {:d} spectra completed. Took {:.2f} +/- {:.2f} sec / spectrum."
                .format((spec_end - spec_start), np.mean(time_taken),
                        np.std(time_taken)))

            # Create output data structure
            censoring_output = None
            if censoring_masks is not None:
                censoring_output = dict([
                    (label, tuple([int(i) for i in mask]))
                    for label, mask in censoring_masks.items()
                ])

            output_data = {
                "hostname": os.uname()[1],
                "generator": __file__,
                "4gp_version": fourgp_version,
                "cannon_version": None,
                "payne_version": model.payne_version,
                "start_time": time_training_start,
                "end_time": time.time(),
                "training_time": time_training_end - time_training_start,
                "description": args.description,
                "train_library": args.train_library,
                "test_library": args.test_library,
                "tolerance": None,
                "assume_scaled_solar": args.assume_scaled_solar,
                "line_list": args.censor_line_list,
                "labels": test_labels,
                "wavelength_raster": tuple(raster),
                "censoring_mask": censoring_output
            }

            # Write brief summary of run to JSON file, without masses of data
            #with gzip.open("{:s}.summary.json.gz".format(output_filename), "wt") as f:
            #    f.write(json.dumps(output_data, indent=2))
            with gzip.open(
                    os.path.join(
                        output_filename,
                        "batch_{:04d}_of_{:04d}.summary.json.gz".format(
                            args.test_batch_number, args.test_batch_count)),
                    "wt") as f:
                f.write(json.dumps(output_data, indent=2))

            # Write full results to JSON file
            output_data["spectra"] = results
            #with gzip.open("{:s}.full.json.gz".format(output_filename), "wt") as f:
            #    f.write(json.dumps(output_data, indent=2))
            with gzip.open(
                    os.path.join(
                        output_filename,
                        "batch_{:04d}_of_{:04d}.full.json.gz".format(
                            args.test_batch_number, args.test_batch_count)),
                    "wt") as f:
                f.write(json.dumps(output_data, indent=2))

            logging.info(
                "Saving results, batch {:04d} of {:04d} completed".format(
                    args.test_batch_number, args.test_batch_count))
        else:
            # Load teh test results from batches and join them
            logging.info("Loading Payne results from disk")
            payne_batches_summary = {}
            payne_batches_full = {}
            for i in range(args.test_batch_count):
                filename_summary = os.path.join(
                    output_filename,
                    "batch_{:04d}_of_{:04d}.summary.json.gz".format(
                        i, args.test_batch_count))
                filename_full = os.path.join(
                    output_filename,
                    "batch_{:04d}_of_{:04d}.full.json.gz".format(
                        i, args.test_batch_count))

                with gzip.open(filename_summary, "r") as f:
                    payne_batches_summary.update(
                        json.loads(f.read().decode('utf-8')))

                with gzip.open(filename_full, "r") as f:
                    if i == 0:
                        payne_batches_full.update(
                            json.loads(f.read().decode('utf-8')))
                    else:
                        payne_batches_full['spectra'].extend(
                            json.loads(f.read().decode('utf-8'))['spectra'])


                assert os.path.exists(filename_summary), "Could not proceed with joinning results, because " \
                                                              "test data for batch {:d} of spectra is not present " \
                                                              "on this server.".format(i)

            logging.info("Payne results loaded successfully")

            with gzip.open("{:s}.full.json.gz".format(output_filename),
                           "wt") as f:
                f.write(json.dumps(payne_batches_full, indent=2))

            with gzip.open("{:s}.summary.json.gz".format(output_filename),
                           "wt") as f:
                f.write(json.dumps(payne_batches_summary, indent=2))

            logging.info("Payne batches merged successfully")
Beispiel #6
0
def main():
    """
    Main entry point for running the Cannon.
    """
    logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s:%(filename)s:%(message)s',
                        datefmt='%d/%m/%Y %H:%M:%S')
    logger = logging.getLogger(__name__)

    # Read input parameters
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--test', required=True, dest='test_library',
                        help="Library of spectra to test the trained Cannon on. Stars may be filtered by parameters by "
                             "placing a comma-separated list of constraints in [] brackets after the name of the "
                             "library. Use the syntax [Teff=3000] to demand equality, or [0<[Fe/H]<0.2] to specify a "
                             "range.")
    parser.add_argument('--train', required=False, dest='train_library', default=None,
                        help="Library of labelled spectra to train the Cannon on. Stars may be filtered by parameters "
                             "by placing a comma-separated list of constraints in [] brackets after the name of the "
                             "library. Use the syntax [Teff=3000] to demand equality, or [0<[Fe/H]<0.2] to specify a "
                             "range.")
    parser.add_argument('--workspace', dest='workspace', default="",
                        help="Directory where we expect to find spectrum libraries.")
    parser.add_argument('--cannon-version', default="casey_old", dest='cannon_version',
                        choices=("casey_old", "casey_new", "anna_ho"),
                        help="Select which implementation of the Cannon to use: Andy Casey's or Anna Ho's.")
    parser.add_argument('--polynomial-order', default=2, dest='polynomial_order', type=int,
                        help="The maximum order of polynomials to use as basis functions in the Cannon.")
    parser.add_argument('--continuum-normalisation', default="none", dest='continuum_normalisation',
                        help="Select continuum normalisation method: none, running_mean or polynomial.")
    parser.add_argument('--reload-cannon', required=False, dest='reload_cannon', default=None,
                        help="Skip training step, and reload a Cannon that we've previously trained. Specify the full "
                             "path to the .cannon file containing the trained Cannon, but without the .cannon suffix.")
    parser.add_argument('--description', dest='description',
                        help="A description of this fitting run.")
    parser.add_argument('--labels', dest='labels',
                        default="Teff,logg,[Fe/H]",
                        help="List of the labels the Cannon is to learn to estimate.")
    parser.add_argument('--label-expressions', dest='label_expressions',
                        default="",
                        help="List of the algebraic labels the Cannon is to learn to estimate "
                             "(e.g. photometry_B - photometry_V).")
    parser.add_argument('--labels-individual', dest='labels_individual',
                        default="",
                        help="List of the labels the Cannon is to fit in separate fitting runs.")
    parser.add_argument('--censor-scheme', default="1", dest='censor_scheme',
                        help="Censoring scheme version to use (1, 2 or 3).")
    parser.add_argument('--censor', default="", dest='censor_line_list',
                        help="Optional list of line positions for the Cannon to fit, ignoring continuum between.")
    parser.add_argument('--tolerance', default=None, dest='tolerance', type=float,
                        help="The parameter xtol which is passed to the scipy optimisation routines as xtol to "
                             "determine whether they have converged.")
    parser.add_argument('--output-file', default="./test_cannon.out", dest='output_file',
                        help="Data file to write output to.")
    parser.add_argument('--assume-scaled-solar',
                        action='store_true',
                        dest="assume_scaled_solar",
                        help="Assume scaled solar abundances for any elements which don't have abundances individually "
                             "specified. Useful for working with incomplete data sets.")
    parser.add_argument('--no-assume-scaled-solar',
                        action='store_false',
                        dest="assume_scaled_solar",
                        help="Do not assume scaled solar abundances; throw an error if training set is has missing "
                             "labels.")
    parser.set_defaults(assume_scaled_solar=False)
    parser.add_argument('--multithread',
                        action='store_true',
                        dest="multithread",
                        help="Use multiple thread to speed Cannon up.")
    parser.add_argument('--nothread',
                        action='store_false',
                        dest="multithread",
                        help="Do not use multiple threads - use only one CPU core.")
    parser.set_defaults(multithread=True)
    parser.add_argument('--interpolate',
                        action='store_true',
                        dest="interpolate",
                        help="Interpolate the test spectra on the training spectra's wavelength raster. DANGEROUS!")
    parser.add_argument('--nointerpolate',
                        action='store_false',
                        dest="interpolate",
                        help="Do not interpolate the test spectra onto a different raster.")
    parser.set_defaults(interpolate=False)
    args = parser.parse_args()

    logging.info("Testing Cannon with arguments <{}> <{}> <{}> <{}>".format(args.test_library,
                                                                            args.train_library,
                                                                            args.censor_line_list,
                                                                            args.output_file))

    # Pick which Cannon version to use
    cannon_class, continuum_normalised_testing, continuum_normalised_training = \
        select_cannon(cannon_version=args.cannon_version,
                      continuum_normalisation=args.continuum_normalisation)

    # List of labels over which we are going to test the performance of the Cannon
    test_label_fields = args.labels.split(",")

    # List of labels we're going to fit individually
    if args.labels_individual:
        test_labels_individual = [i.split("+") for i in args.labels_individual.split(",")]
    else:
        test_labels_individual = [[]]

    # Set path to workspace where we expect to find libraries of spectra
    our_path = os_path.split(os_path.abspath(__file__))[0]
    workspace = args.workspace if args.workspace else os_path.join(our_path, "../../../workspace")

    # Find out whether we're reloading a previously saved Cannon
    reloading_cannon = args.reload_cannon is not None

    # Open training set
    training_library = training_library_ids_all = None
    if not reloading_cannon:
        spectra = SpectrumLibrarySqlite.open_and_search(
            library_spec=args.train_library,
            workspace=workspace,
            extra_constraints={"continuum_normalised": continuum_normalised_training}
        )
        training_library, training_library_items = [spectra[i] for i in ("library", "items")]

        # Make list of IDs of all spectra in the training set
        training_library_ids_all = [i["specId"] for i in training_library_items]

    # Open test set
    spectra = SpectrumLibrarySqlite.open_and_search(
        library_spec=args.test_library,
        workspace=workspace,
        extra_constraints={"continuum_normalised": continuum_normalised_testing}
    )
    test_library, test_library_items = [spectra[i] for i in ("library", "items")]

    # Make list of IDs of all spectra in the test set
    test_library_ids = [i["specId"] for i in test_library_items]

    # Fit each set of labels we're fitting individually, one by one
    for labels_individual_batch_count, test_labels_individual_batch in enumerate(test_labels_individual):

        # Create filename for the output from this Cannon run
        output_filename = args.output_file

        # If we're fitting elements individually, individually number the runs to fit each element
        if len(test_labels_individual) > 1:
            output_filename += "-{:03d}".format(labels_individual_batch_count)

        # Sequence of tasks if we're reloading a pre-saved Cannon from disk
        if reloading_cannon:

            # Load the JSON data that summarises the Cannon training that we're about to reload
            json_summary_filename = "{}.summary.json.gz".format(args.reload_cannon)
            cannon_pickle_filename = "{}.cannon".format(args.reload_cannon)

            with gzip.open(json_summary_filename, "rt") as f:
                summary_json = json.loads(f.read())

            raster = np.array(summary_json['wavelength_raster'])
            test_labels = summary_json['labels']
            training_library_ids = summary_json['training_spectra_ids']
            training_library_string = summary_json['train_library']
            assume_scaled_solar = summary_json['assume_scaled_solar']
            tolerance = summary_json['tolerance']
            line_list = summary_json['line_list']
            censoring_masks = None

            # If we're doing our own continuum normalisation, we need to treat each wavelength arm separately
            wavelength_arm_breaks = SpectrumProperties(raster).wavelength_arms()['break_points']

            time_training_start = time.time()
            model = cannon_class(training_set=None,
                                 wavelength_arms=wavelength_arm_breaks,
                                 load_from_file=cannon_pickle_filename,
                                 label_names=test_labels,
                                 tolerance=args.tolerance,
                                 polynomial_order=args.polynomial_order,
                                 censors=None,
                                 threads=None if args.multithread else 1
                                 )
            time_training_end = time.time()

        # Sequence of tasks if we're training a Cannon from scratch
        else:

            training_library_string = args.train_library
            assume_scaled_solar = args.assume_scaled_solar
            tolerance = args.tolerance
            line_list = args.censor_line_list

            # If requested, fill in any missing labels on the training set by assuming scaled-solar abundances
            if args.assume_scaled_solar:
                training_library_ids, training_spectra = autocomplete_scaled_solar_abundances(
                    training_library=training_library,
                    training_library_ids_all=training_library_ids_all,
                    label_list=test_label_fields + test_labels_individual_batch
                )

            # Otherwise we reject any training spectra which have incomplete labels
            else:
                training_library_ids, training_spectra = filter_training_spectra(
                    training_library=training_library,
                    training_library_ids_all=training_library_ids_all,
                    label_list=test_label_fields + test_labels_individual_batch
                )

            # Look up the raster on which the training spectra are sampled
            raster = training_spectra.wavelengths

            # Evaluate labels which are calculated via metadata expressions
            test_labels_expressions = []
            if args.label_expressions.strip():
                test_labels_expressions = args.label_expressions.split(",")
                evaluate_computed_labels(label_expressions=test_labels_expressions, spectra=training_spectra)

            # Make combined list of all labels the Cannon is going to fit
            test_labels = test_label_fields + test_labels_individual_batch + test_labels_expressions
            logging.info("Beginning fit of labels <{}>.".format(",".join(test_labels)))

            # If required, generate the censoring masks
            censoring_masks = create_censoring_masks(
                censoring_scheme=int(args.censor_scheme),
                raster=raster,
                censoring_line_list=args.censor_line_list,
                label_fields=test_label_fields + test_labels_individual_batch,
                label_expressions=test_labels_expressions
            )

            # If we're doing our own continuum normalisation, we need to treat each wavelength arm separately
            wavelength_arm_breaks = SpectrumProperties(raster).wavelength_arms()['break_points']

            # Construct and train a model
            time_training_start = time.time()
            model = cannon_class(training_set=training_spectra,
                                 wavelength_arms=wavelength_arm_breaks,
                                 label_names=test_labels,
                                 tolerance=args.tolerance,
                                 polynomial_order=args.polynomial_order,
                                 censors=censoring_masks,
                                 threads=None if args.multithread else 1
                                 )
            time_training_end = time.time()

            # Save the model
            model.save_model(filename="{:s}.cannon".format(output_filename),
                             overwrite=True)

        # Test the model
        N = len(test_library_ids)
        time_taken = np.zeros(N)
        results = []
        for index in range(N):
            test_spectrum_array = test_library.open(ids=test_library_ids[index])
            spectrum = test_spectrum_array.extract_item(0)
            logging.info("Testing {}/{}: {}".format(index + 1, N, spectrum.metadata['Starname']))

            # Calculate the time taken to process this spectrum
            time_start = time.time()

            # If requested, interpolate the test set onto the same raster as the training set. DANGEROUS!
            if args.interpolate:
                spectrum = resample_spectrum(spectrum=spectrum, training_spectra=training_spectra)

            # Pass spectrum to the Cannon
            labels, cov, meta = model.fit_spectrum(spectrum=spectrum)

            # Check whether Cannon failed
            if labels is None:
                continue

            # Measure the time taken
            time_end = time.time()
            time_taken[index] = time_end - time_start

            # Identify which star it is and what the SNR is
            star_name = spectrum.metadata["Starname"] if "Starname" in spectrum.metadata else ""
            uid = spectrum.metadata["uid"] if "uid" in spectrum.metadata else ""

            # From the label covariance matrix extract the standard deviation in each label value
            # (diagonal terms in the matrix are variances)
            if args.cannon_version == "anna_ho":
                err_labels = cov[0]
            else:
                err_labels = np.sqrt(np.diag(cov[0]))

            # Turn list of label values into a dictionary
            cannon_output = dict(list(zip(test_labels, labels[0])))

            # Add the standard deviations of each label into the dictionary
            cannon_output.update(dict(list(zip(["E_{}".format(label_name) for label_name in test_labels], err_labels))))

            # Add the star name and the SNR ratio of the test spectrum
            result = {"Starname": star_name,
                      "uid": uid,
                      "time": time_taken[index],
                      "spectrum_metadata": spectrum.metadata,
                      "cannon_output": cannon_output
                      }
            results.append(result)

        # Report time taken
        logging.info("Fitting of {:d} spectra completed. Took {:.2f} +/- {:.2f} sec / spectrum.".
                     format(N,
                            np.mean(time_taken),
                            np.std(time_taken)))

        # Create output data structure
        censoring_output = None
        if reloading_cannon:
            censoring_output = summary_json['censoring_mask']
        else:
            if censoring_masks is not None:
                censoring_output = dict([(label, tuple([int(i) for i in mask]))
                                         for label, mask in censoring_masks.items()])

        output_data = {
            "hostname": os.uname()[1],
            "generator": __file__,
            "4gp_version": fourgp_version,
            "cannon_version": model.cannon_version,
            "start_time": time_training_start,
            "end_time": time.time(),
            "training_time": time_training_end - time_training_start,
            "description": args.description,
            "train_library": training_library_string,
            "test_library": args.test_library,
            "training_spectra_ids": training_library_ids,
            "tolerance": tolerance,
            "assume_scaled_solar": assume_scaled_solar,
            "line_list": line_list,
            "labels": test_labels,
            "wavelength_raster": tuple(raster),
            "censoring_mask": censoring_output
        }

        # Write brief summary of run to JSON file, without masses of data
        with gzip.open("{:s}.summary.json.gz".format(output_filename), "wt") as f:
            f.write(json.dumps(output_data, indent=2))

        # Write full results to JSON file
        output_data["spectra"] = results
        with gzip.open("{:s}.full.json.gz".format(output_filename), "wt") as f:
            f.write(json.dumps(output_data, indent=2))
Beispiel #7
0
    output_select = random.uniform(a=0, b=weights_sum)
    for index, weight in enumerate(weights):
        output_select -= weight
        if output_select <= 0:
            selected_index = index
            break
    return selected_index


# Open input spectrum library(s), and fetch a list of all the flux-normalised spectra within each
input_libraries = []

if args.input_library is not None:
    input_libraries = [
        SpectrumLibrarySqlite.open_and_search(
            library_spec=item,
            workspace=workspace,
            extra_constraints={"continuum_normalised": 0})
        for item in args.input_library
    ]

# Report to user how many spectra we have just found
logger.info("Opening {:d} input libraries. These contain {:s} spectra.".format(
    len(input_libraries), str([len(x['items']) for x in input_libraries])))

# Open contaminating spectrum library(s), if any, and fetch a list of all the flux-normalised spectra within each
contamination_libraries = []
if args.contamination_library is not None:
    contamination_libraries = [
        SpectrumLibrarySqlite.open_and_search(
            library_spec=item,
            workspace=workspace,
def resample_templates(args, logger):
    """
    Resample a spectrum library of templates onto a fixed logarithmic stride, representing each of the 4MOST arms in
    turn. We use 4FS to down-sample the templates to the resolution of 4MOST observations, and automatically detect
    the list of arms contained within each 4FS mock observation. We then resample the 4FS output onto a new raster
    with fixed logarithmic stride.

    :param args:
        Object containing arguments supplied by the used, for example the name of the spectrum libraries we use for
        input and output. The required fields are defined by the user interface above.
    :param logger:
        A python logging object.
    :return:
        None.
    """
    # Set path to workspace where we expect to find libraries of spectra
    workspace = args.workspace if args.workspace else os_path.join(
        args.our_path, "../../../workspace")

    # Open input template spectra
    spectra = SpectrumLibrarySqlite.open_and_search(
        library_spec=args.templates_in,
        workspace=workspace,
        extra_constraints={"continuum_normalised": 0})

    templates_library, templates_library_items, templates_spectra_constraints = \
        [spectra[i] for i in ("library", "items", "constraints")]

    # Create new SpectrumLibrary to hold the resampled output templates
    library_path = os_path.join(workspace, args.templates_out)
    output_library = SpectrumLibrarySqlite(path=library_path, create=True)

    # Instantiate 4FS wrapper
    etc_wrapper = FourFS(path_to_4fs=os_path.join(args.binary_path,
                                                  "OpSys/ETC"),
                         snr_list=[250.],
                         magnitude=13,
                         snr_per_pixel=True)

    for input_spectrum_id in templates_library_items:
        logger.info("Working on <{}>".format(input_spectrum_id['filename']))

        # Open Spectrum data from disk
        input_spectrum_array = templates_library.open(
            ids=input_spectrum_id['specId'])

        # Load template spectrum (flux normalised)
        template_flux_normalised = input_spectrum_array.extract_item(0)

        # Look up the unique ID of the star we've just loaded
        # Newer spectrum libraries have a uid field which is guaranteed unique; for older spectrum libraries use
        # Starname instead.

        # Work out which field we're using (uid or Starname)
        spectrum_matching_field = 'uid' if 'uid' in template_flux_normalised.metadata else 'Starname'

        # Look up the unique ID of this object
        object_name = template_flux_normalised.metadata[
            spectrum_matching_field]

        # Search for the continuum-normalised version of this same object (which will share the same uid / name)
        search_criteria = {
            spectrum_matching_field: object_name,
            'continuum_normalised': 1
        }

        continuum_normalised_spectrum_id = templates_library.search(
            **search_criteria)

        # Check that continuum-normalised spectrum exists and is unique
        assert len(continuum_normalised_spectrum_id
                   ) == 1, "Could not find continuum-normalised spectrum."

        # Load the continuum-normalised version
        template_continuum_normalised_arr = templates_library.open(
            ids=continuum_normalised_spectrum_id[0]['specId'])

        # Turn the SpectrumArray we got back into a single Spectrum
        template_continuum_normalised = template_continuum_normalised_arr.extract_item(
            0)

        # Now create a mock observation of this template using 4FS
        logger.info("Passing template through 4FS")
        mock_observed_template = etc_wrapper.process_spectra(
            spectra_list=((template_flux_normalised,
                           template_continuum_normalised), ))

        # Loop over LRS and HRS
        for mode in mock_observed_template:

            # Loop over the spectra we simulated (there was only one!)
            for index in mock_observed_template[mode]:

                # Loop over the various SNRs we simulated (there was only one!)
                for snr in mock_observed_template[mode][index]:

                    # Create a unique ID for this arm's data
                    unique_id = hashlib.md5(os.urandom(32)).hexdigest()[:16]

                    # Import the flux- and continuum-normalised spectra separately, but give them the same ID
                    for spectrum_type in mock_observed_template[mode][index][
                            snr]:

                        # Extract continuum-normalised mock observation
                        logger.info("Resampling {} spectrum".format(mode))
                        mock_observed = mock_observed_template[mode][index][
                            snr][spectrum_type]

                        # Replace errors which are nans with a large value
                        mock_observed.value_errors[np.isnan(
                            mock_observed.value_errors)] = 1000.

                        # Check for NaN values in spectrum itself
                        if not np.all(np.isfinite(mock_observed.values)):
                            print(
                                "Warning: NaN values in template <{}>".format(
                                    template_flux_normalised.
                                    metadata['Starname']))
                            mock_observed.value_errors[np.isnan(
                                mock_observed.values)] = 1000.
                            mock_observed.values[np.isnan(
                                mock_observed.values)] = 1.

                        # Resample template onto a logarithmic raster of fixed step
                        resampler = SpectrumResampler(mock_observed)

                        # Construct the raster for each wavelength arm
                        wavelength_arms = SpectrumProperties(
                            mock_observed.wavelengths).wavelength_arms()

                        # Resample 4FS output for each arm onto a fixed logarithmic stride
                        for arm_count, arm in enumerate(
                                wavelength_arms["wavelength_arms"]):
                            arm_raster, mean_pixel_width = arm
                            name = "{}_{}".format(mode, arm_count)

                            arm_info = {
                                "lambda_min": arm_raster[0],
                                "lambda_max": arm_raster[-1],
                                "lambda_step": mean_pixel_width
                            }

                            arm_raster = logarithmic_raster(
                                lambda_min=arm_info['lambda_min'],
                                lambda_max=arm_info['lambda_max'],
                                lambda_step=arm_info['lambda_step'])

                            # Resample 4FS output onto a fixed logarithmic step
                            mock_observed_arm = resampler.onto_raster(
                                arm_raster)

                            # Save it into output spectrum library
                            output_library.insert(
                                spectra=mock_observed_arm,
                                filenames=input_spectrum_id['filename'],
                                metadata_list={
                                    "uid": unique_id,
                                    "template_id": object_name,
                                    "mode": mode,
                                    "arm_name":
                                    "{}_{}".format(mode, arm_count),
                                    "lambda_min": arm_raster[0],
                                    "lambda_max": arm_raster[-1],
                                    "lambda_step": mean_pixel_width
                                })
                    help="Separator to use between fields in the CSV output.")
parser.add_argument(
    '--workspace',
    dest='workspace',
    default="",
    help="Directory where we expect to find spectrum libraries.")
args = parser.parse_args()

# Set path to workspace where we expect to find libraries of spectra
our_path = os_path.split(os_path.abspath(__file__))[0]
workspace = args.workspace if args.workspace else os_path.join(
    our_path, "../../../workspace")

# Open spectrum library we're going to export from, and search for flux-normalised spectra meeting our filtering
# constraints
input_library_info = SpectrumLibrarySqlite.open_and_search(
    library_spec=args.library, workspace=workspace, extra_constraints={})

# Get a list of the spectrum IDs which we were returned
input_library, library_items = [
    input_library_info[i] for i in ("library", "items")
]
library_ids = [i["specId"] for i in library_items]

# Fetch list of all metadata fields, and sort it alphabetically
fields = [i.strip() for i in input_library.list_metadata_fields()]
fields.sort()

# At the top of the CSV file, write column headings with the field names
line = args.separator.join(fields)
print(line)
# Start logging our progress
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s:%(filename)s:%(message)s',
                    datefmt='%d/%m/%Y %H:%M:%S')
logger = logging.getLogger(__name__)
logger.info("Creating synthetic versions of stars from <{}>".format(input_library))

# Instantiate base synthesizer
synthesizer = Synthesizer(library_name="pepsi_synthetic",
                          logger=logger,
                          docstring=__doc__)

star_list = []

# Open input spectrum library
spectra = SpectrumLibrarySqlite.open_and_search(library_spec=input_library,
                                                workspace=synthesizer.workspace,
                                                extra_constraints={"continuum_normalised": 1}
                                                )

# Get a list of the spectrum IDs which we were returned
input_library, input_spectra_ids, input_spectra_constraints = [spectra[i] for i in ("library", "items", "constraints")]

# Loop over input spectra
for input_spectrum_id in input_spectra_ids:
    logger.info("Working on <{}>".format(input_spectrum_id['filename']))
    # Open Spectrum data from disk
    input_spectrum_array = input_library.open(ids=input_spectrum_id['specId'])

    # Turn SpectrumArray object into a Spectrum object
    input_spectrum = input_spectrum_array.extract_item(0)
    metadata = input_spectrum.metadata
def main():
    """
    Main entry point for running the Payne.
    """
    global logger

    logging.basicConfig(
        level=logging.INFO,
        format='[%(asctime)s] %(levelname)s:%(filename)s:%(message)s',
        datefmt='%d/%m/%Y %H:%M:%S')
    logger = logging.getLogger(__name__)

    # Read input parameters
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        '--test',
        required=True,
        dest='test_library',
        help=
        "Library of spectra to test the trained Payne on. Stars may be filtered by parameters by "
        "placing a comma-separated list of constraints in [] brackets after the name of the "
        "library. Use the syntax [Teff=3000] to demand equality, or [0<[Fe/H]<0.2] to specify a "
        "range.")
    parser.add_argument(
        '--train',
        required=True,
        dest='train_library',
        help=
        "Library of labelled spectra to train the Payne on. Stars may be filtered by parameters "
        "by placing a comma-separated list of constraints in [] brackets after the name of the "
        "library. Use the syntax [Teff=3000] to demand equality, or [0<[Fe/H]<0.2] to specify a "
        "range.")
    parser.add_argument(
        '--workspace',
        dest='workspace',
        default="",
        help="Directory where we expect to find spectrum libraries.")
    parser.add_argument(
        '--train-batch-number',
        required=False,
        dest='batch_number',
        type=int,
        default=0,
        help=
        "If training pixels in multiple batches on different machines, then this is the number of "
        "the batch of pixels we are to train. It should be in the range 0 .. batch_count-1 "
        "inclusive. If it is -1, then we skip training to move straight to testing."
    )
    parser.add_argument(
        '--train-batch-count',
        required=False,
        dest='batch_count',
        type=int,
        default=1,
        help=
        "If training pixels in multiple batches on different machines, then this is the number "
        "of batches.")
    parser.add_argument('--description',
                        dest='description',
                        help="A description of this fitting run.")
    parser.add_argument(
        '--labels',
        dest='labels',
        default="Teff,logg,[Fe/H]",
        help="List of the labels the Payne is to learn to estimate.")
    parser.add_argument(
        '--label-expressions',
        dest='label_expressions',
        default="",
        help="List of the algebraic labels the Payne is to learn to estimate "
        "(e.g. photometry_B - photometry_V).")
    parser.add_argument(
        '--labels-individual',
        dest='labels_individual',
        default="",
        help="List of the labels the Payne is to fit in separate fitting runs."
    )
    parser.add_argument('--censor-scheme',
                        default="1",
                        dest='censor_scheme',
                        help="Censoring scheme version to use (1, 2 or 3).")
    parser.add_argument(
        '--censor',
        default="",
        dest='censor_line_list',
        help=
        "Optional list of line positions for the Payne to fit, ignoring continuum between."
    )
    parser.add_argument('--output-file',
                        default="./test_cannon.out",
                        dest='output_file',
                        help="Data file to write output to.")
    parser.add_argument(
        '--assume-scaled-solar',
        action='store_true',
        dest="assume_scaled_solar",
        help=
        "Assume scaled solar abundances for any elements which don't have abundances individually "
        "specified. Useful for working with incomplete data sets.")
    parser.add_argument(
        '--no-assume-scaled-solar',
        action='store_false',
        dest="assume_scaled_solar",
        help=
        "Do not assume scaled solar abundances; throw an error if training set is has missing "
        "labels.")
    parser.set_defaults(assume_scaled_solar=False)
    parser.add_argument('--multithread',
                        action='store_true',
                        dest="multithread",
                        help="Use multiple thread to speed Payne up.")
    parser.add_argument(
        '--nothread',
        action='store_false',
        dest="multithread",
        help="Do not use multiple threads - use only one CPU core.")
    parser.set_defaults(multithread=True)
    parser.add_argument(
        '--interpolate',
        action='store_true',
        dest="interpolate",
        help=
        "Interpolate the test spectra on the training spectra's wavelength raster. DANGEROUS!"
    )
    parser.add_argument(
        '--nointerpolate',
        action='store_false',
        dest="interpolate",
        help="Do not interpolate the test spectra onto a different raster.")
    parser.set_defaults(interpolate=False)
    args = parser.parse_args()

    logger.info("Testing Payne with arguments <{}> <{}> <{}> <{}>".format(
        args.test_library, args.train_library, args.censor_line_list,
        args.output_file))

    # List of labels over which we are going to test the performance of the Payne
    test_label_fields = args.labels.split(",")

    # List of labels we're going to fit individually
    if args.labels_individual:
        test_labels_individual = [
            i.split("+") for i in args.labels_individual.split(",")
        ]
    else:
        test_labels_individual = [[]]

    # Set path to workspace where we expect to find libraries of spectra
    our_path = os_path.split(os_path.abspath(__file__))[0]
    workspace = args.workspace if args.workspace else os_path.join(
        our_path, "../../../workspace")

    # Open training set
    spectra = SpectrumLibrarySqlite.open_and_search(
        library_spec=args.train_library,
        workspace=workspace,
        extra_constraints={"continuum_normalised": True})
    training_library, training_library_items = [
        spectra[i] for i in ("library", "items")
    ]

    # Open test set
    spectra = SpectrumLibrarySqlite.open_and_search(
        library_spec=args.test_library,
        workspace=workspace,
        extra_constraints={"continuum_normalised": True})
    test_library, test_library_items = [
        spectra[i] for i in ("library", "items")
    ]

    # Load training set
    training_library_ids_all = [i["specId"] for i in training_library_items]
    training_spectra_all = training_library.open(ids=training_library_ids_all)
    raster = training_spectra_all.wavelengths

    # Load test set
    test_library_ids = [i["specId"] for i in test_library_items]

    # Fit each set of labels we're fitting individually, one by one
    for labels_individual_batch_count, test_labels_individual_batch in enumerate(
            test_labels_individual):

        # Create filename for the output from this Payne run
        output_filename = args.output_file
        # If we're fitting elements individually, individually number the runs to fit each element
        if len(test_labels_individual) > 1:
            output_filename += "-{:03d}".format(labels_individual_batch_count)

        # If requested, fill in any missing labels on the training set by assuming scaled-solar abundances
        if args.assume_scaled_solar:
            training_spectra = autocomplete_scaled_solar_abundances(
                input_spectra=training_spectra_all,
                label_list=test_label_fields + test_labels_individual_batch)
        else:
            training_spectra = filter_training_spectra(
                input_spectra=training_spectra_all,
                label_list=test_label_fields + test_labels_individual_batch,
                input_library=training_library,
                input_spectrum_ids=training_library_ids_all)

        # Evaluate labels which are calculated via metadata expressions
        test_labels_expressions = []
        if args.label_expressions.strip():
            test_labels_expressions = args.label_expressions.split(",")
            evaluate_computed_labels(label_expressions=test_labels_expressions,
                                     spectra=training_spectra)

        # Make combined list of all labels the Payne is going to fit
        test_labels = test_label_fields + test_labels_individual_batch + test_labels_expressions
        logger.info("Beginning fit of labels <{}>.".format(
            ",".join(test_labels)))

        # If required, generate the censoring masks
        censoring_masks = create_censoring_masks(
            censoring_scheme=int(args.censor_scheme),
            raster=raster,
            censoring_line_list=args.censor_line_list,
            label_fields=test_label_fields + test_labels_individual_batch,
            label_expressions=test_labels_expressions)

        # Construct and train a model
        time_training_start = time.time()

        model = PayneInstanceTing(training_set=training_spectra,
                                  label_names=test_labels,
                                  batch_number=args.batch_number,
                                  batch_count=args.batch_count,
                                  censors=censoring_masks,
                                  threads=None if args.multithread else 1,
                                  training_data_archive=output_filename)

        time_training_end = time.time()

        # Test the model
        N = len(test_library_ids)
        time_taken = np.zeros(N)
        results = []
        for index in range(N):
            test_spectrum_array = test_library.open(
                ids=test_library_ids[index])
            spectrum = test_spectrum_array.extract_item(0)
            logger.info("Testing {}/{}: {}".format(
                index + 1, N, spectrum.metadata['Starname']))

            # Calculate the time taken to process this spectrum
            time_start = time.time()

            # If requested, interpolate the test set onto the same raster as the training set. DANGEROUS!
            if args.interpolate:
                spectrum = resample_spectrum(spectrum=spectrum,
                                             training_spectra=training_spectra)

            # Pass spectrum to the Payne
            fit_data = model.fit_spectrum(spectrum=spectrum)

            # Check whether Payne failed
            # if labels is None:
            #    continue

            # Measure the time taken
            time_end = time.time()
            time_taken[index] = time_end - time_start

            # Identify which star it is and what the SNR is
            star_name = spectrum.metadata[
                "Starname"] if "Starname" in spectrum.metadata else ""
            uid = spectrum.metadata["uid"] if "uid" in spectrum.metadata else ""

            # Fudge the errors for now until I work this out
            err_labels = [0 for item in test_labels]

            # Turn list of label values into a dictionary
            payne_output = dict(list(zip(test_labels, fit_data['results'][0])))

            # Add the standard deviations of each label into the dictionary
            payne_output.update(
                dict(
                    list(
                        zip([
                            "E_{}".format(label_name)
                            for label_name in test_labels
                        ], err_labels))))

            # Add the star name and the SNR ratio of the test spectrum
            result = {
                "Starname": star_name,
                "uid": uid,
                "time": time_taken[index],
                "spectrum_metadata": spectrum.metadata,
                "cannon_output": payne_output
            }
            results.append(result)

        # Report time taken
        logger.info(
            "Fitting of {:d} spectra completed. Took {:.2f} +/- {:.2f} sec / spectrum."
            .format(N, np.mean(time_taken), np.std(time_taken)))

        # Create output data structure
        censoring_output = None
        if censoring_masks is not None:
            censoring_output = dict([
                (label, tuple([int(i) for i in mask]))
                for label, mask in censoring_masks.items()
            ])

        output_data = {
            "hostname": os.uname()[1],
            "generator": __file__,
            "4gp_version": fourgp_version,
            "cannon_version": None,
            "payne_version": model.payne_version,
            "start_time": time_training_start,
            "end_time": time.time(),
            "training_time": time_training_end - time_training_start,
            "description": args.description,
            "train_library": args.train_library,
            "test_library": args.test_library,
            "tolerance": None,
            "assume_scaled_solar": args.assume_scaled_solar,
            "line_list": args.censor_line_list,
            "labels": test_labels,
            "wavelength_raster": tuple(raster),
            "censoring_mask": censoring_output
        }

        # Write brief summary of run to JSON file, without masses of data
        with gzip.open("{:s}.summary.json.gz".format(output_filename),
                       "wt") as f:
            f.write(json.dumps(output_data, indent=2))

        # Write full results to JSON file
        output_data["spectra"] = results
        with gzip.open("{:s}.full.json.gz".format(output_filename), "wt") as f:
            f.write(json.dumps(output_data, indent=2))
synthesizer.set_star_list(star_list)

# Create new SpectrumLibrary
synthesizer.create_spectrum_library()

# Iterate over the spectra we're supposed to be synthesizing
synthesizer.do_synthesis()

# Close TurboSpectrum synthesizer instance
synthesizer.clean_up()

# Load spectrum
spectra = SpectrumLibrarySqlite.open_and_search(
    library_spec=synthesizer.args.library,
    workspace=synthesizer.workspace,
    extra_constraints={
        "Starname": "Sun",
        "continuum_normalised": 0
    })
input_library, input_spectra_ids, input_spectra_constraints = [
    spectra[i] for i in ("library", "items", "constraints")
]

input_spectrum_array = input_library.open(ids=input_spectra_ids[0]['specId'])
input_spectrum = input_spectrum_array.extract_item(0)

# Process spectra through reddening model
reddener = SpectrumReddener(input_spectrum=input_spectrum)

# Instantiate 4FS wrapper
etc_wrapper = FourFS(path_to_4fs=os_path.join(synthesizer.args.binary_path,
def tabulate_labels(library_list, label_list, output_file, workspace=None):
    """
    Take a SpectrumLibrary and tabulate a list of the stellar parameters of the stars within it.

    :param workspace:
        Path to the workspace where we expect to find SpectrumLibraries stored
    :param library_list:
        A list of the SpectrumLibraries we are to tabulate the contents of
    :param label_list:
        A list of the labels whose values we are to tabulate
    :param output_file:
        The filename of the ASCII output file we are to produce
    :return:
        None
    """

    # Set path to workspace where we expect to find libraries of spectra
    if workspace is None:
        our_path = os_path.split(os_path.abspath(__file__))[0]
        workspace = os_path.join(our_path, "../../../../workspace")

    # Open output data file
    with open(output_file, "w") as output:
        # Loop over each spectrum library in turn
        for library in library_list:

            # Extract name of spectrum library we are to open. Filter off any constraints which follow the name in []
            test = re.match("([^\[]*)\[(.*)\]$", library)
            if test is None:
                library_name = library
            else:
                library_name = test.group(1)

            # Open spectrum library and extract list of metadata fields which are defined on this library
            library_path = os_path.join(workspace, library_name)
            library_object = SpectrumLibrarySqlite(path=library_path,
                                                   create=False)
            metadata_fields = library_object.list_metadata_fields()

            # Now search library for spectra matching any input constraints, with additional constraint on only
            # returning continuum normalised spectra, if that field is defined for this library
            constraints = {}
            if "continuum_normalised" in metadata_fields:
                constraints["continuum_normalised"] = 1

            library_spectra = SpectrumLibrarySqlite.open_and_search(
                library_spec=library,
                workspace=workspace,
                extra_constraints=constraints)

            # Write column headers at the top of the output
            columns = label_list if label_list is not None else library_object.list_metadata_fields(
            )
            output.write("# ")
            for label in columns:
                output.write("{} ".format(label))
            output.write("\n")

            # Loop over objects in each spectrum library
            for item in library_spectra["items"]:
                metadata = library_object.get_metadata(ids=item['specId'])[0]

                for label in columns:
                    output.write("{} ".format(metadata.get(label, "-")))
                output.write("\n")