Esempio n. 1
0
    def _on_select(self, xmin, xmax):
        wvln = self._ui['textbox'].text().strip()

        if wvln == '':
            self._ui['textbox'].setText('Error: please enter a wavelength value before selecting')
            return

        wave_val = float(wvln)

        # if line_list specified, find closest line from list:
        if self.line_list is not None:
            absdiff = np.abs(self.line_list - wave_val)
            idx = absdiff.argmin()
            if absdiff[idx] > 1.:
                logger.error("Couldn't find precise line corresponding to "
                             "input {:.3f}".format(wave_val))
                return

            logger.info("Snapping input wavelength {:.3f} to line list "
                        "value {:.3f}".format(wave_val, self.line_list[idx]))
            wave_val = self.line_list[idx]
            self._done_wavel_idx.append(idx)

        # line_props, line_cov = self.get_line_props(xmin, xmax)
        line_props,_ = self.get_line_props(xmin, xmax)
        if line_props is None:
            return

        self.draw_line_marker(line_props, wave_val, xmin, xmax)

        self.fig.suptitle('')
        plt.draw()
        self.fig.canvas.draw()

        self._map_dict['wavel'].append(wave_val)
        self._map_dict['pixel'].append(line_props['x0'])
def main(db_path,
         run_name,
         data_root_path=None,
         filename=None,
         overwrite=False,
         pool=None):

    if pool is None:
        pool = schwimmbad.SerialPool()

    # connect to the database
    engine = db_connect(db_path)
    # engine.echo = True
    logger.debug("Connected to database at '{}'".format(db_path))

    # create a new session for interacting with the database
    session = Session()

    root_path, _ = path.split(db_path)
    if data_root_path is None:
        data_root_path = root_path

    plot_path = path.join(root_path, 'plots', run_name)
    if not path.exists(plot_path):
        os.makedirs(plot_path, exist_ok=True)

    # TODO: there might be some bugs here...
    n_lines = session.query(SpectralLineInfo).count()
    Halpha = session.query(SpectralLineInfo)\
                    .filter(SpectralLineInfo.name == 'Halpha').one()
    OI_lines = session.query(SpectralLineInfo)\
                      .filter(SpectralLineInfo.name.contains('[OI]')).all()

    if filename is None:  # grab all unfinished sources
        observations = session.query(Observation).join(Run)\
                              .filter(Run.name == run_name).all()

    else:  # only process the observation corresponding to this filename
        observations = session.query(Observation).join(Run)\
                              .filter(Run.name == run_name)\
                              .filter(Observation.filename_raw == filename).all()

    for obs in observations:
        measurements = session.query(SpectralLineMeasurement)\
                              .join(Observation)\
                              .filter(Observation.id == obs.id).all()

        if len(measurements) == n_lines and not overwrite:
            logger.debug('All line measurements already complete for object '
                         '{0} in file {1}'.format(obs.object,
                                                  obs.filename_raw))
            continue

        # Read the spectrum data and get wavelength solution
        filebase, _ = path.splitext(obs.filename_1d)
        filename_1d = obs.path_1d(data_root_path)
        spec = Table.read(filename_1d)
        logger.debug('Loaded 1D spectrum for object {0} from file {1}'.format(
            obs.object, filename_1d))

        # Extract region around Halpha
        x, (flux, ivar) = extract_region(
            spec['wavelength'],
            center=Halpha.wavelength.value,
            width=100,
            arrs=[spec['source_flux'], spec['source_ivar']])

        # We start by doing maximum likelihood estimation to fit the line, then
        # use the best-fit parameters to initialize an MCMC run.
        # TODO: need to figure out if it's emission or absorption...for now just
        #   assume absorption
        absorp_emiss = -1.
        lf = VoigtLineFitter(x, flux, ivar, absorp_emiss=absorp_emiss)
        lf.fit()
        fit_pars = lf.get_gp_mean_pars()

        if (not lf.success
                or abs(fit_pars['x0'] - Halpha.wavelength.value) > 16.
                or  # 16 Å = ~700 km/s
                abs(fit_pars['amp']) < 10):  # minimum amplitude - MAGIC NUMBER
            # TODO: should try again with emission line
            logger.error('absorption line has tiny amplitude! did '
                         'auto-determination of absorption/emission fail?')
            # TODO: what now?
            continue

        fig = lf.plot_fit()
        fig.savefig(path.join(plot_path, '{}_maxlike.png'.format(filebase)),
                    dpi=256)
        plt.close(fig)

        # ----------------------------------------------------------------------

        # Run `emcee` instead to sample over GP model parameters:
        if fit_pars['std_G'] < 1E-2:
            lf.gp.freeze_parameter('mean:ln_std_G')

        initial = np.array(lf.gp.get_parameter_vector())
        if initial[4] < -10:  # TODO: ???
            initial[4] = -8.
        if initial[5] < -10:  # TODO: ???
            initial[5] = -8.
        ndim, nwalkers = len(initial), 64

        sampler = emcee.EnsembleSampler(nwalkers,
                                        ndim,
                                        log_probability,
                                        pool=pool,
                                        args=(lf.gp, flux))

        logger.debug("Running burn-in...")
        p0 = initial + 1e-6 * np.random.randn(nwalkers, ndim)
        p0, lp, _ = sampler.run_mcmc(p0, 128)

        logger.debug("Running 2nd burn-in...")
        sampler.reset()
        p0 = p0[lp.argmax()] + 1e-3 * np.random.randn(nwalkers, ndim)
        p0, lp, _ = sampler.run_mcmc(p0, 512)

        logger.debug("Running production...")
        sampler.reset()
        pos, lp, _ = sampler.run_mcmc(p0, 1024)

        fit_kw = dict()
        for i, par_name in enumerate(lf.gp.get_parameter_names()):
            if 'kernel' in par_name: continue

            # remove 'mean:'
            par_name = par_name[5:]

            # skip bg
            if par_name.startswith('bg'): continue

            samples = sampler.flatchain[:, i]

            if par_name.startswith('ln_'):
                par_name = par_name[3:]
                samples = np.exp(samples)

            MAD = np.median(np.abs(samples - np.median(samples)))
            fit_kw[par_name] = np.median(samples)
            fit_kw[par_name + '_error'] = 1.5 * MAD  # convert to ~stddev

        # remove all previous line measurements
        q = session.query(SpectralLineMeasurement).join(Observation)\
                   .filter(Observation.id == obs.id)
        if q.count() > 0:
            for meas in q.all():
                session.delete(meas)
            session.commit()

        slm = SpectralLineMeasurement(**fit_kw)
        slm.info = Halpha
        slm.observation = obs
        session.add(slm)
        session.commit()

        # --------------------------------------------------------------------
        # plot MCMC traces
        fig, axes = plt.subplots(2, 4, figsize=(18, 6))
        for i in range(sampler.dim):
            for walker in sampler.chain[..., i]:
                axes.flat[i].plot(walker,
                                  marker='',
                                  drawstyle='steps-mid',
                                  alpha=0.2)
            axes.flat[i].set_title(lf.gp.get_parameter_names()[i], fontsize=12)
        fig.tight_layout()
        fig.savefig(path.join(plot_path, '{}_mcmc_trace.png'.format(filebase)),
                    dpi=256)
        plt.close(fig)
        # --------------------------------------------------------------------

        # --------------------------------------------------------------------
        # plot samples
        fig, axes = plt.subplots(3, 1, figsize=(10, 10), sharex=True)

        samples = sampler.flatchain
        for s in samples[np.random.randint(len(samples), size=32)]:
            lf.gp.set_parameter_vector(s)
            lf.plot_fit(axes=axes, fit_alpha=0.2)

        fig.tight_layout()
        fig.savefig(path.join(plot_path, '{}_mcmc_fits.png'.format(filebase)),
                    dpi=256)
        plt.close(fig)
        # --------------------------------------------------------------------

        # --------------------------------------------------------------------
        # corner plot
        fig = corner.corner(
            sampler.flatchain[::10, :],
            labels=[x.split(':')[1] for x in lf.gp.get_parameter_names()])
        fig.savefig(path.join(plot_path, '{}_corner.png'.format(filebase)),
                    dpi=256)
        plt.close(fig)
        # --------------------------------------------------------------------

        # compute centroids for sky lines
        sky_centroids = []
        for j, sky_line in enumerate(OI_lines):
            wvln = sky_line.wavelength.value
            x, (flux, ivar) = extract_region(
                spec['wavelength'],
                center=wvln,
                width=32.,  # angstroms
                arrs=[spec['background_flux'], spec['background_ivar']])

            lf = GaussianLineFitter(x, flux, ivar,
                                    absorp_emiss=1.)  # all emission lines

            try:
                lf.fit()
                fit_pars = lf.get_gp_mean_pars()

            except Exception as e:
                logger.warn("Failed to fit sky line {0}:\n{1}".format(
                    sky_line, e))
                lf.success = False
                fit_pars = lf.get_init()
                # OMG this is the biggest effing hack
                fit_pars['amp'] = 0.
                fit_pars['bg_coef'] = None
                fit_pars['x0'] = 0.

            # HACK: hackish signal-to-noise
            max_ = fit_pars['amp'] / np.sqrt(2 * np.pi * fit_pars['std']**2)
            SNR = max_ / np.median(1 / np.sqrt(ivar))

            if (not lf.success or abs(fit_pars['x0'] - wvln) > 4
                    or fit_pars['amp'] < 10 or fit_pars['std'] > 4
                    or SNR < 2.5):
                # failed
                x0 = np.nan * u.angstrom
                title = 'f****d'
                fit_pars['amp'] = 0.

            else:
                x0 = fit_pars['x0'] * u.angstrom
                title = '{:.2f}'.format(fit_pars['amp'])

            if lf.success:
                fig = lf.plot_fit()
                fig.suptitle(title, y=0.95)
                fig.subplots_adjust(top=0.8)
                fig.savefig(path.join(
                    plot_path,
                    '{}_maxlike_sky_{:.0f}.png'.format(filebase, wvln)),
                            dpi=256)
                plt.close(fig)

            # store the sky line measurements
            fit_pars['std_G'] = fit_pars.pop('std')  # HACK
            fit_pars.pop('bg_coef')  # HACK
            slm = SpectralLineMeasurement(**fit_pars)
            slm.info = sky_line
            slm.observation = obs
            session.add(slm)
            session.commit()

            sky_centroids.append(x0)
        sky_centroids = u.Quantity(sky_centroids)

        logger.info('{} [{}]: x0={x0:.3f} σ={err:.3f}\n--------'.format(
            obs.object, filebase, x0=fit_kw['x0'], err=fit_kw['x0_error']))

        session.commit()

    pool.close()
