def test_hdf5_file_input():
    """Case where an hdf5 file is input. One of these also includes a
    normalized spectrum"""
    catfile = os.path.join(TEST_DATA_DIR, 'point_sources.cat')
    output_hdf5 = os.path.join(TEST_DATA_DIR, 'all_spectra.hdf5')
    sed_file = os.path.join(TEST_DATA_DIR,
                            'sed_file_with_normalized_dataset.hdf5')
    sed_catalog = spec.make_all_spectra(
        catfile,
        input_spectra_file=sed_file,
        normalizing_mag_column='nircam_f444w_magnitude',
        output_filename=output_hdf5)

    comparison = hdf5.open(
        os.path.join(TEST_DATA_DIR,
                     'output_spec_from_hdf5_input_including_normalized.hdf5'))
    constructed = hdf5.open(sed_catalog)
    for key in comparison:
        assert key in constructed.keys()
        assert all(comparison[key]["wavelengths"].value == constructed[key]
                   ["wavelengths"].value)
        assert all(comparison[key]["fluxes"].value == constructed[key]
                   ["fluxes"].value)
        assert comparison[key]["wavelengths"].unit == constructed[key][
            "wavelengths"].unit
        assert comparison[key]["fluxes"].unit == constructed[key][
            "fluxes"].unit

    cat_base = catfile.split('.')[0]
    outbase = cat_base + '_with_flambda.cat'
    flambda_output_catalog = os.path.join(TEST_DATA_DIR, outbase)
    os.remove(flambda_output_catalog)
    os.remove(sed_catalog)
def test_manual_inputs():
    """Case where spectra are input manually along side ascii catalog"""
    # Test case where spectra are input manually
    catfile = os.path.join(TEST_DATA_DIR, 'point_sources.cat')
    output_hdf5 = os.path.join(TEST_DATA_DIR, 'all_spectra.hdf5')
    hdf5file = os.path.join(TEST_DATA_DIR,
                            'sed_file_with_normalized_dataset.hdf5')
    sed_dict = hdf5.open(hdf5file)
    sed_catalog = spec.make_all_spectra(
        catfile,
        input_spectra=sed_dict,
        normalizing_mag_column='nircam_f444w_magnitude',
        output_filename=output_hdf5)
    constructed = hdf5.open(sed_catalog)
    comparison = hdf5.open(
        os.path.join(TEST_DATA_DIR, 'output_spec_from_manual_input.hdf5'))
    for key in comparison:
        assert key in constructed.keys()
        assert all(comparison[key]["wavelengths"].value == constructed[key]
                   ["wavelengths"].value)
        assert all(comparison[key]["fluxes"].value == constructed[key]
                   ["fluxes"].value)
        assert comparison[key]["wavelengths"].unit == constructed[key][
            "wavelengths"].unit
        assert comparison[key]["fluxes"].unit == constructed[key][
            "fluxes"].unit

    cat_base = catfile.split('.')[0]
    outbase = cat_base + '_with_flambda.cat'
    flambda_output_catalog = os.path.join(TEST_DATA_DIR, outbase)
    os.remove(flambda_output_catalog)
    os.remove(sed_catalog)
def test_multiple_ascii_catalogs():
    """Case where multiple ascii catalogs are input"""
    catfile = os.path.join(TEST_DATA_DIR, 'point_sources.cat')
    galfile = os.path.join(TEST_DATA_DIR, 'galaxies.cat')
    catalogs = [catfile, galfile]
    output_hdf5 = os.path.join(TEST_DATA_DIR, 'all_spectra.hdf5')
    sed_catalog = spec.make_all_spectra(catalogs, output_filename=output_hdf5)
    comparison = hdf5.open(
        os.path.join(TEST_DATA_DIR, 'source_sed_file_from_point_sources.hdf5'))
    constructed = hdf5.open(output_hdf5)
    for key in comparison:
        assert key in constructed.keys()
        assert all(comparison[key]["wavelengths"].value == constructed[key]
                   ["wavelengths"].value)
        assert all(comparison[key]["fluxes"].value == constructed[key]
                   ["fluxes"].value)
        assert comparison[key]["wavelengths"].unit == constructed[key][
            "wavelengths"].unit
        assert comparison[key]["fluxes"].unit == constructed[key][
            "fluxes"].unit
    for catfile in catalogs:
        cat_base = catfile.split('.')[0]
        outbase = cat_base + '_with_flambda.cat'
        flambda_output_catalog = os.path.join(TEST_DATA_DIR, outbase)
        os.remove(flambda_output_catalog)
    os.remove(sed_catalog)