Esempio n. 3
0
def main(night_path, skip_list_file, mask_file, overwrite=False, plot=False):
    """
    See argparse block at bottom of script for description of parameters.
    """

    night_path = path.realpath(path.expanduser(night_path))
    if not path.exists(night_path):
        raise IOError("Path '{}' doesn't exist".format(night_path))
    logger.info("Reading data from path: {}".format(night_path))

    base_path, night_name = path.split(night_path)
    data_path, run_name = path.split(base_path)
    output_path = path.realpath(
        path.join(data_path, 'processed', run_name, night_name))
    os.makedirs(output_path, exist_ok=True)
    logger.info("Saving processed files to path: {}".format(output_path))

    if plot:  # if we're making plots
        plot_path = path.realpath(path.join(output_path, 'plots'))
        logger.debug("Will make and save plots to: {}".format(plot_path))
        os.makedirs(plot_path, exist_ok=True)
    else:
        plot_path = None

    # check for files to skip (e.g., saturated or errored exposures)
    if skip_list_file is not None:  # a file containing a list of filenames to skip
        with open(skip_list_file, 'r') as f:
            skip_list = [x.strip() for x in f if x.strip()]
    else:
        skip_list = None

    # look for pixel mask file
    if mask_file is not None:
        with open(
                mask_file, 'r'
        ) as f:  # load YAML file specifying pixel masks for nearby sources
            pixel_mask_spec = yaml.load(f.read())
    else:
        pixel_mask_spec = None

    # generate the raw image file collection to process
    ic = GlobImageFileCollection(night_path, skip_filenames=skip_list)
    logger.info("Frames to process:")
    logger.info("- Bias frames: {}".format(
        len(ic.files_filtered(imagetyp='BIAS'))))
    logger.info("- Flat frames: {}".format(
        len(ic.files_filtered(imagetyp='FLAT'))))
    logger.info("- Comparison lamp frames: {}".format(
        len(ic.files_filtered(imagetyp='COMP'))))
    logger.info("- Object frames: {}".format(
        len(ic.files_filtered(imagetyp='OBJECT'))))

    # HACK:
    ic = GlobImageFileCollection(night_path, skip_filenames=skip_list)

    # ============================
    # Create the master bias frame
    # ============================

    # overscan region of the CCD, using FITS index notation
    oscan_fits_section = "[{}:{},:]".format(oscan_idx, oscan_idx + oscan_size)

    master_bias_file = path.join(output_path, 'master_bias.fits')

    if not os.path.exists(master_bias_file) or overwrite:
        # get list of overscan-subtracted bias frames as 2D image arrays
        bias_list = []
        for hdu, fname in ic.hdus(return_fname=True, imagetyp='BIAS'):
            logger.debug('Processing Bias frame: {0}'.format(fname))
            ccd = CCDData.read(path.join(ic.location, fname), unit='adu')
            ccd = ccdproc.gain_correct(ccd, gain=ccd_gain)
            ccd = ccdproc.subtract_overscan(ccd, overscan=ccd[:, oscan_idx:])
            ccd = ccdproc.trim_image(ccd,
                                     fits_section="[1:{},:]".format(oscan_idx))
            bias_list.append(ccd)

        # combine all bias frames into a master bias frame
        logger.info("Creating master bias frame")
        master_bias = ccdproc.combine(bias_list,
                                      method='average',
                                      clip_extrema=True,
                                      nlow=1,
                                      nhigh=1,
                                      error=True)
        master_bias.write(master_bias_file, overwrite=True)

    else:
        logger.info("Master bias frame file already exists: {}".format(
            master_bias_file))
        master_bias = CCDData.read(master_bias_file)

    if plot:
        # TODO: this assumes vertical CCD
        assert master_bias.shape[0] > master_bias.shape[1]
        aspect_ratio = master_bias.shape[1] / master_bias.shape[0]

        fig, ax = plt.subplots(1, 1, figsize=(10, 12 * aspect_ratio))
        vmin, vmax = zscaler.get_limits(master_bias.data)
        cs = ax.imshow(master_bias.data.T,
                       origin='bottom',
                       cmap=cmap,
                       vmin=max(0, vmin),
                       vmax=vmax)
        ax.set_title('master bias frame [zscale]')

        fig.colorbar(cs)
        fig.tight_layout()
        fig.savefig(path.join(plot_path, 'master_bias.png'))
        plt.close(fig)

    # ============================
    # Create the master flat field
    # ============================
    # HACK:
    ic = GlobImageFileCollection(night_path, skip_filenames=skip_list)

    master_flat_file = path.join(output_path, 'master_flat.fits')

    if not os.path.exists(master_flat_file) or overwrite:
        # create a list of flat frames
        flat_list = []
        for hdu, fname in ic.hdus(return_fname=True, imagetyp='FLAT'):
            logger.debug('Processing Flat frame: {0}'.format(fname))
            ccd = CCDData.read(path.join(ic.location, fname), unit='adu')
            ccd = ccdproc.gain_correct(ccd, gain=ccd_gain)
            ccd = ccdproc.ccd_process(ccd,
                                      oscan=oscan_fits_section,
                                      trim="[1:{},:]".format(oscan_idx),
                                      master_bias=master_bias)
            flat_list.append(ccd)

        # combine into a single master flat - use 3*sigma sigma-clipping
        logger.info("Creating master flat frame")
        master_flat = ccdproc.combine(flat_list,
                                      method='average',
                                      sigma_clip=True,
                                      low_thresh=3,
                                      high_thresh=3)
        master_flat.write(master_flat_file, overwrite=True)

        # TODO: make plot if requested?

    else:
        logger.info("Master flat frame file already exists: {}".format(
            master_flat_file))
        master_flat = CCDData.read(master_flat_file)

    if plot:
        # TODO: this assumes vertical CCD
        assert master_flat.shape[0] > master_flat.shape[1]
        aspect_ratio = master_flat.shape[1] / master_flat.shape[0]

        fig, ax = plt.subplots(1, 1, figsize=(10, 12 * aspect_ratio))
        vmin, vmax = zscaler.get_limits(master_flat.data)
        cs = ax.imshow(master_flat.data.T,
                       origin='bottom',
                       cmap=cmap,
                       vmin=max(0, vmin),
                       vmax=vmax)
        ax.set_title('master flat frame [zscale]')

        fig.colorbar(cs)
        fig.tight_layout()
        fig.savefig(path.join(plot_path, 'master_flat.png'))
        plt.close(fig)

    # =====================
    # Process object frames
    # =====================
    # HACK:
    ic = GlobImageFileCollection(night_path, skip_filenames=skip_list)

    logger.info("Beginning object frame processing...")
    for hdu, fname in ic.hdus(return_fname=True, imagetyp='OBJECT'):
        new_fname = path.join(output_path, 'p_{}'.format(fname))

        # -------------------------------------------
        # First do the simple processing of the frame
        # -------------------------------------------

        logger.debug("Processing '{}' [{}]".format(hdu.header['OBJECT'],
                                                   fname))
        if path.exists(new_fname) and not overwrite:
            logger.log(1, "\tAlready processed! {}".format(new_fname))
            ext = SourceCCDExtractor(filename=path.join(
                ic.location, new_fname),
                                     plot_path=plot_path,
                                     zscaler=zscaler,
                                     cmap=cmap,
                                     **ccd_props)
            nccd = ext.ccd

            # HACK: F**K this is a bad hack
            ext._filename_base = ext._filename_base[2:]

        else:
            # process the frame!
            ext = SourceCCDExtractor(filename=path.join(ic.location, fname),
                                     plot_path=plot_path,
                                     zscaler=zscaler,
                                     cmap=cmap,
                                     unit='adu',
                                     **ccd_props)

            _pix_mask = pixel_mask_spec.get(
                fname, None) if pixel_mask_spec is not None else None
            nccd = ext.process_raw_frame(pixel_mask_spec=_pix_mask,
                                         master_bias=master_bias,
                                         master_flat=master_flat)
            nccd.write(new_fname, overwrite=overwrite)

        # -------------------------------------------
        # Now do the 1D extraction
        # -------------------------------------------

        fname_1d = path.join(output_path, '1d_{0}'.format(fname))
        if path.exists(fname_1d) and not overwrite:
            logger.log(1, "\tAlready extracted! {}".format(fname_1d))
            continue

        else:
            logger.debug("\tExtracting to 1D")

            # first step is to fit a voigt profile to a middle-ish row to determine LSF
            lsf_p = ext.get_lsf_pars()  # MAGIC NUMBER

            try:
                tbl = ext.extract_1d(lsf_p)
            except Exception as e:
                logger.error('Failed! {}: {}'.format(e.__class__.__name__,
                                                     str(e)))
                continue

            hdu0 = fits.PrimaryHDU(header=nccd.header)
            hdu1 = fits.table_to_hdu(tbl)
            hdulist = fits.HDUList([hdu0, hdu1])

            hdulist.writeto(fname_1d, overwrite=overwrite)

        del ext

    # ==============================
    # Process comparison lamp frames
    # ==============================
    # HACK:
    ic = GlobImageFileCollection(night_path, skip_filenames=skip_list)

    logger.info("Beginning comp. lamp frame processing...")
    for hdu, fname in ic.hdus(return_fname=True, imagetyp='COMP'):
        new_fname = path.join(output_path, 'p_{}'.format(fname))

        logger.debug("\tProcessing '{}'".format(hdu.header['OBJECT']))

        if path.exists(new_fname) and not overwrite:
            logger.log(1, "\tAlready processed! {}".format(new_fname))
            ext = CompCCDExtractor(filename=path.join(ic.location, new_fname),
                                   plot_path=plot_path,
                                   zscaler=zscaler,
                                   cmap=cmap,
                                   **ccd_props)
            nccd = ext.ccd

            # HACK: F**K this is a bad hack
            ext._filename_base = ext._filename_base[2:]

        else:
            # process the frame!
            ext = CompCCDExtractor(filename=path.join(ic.location, fname),
                                   plot_path=plot_path,
                                   unit='adu',
                                   **ccd_props)

            _pix_mask = pixel_mask_spec.get(
                fname, None) if pixel_mask_spec is not None else None
            nccd = ext.process_raw_frame(
                pixel_mask_spec=_pix_mask,
                master_bias=master_bias,
                master_flat=master_flat,
            )
            nccd.write(new_fname, overwrite=overwrite)

        # -------------------------------------------
        # Now do the 1D extraction
        # -------------------------------------------

        fname_1d = path.join(output_path, '1d_{0}'.format(fname))
        if path.exists(fname_1d) and not overwrite:
            logger.log(1, "\tAlready extracted! {}".format(fname_1d))
            continue

        else:
            logger.debug("\tExtracting to 1D")

            try:
                tbl = ext.extract_1d()
            except Exception as e:
                logger.error('Failed! {}: {}'.format(e.__class__.__name__,
                                                     str(e)))
                continue

            hdu0 = fits.PrimaryHDU(header=nccd.header)
            hdu1 = fits.table_to_hdu(tbl)
            hdulist = fits.HDUList([hdu0, hdu1])

            hdulist.writeto(fname_1d, overwrite=overwrite)