def test_single_mag_column():
    """Case where input ascii catalog contains only one magntude column.
    In this case extrapolation is necessary."""
    catfile = os.path.join(TEST_DATA_DIR, 'point_sources_one_filter.cat')
    output_hdf5 = os.path.join(TEST_DATA_DIR, 'all_spectra.hdf5')
    sed_catalog = spec.make_all_spectra(catfile, output_filename=output_hdf5)
    comparison = hdf5.open(
        os.path.join(TEST_DATA_DIR, 'output_spec_from_one_filter.hdf5'))
    constructed = hdf5.open(output_hdf5)
    for key in comparison:
        assert key in constructed.keys()
        assert all(comparison[key]["wavelengths"].value == constructed[key]
                   ["wavelengths"].value)
        assert all(comparison[key]["fluxes"].value == constructed[key]
                   ["fluxes"].value)
        assert comparison[key]["wavelengths"].unit == constructed[key][
            "wavelengths"].unit
        assert comparison[key]["fluxes"].unit == constructed[key][
            "fluxes"].unit

    cat_base = catfile.split('.')[0]
    outbase = cat_base + '_with_flambda.cat'
    flambda_output_catalog = os.path.join(TEST_DATA_DIR, outbase)
    os.remove(flambda_output_catalog)
    os.remove(sed_catalog)
def test_manual_plus_file_inputs():
    """Case where spectra are input via hdf5 file as well as manually"""
    catfile = os.path.join(TEST_DATA_DIR, 'point_sources.cat')
    sed_file = os.path.join(TEST_DATA_DIR, 'sed_file_with_normalized_dataset.hdf5')
    output_hdf5 = os.path.join(TEST_DATA_DIR, 'all_spectra.hdf5')
    manual_sed = {}
    manual_sed[7] = {"wavelengths": [0.9, 1.4, 1.9, 3.5, 5.1]*u.micron,
                     "fluxes": [1e-17, 1.1e-17, 1.5e-17, 1.4e-17, 1.1e-17] * FLAMBDA_CGS_UNITS}
    sed_catalog = spec.make_all_spectra(catfile, input_spectra=manual_sed, input_spectra_file=sed_file,
                                        normalizing_mag_column='nircam_f444w_magnitude',
                                        output_filename=output_hdf5, module='A')
    comparison = hdf5.open(os.path.join(TEST_DATA_DIR, 'output_spec_from_file_plus_manual_input.hdf5'))
    constructed = hdf5.open(output_hdf5)
    for key in comparison:
        assert key in constructed.keys()
        assert all(comparison[key]["wavelengths"].value == constructed[key]["wavelengths"].value)
        assert all(comparison[key]["fluxes"].value == constructed[key]["fluxes"].value)
        assert comparison[key]["wavelengths"].unit == constructed[key]["wavelengths"].unit
        assert comparison[key]["fluxes"].unit == constructed[key]["fluxes"].unit

    cat_base = catfile.split('.')[0]
    outbase = cat_base + '_with_flambda.cat'
    flambda_output_catalog = os.path.join(TEST_DATA_DIR, outbase)
    os.remove(flambda_output_catalog)
    os.remove(sed_catalog)