Esempio n. 4
0
    def auto_identify(self):
        if self.line_list is None:
            raise ValueError("Can't auto-identify lines without a line list.")

        if len(self._map_dict['wavel']) < 4:
            msg = "Please identify at least 4 lines before trying auto-identify."
            logger.error(msg)
            self._ui['textbox'].setText("ERROR: {}".format(msg))
            return None

        _idx = np.argsort(self._map_dict['wavel'])
        wvln = np.array(self._map_dict['wavel'])[_idx]
        pixl = np.array(self._map_dict['pixel'])[_idx]

        # build an approximate wavelength solution to predict where lines are
        spl = InterpolatedUnivariateSpline(wvln, pixl, k=1) # use linear interp.

        predicted_pixels = spl(self.line_list)

        new_wavels = []
        new_pixels = []

        # from Wikipedia: https://en.wikipedia.org/wiki/Voigt_profile
        fG = 2*self._line_std_G*np.sqrt(2*np.log(2))
        fL = 2*self._line_hwhm_L
        lw = 0.5346*fL + np.sqrt(0.2166*fL**2 + fG**2)
        for pix_ctr,xmin,xmax,wave_idx,wave in zip(predicted_pixels,
                                                   predicted_pixels-5*lw,
                                                   predicted_pixels+5*lw,
                                                   range(len(self.line_list)),
                                                   self.line_list):

            if pix_ctr < 200 or pix_ctr > 1600: # skip if outside good rows
                continue

            elif wave_idx in self._done_wavel_idx: # skip if already fit
                continue

            logger.debug("Fitting line at predicted pix={:.2f}, λ={:.2f}"
                         .format(pix_ctr, wave))
            try:
                lp,gp = self.get_line_props(xmin, xmax,
                                            std_G0=self._line_std_G,
                                            hwhm_L0=self._line_hwhm_L)
            except Exception as e:
                logger.error("Failed to auto-fit line at {} ({msg})"
                             .format(wave, msg=str(e)))
                continue

            print(lp['amp'], lp['x0'])
            if lp is None or lp['amp'] < 100.: # HACK
                continue

            # figure out closest line
            # _all_pix = np.concatenate((self._map_dict['pixel'], new_pixels))
            # _all_wav = np.concatenate((self._map_dict['wavel'], new_wavels))
            # _diff = np.abs(lp['x0'] - np.array(_all_pix))
            # min_diff_idx = np.argmin(_diff)
            # min_diff_pix = _all_pix[min_diff_idx]
            # min_diff_wav = _all_wav[min_diff_idx]

            # if _diff[min_diff_idx] < 3.:
            #     logger.error("Fit line is too close to another at pix={:.2f}, λ={:.2f}"
            #                  .format(min_diff_pix, min_diff_wav))
            #     continue

            self.draw_line_marker(lp, wave, xmin, xmax, gp=gp)
            new_wavels.append(wave)
            new_pixels.append(pix_ctr)
            self._done_wavel_idx.append(wave_idx)

        self.fig.canvas.draw()

        _idx = np.argsort(new_wavels)
        self._map_dict['wavel'] = np.array(new_wavels)[_idx]
        self._map_dict['pixel'] = np.array(new_pixels)[_idx]