def test_multiple_mag_columns():
    """Case where ascii catalog with multiple magnitude columns is input"""
    catfile = os.path.join(TEST_DATA_DIR, 'point_sources.cat')
    output_hdf5 = os.path.join(TEST_DATA_DIR, 'all_spectra.hdf5')
    sed_catalog = spec.make_all_spectra(catfile, output_filename=output_hdf5)
    constructed = hdf5.open(sed_catalog)
    comparison = hdf5.open(os.path.join(TEST_DATA_DIR, 'output_spec_from_multiple_filter.hdf5'))
    for key in comparison:
        assert key in constructed.keys()
        assert all(comparison[key]["wavelengths"].value == constructed[key]["wavelengths"].value)
        assert all(comparison[key]["fluxes"].value == constructed[key]["fluxes"].value)
        assert comparison[key]["wavelengths"].unit == constructed[key]["wavelengths"].unit
        assert comparison[key]["fluxes"].unit == constructed[key]["fluxes"].unit

    cat_base = catfile.split('.')[0]
    outbase = cat_base + '_with_flambda.cat'
    flambda_output_catalog = os.path.join(TEST_DATA_DIR, outbase)
    os.remove(flambda_output_catalog)
    os.remove(sed_catalog)
    def create(self):
        """MAIN FUNCTION"""

        # Get parameters necessary to create the TSO data
        orig_parameters = self.get_param_info()
        subarray_table = utils.read_subarray_definition_file(
            orig_parameters['Reffiles']['subarray_defs'])
        orig_parameters = utils.get_subarray_info(orig_parameters,
                                                  subarray_table)
        orig_parameters = utils.read_pattern_check(orig_parameters)

        self.basename = os.path.join(
            orig_parameters['Output']['directory'],
            orig_parameters['Output']['file'][0:-5].split('/')[-1])

        # Determine file splitting information. First get some basic info
        # on the exposure
        self.numints = orig_parameters['Readout']['nint']
        self.numgroups = orig_parameters['Readout']['ngroup']
        self.numframes = orig_parameters['Readout']['nframe']
        self.numskips = orig_parameters['Readout']['nskip']
        self.namps = orig_parameters['Readout']['namp']
        self.numresets = orig_parameters['Readout']['resets_bet_ints']
        self.frames_per_group, self.frames_per_int, self.total_frames = utils.get_frame_count_info(
            self.numints, self.numgroups, self.numframes, self.numskips,
            self.numresets)

        # Get gain map for later unit conversion
        #gainfile = orig_parameters['Reffiles']['gain']
        #gain, gainheader = file_io.read_gain_file(gainfile)

        # Make 2 copies of the input parameter file, separating the TSO
        # source from the other sources
        self.split_param_file(orig_parameters)

        print('Splitting background and TSO source into multiple yaml files')
        print('Running background sources through catalog_seed_image')
        print('background param file is:', self.background_paramfile)

        # Run the catalog_seed_generator on the non-TSO (background) sources
        background_direct = catalog_seed_image.Catalog_seed()
        background_direct.paramfile = self.background_paramfile
        background_direct.make_seed()
        background_segmentation_map = background_direct.seed_segmap

        # Stellar spectrum hdf5 file will be required, so no need to create one here.
        # Create hdf5 file with spectra of all sources if requested
        self.final_SED_file = spectra_from_catalog.make_all_spectra(
            self.catalog_files,
            input_spectra_file=self.SED_file,
            extrapolate_SED=self.extrapolate_SED,
            output_filename=self.final_SED_file,
            normalizing_mag_column=self.SED_normalizing_catalog_column)

        bkgd_waves, bkgd_fluxes = backgrounds.nircam_background_spectrum(
            orig_parameters, self.detector, self.module)

        # Run the disperser on the background sources. Add the background
        # signal here as well
        print('\n\nDispersing background sources\n\n')

        background_done = False
        background_seed_files = [
            background_direct.ptsrc_seed_filename,
            background_direct.galaxy_seed_filename,
            background_direct.extended_seed_filename
        ]
        for seed_file in background_seed_files:
            if seed_file is not None:
                print("Dispersing seed image:", seed_file)
                disp = self.run_disperser(seed_file,
                                          orders=self.orders,
                                          add_background=not background_done,
                                          background_waves=bkgd_waves,
                                          background_fluxes=bkgd_fluxes,
                                          finalize=True)
                if not background_done:
                    # Background is added at the first opportunity. At this
                    # point, create an array to hold the final combined
                    # dispersed background
                    background_done = True
                    background_dispersed = copy.deepcopy(disp.final)
                else:
                    background_dispersed += disp.final

        # Run the catalog_seed_generator on the TSO source
        tso_direct = catalog_seed_image.Catalog_seed()
        tso_direct.paramfile = self.tso_paramfile
        tso_direct.make_seed()
        tso_segmentation_map = tso_direct.seed_segmap
        outside_tso_source = tso_segmentation_map == 0
        tso_segmentation_map[outside_tso_source] = background_segmentation_map[
            outside_tso_source]

        # Dimensions are (y, x)
        self.seed_dimensions = tso_direct.nominal_dims

        # Read in the transmission spectrum that goes with the TSO source
        tso_params = utils.read_yaml(self.tso_paramfile)
        tso_catalog_file = tso_params['simSignals']['tso_grism_catalog']

        tso_catalog = ascii.read(tso_catalog_file)

        transmission_file = tso_catalog['Transmission_spectrum'].data
        transmission_spectrum = ascii.read(transmission_file[0])

        # Calculate the total exposure time, including resets, to check
        # against the times provided in the catalog file.
        total_exposure_time = self.calculate_exposure_time() * u.second

        # Check to be sure the start and end times provided in the catalog
        # are enough to cover the length of the exposure.
        tso_catalog = self.tso_catalog_check(tso_catalog, total_exposure_time)

        # Use batman to create lightcurves from the transmission spectrum
        lightcurves, times = self.make_lightcurves(tso_catalog, self.frametime,
                                                   transmission_spectrum)

        # Determine which frames of the exposure will take place with the unaltered stellar
        # spectrum. This will be all frames where the associated lightcurve is 1.0 everywhere.
        transit_frames, unaltered_frames = self.find_transit_frames(
            lightcurves)
        print('Frame numbers containing the transit: {} - {}'.format(
            np.min(transit_frames), np.max(transit_frames)))

        # Run the disperser using the original, unaltered stellar spectrum. Set 'cache=True'
        print('\n\nDispersing TSO source\n\n')
        grism_seed_object = self.run_disperser(tso_direct.seed_file,
                                               orders=self.orders,
                                               add_background=False,
                                               cache=True,
                                               finalize=True)

        # Crop dispersed seed images to correct final subarray size
        #no_transit_signal = grism_seed_object.final
        no_transit_signal = utils.crop_to_subarray(grism_seed_object.final,
                                                   tso_direct.subarray_bounds)
        background_dispersed = utils.crop_to_subarray(
            background_dispersed, tso_direct.subarray_bounds)

        # Mulitp[ly the dispersed seed images by the flat field
        no_transit_signal *= tso_direct.flatfield
        background_dispersed *= tso_direct.flatfield

        # Save the dispersed seed images if requested
        if self.save_dispersed_seed:
            h_back = fits.PrimaryHDU(background_dispersed)
            h_back.header['EXTNAME'] = 'BACKGROUND_SOURCES'
            h_tso = fits.ImageHDU(grism_seed_object.final)
            h_tso.header['EXTNAME'] = 'TSO_SOURCE'
            hlist = fits.HDUList([h_back, h_tso])
            disp_filename = '{}_dispersed_seed_images.fits'.format(
                self.basename)
            hlist.writeto(disp_filename, overwrite=True)
            print(
                '\nDispersed seed images (background sources and TSO source) saved to {}.\n\n'
                .format(disp_filename))

        # Calculate file splitting info
        self.file_splitting()

        # Prepare for creating output files
        segment_file_dir = orig_parameters['Output']['directory']
        if orig_parameters['Readout']['pupil'][0].upper() == 'F':
            usefilt = 'pupil'
        else:
            usefilt = 'filter'
        segment_file_base = orig_parameters['Output']['file'].replace(
            '.fits', '_')
        segment_file_base = '{}_{}_'.format(
            segment_file_base, orig_parameters['Readout'][usefilt])
        segment_file_base = os.path.join(segment_file_dir, segment_file_base)

        # Loop over frames and integrations up to the size of the segment
        # file.
        ints_per_segment = self.int_segment_indexes[
            1:] - self.int_segment_indexes[:-1]
        groups_per_segment = self.grp_segment_indexes[
            1:] - self.grp_segment_indexes[:-1]

        total_frame_counter = 0
        previous_segment = 1
        segment_part_number = 0
        segment_starting_int_number = 0

        self.segment_part_number = 0
        self.segment_ints = 0
        self.segment_frames = 0
        self.segment_frame_start_number = 0
        self.segment_int_start_number = 0
        self.part_int_start_number = 0
        self.part_frame_start_number = 0

        # Get split files' metadata
        split_meta = SplitFileMetaData(self.int_segment_indexes,
                                       self.grp_segment_indexes,
                                       self.file_segment_indexes,
                                       self.group_segment_indexes_g,
                                       self.frames_per_int,
                                       self.frames_per_group, self.frametime)

        # List of all output seed files
        self.seed_files = []

        counter = 0
        for i, int_dim in enumerate(ints_per_segment):
            int_start = self.int_segment_indexes[i]
            int_end = self.int_segment_indexes[i + 1]
            for j, grp_dim in enumerate(groups_per_segment):
                initial_frame = self.grp_segment_indexes[j]
                # int_dim and grp_dim are the number of integrations and
                # groups in the current segment PART
                print(
                    "\n\nCurrent segment part contains: {} integrations and {} groups."
                    .format(int_dim, grp_dim))
                print("Creating frame by frame dispersed signal")
                segment_seed = np.zeros(
                    (int_dim, grp_dim, self.seed_dimensions[0],
                     self.seed_dimensions[1]))

                for integ in np.arange(int_dim):
                    overall_integration_number = int_start + integ
                    previous_frame = np.zeros(self.seed_dimensions)

                    for frame in np.arange(grp_dim):
                        #print('TOTAL FRAME COUNTER: ', total_frame_counter)
                        #print('integ and frame: ', integ, frame)
                        # If a frame is from the part of the lightcurve
                        # with no transit, then the signal in the frame
                        # comes from no_transit_signal
                        if total_frame_counter in unaltered_frames:
                            frame_only_signal = (
                                background_dispersed +
                                no_transit_signal) * self.frametime
                        # If the frame is from a part of the lightcurve
                        # where the transit is happening, then call the
                        # cached disperser with the appropriate lightcurve
                        elif total_frame_counter in transit_frames:
                            #print("{} is during the transit".format(total_frame_counter))
                            frame_transmission = lightcurves[
                                total_frame_counter, :]
                            trans_interp = interp1d(
                                transmission_spectrum['Wavelength'],
                                frame_transmission)

                            for order in self.orders:
                                grism_seed_object.this_one[
                                    order].disperse_all_from_cache(
                                        trans_interp)
                            # Here is where we call finalize on the TSO object
                            # This will update grism_seed_object.final to
                            # contain the correct signal
                            grism_seed_object.finalize(Back=None,
                                                       BackLevel=None)
                            cropped_grism_seed_object = utils.crop_to_subarray(
                                grism_seed_object.final,
                                tso_direct.subarray_bounds)
                            frame_only_signal = (
                                background_dispersed +
                                cropped_grism_seed_object) * self.frametime

                        # Now add the signal from this frame to that in the
                        # previous frame in order to arrive at the total
                        # cumulative signal
                        segment_seed[
                            integ,
                            frame, :, :] = previous_frame + frame_only_signal
                        previous_frame = copy.deepcopy(
                            segment_seed[integ, frame, :, :])
                        total_frame_counter += 1

                    # At the end of each integration, increment the
                    # total_frame_counter by the number of resets between
                    # integrations
                    total_frame_counter += self.numresets

                # Use the split files' metadata
                self.segment_number = split_meta.segment_number[counter]
                self.segment_ints = split_meta.segment_ints[counter]
                self.segment_frames = split_meta.segment_frames[counter]
                self.segment_part_number = split_meta.segment_part_number[
                    counter]
                self.segment_frame_start_number = split_meta.segment_frame_start_number[
                    counter]
                self.segment_int_start_number = split_meta.segment_int_start_number[
                    counter]
                self.part_int_start_number = split_meta.part_int_start_number[
                    counter]
                self.part_frame_start_number = split_meta.part_frame_start_number[
                    counter]
                counter += 1

                print('Overall integration number: ',
                      overall_integration_number)
                segment_file_name = '{}seg{}_part{}_seed_image.fits'.format(
                    segment_file_base,
                    str(self.segment_number).zfill(3),
                    str(self.segment_part_number).zfill(3))

                print('Segment int and frame start numbers: {} {}'.format(
                    self.segment_int_start_number,
                    self.segment_frame_start_number))
                #print('Part int and frame start numbers (ints and frames within the segment): {} {}'.format(self.part_int_start_number, self.part_frame_start_number))

                # Disperser output is always full frame. Crop to the
                # requested subarray if necessary
                if orig_parameters['Readout'][
                        'array_name'] not in self.fullframe_apertures:
                    print("Dispersed seed image size: {}".format(
                        segment_seed.shape))
                    segment_seed = utils.crop_to_subarray(
                        segment_seed, tso_direct.subarray_bounds)
                    #gain = utils.crop_to_subarray(gain, tso_direct.subarray_bounds)

                # Segmentation map will be centered in a frame that is larger
                # than full frame by a factor of sqrt(2), so crop appropriately
                print('Cropping segmentation map to appropriate aperture')
                segy, segx = tso_segmentation_map.shape
                dx = int((segx - tso_direct.nominal_dims[1]) / 2)
                dy = int((segy - tso_direct.nominal_dims[0]) / 2)
                segbounds = [
                    tso_direct.subarray_bounds[0] + dx,
                    tso_direct.subarray_bounds[1] + dy,
                    tso_direct.subarray_bounds[2] + dx,
                    tso_direct.subarray_bounds[3] + dy
                ]
                tso_segmentation_map = utils.crop_to_subarray(
                    tso_segmentation_map, segbounds)

                # Convert seed image to ADU/sec to be consistent
                # with other simulator outputs
                gain = MEAN_GAIN_VALUES['nircam']['lw{}'.format(
                    self.module.lower())]
                segment_seed /= gain

                # Update seed image header to reflect the
                # division by the gain
                tso_direct.seedinfo['units'] = 'ADU/sec'

                # Save the seed image. Save in units of ADU/sec
                print('Saving seed image')
                tso_seed_header = fits.getheader(tso_direct.seed_file)
                self.save_seed(segment_seed, tso_segmentation_map,
                               tso_seed_header, orig_parameters)  #,
                #segment_number, segment_part_number)

        # Prepare dark current exposure if
        # needed.
        print('Running dark prep')
        d = dark_prep.DarkPrep()
        d.paramfile = self.paramfile
        d.prepare()

        # Combine into final observation
        print('Running observation generator')
        obs = obs_generator.Observation()
        obs.linDark = d.dark_files
        obs.seed = self.seed_files
        obs.segmap = tso_segmentation_map
        obs.seedheader = tso_direct.seedinfo
        obs.paramfile = self.paramfile
        obs.create()