Esempio n. 1
0
def rectwv_coeff_to_ds9(rectwv_coeff,
                        limits=None,
                        rectified=False,
                        numpix=100):
    """Generate ds9 region output with requested slitlet limits

    Parameters
    ----------
    rectwv_coeff : RectWaveCoeff instance
        Rectification and wavelength calibration coefficients for the
        particular CSU configuration.
    limits : str
        Region to be saved: the only two possibilities are 'boundaries'
        and 'frontiers'.
    rectified : bool
        If True, the regions correspond to the rectified image.
    numpix : int
        Number of points in which the X-range interval is subdivided
        in order to save each boundary as a connected set of line
        segments. This option is only relevant when rectified=False.

    Returns
    -------
    output : str
        String containing the full output to be exported as a ds9 region
        file.
    """

    # protections
    if limits not in ['boundaries', 'frontiers']:
        raise ValueError('Unexpect limits=' + str(limits))

    # retrieve relevant wavelength calibration parameters
    grism_name = rectwv_coeff.tags['grism']
    filter_name = rectwv_coeff.tags['filter']
    wv_parameters = set_wv_parameters(filter_name, grism_name)
    naxis1_enlarged = wv_parameters['naxis1_enlarged']
    crpix1_enlarged = wv_parameters['crpix1_enlarged']
    crval1_enlarged = wv_parameters['crval1_enlarged']
    cdelt1_enlarged = wv_parameters['cdelt1_enlarged']

    ds9_output = '# Region file format: DS9 version 4.1\n' \
                 'global color=green dashlist=2 4 width=2 ' \
                 'font="helvetica 10 normal roman" select=1 ' \
                 'highlite=1 dash=1 fixed=0 edit=1 ' \
                 'move=1 delete=1 include=1 source=1\nphysical\n#\n'

    ds9_output += '#\n# uuid (rectwv_coeff): {0}\n'.format(rectwv_coeff.uuid)
    ds9_output += \
        '# grism...............: {0}\n'.format(rectwv_coeff.tags['grism'])
    ds9_output += \
        '# filter..............: {0}\n'.format(rectwv_coeff.tags['filter'])

    for islitlet in range(1, EMIR_NBARS + 1):
        if islitlet not in rectwv_coeff.missing_slitlets:
            dumdict = rectwv_coeff.contents[islitlet - 1]
            if islitlet % 2 == 0:
                if limits == 'frontiers':
                    colorbox = '#0000ff'  # '#ff77ff'
                else:
                    colorbox = '#ff00ff'  # '#ff77ff'
            else:
                if limits == 'frontiers':
                    colorbox = '#0000ff'  # '#4444ff'
                else:
                    colorbox = '#00ffff'  # '#4444ff'

            ds9_output += '#\n# islitlet...........: {0}\n'.format(islitlet)
            ds9_output += '# csu_bar_slit_center: {0}\n'.format(
                dumdict['csu_bar_slit_center']
            )
            if rectified:
                crpix1_linear = 1.0
                crval1_linear = dumdict['crval1_linear']
                cdelt1_linear = dumdict['cdelt1_linear']
                if limits == 'frontiers':
                    ydum_lower = dumdict['y0_frontier_lower_expected']
                    ydum_upper = dumdict['y0_frontier_upper_expected']
                else:
                    ydum_lower = dumdict['y0_reference_lower_expected']
                    ydum_upper = dumdict['y0_reference_upper_expected']
                wave_ini = crval1_linear + \
                           (0.5 - crpix1_linear) * cdelt1_linear
                xdum_ini = (wave_ini - crval1_enlarged) / cdelt1_enlarged
                xdum_ini += crpix1_enlarged
                wave_end = crval1_linear + \
                           (EMIR_NAXIS1 + 0.5 - crpix1_linear) * cdelt1_linear
                xdum_end = (wave_end - crval1_enlarged) / cdelt1_enlarged
                xdum_end += crpix1_enlarged
                for ydum in [ydum_lower, ydum_upper]:
                    ds9_output += \
                        'line {0} {1} {2} {3}'.format(
                            xdum_ini, ydum,
                            xdum_end, ydum
                        )
                    ds9_output += ' # color={0}\n'.format(colorbox)
                # slitlet label
                ydum_label = (ydum_lower + ydum_upper) / 2.0
                xdum_label = EMIR_NAXIS1 / 2 + 0.5
                wave_center = crval1_linear + \
                              (xdum_label - crpix1_linear) * cdelt1_linear
                xdum_label = (wave_center - crval1_enlarged) / cdelt1_enlarged
                xdum_label += crpix1_enlarged
                ds9_output += 'text {0} {1} {{{2}}} # color={3} ' \
                              'font="helvetica 10 bold ' \
                              'roman"\n'.format(xdum_label, ydum_label,
                                                islitlet, colorbox)
            else:
                if limits == 'frontiers':
                    pol_lower = Polynomial(
                        dumdict['frontier']['poly_coef_lower']
                    )
                    pol_upper = Polynomial(
                        dumdict['frontier']['poly_coef_upper']
                    )
                else:
                    pol_lower = Polynomial(
                        dumdict['spectrail']['poly_coef_lower']
                    )
                    pol_upper = Polynomial(
                        dumdict['spectrail']['poly_coef_upper']
                    )
                xdum = np.linspace(1, EMIR_NAXIS1, num=numpix)
                ydum = pol_lower(xdum)
                for i in range(len(xdum) - 1):
                    ds9_output += \
                        'line {0} {1} {2} {3}'.format(
                            xdum[i], ydum[i],
                            xdum[i + 1], ydum[i + 1]
                        )
                    ds9_output += ' # color={0}\n'.format(colorbox)
                ydum = pol_upper(xdum)
                for i in range(len(xdum) - 1):
                    ds9_output += \
                        'line {0} {1} {2} {3}'.format(
                            xdum[i], ydum[i],
                            xdum[i + 1], ydum[i + 1]
                        )
                    ds9_output += ' # color={0}\n'.format(colorbox)
                # slitlet label
                xdum_label = EMIR_NAXIS1 / 2 + 0.5
                ydum_lower = pol_lower(xdum_label)
                ydum_upper = pol_upper(xdum_label)
                ydum_label = (ydum_lower + ydum_upper) / 2.0
                ds9_output += 'text {0} {1} {{{2}}} # color={3} ' \
                              'font="helvetica 10 bold ' \
                              'roman"\n'.format(xdum_label, ydum_label,
                                                islitlet, colorbox)

    return ds9_output
Esempio n. 2
0
def refine_rectwv_coeff(input_image,
                        rectwv_coeff,
                        refine_wavecalib_mode,
                        minimum_slitlet_width_mm,
                        maximum_slitlet_width_mm,
                        save_intermediate_results=False,
                        debugplot=0):
    """Refine RectWaveCoeff object using a catalogue of lines

    One and only one among refine_with_oh_lines_mode and
    refine_with_arc_lines must be different from zero.

    Parameters
    ----------
    input_image : HDUList object
        Input 2D image.
    rectwv_coeff : RectWaveCoeff instance
        Rectification and wavelength calibration coefficients for the
        particular CSU configuration.
    refine_wavecalib_mode : int
        Integer, indicating the type of refinement:
        0 : no refinement
        1 : apply the same global offset to all the slitlets (using ARC lines)
        2 : apply individual offset to each slitlet (using ARC lines)
        11 : apply the same global offset to all the slitlets (using OH lines)
        12 : apply individual offset to each slitlet (using OH lines)
    minimum_slitlet_width_mm : float
        Minimum slitlet width (mm) for a valid slitlet.
    maximum_slitlet_width_mm : float
        Maximum slitlet width (mm) for a valid slitlet.
    save_intermediate_results : bool
        If True, save plots in PDF files
    debugplot : int
        Determines whether intermediate computations and/or plots
        are displayed. The valid codes are defined in
        numina.array.display.pause_debugplot.

    Returns
    -------
    refined_rectwv_coeff : RectWaveCoeff instance
        Refined rectification and wavelength calibration coefficients
        for the particular CSU configuration.
    expected_cat_image : HDUList object
        Output 2D image with the expected catalogued lines.

    """

    logger = logging.getLogger(__name__)

    if save_intermediate_results:
        from matplotlib.backends.backend_pdf import PdfPages
        pdf = PdfPages('crosscorrelation.pdf')
    else:
        pdf = None

    # image header
    main_header = input_image[0].header
    filter_name = main_header['filter']
    grism_name = main_header['grism']

    # protections
    if refine_wavecalib_mode not in [1, 2, 11, 12]:
        logger.error('Wavelength calibration refinemente mode={}'.format(
            refine_wavecalib_mode))
        raise ValueError("Invalid wavelength calibration refinement mode")

    # read tabulated lines
    if refine_wavecalib_mode in [1, 2]:  # ARC lines
        if grism_name == 'LR':
            catlines_file = 'lines_argon_neon_xenon_empirical_LR.dat'
        else:
            catlines_file = 'lines_argon_neon_xenon_empirical.dat'
        dumdata = pkgutil.get_data('emirdrp.instrument.configs', catlines_file)
        arc_lines_tmpfile = StringIO(dumdata.decode('utf8'))
        catlines = np.genfromtxt(arc_lines_tmpfile)
        # define wavelength and flux as separate arrays
        catlines_all_wave = catlines[:, 0]
        catlines_all_flux = catlines[:, 1]
        mode = refine_wavecalib_mode
    elif refine_wavecalib_mode in [11, 12]:  # OH lines
        dumdata = pkgutil.get_data('emirdrp.instrument.configs',
                                   'Oliva_etal_2013.dat')
        oh_lines_tmpfile = StringIO(dumdata.decode('utf8'))
        catlines = np.genfromtxt(oh_lines_tmpfile)
        # define wavelength and flux as separate arrays
        catlines_all_wave = np.concatenate((catlines[:, 1], catlines[:, 0]))
        catlines_all_flux = np.concatenate((catlines[:, 2], catlines[:, 2]))
        mode = refine_wavecalib_mode - 10
    else:
        raise ValueError('Unexpected mode={}'.format(refine_wavecalib_mode))

    # initialize output
    refined_rectwv_coeff = deepcopy(rectwv_coeff)

    logger.info('Computing median spectrum')
    # compute median spectrum and normalize it
    sp_median = median_slitlets_rectified(
        input_image,
        mode=2,
        minimum_slitlet_width_mm=minimum_slitlet_width_mm,
        maximum_slitlet_width_mm=maximum_slitlet_width_mm)[0].data
    sp_median /= sp_median.max()

    # determine minimum and maximum useful wavelength
    jmin, jmax = find_pix_borders(sp_median, 0)
    naxis1 = main_header['naxis1']
    naxis2 = main_header['naxis2']
    crpix1 = main_header['crpix1']
    crval1 = main_header['crval1']
    cdelt1 = main_header['cdelt1']
    xwave = crval1 + (np.arange(naxis1) + 1.0 - crpix1) * cdelt1
    if grism_name == 'LR':
        wv_parameters = set_wv_parameters(filter_name, grism_name)
        wave_min = wv_parameters['wvmin_useful']
        wave_max = wv_parameters['wvmax_useful']
    else:
        wave_min = crval1 + (jmin + 1 - crpix1) * cdelt1
        wave_max = crval1 + (jmax + 1 - crpix1) * cdelt1
    logger.info('Setting wave_min to {}'.format(wave_min))
    logger.info('Setting wave_max to {}'.format(wave_max))

    # extract subset of catalogue lines within current wavelength range
    lok1 = catlines_all_wave >= wave_min
    lok2 = catlines_all_wave <= wave_max
    catlines_reference_wave = catlines_all_wave[lok1 * lok2]
    catlines_reference_flux = catlines_all_flux[lok1 * lok2]
    catlines_reference_flux /= catlines_reference_flux.max()

    # estimate sigma to broaden catalogue lines
    csu_config = CsuConfiguration.define_from_header(main_header)
    # segregate slitlets
    list_useful_slitlets = csu_config.widths_in_range_mm(
        minwidth=minimum_slitlet_width_mm, maxwidth=maximum_slitlet_width_mm)
    list_not_useful_slitlets = [
        i for i in list(range(1, EMIR_NBARS + 1))
        if i not in list_useful_slitlets
    ]
    logger.info('list of useful slitlets: {}'.format(list_useful_slitlets))
    logger.info(
        'list of not useful slitlets: {}'.format(list_not_useful_slitlets))
    tempwidths = np.array([
        csu_config.csu_bar_slit_width(islitlet)
        for islitlet in list_useful_slitlets
    ])
    widths_summary = summary(tempwidths)
    logger.info('Statistics of useful slitlet widths (mm):')
    logger.info('- npoints....: {0:d}'.format(widths_summary['npoints']))
    logger.info('- mean.......: {0:7.3f}'.format(widths_summary['mean']))
    logger.info('- median.....: {0:7.3f}'.format(widths_summary['median']))
    logger.info('- std........: {0:7.3f}'.format(widths_summary['std']))
    logger.info('- robust_std.: {0:7.3f}'.format(widths_summary['robust_std']))
    # empirical transformation of slit width (mm) to pixels
    sigma_broadening = cdelt1 * widths_summary['median']

    # convolve location of catalogue lines to generate expected spectrum
    xwave_reference, sp_reference = convolve_comb_lines(
        catlines_reference_wave, catlines_reference_flux, sigma_broadening,
        crpix1, crval1, cdelt1, naxis1)
    sp_reference /= sp_reference.max()

    # generate image2d with expected lines
    image2d_expected_lines = np.tile(sp_reference, (naxis2, 1))
    hdu = fits.PrimaryHDU(data=image2d_expected_lines, header=main_header)
    expected_cat_image = fits.HDUList([hdu])

    if (abs(debugplot) % 10 != 0) or (pdf is not None):
        ax = ximplotxy(xwave,
                       sp_median,
                       'C1-',
                       xlabel='Wavelength (Angstroms, in vacuum)',
                       ylabel='Normalized number of counts',
                       title='Median spectrum',
                       label='observed spectrum',
                       show=False)
        # overplot reference catalogue lines
        ax.stem(catlines_reference_wave,
                catlines_reference_flux,
                'C4-',
                markerfmt=' ',
                basefmt='C4-',
                label='tabulated lines')
        # overplot convolved reference lines
        ax.plot(xwave_reference,
                sp_reference,
                'C0-',
                label='expected spectrum')
        ax.legend()
        if pdf is not None:
            pdf.savefig()
        else:
            pause_debugplot(debugplot=debugplot, pltshow=True)

    # compute baseline signal in sp_median
    baseline = np.percentile(sp_median[sp_median > 0], q=10)
    if (abs(debugplot) % 10 != 0) or (pdf is not None):
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.hist(sp_median, bins=1000, log=True)
        ax.set_xlabel('Normalized number of counts')
        ax.set_ylabel('Number of pixels')
        ax.set_title('Median spectrum')
        ax.axvline(float(baseline), linestyle='--', color='grey')
        if pdf is not None:
            pdf.savefig()
        else:
            geometry = (0, 0, 640, 480)
            set_window_geometry(geometry)
            plt.show()
    # subtract baseline to sp_median (only pixels with signal above zero)
    lok = np.where(sp_median > 0)
    sp_median[lok] -= baseline

    # compute global offset through periodic correlation
    logger.info('Computing global offset')
    global_offset, fpeak = periodic_corr1d(
        sp_reference=sp_reference,
        sp_offset=sp_median,
        fminmax=None,
        naround_zero=50,
        plottitle='Median spectrum (cross-correlation)',
        pdf=pdf,
        debugplot=debugplot)
    logger.info('Global offset: {} pixels'.format(-global_offset))

    missing_slitlets = rectwv_coeff.missing_slitlets

    if mode == 1:
        # apply computed offset to obtain refined_rectwv_coeff_global
        for islitlet in range(1, EMIR_NBARS + 1):
            if islitlet not in missing_slitlets:
                i = islitlet - 1
                dumdict = refined_rectwv_coeff.contents[i]
                dumdict['wpoly_coeff'][0] -= global_offset * cdelt1

    elif mode == 2:
        # compute individual offset for each slitlet
        logger.info('Computing individual offsets')
        median_55sp = median_slitlets_rectified(input_image, mode=1)
        offset_array = np.zeros(EMIR_NBARS)
        xplot = []
        yplot = []
        xplot_skipped = []
        yplot_skipped = []
        cout = '0'
        for islitlet in range(1, EMIR_NBARS + 1):
            if islitlet in list_useful_slitlets:
                i = islitlet - 1
                sp_median = median_55sp[0].data[i, :]
                lok = np.where(sp_median > 0)
                baseline = np.percentile(sp_median[lok], q=10)
                sp_median[lok] -= baseline
                sp_median /= sp_median.max()
                offset_array[i], fpeak = periodic_corr1d(
                    sp_reference=sp_reference,
                    sp_offset=median_55sp[0].data[i, :],
                    fminmax=None,
                    naround_zero=50,
                    plottitle='slitlet #{0} (cross-correlation)'.format(
                        islitlet),
                    pdf=pdf,
                    debugplot=debugplot)
                dumdict = refined_rectwv_coeff.contents[i]
                dumdict['wpoly_coeff'][0] -= offset_array[i] * cdelt1
                xplot.append(islitlet)
                yplot.append(-offset_array[i])
                # second correction
                wpoly_coeff_refined = check_wlcalib_sp(
                    sp=median_55sp[0].data[i, :],
                    crpix1=crpix1,
                    crval1=crval1 - offset_array[i] * cdelt1,
                    cdelt1=cdelt1,
                    wv_master=catlines_reference_wave,
                    coeff_ini=dumdict['wpoly_coeff'],
                    naxis1_ini=EMIR_NAXIS1,
                    title='slitlet #{0} (after applying offset)'.format(
                        islitlet),
                    ylogscale=False,
                    pdf=pdf,
                    debugplot=debugplot)
                dumdict['wpoly_coeff'] = wpoly_coeff_refined
                cout += '.'

            else:
                xplot_skipped.append(islitlet)
                yplot_skipped.append(0)
                cout += 'i'

            if islitlet % 10 == 0:
                if cout != 'i':
                    cout = str(islitlet // 10)

            logger.info(cout)

        # show offsets with opposite sign
        stat_summary = summary(np.array(yplot))
        logger.info('Statistics of individual slitlet offsets (pixels):')
        logger.info('- npoints....: {0:d}'.format(stat_summary['npoints']))
        logger.info('- mean.......: {0:7.3f}'.format(stat_summary['mean']))
        logger.info('- median.....: {0:7.3f}'.format(stat_summary['median']))
        logger.info('- std........: {0:7.3f}'.format(stat_summary['std']))
        logger.info('- robust_std.: {0:7.3f}'.format(
            stat_summary['robust_std']))
        if (abs(debugplot) % 10 != 0) or (pdf is not None):
            ax = ximplotxy(xplot,
                           yplot,
                           linestyle='',
                           marker='o',
                           color='C0',
                           xlabel='slitlet number',
                           ylabel='-offset (pixels) = offset to be applied',
                           title='cross-correlation result',
                           show=False,
                           **{'label': 'individual slitlets'})
            if len(xplot_skipped) > 0:
                ax.plot(xplot_skipped, yplot_skipped, 'mx')
            ax.axhline(-global_offset,
                       linestyle='--',
                       color='C1',
                       label='global offset')
            ax.legend()
            if pdf is not None:
                pdf.savefig()
            else:
                pause_debugplot(debugplot=debugplot, pltshow=True)
    else:
        raise ValueError('Unexpected mode={}'.format(mode))

    # close output PDF file
    if pdf is not None:
        pdf.close()

    # return result
    return refined_rectwv_coeff, expected_cat_image
Esempio n. 3
0
def spectral_lines_to_ds9(rectwv_coeff,
                    spectral_lines=None,
                    rectified=False):
    """Generate ds9 region output with requested spectral lines

    Parameters
    ----------
    rectwv_coeff : RectWaveCoeff instance
        Rectification and wavelength calibration coefficients for the
        particular CSU configuration.
    spectral_lines : str
        Spectral lines to be saved: the only two possibilities are 'arc'
        and 'oh'.
    rectified : bool
        If True, the regions correspond to the rectified image.

    Returns
    -------
    output : str
        String containing the full output to be exported as a ds9 region
        file.
    """

    # protections
    if spectral_lines not in ['arc', 'oh']:
        raise ValueError('Unexpected spectral lines=' + str(spectral_lines))

    if spectral_lines == 'arc':
        grism_name = rectwv_coeff.tags['grism']
        if grism_name == 'LR':
            catlines_file = 'lines_argon_neon_xenon_empirical_LR.dat'
        else:
            catlines_file = 'lines_argon_neon_xenon_empirical.dat'
        dumdata = pkgutil.get_data('emirdrp.instrument.configs',
                                   catlines_file)
        arc_lines_tmpfile = StringIO(dumdata.decode('utf8'))
        catlines = np.genfromtxt(arc_lines_tmpfile)
        # define wavelength and flux as separate arrays
        catlines_all_wave = catlines[:, 0]
        catlines_all_flux = catlines[:, 1]
    elif spectral_lines == 'oh':
        dumdata = pkgutil.get_data('emirdrp.instrument.configs',
                                   'Oliva_etal_2013.dat')
        oh_lines_tmpfile = StringIO(dumdata.decode('utf8'))
        catlines = np.genfromtxt(oh_lines_tmpfile)
        # define wavelength and flux as separate arrays
        catlines_all_wave = np.concatenate((catlines[:, 1],
                                            catlines[:, 0]))
        catlines_all_flux = np.concatenate((catlines[:, 2],
                                            catlines[:, 2]))
    else:
        raise ValueError('This should not happen!')

    # retrieve relevant wavelength calibration parameters
    grism_name = rectwv_coeff.tags['grism']
    filter_name = rectwv_coeff.tags['filter']
    wv_parameters = set_wv_parameters(filter_name, grism_name)
    naxis1_enlarged = wv_parameters['naxis1_enlarged']
    crpix1_enlarged = wv_parameters['crpix1_enlarged']
    crval1_enlarged = wv_parameters['crval1_enlarged']
    cdelt1_enlarged = wv_parameters['cdelt1_enlarged']

    ds9_output = '# Region file format: DS9 version 4.1\n' \
                 'global color=#00ffff dashlist=0 0 width=2 ' \
                 'font="helvetica 10 normal roman" select=1 ' \
                 'highlite=1 dash=0 fixed=0 edit=1 move=1 delete=1 ' \
                 'include=1 source=1\nphysical\n'

    ds9_output += '#\n# uuid (rectwv_coeff): {0}\n'.format(
        rectwv_coeff.uuid)
    ds9_output += \
        '# grism...............: {0}\n'.format(rectwv_coeff.tags['grism'])
    ds9_output += \
        '# filter..............: {0}\n'.format(rectwv_coeff.tags['filter'])

    global_integer_offset_x_pix = rectwv_coeff.global_integer_offset_x_pix
    global_integer_offset_y_pix = rectwv_coeff.global_integer_offset_y_pix

    for islitlet in range(1, EMIR_NBARS + 1):
        if islitlet not in rectwv_coeff.missing_slitlets:
            dumdict = rectwv_coeff.contents[islitlet - 1]
            if islitlet % 2 == 0:
                colorbox = '#ff00ff'  # '#ff77ff'
            else:
                colorbox = '#00ffff'  # '#4444ff'

            ds9_output += '#\n# islitlet...........: {0}\n'.format(
                islitlet)
            ds9_output += '# csu_bar_slit_center: {0}\n'.format(
                dumdict['csu_bar_slit_center']
            )
            crpix1_linear = 1.0
            crval1_linear = dumdict['crval1_linear']
            cdelt1_linear = dumdict['cdelt1_linear']
            wave_ini = crval1_linear + \
                       (0.5 - crpix1_linear) * cdelt1_linear
            wave_end = crval1_linear + \
                       (EMIR_NAXIS1 + 0.5 - crpix1_linear) * cdelt1_linear
            if rectified:
                ydum_lower = dumdict['y0_reference_lower_expected']
                ydum_upper = dumdict['y0_reference_upper_expected']
                # spectral lines
                for wave in catlines_all_wave:
                    if wave_ini <= wave <= wave_end:
                        xdum = (wave - crval1_enlarged) / cdelt1_enlarged
                        xdum += crpix1_enlarged
                        ds9_output += \
                            'line {0} {1} {2} {3}'.format(
                                xdum, ydum_lower,
                                xdum, ydum_upper
                            )
                        ds9_output += ' # color={0}\n'.format(colorbox)
                # slitlet label
                ydum_label = (ydum_lower + ydum_upper) / 2.0
                xdum_label = EMIR_NAXIS1 / 2 + 0.5
                wave_center = crval1_linear + \
                              (xdum_label - crpix1_linear) * cdelt1_linear
                xdum_label = (wave_center - crval1_enlarged) / cdelt1_enlarged
                xdum_label += crpix1_enlarged
                ds9_output += 'text {0} {1} {{{2}}} # color={3} ' \
                              'font="helvetica 10 bold ' \
                              'roman"\n'.format(xdum_label, ydum_label,
                                                islitlet, colorbox)
            else:
                bb_ns1_orig = dumdict['bb_ns1_orig']
                ttd_order = dumdict['ttd_order']
                aij = dumdict['ttd_aij']
                bij = dumdict['ttd_bij']
                min_row_rectified = float(dumdict['min_row_rectified'])
                max_row_rectified = float(dumdict['max_row_rectified'])
                mean_row_rectified = (min_row_rectified + max_row_rectified)/2
                wpoly_coeff = dumdict['wpoly_coeff']
                x0 = []
                y0 = []
                x1 = []
                y1 = []
                x2 = []
                y2 = []
                # spectral lines
                for wave in catlines_all_wave:
                    if wave_ini <= wave <= wave_end:
                        tmp_coeff = np.copy(wpoly_coeff)
                        tmp_coeff[0] -= wave
                        tmp_xroots = np.polynomial.Polynomial(
                            tmp_coeff).roots()
                        for dum in tmp_xroots:
                            if np.isreal(dum):
                                dum = dum.real
                                if 1 <= dum <= EMIR_NAXIS1:
                                    x0.append(dum)
                                    y0.append(mean_row_rectified)
                                    x1.append(dum)
                                    y1.append(min_row_rectified)
                                    x2.append(dum)
                                    y2.append(max_row_rectified)
                        pass
                xx0, yy0 = fmap(ttd_order, aij, bij, np.array(x0),
                                np.array(y0))
                xx0 -= global_integer_offset_x_pix
                yy0 += bb_ns1_orig
                yy0 -= global_integer_offset_y_pix
                xx1, yy1 = fmap(ttd_order, aij, bij, np.array(x1),
                                np.array(y1))
                xx1 -= global_integer_offset_x_pix
                yy1 += bb_ns1_orig
                yy1 -= global_integer_offset_y_pix
                xx2, yy2 = fmap(ttd_order, aij, bij, np.array(x2),
                                np.array(y2))
                xx2 -= global_integer_offset_x_pix
                yy2 += bb_ns1_orig
                yy2 -= global_integer_offset_y_pix
                for xx1_, xx2_, yy1_, yy2_ in zip(xx1, xx2, yy1, yy2):
                    ds9_output += \
                        'line {0} {1} {2} {3}'.format(
                            xx1_, yy1_, xx2_, yy2_
                        )
                    ds9_output += ' # color={0}\n'.format(colorbox)
                # slitlet label
                pol_lower = Polynomial(dumdict['spectrail']['poly_coef_lower'])
                pol_upper = Polynomial(dumdict['spectrail']['poly_coef_upper'])
                xdum_label = EMIR_NAXIS1 / 2 + 0.5
                ydum_lower = pol_lower(xdum_label)
                ydum_upper = pol_upper(xdum_label)
                ydum_label = (ydum_lower + ydum_upper) / 2.0
                ds9_output += 'text {0} {1} {{{2}}} # color={3} ' \
                              'font="helvetica 10 bold ' \
                              'roman"\n'.format(xdum_label, ydum_label,
                                                islitlet, colorbox)

    return ds9_output
Esempio n. 4
0
def display_slitlet_arrangement(fileobj,
                                grism=None,
                                spfilter=None,
                                bbox=None,
                                adjust=None,
                                geometry=None,
                                debugplot=0):
    """Display slitlet arrangment from CSUP keywords in FITS header.

    Parameters
    ----------
    fileobj : file object
        FITS or TXT file object.
    grism : str
        Grism.
    grism : str
        Filter.
    bbox : tuple of 4 floats
        If not None, values for xmin, xmax, ymin and ymax.
    adjust : bool
        Adjust X range according to minimum and maximum csu_bar_left
        and csu_bar_right (note that this option is overriden by 'bbox')
    geometry : tuple (4 integers) or None
        x, y, dx, dy values employed to set the Qt backend geometry.
    debugplot : int
        Determines whether intermediate computations and/or plots
        are displayed. The valid codes are defined in
        numina.array.display.pause_debugplot

    Returns
    -------
    csu_bar_left : list of floats
        Location (mm) of the left bar for each slitlet.
    csu_bar_right : list of floats
        Location (mm) of the right bar for each slitlet, using the
        same origin employed for csu_bar_left (which is not the
        value stored in the FITS keywords.
    csu_bar_slit_center : list of floats
        Middle point (mm) in between the two bars defining a slitlet.
    csu_bar_slit_width : list of floats
        Slitlet width (mm), computed as the distance between the two
        bars defining the slitlet.

    """

    if fileobj.name[-4:] == ".txt":
        if grism is None:
            raise ValueError("Undefined grism!")
        if spfilter is None:
            raise ValueError("Undefined filter!")
        # define CsuConfiguration object
        csu_config = CsuConfiguration()
        csu_config._csu_bar_left = []
        csu_config._csu_bar_right = []
        csu_config._csu_bar_slit_center = []
        csu_config._csu_bar_slit_width = []

        # since the input filename has been opened with argparse in binary
        # mode, it is necessary to close it and open it in text mode
        fileobj.close()
        # read TXT file
        with open(fileobj.name, mode='rt') as f:
            file_content = f.read().splitlines()
        next_id_bar = 1
        for line in file_content:
            if len(line) > 0:
                if line[0] not in ['#']:
                    line_contents = line.split()
                    id_bar = int(line_contents[0])
                    position = float(line_contents[1])
                    if id_bar == next_id_bar:
                        if id_bar <= EMIR_NBARS:
                            csu_config._csu_bar_left.append(position)
                            next_id_bar = id_bar + EMIR_NBARS
                        else:
                            csu_config._csu_bar_right.append(341.5 - position)
                            next_id_bar = id_bar - EMIR_NBARS + 1
                    else:
                        raise ValueError("Unexpected id_bar:" + str(id_bar))

        # compute slit width and center
        for i in range(EMIR_NBARS):
            csu_config._csu_bar_slit_center.append(
                (csu_config._csu_bar_left[i] + csu_config._csu_bar_right[i]) /
                2)
            csu_config._csu_bar_slit_width.append(
                csu_config._csu_bar_right[i] - csu_config._csu_bar_left[i])

    else:
        # read input FITS file
        hdulist = fits.open(fileobj.name)
        image_header = hdulist[0].header
        hdulist.close()

        # additional info from header
        grism = image_header['grism']
        spfilter = image_header['filter']

        # define slitlet arrangement
        csu_config = CsuConfiguration.define_from_fits(fileobj)

    # determine calibration
    if grism in ["J", "OPEN"] and spfilter == "J":
        wv_parameters = set_wv_parameters("J", "J")
    elif grism in ["H", "OPEN"] and spfilter == "H":
        wv_parameters = set_wv_parameters("H", "H")
    elif grism in ["K", "OPEN"] and spfilter == "Ksp":
        wv_parameters = set_wv_parameters("Ksp", "K")
    elif grism in ["LR", "OPEN"] and spfilter == "YJ":
        wv_parameters = set_wv_parameters("YJ", "LR")
    elif grism in ["LR", "OPEN"] and spfilter == "HK":
        wv_parameters = set_wv_parameters("HK", "LR")
    else:
        raise ValueError("Invalid grism + filter configuration")

    crval1 = wv_parameters['poly_crval1_linear']
    cdelt1 = wv_parameters['poly_cdelt1_linear']

    wvmin_useful = wv_parameters['wvmin_useful']
    wvmax_useful = wv_parameters['wvmax_useful']

    # display arrangement
    if abs(debugplot) >= 10:
        print("slit     left    right   center   width   min.wave   max.wave")
        print("====  =======  =======  =======   =====   ========   ========")
        for i in range(EMIR_NBARS):
            ibar = i + 1
            csu_crval1 = crval1(csu_config.csu_bar_slit_center(ibar))
            csu_cdelt1 = cdelt1(csu_config.csu_bar_slit_center(ibar))
            csu_crvaln = csu_crval1 + (EMIR_NAXIS1 - 1) * csu_cdelt1
            if wvmin_useful is not None:
                csu_crval1 = np.amax([csu_crval1, wvmin_useful])
            if wvmax_useful is not None:
                csu_crvaln = np.amin([csu_crvaln, wvmax_useful])
            print("{0:4d} {1:8.3f} {2:8.3f} {3:8.3f} {4:7.3f}   "
                  "{5:8.2f}   {6:8.2f}".format(
                      ibar, csu_config.csu_bar_left(ibar),
                      csu_config.csu_bar_right(ibar),
                      csu_config.csu_bar_slit_center(ibar),
                      csu_config.csu_bar_slit_width(ibar), csu_crval1,
                      csu_crvaln))
        print("---> {0:8.3f} {1:8.3f} {2:8.3f} {3:7.3f} <- mean (all)".format(
            np.mean(csu_config._csu_bar_left),
            np.mean(csu_config._csu_bar_right),
            np.mean(csu_config._csu_bar_slit_center),
            np.mean(csu_config._csu_bar_slit_width)))
        print("---> {0:8.3f} {1:8.3f} {2:8.3f} {3:7.3f} <- mean (odd)".format(
            np.mean(csu_config._csu_bar_left[::2]),
            np.mean(csu_config._csu_bar_right[::2]),
            np.mean(csu_config._csu_bar_slit_center[::2]),
            np.mean(csu_config._csu_bar_slit_width[::2])))
        print("---> {0:8.3f} {1:8.3f} {2:8.3f} {3:7.3f} <- mean (even)".format(
            np.mean(csu_config._csu_bar_left[1::2]),
            np.mean(csu_config._csu_bar_right[1::2]),
            np.mean(csu_config._csu_bar_slit_center[1::2]),
            np.mean(csu_config._csu_bar_slit_width[1::2])))

    # display slit arrangement
    if abs(debugplot) % 10 != 0:
        fig = plt.figure()
        set_window_geometry(geometry)
        ax = fig.add_subplot(111)
        if bbox is None:
            if adjust:
                xmin = min(csu_config._csu_bar_left)
                xmax = max(csu_config._csu_bar_right)
                dx = xmax - xmin
                if dx == 0:
                    dx = 1
                xmin -= dx / 20
                xmax += dx / 20
                ax.set_xlim(xmin, xmax)
            else:
                ax.set_xlim(0., 341.5)
            ax.set_ylim(0, 56)
        else:
            ax.set_xlim(bbox[0], bbox[1])
            ax.set_ylim(bbox[2], bbox[3])
        ax.set_xlabel('csu_bar_position (mm)')
        ax.set_ylabel('slit number')
        for i in range(EMIR_NBARS):
            ibar = i + 1
            ax.add_patch(
                patches.Rectangle((csu_config.csu_bar_left(ibar), ibar - 0.5),
                                  csu_config.csu_bar_slit_width(ibar), 1.0))
            ax.plot([0., csu_config.csu_bar_left(ibar)], [ibar, ibar],
                    '-',
                    color='gray')
            ax.plot([csu_config.csu_bar_right(ibar), 341.5], [ibar, ibar],
                    '-',
                    color='gray')
        plt.title("File: " + fileobj.name + "\ngrism=" + grism + ", filter=" +
                  spfilter)
        pause_debugplot(debugplot, pltshow=True)

    # return results
    return csu_config._csu_bar_left, csu_config._csu_bar_right, \
           csu_config._csu_bar_slit_center, csu_config._csu_bar_slit_width
def display_slitlet_arrangement(fileobj,
                                grism=None,
                                spfilter=None,
                                bbox=None,
                                adjust=None,
                                geometry=None,
                                debugplot=0):
    """Display slitlet arrangment from CSUP keywords in FITS header.

    Parameters
    ----------
    fileobj : file object
        FITS or TXT file object.
    grism : str
        Grism.
    grism : str
        Filter.
    bbox : tuple of 4 floats
        If not None, values for xmin, xmax, ymin and ymax.
    adjust : bool
        Adjust X range according to minimum and maximum csu_bar_left
        and csu_bar_right (note that this option is overriden by 'bbox')
    geometry : tuple (4 integers) or None
        x, y, dx, dy values employed to set the Qt backend geometry.
    debugplot : int
        Determines whether intermediate computations and/or plots
        are displayed. The valid codes are defined in
        numina.array.display.pause_debugplot

    Returns
    -------
    csu_bar_left : list of floats
        Location (mm) of the left bar for each slitlet.
    csu_bar_right : list of floats
        Location (mm) of the right bar for each slitlet, using the
        same origin employed for csu_bar_left (which is not the
        value stored in the FITS keywords.
    csu_bar_slit_center : list of floats
        Middle point (mm) in between the two bars defining a slitlet.
    csu_bar_slit_width : list of floats
        Slitlet width (mm), computed as the distance between the two
        bars defining the slitlet.

    """

    if fileobj.name[-4:] == ".txt":
        if grism is None:
            raise ValueError("Undefined grism!")
        if spfilter is None:
            raise ValueError("Undefined filter!")
        # define CsuConfiguration object
        csu_config = CsuConfiguration()
        csu_config._csu_bar_left = []
        csu_config._csu_bar_right = []
        csu_config._csu_bar_slit_center = []
        csu_config._csu_bar_slit_width = []

        # since the input filename has been opened with argparse in binary
        # mode, it is necessary to close it and open it in text mode
        fileobj.close()
        # read TXT file
        with open(fileobj.name, mode='rt') as f:
            file_content = f.read().splitlines()
        next_id_bar = 1
        for line in file_content:
            if len(line) > 0:
                if line[0] not in ['#']:
                    line_contents = line.split()
                    id_bar = int(line_contents[0])
                    position = float(line_contents[1])
                    if id_bar == next_id_bar:
                        if id_bar <= EMIR_NBARS:
                            csu_config._csu_bar_left.append(position)
                            next_id_bar = id_bar + EMIR_NBARS
                        else:
                            csu_config._csu_bar_right.append(341.5 - position)
                            next_id_bar = id_bar - EMIR_NBARS + 1
                    else:
                        raise ValueError("Unexpected id_bar:" + str(id_bar))

        # compute slit width and center
        for i in range(EMIR_NBARS):
            csu_config._csu_bar_slit_center.append(
                (csu_config._csu_bar_left[i] + csu_config._csu_bar_right[i])/2
            )
            csu_config._csu_bar_slit_width.append(
                csu_config._csu_bar_right[i] - csu_config._csu_bar_left[i]
            )

    else:
        # read input FITS file
        hdulist = fits.open(fileobj.name)
        image_header = hdulist[0].header
        hdulist.close()

        # additional info from header
        grism = image_header['grism']
        spfilter = image_header['filter']

        # define slitlet arrangement
        csu_config = CsuConfiguration.define_from_fits(fileobj)

    # determine calibration
    if grism in ["J", "OPEN"] and spfilter == "J":
        wv_parameters = set_wv_parameters("J", "J")
    elif grism in ["H", "OPEN"] and spfilter == "H":
        wv_parameters = set_wv_parameters("H", "H")
    elif grism in ["K", "OPEN"] and spfilter == "Ksp":
        wv_parameters = set_wv_parameters("Ksp", "K")
    elif grism in ["LR", "OPEN"] and spfilter == "YJ":
        wv_parameters = set_wv_parameters("YJ", "LR")
    elif grism in ["LR", "OPEN"] and spfilter == "HK":
        wv_parameters = set_wv_parameters("HK", "LR")
    else:
        raise ValueError("Invalid grism + filter configuration")

    crval1 = wv_parameters['poly_crval1_linear']
    cdelt1 = wv_parameters['poly_cdelt1_linear']

    wvmin_useful = wv_parameters['wvmin_useful']
    wvmax_useful = wv_parameters['wvmax_useful']

    # display arrangement
    if abs(debugplot) >= 10:
        print("slit     left    right   center   width   min.wave   max.wave")
        print("====  =======  =======  =======   =====   ========   ========")
        for i in range(EMIR_NBARS):
            ibar = i + 1
            csu_crval1 = crval1(csu_config.csu_bar_slit_center(ibar))
            csu_cdelt1 = cdelt1(csu_config.csu_bar_slit_center(ibar))
            csu_crvaln = csu_crval1 + (EMIR_NAXIS1 - 1) * csu_cdelt1
            if wvmin_useful is not None:
                csu_crval1 = np.amax([csu_crval1, wvmin_useful])
            if wvmax_useful is not None:
                csu_crvaln = np.amin([csu_crvaln, wvmax_useful])
            print("{0:4d} {1:8.3f} {2:8.3f} {3:8.3f} {4:7.3f}   "
                  "{5:8.2f}   {6:8.2f}".format(
                ibar, csu_config.csu_bar_left(ibar),
                csu_config.csu_bar_right(ibar),
                csu_config.csu_bar_slit_center(ibar),
                csu_config.csu_bar_slit_width(ibar),
                csu_crval1, csu_crvaln)
            )
        print(
            "---> {0:8.3f} {1:8.3f} {2:8.3f} {3:7.3f} <- mean (all)".format(
                np.mean(csu_config._csu_bar_left),
                np.mean(csu_config._csu_bar_right),
                np.mean(csu_config._csu_bar_slit_center),
                np.mean(csu_config._csu_bar_slit_width)
            )
        )
        print(
            "---> {0:8.3f} {1:8.3f} {2:8.3f} {3:7.3f} <- mean (odd)".format(
                np.mean(csu_config._csu_bar_left[::2]),
                np.mean(csu_config._csu_bar_right[::2]),
                np.mean(csu_config._csu_bar_slit_center[::2]),
                np.mean(csu_config._csu_bar_slit_width[::2])
            )
        )
        print(
            "---> {0:8.3f} {1:8.3f} {2:8.3f} {3:7.3f} <- mean (even)".format(
                np.mean(csu_config._csu_bar_left[1::2]),
                np.mean(csu_config._csu_bar_right[1::2]),
                np.mean(csu_config._csu_bar_slit_center[1::2]),
                np.mean(csu_config._csu_bar_slit_width[1::2])
            )
        )

    # display slit arrangement
    if abs(debugplot) % 10 != 0:
        fig = plt.figure()
        set_window_geometry(geometry)
        ax = fig.add_subplot(111)
        if bbox is None:
            if adjust:
                xmin = min(csu_config._csu_bar_left)
                xmax = max(csu_config._csu_bar_right)
                dx = xmax - xmin
                if dx == 0:
                    dx = 1
                xmin -= dx/20
                xmax += dx/20
                ax.set_xlim(xmin, xmax)
            else:
                ax.set_xlim(0., 341.5)
            ax.set_ylim(0, 56)
        else:
            ax.set_xlim(bbox[0], bbox[1])
            ax.set_ylim(bbox[2], bbox[3])
        ax.set_xlabel('csu_bar_position (mm)')
        ax.set_ylabel('slit number')
        for i in range(EMIR_NBARS):
            ibar = i + 1
            ax.add_patch(patches.Rectangle(
                (csu_config.csu_bar_left(ibar), ibar-0.5),
                csu_config.csu_bar_slit_width(ibar), 1.0))
            ax.plot([0., csu_config.csu_bar_left(ibar)], [ibar, ibar],
                    '-', color='gray')
            ax.plot([csu_config.csu_bar_right(ibar), 341.5],
                    [ibar, ibar], '-', color='gray')
        plt.title("File: " + fileobj.name + "\ngrism=" + grism +
                  ", filter=" + spfilter)
        pause_debugplot(debugplot, pltshow=True)

    # return results
    return csu_config._csu_bar_left, csu_config._csu_bar_right, \
           csu_config._csu_bar_slit_center, csu_config._csu_bar_slit_width
Esempio n. 6
0
    def run(self, rinput):
        self.logger.info('starting generation of flatpix2pix')

        self.logger.info('rectwv_coeff..........................: {}'.format(
            rinput.rectwv_coeff))
        self.logger.info('master_rectwv.........................: {}'.format(
            rinput.master_rectwv))
        self.logger.info('Minimum slitlet width (mm)............: {}'.format(
            rinput.minimum_slitlet_width_mm))
        self.logger.info('Maximum slitlet width (mm)............: {}'.format(
            rinput.maximum_slitlet_width_mm))
        self.logger.info('Global offset X direction (pixels)....: {}'.format(
            rinput.global_integer_offset_x_pix))
        self.logger.info('Global offset Y direction (pixels)....: {}'.format(
            rinput.global_integer_offset_y_pix))
        self.logger.info('Minimum fraction......................: {}'.format(
            rinput.minimum_fraction))
        self.logger.info('Minimum value in output...............: {}'.format(
            rinput.minimum_value_in_output))
        self.logger.info('Maximum value in output...............: {}'.format(
            rinput.maximum_value_in_output))

        # check rectification and wavelength calibration information
        if rinput.master_rectwv is None and rinput.rectwv_coeff is None:
            raise ValueError('No master_rectwv nor rectwv_coeff data have '
                             'been provided')
        elif rinput.master_rectwv is not None and \
                rinput.rectwv_coeff is not None:
            self.logger.warning('rectwv_coeff will be used instead of '
                                'master_rectwv')
        if rinput.rectwv_coeff is not None and \
                (rinput.global_integer_offset_x_pix != 0 or
                 rinput.global_integer_offset_y_pix != 0):
            raise ValueError('global_integer_offsets cannot be used '
                             'simultaneously with rectwv_coeff')

        # check headers to detect lamp status (on/off)
        list_lampincd = []
        for fname in rinput.obresult.frames:
            with fname.open() as f:
                list_lampincd.append(f[0].header['lampincd'])

        # check number of images
        nimages = len(rinput.obresult.frames)
        n_on = list_lampincd.count(1)
        n_off = list_lampincd.count(0)
        self.logger.info(
            'Number of images with lamp ON.........: {}'.format(n_on))
        self.logger.info(
            'Number of images with lamp OFF........: {}'.format(n_off))
        self.logger.info(
            'Total number of images................: {}'.format(nimages))
        if n_on == 0:
            raise ValueError('Insufficient number of images with lamp ON')
        if n_on + n_off != nimages:
            raise ValueError('Number of images does not match!')

        # check combination method
        if rinput.method_kwargs == {}:
            method_kwargs = None
        else:
            if rinput.method == 'sigmaclip':
                method_kwargs = rinput.method_kwargs
            else:
                raise ValueError('Unexpected method_kwargs={}'.format(
                    rinput.method_kwargs))

        # build object to proceed with bpm, bias, and dark (not flat)
        flow = self.init_filters(rinput)

        # available combination methods
        method = getattr(combine, rinput.method)

        # basic reduction of images with lamp ON or OFF
        lampmode = {0: 'off', 1: 'on'}
        reduced_image_on = None
        reduced_image_off = None
        for imode in lampmode.keys():
            self.logger.info('starting basic reduction of images with'
                             ' lamp {}'.format(lampmode[imode]))
            tmplist = [
                rinput.obresult.frames[i]
                for i, lampincd in enumerate(list_lampincd)
                if lampincd == imode
            ]
            if len(tmplist) > 0:
                with contextlib.ExitStack() as stack:
                    hduls = [
                        stack.enter_context(fname.open()) for fname in tmplist
                    ]
                    reduced_image = combine_imgs(hduls,
                                                 method=method,
                                                 method_kwargs=method_kwargs,
                                                 errors=False,
                                                 prolog=None)
                if imode == 0:
                    reduced_image_off = flow(reduced_image)
                    hdr = reduced_image_off[0].header
                    self.set_base_headers(hdr)
                    self.save_intermediate_img(reduced_image_off,
                                               'reduced_image_off.fits')
                elif imode == 1:
                    reduced_image_on = flow(reduced_image)
                    hdr = reduced_image_on[0].header
                    self.set_base_headers(hdr)
                    self.save_intermediate_img(reduced_image_on,
                                               'reduced_image_on.fits')
                else:
                    raise ValueError('Unexpected imode={}'.format(imode))

        # computation of ON-OFF
        header_on = reduced_image_on[0].header
        data_on = reduced_image_on[0].data.astype('float32')
        if n_off > 0:
            header_off = reduced_image_off[0].header
            data_off = reduced_image_off[0].data.astype('float32')
        else:
            header_off = None
            data_off = np.zeros_like(data_on)
        reduced_data = data_on - data_off

        # update reduced image header
        reduced_image = self.create_reduced_image(rinput, reduced_data,
                                                  header_on, header_off,
                                                  list_lampincd)

        # save intermediate image in work directory
        self.save_intermediate_img(reduced_image, 'reduced_image.fits')

        # define rectification and wavelength calibration coefficients
        if rinput.rectwv_coeff is None:
            rectwv_coeff = rectwv_coeff_from_mos_library(
                reduced_image, rinput.master_rectwv)
            # set global offsets
            rectwv_coeff.global_integer_offset_x_pix = \
                rinput.global_integer_offset_x_pix
            rectwv_coeff.global_integer_offset_y_pix = \
                rinput.global_integer_offset_y_pix
        else:
            rectwv_coeff = rinput.rectwv_coeff
        # save as JSON in work directory
        self.save_structured_as_json(rectwv_coeff, 'rectwv_coeff.json')
        # ds9 region files (to be saved in the work directory)
        if self.intermediate_results:
            save_four_ds9(rectwv_coeff)
            save_spectral_lines_ds9(rectwv_coeff)

        # clean (interpolate) defects
        self.logger.debug('interpolating image defect')
        reduced_data_clean = clean_defects(reduced_data, debugplot=0)

        # apply global offsets (to both, the original and the cleaned version)
        image2d = apply_integer_offsets(
            image2d=reduced_data,
            offx=rectwv_coeff.global_integer_offset_x_pix,
            offy=rectwv_coeff.global_integer_offset_y_pix)
        image2d_clean = apply_integer_offsets(
            image2d=reduced_data_clean,
            offx=rectwv_coeff.global_integer_offset_x_pix,
            offy=rectwv_coeff.global_integer_offset_y_pix)

        # load CSU configuration
        csu_conf = CsuConfiguration.define_from_header(reduced_image[0].header)
        # determine (pseudo) longslits
        dict_longslits = csu_conf.pseudo_longslits()

        # valid slitlet numbers
        list_valid_islitlets = list(range(1, EMIR_NBARS + 1))
        for idel in rectwv_coeff.missing_slitlets:
            self.logger.info('-> Removing slitlet (not defined): ' + str(idel))
            list_valid_islitlets.remove(idel)
        # filter out slitlets with widths outside valid range
        list_outside_valid_width = []
        for islitlet in list_valid_islitlets:
            slitwidth = csu_conf.csu_bar_slit_width(islitlet)
            if (slitwidth < rinput.minimum_slitlet_width_mm) or \
                    (slitwidth > rinput.maximum_slitlet_width_mm):
                list_outside_valid_width.append(islitlet)
                self.logger.info('-> Removing slitlet (width out of range): ' +
                                 str(islitlet))
        if len(list_outside_valid_width) > 0:
            for idel in list_outside_valid_width:
                list_valid_islitlets.remove(idel)

        # initialize rectified image
        image2d_flatfielded = np.zeros((EMIR_NAXIS2, EMIR_NAXIS1))

        # main loop
        grism_name = rectwv_coeff.tags['grism']
        filter_name = rectwv_coeff.tags['filter']
        cout = '0'
        debugplot = rinput.debugplot
        for islitlet in list(range(1, EMIR_NBARS + 1)):
            if islitlet in list_valid_islitlets:
                # define Slitlet2D object
                slt = Slitlet2D(islitlet=islitlet,
                                rectwv_coeff=rectwv_coeff,
                                debugplot=debugplot)
                if abs(slt.debugplot) > 10:
                    print(slt)

                # extract (distorted) slitlet from the initial image
                slitlet2d = slt.extract_slitlet2d(image_2k2k=image2d,
                                                  subtitle='original image')
                slitlet2d_clean = slt.extract_slitlet2d(
                    image_2k2k=image2d_clean,
                    subtitle='original (cleaned) image')

                # rectify slitlet
                slitlet2d_rect = slt.rectify(
                    slitlet2d=slitlet2d_clean,
                    resampling=2,
                    subtitle='original (cleaned) rectified')
                naxis2_slitlet2d, naxis1_slitlet2d = slitlet2d_rect.shape

                if naxis1_slitlet2d != EMIR_NAXIS1:
                    print('naxis1_slitlet2d: ', naxis1_slitlet2d)
                    print('EMIR_NAXIS1.....: ', EMIR_NAXIS1)
                    raise ValueError("Unexpected naxis1_slitlet2d")

                slitlet2d_rect_mask = np.zeros(
                    (naxis2_slitlet2d, naxis1_slitlet2d), dtype=bool)

                # for grism LR set to zero data beyond useful wavelength range
                if grism_name == 'LR':
                    wv_parameters = set_wv_parameters(filter_name, grism_name)
                    x_pix = np.arange(1, naxis1_slitlet2d + 1)
                    wl_pix = polyval(x_pix, slt.wpoly)
                    lremove = wl_pix < wv_parameters['wvmin_useful']
                    slitlet2d_rect[:, lremove] = 0.0
                    slitlet2d_rect_mask[:, lremove] = True
                    lremove = wl_pix > wv_parameters['wvmax_useful']
                    slitlet2d_rect[:, lremove] = 0.0
                    slitlet2d_rect_mask[:, lremove] = True

                # get useful slitlet region (use boundaries)
                spectrail = slt.list_spectrails[0]
                yy0 = slt.corr_yrect_a + \
                      slt.corr_yrect_b * spectrail(slt.x0_reference)
                ii1 = int(yy0 + 0.5) - slt.bb_ns1_orig
                spectrail = slt.list_spectrails[2]
                yy0 = slt.corr_yrect_a + \
                      slt.corr_yrect_b * spectrail(slt.x0_reference)
                ii2 = int(yy0 + 0.5) - slt.bb_ns1_orig

                # median spatial profile along slitlet (to be used later)
                image2d_rect_masked = np.ma.masked_array(
                    slitlet2d_rect, mask=slitlet2d_rect_mask)
                if abs(slt.debugplot) % 10 != 0:
                    slt.ximshow_rectified(
                        slitlet2d_rect=image2d_rect_masked.data,
                        subtitle='original (cleaned) rectified and masked')
                ycut_median = np.ma.median(image2d_rect_masked, axis=1).data
                ycut_median_median = np.median(ycut_median[ii1:(ii2 + 1)])
                ycut_median /= ycut_median_median
                if abs(slt.debugplot) % 10 != 0:
                    ximplotxy(np.arange(1, naxis2_slitlet2d + 1),
                              ycut_median,
                              'o',
                              title='median value at each scan',
                              debugplot=12)

                # median spectrum
                sp_collapsed = np.median(slitlet2d_rect[ii1:(ii2 + 1), :],
                                         axis=0)

                # smooth median spectrum along the spectral direction
                sp_median = ndimage.median_filter(sp_collapsed,
                                                  rinput.nwindow_median,
                                                  mode='nearest')

                ymax_spmedian = sp_median.max()
                y_threshold = ymax_spmedian * rinput.minimum_fraction
                lremove = np.where(sp_median < y_threshold)
                sp_median[lremove] = 0.0

                if abs(slt.debugplot) % 10 != 0:
                    xaxis1 = np.arange(1, naxis1_slitlet2d + 1)
                    title = 'Slitlet#' + str(islitlet) + ' (median spectrum)'
                    ax = ximplotxy(xaxis1,
                                   sp_collapsed,
                                   title=title,
                                   show=False,
                                   **{'label': 'collapsed spectrum'})
                    ax.plot(xaxis1, sp_median, label='fitted spectrum')
                    ax.plot([1, naxis1_slitlet2d],
                            2 * [y_threshold],
                            label='threshold')
                    # ax.plot(xknots, yknots, 'o', label='knots')
                    ax.legend()
                    ax.set_ylim(-0.05 * ymax_spmedian, 1.05 * ymax_spmedian)
                    pause_debugplot(slt.debugplot,
                                    pltshow=True,
                                    tight_layout=True)

                # generate rectified slitlet region filled with the
                # median spectrum
                slitlet2d_rect_spmedian = np.tile(sp_median,
                                                  (naxis2_slitlet2d, 1))
                if abs(slt.debugplot) % 10 != 0:
                    slt.ximshow_rectified(
                        slitlet2d_rect=slitlet2d_rect_spmedian,
                        subtitle='rectified, filled with median spectrum')

                # apply median ycut
                for i in range(naxis2_slitlet2d):
                    slitlet2d_rect_spmedian[i, :] *= ycut_median[i]

                if abs(slt.debugplot) % 10 != 0:
                    slt.ximshow_rectified(
                        slitlet2d_rect=slitlet2d_rect_spmedian,
                        subtitle='rectified, filled with rescaled median '
                        'spectrum')

                # unrectified image
                slitlet2d_unrect_spmedian = slt.rectify(
                    slitlet2d=slitlet2d_rect_spmedian,
                    resampling=2,
                    inverse=True,
                    subtitle='unrectified, filled with median spectrum')

                # normalize initial slitlet image (avoid division by zero)
                slitlet2d_norm = np.zeros_like(slitlet2d)
                for j in range(naxis1_slitlet2d):
                    for i in range(naxis2_slitlet2d):
                        den = slitlet2d_unrect_spmedian[i, j]
                        if den == 0:
                            slitlet2d_norm[i, j] = 1.0
                        else:
                            slitlet2d_norm[i, j] = slitlet2d[i, j] / den
                # set to 1.0 one additional pixel at each side (since
                # 'den' above is small at the borders and generates wrong
                # bright pixels)
                slitlet2d_norm = fix_pix_borders(image2d=slitlet2d_norm,
                                                 nreplace=1,
                                                 sought_value=1.0,
                                                 replacement_value=1.0)

                if abs(slt.debugplot) % 10 != 0:
                    slt.ximshow_unrectified(
                        slitlet2d=slitlet2d_norm,
                        subtitle='unrectified, pixel-to-pixel')

                # compute smooth surface
                # clipped region
                slitlet2d_rect_clipped = slitlet2d_rect_spmedian.copy()
                slitlet2d_rect_clipped[:(ii1 - 1), :] = 0.0
                slitlet2d_rect_clipped[(ii2 + 2):, :] = 0.0
                # unrectified clipped image
                slitlet2d_unrect_clipped = slt.rectify(
                    slitlet2d=slitlet2d_rect_clipped,
                    resampling=2,
                    inverse=True,
                    subtitle='unrectified, filled with median spectrum '
                    '(clipped)')
                # normalize initial slitlet image (avoid division by zero)
                slitlet2d_norm_clipped = np.zeros_like(slitlet2d)
                for j in range(naxis1_slitlet2d):
                    for i in range(naxis2_slitlet2d):
                        den = slitlet2d_unrect_clipped[i, j]
                        if den == 0:
                            slitlet2d_norm_clipped[i, j] = 1.0
                        else:
                            slitlet2d_norm_clipped[i, j] = \
                                slitlet2d[i, j] / den
                # set to 1.0 one additional pixel at each side (since
                # 'den' above is small at the borders and generates wrong
                # bright pixels)
                slitlet2d_norm_clipped = fix_pix_borders(
                    image2d=slitlet2d_norm_clipped,
                    nreplace=1,
                    sought_value=1.0,
                    replacement_value=1.0)
                slitlet2d_norm_clipped = slitlet2d_norm_clipped.transpose()
                slitlet2d_norm_clipped = fix_pix_borders(
                    image2d=slitlet2d_norm_clipped,
                    nreplace=1,
                    sought_value=1.0,
                    replacement_value=1.0)
                slitlet2d_norm_clipped = slitlet2d_norm_clipped.transpose()
                slitlet2d_norm_smooth = ndimage.median_filter(
                    slitlet2d_norm_clipped, size=(5, 31), mode='nearest')
                # apply smooth surface to pix2pix
                slitlet2d_norm /= slitlet2d_norm_smooth

                if abs(slt.debugplot) % 10 != 0:
                    slt.ximshow_unrectified(
                        slitlet2d=slitlet2d_norm_clipped,
                        subtitle='unrectified, pixel-to-pixel (clipped)')
                    slt.ximshow_unrectified(
                        slitlet2d=slitlet2d_norm_smooth,
                        subtitle='unrectified, pixel-to-pixel (smoothed)')
                    slt.ximshow_unrectified(
                        slitlet2d=slitlet2d_norm,
                        subtitle='unrectified, pixel-to-pixel (improved)')

                # ---

                # check for (pseudo) longslit with previous and next slitlet
                imin = dict_longslits[islitlet].imin()
                imax = dict_longslits[islitlet].imax()
                if islitlet > 1:
                    same_slitlet_below = (islitlet - 1) >= imin
                else:
                    same_slitlet_below = False
                if islitlet < EMIR_NBARS:
                    same_slitlet_above = (islitlet + 1) <= imax
                else:
                    same_slitlet_above = False

                for j in range(EMIR_NAXIS1):
                    xchannel = j + 1
                    y0_lower = slt.list_frontiers[0](xchannel)
                    y0_upper = slt.list_frontiers[1](xchannel)
                    n1, n2 = nscan_minmax_frontiers(y0_frontier_lower=y0_lower,
                                                    y0_frontier_upper=y0_upper,
                                                    resize=True)
                    # note that n1 and n2 are scans (ranging from 1 to NAXIS2)
                    nn1 = n1 - slt.bb_ns1_orig + 1
                    nn2 = n2 - slt.bb_ns1_orig + 1
                    image2d_flatfielded[(n1 - 1):n2, j] = \
                        slitlet2d_norm[(nn1 - 1):nn2, j]

                    # force to 1.0 region around frontiers
                    if not same_slitlet_below:
                        image2d_flatfielded[(n1 - 1):(n1 + 2), j] = 1
                    if not same_slitlet_above:
                        image2d_flatfielded[(n2 - 5):n2, j] = 1
                cout += '.'
            else:
                cout += 'i'

            if islitlet % 10 == 0:
                if cout != 'i':
                    cout = str(islitlet // 10)

            self.logger.info(cout)

        # restore global offsets
        image2d_flatfielded = apply_integer_offsets(
            image2d=image2d_flatfielded,
            offx=-rectwv_coeff.global_integer_offset_x_pix,
            offy=-rectwv_coeff.global_integer_offset_y_pix)

        # set pixels below minimum value to 1.0
        filtered = np.where(
            image2d_flatfielded < rinput.minimum_value_in_output)
        image2d_flatfielded[filtered] = 1.0

        # set pixels above maximum value to 1.0
        filtered = np.where(
            image2d_flatfielded > rinput.maximum_value_in_output)
        image2d_flatfielded[filtered] = 1.0

        # update image header
        reduced_flatpix2pix = self.create_reduced_image(
            rinput, image2d_flatfielded, header_on, header_off, list_lampincd)

        # ds9 region files (to be saved in the work directory)
        if self.intermediate_results:
            save_four_ds9(rectwv_coeff)
            save_spectral_lines_ds9(rectwv_coeff)

        # save results in results directory
        self.logger.info('end of flatpix2pix generation')
        result = self.create_result(reduced_flatpix2pix=reduced_flatpix2pix)
        return result
Esempio n. 7
0
def refine_rectwv_coeff(input_image, rectwv_coeff,
                        refine_wavecalib_mode,
                        minimum_slitlet_width_mm,
                        maximum_slitlet_width_mm,
                        save_intermediate_results=False,
                        debugplot=0):
    """Refine RectWaveCoeff object using a catalogue of lines

    One and only one among refine_with_oh_lines_mode and
    refine_with_arc_lines must be different from zero.

    Parameters
    ----------
    input_image : HDUList object
        Input 2D image.
    rectwv_coeff : RectWaveCoeff instance
        Rectification and wavelength calibration coefficients for the
        particular CSU configuration.
    refine_wavecalib_mode : int
        Integer, indicating the type of refinement:
        0 : no refinement
        1 : apply the same global offset to all the slitlets (using ARC lines)
        2 : apply individual offset to each slitlet (using ARC lines)
        11 : apply the same global offset to all the slitlets (using OH lines)
        12 : apply individual offset to each slitlet (using OH lines)
    minimum_slitlet_width_mm : float
        Minimum slitlet width (mm) for a valid slitlet.
    maximum_slitlet_width_mm : float
        Maximum slitlet width (mm) for a valid slitlet.
    save_intermediate_results : bool
        If True, save plots in PDF files
    debugplot : int
        Determines whether intermediate computations and/or plots
        are displayed. The valid codes are defined in
        numina.array.display.pause_debugplot.

    Returns
    -------
    refined_rectwv_coeff : RectWaveCoeff instance
        Refined rectification and wavelength calibration coefficients
        for the particular CSU configuration.
    expected_cat_image : HDUList object
        Output 2D image with the expected catalogued lines.

    """

    logger = logging.getLogger(__name__)

    if save_intermediate_results:
        from matplotlib.backends.backend_pdf import PdfPages
        pdf = PdfPages('crosscorrelation.pdf')
    else:
        pdf = None

    # image header
    main_header = input_image[0].header
    filter_name = main_header['filter']
    grism_name = main_header['grism']

    # protections
    if refine_wavecalib_mode not in [1, 2, 11, 12]:
        logger.error('Wavelength calibration refinemente mode={}'. format(
            refine_wavecalib_mode
        ))
        raise ValueError("Invalid wavelength calibration refinement mode")

    # read tabulated lines
    if refine_wavecalib_mode in [1, 2]:        # ARC lines
        if grism_name == 'LR':
            catlines_file = 'lines_argon_neon_xenon_empirical_LR.dat'
        else:
            catlines_file = 'lines_argon_neon_xenon_empirical.dat'
        dumdata = pkgutil.get_data('emirdrp.instrument.configs', catlines_file)
        arc_lines_tmpfile = StringIO(dumdata.decode('utf8'))
        catlines = np.genfromtxt(arc_lines_tmpfile)
        # define wavelength and flux as separate arrays
        catlines_all_wave = catlines[:, 0]
        catlines_all_flux = catlines[:, 1]
        mode = refine_wavecalib_mode
    elif refine_wavecalib_mode in [11, 12]:    # OH lines
        dumdata = pkgutil.get_data(
            'emirdrp.instrument.configs',
            'Oliva_etal_2013.dat'
        )
        oh_lines_tmpfile = StringIO(dumdata.decode('utf8'))
        catlines = np.genfromtxt(oh_lines_tmpfile)
        # define wavelength and flux as separate arrays
        catlines_all_wave = np.concatenate((catlines[:, 1], catlines[:, 0]))
        catlines_all_flux = np.concatenate((catlines[:, 2], catlines[:, 2]))
        mode = refine_wavecalib_mode - 10
    else:
        raise ValueError('Unexpected mode={}'.format(refine_wavecalib_mode))

    # initialize output
    refined_rectwv_coeff = deepcopy(rectwv_coeff)

    logger.info('Computing median spectrum')
    # compute median spectrum and normalize it
    sp_median = median_slitlets_rectified(
        input_image,
        mode=2,
        minimum_slitlet_width_mm=minimum_slitlet_width_mm,
        maximum_slitlet_width_mm=maximum_slitlet_width_mm
    )[0].data
    sp_median /= sp_median.max()

    # determine minimum and maximum useful wavelength
    jmin, jmax = find_pix_borders(sp_median, 0)
    naxis1 = main_header['naxis1']
    naxis2 = main_header['naxis2']
    crpix1 = main_header['crpix1']
    crval1 = main_header['crval1']
    cdelt1 = main_header['cdelt1']
    xwave = crval1 + (np.arange(naxis1) + 1.0 - crpix1) * cdelt1
    if grism_name == 'LR':
        wv_parameters = set_wv_parameters(filter_name, grism_name)
        wave_min = wv_parameters['wvmin_useful']
        wave_max = wv_parameters['wvmax_useful']
    else:
        wave_min = crval1 + (jmin + 1 - crpix1) * cdelt1
        wave_max = crval1 + (jmax + 1 - crpix1) * cdelt1
    logger.info('Setting wave_min to {}'.format(wave_min))
    logger.info('Setting wave_max to {}'.format(wave_max))

    # extract subset of catalogue lines within current wavelength range
    lok1 = catlines_all_wave >= wave_min
    lok2 = catlines_all_wave <= wave_max
    catlines_reference_wave = catlines_all_wave[lok1*lok2]
    catlines_reference_flux = catlines_all_flux[lok1*lok2]
    catlines_reference_flux /= catlines_reference_flux.max()

    # estimate sigma to broaden catalogue lines
    csu_config = CsuConfiguration.define_from_header(main_header)
    # segregate slitlets
    list_useful_slitlets = csu_config.widths_in_range_mm(
        minwidth=minimum_slitlet_width_mm,
        maxwidth=maximum_slitlet_width_mm
    )
    # remove missing slitlets
    if len(refined_rectwv_coeff.missing_slitlets) > 0:
        for iremove in refined_rectwv_coeff.missing_slitlets:
            if iremove in list_useful_slitlets:
                list_useful_slitlets.remove(iremove)

    list_not_useful_slitlets = [i for i in list(range(1, EMIR_NBARS + 1))
                                if i not in list_useful_slitlets]
    logger.info('list of useful slitlets: {}'.format(
        list_useful_slitlets))
    logger.info('list of not useful slitlets: {}'.format(
        list_not_useful_slitlets))
    tempwidths = np.array([csu_config.csu_bar_slit_width(islitlet)
                           for islitlet in list_useful_slitlets])
    widths_summary = summary(tempwidths)
    logger.info('Statistics of useful slitlet widths (mm):')
    logger.info('- npoints....: {0:d}'.format(widths_summary['npoints']))
    logger.info('- mean.......: {0:7.3f}'.format(widths_summary['mean']))
    logger.info('- median.....: {0:7.3f}'.format(widths_summary['median']))
    logger.info('- std........: {0:7.3f}'.format(widths_summary['std']))
    logger.info('- robust_std.: {0:7.3f}'.format(widths_summary['robust_std']))
    # empirical transformation of slit width (mm) to pixels
    sigma_broadening = cdelt1 * widths_summary['median']

    # convolve location of catalogue lines to generate expected spectrum
    xwave_reference, sp_reference = convolve_comb_lines(
        catlines_reference_wave, catlines_reference_flux, sigma_broadening,
        crpix1, crval1, cdelt1, naxis1
    )
    sp_reference /= sp_reference.max()

    # generate image2d with expected lines
    image2d_expected_lines = np.tile(sp_reference, (naxis2, 1))
    hdu = fits.PrimaryHDU(data=image2d_expected_lines, header=main_header)
    expected_cat_image = fits.HDUList([hdu])

    if (abs(debugplot) % 10 != 0) or (pdf is not None):
        ax = ximplotxy(xwave, sp_median, 'C1-',
                       xlabel='Wavelength (Angstroms, in vacuum)',
                       ylabel='Normalized number of counts',
                       title='Median spectrum',
                       label='observed spectrum', show=False)
        # overplot reference catalogue lines
        ax.stem(catlines_reference_wave, catlines_reference_flux, 'C4-',
                markerfmt=' ', basefmt='C4-', label='tabulated lines')
        # overplot convolved reference lines
        ax.plot(xwave_reference, sp_reference, 'C0-',
                label='expected spectrum')
        ax.legend()
        if pdf is not None:
            pdf.savefig()
        else:
            pause_debugplot(debugplot=debugplot, pltshow=True)

    # compute baseline signal in sp_median
    baseline = np.percentile(sp_median[sp_median > 0], q=10)
    if (abs(debugplot) % 10 != 0) or (pdf is not None):
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.hist(sp_median, bins=1000, log=True)
        ax.set_xlabel('Normalized number of counts')
        ax.set_ylabel('Number of pixels')
        ax.set_title('Median spectrum')
        ax.axvline(float(baseline), linestyle='--', color='grey')
        if pdf is not None:
            pdf.savefig()
        else:
            geometry = (0, 0, 640, 480)
            set_window_geometry(geometry)
            plt.show()
    # subtract baseline to sp_median (only pixels with signal above zero)
    lok = np.where(sp_median > 0)
    sp_median[lok] -= baseline

    # compute global offset through periodic correlation
    logger.info('Computing global offset')
    global_offset, fpeak = periodic_corr1d(
        sp_reference=sp_reference,
        sp_offset=sp_median,
        fminmax=None,
        naround_zero=50,
        plottitle='Median spectrum (cross-correlation)',
        pdf=pdf,
        debugplot=debugplot
    )
    logger.info('Global offset: {} pixels'.format(-global_offset))

    missing_slitlets = rectwv_coeff.missing_slitlets

    if mode == 1:
        # apply computed offset to obtain refined_rectwv_coeff_global
        for islitlet in range(1, EMIR_NBARS + 1):
            if islitlet not in missing_slitlets:
                i = islitlet - 1
                dumdict = refined_rectwv_coeff.contents[i]
                dumdict['wpoly_coeff'][0] -= global_offset*cdelt1

    elif mode == 2:
        # compute individual offset for each slitlet
        logger.info('Computing individual offsets')
        median_55sp = median_slitlets_rectified(input_image, mode=1)
        offset_array = np.zeros(EMIR_NBARS)
        xplot = []
        yplot = []
        xplot_skipped = []
        yplot_skipped = []
        cout = '0'
        for islitlet in range(1, EMIR_NBARS + 1):
            if islitlet in list_useful_slitlets:
                i = islitlet - 1
                sp_median = median_55sp[0].data[i, :]
                lok = np.where(sp_median > 0)
                if np.any(lok):
                    baseline = np.percentile(sp_median[lok], q=10)
                    sp_median[lok] -= baseline
                    sp_median /= sp_median.max()
                    offset_array[i], fpeak = periodic_corr1d(
                        sp_reference=sp_reference,
                        sp_offset=median_55sp[0].data[i, :],
                        fminmax=None,
                        naround_zero=50,
                        plottitle='slitlet #{0} (cross-correlation)'.format(
                            islitlet),
                        pdf=pdf,
                        debugplot=debugplot
                    )
                else:
                    offset_array[i] = 0.0
                dumdict = refined_rectwv_coeff.contents[i]
                dumdict['wpoly_coeff'][0] -= offset_array[i]*cdelt1
                xplot.append(islitlet)
                yplot.append(-offset_array[i])
                # second correction
                wpoly_coeff_refined = check_wlcalib_sp(
                    sp=median_55sp[0].data[i, :],
                    crpix1=crpix1,
                    crval1=crval1-offset_array[i]*cdelt1,
                    cdelt1=cdelt1,
                    wv_master=catlines_reference_wave,
                    coeff_ini=dumdict['wpoly_coeff'],
                    naxis1_ini=EMIR_NAXIS1,
                    title='slitlet #{0} (after applying offset)'.format(
                        islitlet),
                    ylogscale=False,
                    pdf=pdf,
                    debugplot=debugplot
                )
                dumdict['wpoly_coeff'] = wpoly_coeff_refined
                cout += '.'

            else:
                xplot_skipped.append(islitlet)
                yplot_skipped.append(0)
                cout += 'i'

            if islitlet % 10 == 0:
                if cout != 'i':
                    cout = str(islitlet // 10)

            logger.info(cout)

        # show offsets with opposite sign
        stat_summary = summary(np.array(yplot))
        logger.info('Statistics of individual slitlet offsets (pixels):')
        logger.info('- npoints....: {0:d}'.format(stat_summary['npoints']))
        logger.info('- mean.......: {0:7.3f}'.format(stat_summary['mean']))
        logger.info('- median.....: {0:7.3f}'.format(stat_summary['median']))
        logger.info('- std........: {0:7.3f}'.format(stat_summary['std']))
        logger.info('- robust_std.: {0:7.3f}'.format(stat_summary[
                                                        'robust_std']))
        if (abs(debugplot) % 10 != 0) or (pdf is not None):
            ax = ximplotxy(xplot, yplot,
                           linestyle='', marker='o', color='C0',
                           xlabel='slitlet number',
                           ylabel='-offset (pixels) = offset to be applied',
                           title='cross-correlation result',
                           show=False, **{'label': 'individual slitlets'})
            if len(xplot_skipped) > 0:
                ax.plot(xplot_skipped, yplot_skipped, 'mx')
            ax.axhline(-global_offset, linestyle='--', color='C1',
                       label='global offset')
            ax.legend()
            if pdf is not None:
                pdf.savefig()
            else:
                pause_debugplot(debugplot=debugplot, pltshow=True)
    else:
        raise ValueError('Unexpected mode={}'.format(mode))

    # close output PDF file
    if pdf is not None:
        pdf.close()

    # return result
    return refined_rectwv_coeff, expected_cat_image
Esempio n. 8
0
def main(args=None):
    # parse command-line options
    parser = argparse.ArgumentParser(
        description='description: compute pixel-to-pixel flatfield'
    )

    # required arguments
    parser.add_argument("fitsfile",
                        help="Input FITS file (flat ON-OFF)",
                        type=argparse.FileType('rb'))
    parser.add_argument("--rectwv_coeff", required=True,
                        help="Input JSON file with rectification and "
                             "wavelength calibration coefficients",
                        type=argparse.FileType('rt'))
    parser.add_argument("--minimum_slitlet_width_mm", required=True,
                        help="Minimum slitlet width in mm",
                        type=float)
    parser.add_argument("--maximum_slitlet_width_mm", required=True,
                        help="Maximum slitlet width in mm",
                        type=float)
    parser.add_argument("--minimum_fraction", required=True,
                        help="Minimum allowed flatfielding value",
                        type=float, default=0.01)
    parser.add_argument("--minimum_value_in_output",
                        help="Minimum value allowed in output file: pixels "
                             "below this value are set to 1.0 (default=0.01)",
                        type=float, default=0.01)
    parser.add_argument("--maximum_value_in_output",
                        help="Maximum value allowed in output file: pixels "
                             "above this value are set to 1.0 (default=10.0)",
                        type=float, default=10.0)
    parser.add_argument("--nwindow_median",
                        help="Window size to smooth median spectrum in the "
                             "spectral direction",
                        type=int)
    parser.add_argument("--outfile", required=True,
                        help="Output FITS file",
                        type=lambda x: arg_file_is_new(parser, x, mode='wb'))

    # optional arguments
    parser.add_argument("--delta_global_integer_offset_x_pix",
                        help="Delta global integer offset in the X direction "
                             "(default=0)",
                        default=0, type=int)
    parser.add_argument("--delta_global_integer_offset_y_pix",
                        help="Delta global integer offset in the Y direction "
                             "(default=0)",
                        default=0, type=int)
    parser.add_argument("--resampling",
                        help="Resampling method: 1 -> nearest neighbor, "
                             "2 -> linear interpolation (default)",
                        default=2, type=int,
                        choices=(1, 2))
    parser.add_argument("--ignore_DTUconf",
                        help="Ignore DTU configurations differences between "
                             "model and input image",
                        action="store_true")
    parser.add_argument("--debugplot",
                        help="Integer indicating plotting & debugging options"
                             " (default=0)",
                        default=0, type=int,
                        choices=DEBUGPLOT_CODES)
    parser.add_argument("--echo",
                        help="Display full command line",
                        action="store_true")
    args = parser.parse_args(args)

    if args.echo:
        print('\033[1m\033[31m% ' + ' '.join(sys.argv) + '\033[0m\n')

    # This code is obsolete
    raise ValueError('This code is obsolete: use recipe in '
                     'emirdrp/recipes/spec/flatpix2pix.py')

    # read calibration structure from JSON file
    rectwv_coeff = RectWaveCoeff._datatype_load(args.rectwv_coeff.name)

    # modify (when requested) global offsets
    rectwv_coeff.global_integer_offset_x_pix += \
        args.delta_global_integer_offset_x_pix
    rectwv_coeff.global_integer_offset_y_pix += \
        args.delta_global_integer_offset_y_pix

    # read FITS image and its corresponding header
    hdulist = fits.open(args.fitsfile)
    header = hdulist[0].header
    image2d = hdulist[0].data
    hdulist.close()

    # apply global offsets
    image2d = apply_integer_offsets(
        image2d=image2d,
        offx=rectwv_coeff.global_integer_offset_x_pix,
        offy=rectwv_coeff.global_integer_offset_y_pix
    )

    # protections
    naxis2, naxis1 = image2d.shape
    if naxis1 != header['naxis1'] or naxis2 != header['naxis2']:
        print('>>> NAXIS1:', naxis1)
        print('>>> NAXIS2:', naxis2)
        raise ValueError('Something is wrong with NAXIS1 and/or NAXIS2')
    if abs(args.debugplot) >= 10:
        print('>>> NAXIS1:', naxis1)
        print('>>> NAXIS2:', naxis2)

    # check that the input FITS file grism and filter match
    filter_name = header['filter']
    if filter_name != rectwv_coeff.tags['filter']:
        raise ValueError("Filter name does not match!")
    grism_name = header['grism']
    if grism_name != rectwv_coeff.tags['grism']:
        raise ValueError("Filter name does not match!")
    if abs(args.debugplot) >= 10:
        print('>>> grism.......:', grism_name)
        print('>>> filter......:', filter_name)

    # check that the DTU configurations are compatible
    dtu_conf_fitsfile = DtuConfiguration.define_from_fits(args.fitsfile)
    dtu_conf_jsonfile = DtuConfiguration.define_from_dictionary(
        rectwv_coeff.meta_info['dtu_configuration'])
    if dtu_conf_fitsfile != dtu_conf_jsonfile:
        print('DTU configuration (FITS file):\n\t', dtu_conf_fitsfile)
        print('DTU configuration (JSON file):\n\t', dtu_conf_jsonfile)
        if args.ignore_DTUconf:
            print('WARNING: DTU configuration differences found!')
        else:
            raise ValueError('DTU configurations do not match')
    else:
        if abs(args.debugplot) >= 10:
            print('>>> DTU Configuration match!')
            print(dtu_conf_fitsfile)

    # load CSU configuration
    csu_conf_fitsfile = CsuConfiguration.define_from_fits(args.fitsfile)
    if abs(args.debugplot) >= 10:
        print(csu_conf_fitsfile)

    # valid slitlet numbers
    list_valid_islitlets = list(range(1, EMIR_NBARS + 1))
    for idel in rectwv_coeff.missing_slitlets:
        print('-> Removing slitlet (not defined):', idel)
        list_valid_islitlets.remove(idel)
    # filter out slitlets with widths outside valid range
    list_outside_valid_width = []
    for islitlet in list_valid_islitlets:
        slitwidth = csu_conf_fitsfile.csu_bar_slit_width(islitlet)
        if (slitwidth < args.minimum_slitlet_width_mm) or \
                (slitwidth > args.maximum_slitlet_width_mm):
            list_outside_valid_width.append(islitlet)
            print('-> Removing slitlet (invalid width):', islitlet)
    if len(list_outside_valid_width) > 0:
        for idel in list_outside_valid_width:
            list_valid_islitlets.remove(idel)
    print('>>> valid slitlet numbers:\n', list_valid_islitlets)

    # ---

    # compute and store median spectrum (and masked region) for each
    # individual slitlet
    image2d_sp_median = np.zeros((EMIR_NBARS, EMIR_NAXIS1))
    image2d_sp_mask = np.zeros((EMIR_NBARS, EMIR_NAXIS1), dtype=bool)
    for islitlet in list(range(1, EMIR_NBARS + 1)):
        if islitlet in list_valid_islitlets:
            if args.debugplot == 0:
                islitlet_progress(islitlet, EMIR_NBARS, ignore=False)
            # define Slitlet2D object
            slt = Slitlet2D(islitlet=islitlet,
                            rectwv_coeff=rectwv_coeff,
                            debugplot=args.debugplot)

            if abs(args.debugplot) >= 10:
                print(slt)

            # extract (distorted) slitlet from the initial image
            slitlet2d = slt.extract_slitlet2d(
                image_2k2k=image2d,
                subtitle='original image'
            )

            # rectify slitlet
            slitlet2d_rect = slt.rectify(
                slitlet2d=slitlet2d,
                resampling=args.resampling,
                subtitle='original rectified'
            )
            naxis2_slitlet2d, naxis1_slitlet2d = slitlet2d_rect.shape

            if naxis1_slitlet2d != EMIR_NAXIS1:
                print('naxis1_slitlet2d: ', naxis1_slitlet2d)
                print('EMIR_NAXIS1.....: ', EMIR_NAXIS1)
                raise ValueError("Unexpected naxis1_slitlet2d")

            sp_mask = np.zeros(naxis1_slitlet2d, dtype=bool)

            # for grism LR set to zero data beyond useful wavelength range
            if grism_name == 'LR':
                wv_parameters = set_wv_parameters(filter_name, grism_name)
                x_pix = np.arange(1, naxis1_slitlet2d + 1)
                wl_pix = polyval(x_pix, slt.wpoly)
                lremove = wl_pix < wv_parameters['wvmin_useful']
                sp_mask[lremove] = True
                slitlet2d_rect[:, lremove] = 0.0
                lremove = wl_pix > wv_parameters['wvmax_useful']
                slitlet2d_rect[:, lremove] = 0.0
                sp_mask[lremove] = True

            # get useful slitlet region (use boundaries instead of frontiers;
            # note that the nscan_minmax_frontiers() works well independently
            # of using frontiers of boundaries as arguments)
            nscan_min, nscan_max = nscan_minmax_frontiers(
                slt.y0_reference_lower,
                slt.y0_reference_upper,
                resize=False
            )
            ii1 = nscan_min - slt.bb_ns1_orig
            ii2 = nscan_max - slt.bb_ns1_orig + 1

            # median spectrum
            sp_collapsed = np.median(slitlet2d_rect[ii1:(ii2 + 1), :], axis=0)

            # smooth median spectrum along the spectral direction
            sp_median = ndimage.median_filter(
                sp_collapsed,
                args.nwindow_median,
                mode='nearest'
            )

            """
                nremove = 5
                spl = AdaptiveLSQUnivariateSpline(
                    x=xaxis1[nremove:-nremove],
                    y=sp_collapsed[nremove:-nremove],
                    t=11,
                    adaptive=True
                )
                xknots = spl.get_knots()
                yknots = spl(xknots)
                sp_median = spl(xaxis1)

                # compute rms within each knot interval
                nknots = len(xknots)
                rms_array = np.zeros(nknots - 1, dtype=float)
                for iknot in range(nknots - 1):
                    residuals = []
                    for xdum, ydum, yydum in \
                            zip(xaxis1, sp_collapsed, sp_median):
                        if xknots[iknot] <= xdum <= xknots[iknot + 1]:
                            residuals.append(abs(ydum - yydum))
                    if len(residuals) > 5:
                        rms_array[iknot] = np.std(residuals)
                    else:
                        rms_array[iknot] = 0

                # determine in which knot interval falls each pixel
                iknot_array = np.zeros(len(xaxis1), dtype=int)
                for idum, xdum in enumerate(xaxis1):
                    for iknot in range(nknots - 1):
                        if xknots[iknot] <= xdum <= xknots[iknot + 1]:
                            iknot_array[idum] = iknot

                # compute new fit removing deviant points (with fixed knots)
                xnewfit = []
                ynewfit = []
                for idum in range(len(xaxis1)):
                    delta_sp = abs(sp_collapsed[idum] - sp_median[idum])
                    rms_tmp = rms_array[iknot_array[idum]]
                    if idum == 0 or idum == (len(xaxis1) - 1):
                        lok = True
                    elif rms_tmp > 0:
                        if delta_sp < 3.0 * rms_tmp:
                            lok = True
                        else:
                            lok = False
                    else:
                        lok = True
                    if lok:
                        xnewfit.append(xaxis1[idum])
                        ynewfit.append(sp_collapsed[idum])
                nremove = 5
                splnew = AdaptiveLSQUnivariateSpline(
                    x=xnewfit[nremove:-nremove],
                    y=ynewfit[nremove:-nremove],
                    t=xknots[1:-1],
                    adaptive=False
                )
                sp_median = splnew(xaxis1)
            """

            ymax_spmedian = sp_median.max()
            y_threshold = ymax_spmedian * args.minimum_fraction
            lremove = np.where(sp_median < y_threshold)
            sp_median[lremove] = 0.0
            sp_mask[lremove] = True

            image2d_sp_median[islitlet - 1, :] = sp_median
            image2d_sp_mask[islitlet - 1, :] = sp_mask

            if abs(args.debugplot) % 10 != 0:
                xaxis1 = np.arange(1, naxis1_slitlet2d + 1)
                title = 'Slitlet#' + str(islitlet) + ' (median spectrum)'
                ax = ximplotxy(xaxis1, sp_collapsed,
                               title=title,
                               show=False, **{'label' : 'collapsed spectrum'})
                ax.plot(xaxis1, sp_median, label='fitted spectrum')
                ax.plot([1, naxis1_slitlet2d], 2*[y_threshold],
                        label='threshold')
                # ax.plot(xknots, yknots, 'o', label='knots')
                ax.legend()
                ax.set_ylim(-0.05*ymax_spmedian, 1.05*ymax_spmedian)
                pause_debugplot(args.debugplot,
                                pltshow=True, tight_layout=True)
        else:
            if args.debugplot == 0:
                islitlet_progress(islitlet, EMIR_NBARS, ignore=True)

    # ToDo: compute "average" spectrum for each pseudo-longslit, scaling
    #       with the median signal in each slitlet; derive a particular
    #       spectrum for each slitlet (scaling properly)

    image2d_sp_median_masked = np.ma.masked_array(
        image2d_sp_median,
        mask=image2d_sp_mask
    )
    ycut_median = np.ma.median(image2d_sp_median_masked, axis=1).data
    ycut_median_2d = np.repeat(ycut_median, EMIR_NAXIS1).reshape(
        EMIR_NBARS, EMIR_NAXIS1)
    image2d_sp_median_eq = image2d_sp_median_masked / ycut_median_2d
    image2d_sp_median_eq = image2d_sp_median_eq.data

    if True:
        ximshow(image2d_sp_median, title='sp_median', debugplot=12)
        ximplotxy(np.arange(1, EMIR_NBARS + 1), ycut_median, 'ro',
                  title='median value of each spectrum', debugplot=12)
        ximshow(image2d_sp_median_eq, title='sp_median_eq', debugplot=12)

    csu_conf_fitsfile.display_pseudo_longslits(
        list_valid_slitlets=list_valid_islitlets)
    dict_longslits = csu_conf_fitsfile.pseudo_longslits()

    # compute median spectrum for each longslit and insert (properly
    # scaled) that spectrum in each slitlet belonging to that longslit
    image2d_sp_median_longslit = np.zeros((EMIR_NBARS, EMIR_NAXIS1))
    islitlet = 1
    loop = True
    while loop:
        if islitlet in list_valid_islitlets:
            imin = dict_longslits[islitlet].imin()
            imax = dict_longslits[islitlet].imax()
            print('--> imin, imax: ', imin, imax)
            sp_median_longslit = np.median(
                image2d_sp_median_eq[(imin - 1):imax, :], axis=0)
            for i in range(imin, imax+1):
                print('----> i: ', i)
                image2d_sp_median_longslit[(i - 1), :] = \
                    sp_median_longslit * ycut_median[i - 1]
            islitlet = imax
        else:
            print('--> ignoring: ', islitlet)
        if islitlet == EMIR_NBARS:
            loop = False
        else:
            islitlet += 1
    if True:
        ximshow(image2d_sp_median_longslit, debugplot=12)

    # initialize rectified image
    image2d_flatfielded = np.zeros((EMIR_NAXIS2, EMIR_NAXIS1))

    # main loop
    for islitlet in list(range(1, EMIR_NBARS + 1)):
        if islitlet in list_valid_islitlets:
            if args.debugplot == 0:
                islitlet_progress(islitlet, EMIR_NBARS, ignore=False)
            # define Slitlet2D object
            slt = Slitlet2D(islitlet=islitlet,
                            rectwv_coeff=rectwv_coeff,
                            debugplot=args.debugplot)

            # extract (distorted) slitlet from the initial image
            slitlet2d = slt.extract_slitlet2d(
                image_2k2k=image2d,
                subtitle='original image'
            )

            # rectify slitlet
            slitlet2d_rect = slt.rectify(
                slitlet2d=slitlet2d,
                resampling=args.resampling,
                subtitle='original rectified'
            )
            naxis2_slitlet2d, naxis1_slitlet2d = slitlet2d_rect.shape

            sp_median = image2d_sp_median_longslit[islitlet - 1, :]

            # generate rectified slitlet region filled with the median spectrum
            slitlet2d_rect_spmedian = np.tile(sp_median, (naxis2_slitlet2d, 1))
            if abs(args.debugplot) > 10:
                slt.ximshow_rectified(
                    slitlet2d_rect=slitlet2d_rect_spmedian,
                    subtitle='rectified, filled with median spectrum'
                )

            # unrectified image
            slitlet2d_unrect_spmedian = slt.rectify(
                slitlet2d=slitlet2d_rect_spmedian,
                resampling=args.resampling,
                inverse=True,
                subtitle='unrectified, filled with median spectrum'
            )

            # normalize initial slitlet image (avoid division by zero)
            slitlet2d_norm = np.zeros_like(slitlet2d)
            for j in range(naxis1_slitlet2d):
                for i in range(naxis2_slitlet2d):
                    den = slitlet2d_unrect_spmedian[i, j]
                    if den == 0:
                        slitlet2d_norm[i, j] = 1.0
                    else:
                        slitlet2d_norm[i, j] = slitlet2d[i, j] / den

            if abs(args.debugplot) > 10:
                slt.ximshow_unrectified(
                    slitlet2d=slitlet2d_norm,
                    subtitle='unrectified, pixel-to-pixel'
                )

            # check for pseudo-longslit with previous slitlet
            if islitlet > 1:
                if (islitlet - 1) in list_valid_islitlets:
                    c1 = csu_conf_fitsfile.csu_bar_slit_center(islitlet - 1)
                    w1 = csu_conf_fitsfile.csu_bar_slit_width(islitlet - 1)
                    c2 = csu_conf_fitsfile.csu_bar_slit_center(islitlet)
                    w2 = csu_conf_fitsfile.csu_bar_slit_width(islitlet)
                    if abs(w1-w2)/w1 < 0.25:
                        wmean = (w1 + w2) / 2.0
                        if abs(c1 - c2) < wmean/4.0:
                            same_slitlet_below = True
                        else:
                            same_slitlet_below = False
                    else:
                        same_slitlet_below = False
                else:
                    same_slitlet_below = False
            else:
                same_slitlet_below = False

            # check for pseudo-longslit with next slitlet
            if islitlet < EMIR_NBARS:
                if (islitlet + 1) in list_valid_islitlets:
                    c1 = csu_conf_fitsfile.csu_bar_slit_center(islitlet)
                    w1 = csu_conf_fitsfile.csu_bar_slit_width(islitlet)
                    c2 = csu_conf_fitsfile.csu_bar_slit_center(islitlet + 1)
                    w2 = csu_conf_fitsfile.csu_bar_slit_width(islitlet + 1)
                    if abs(w1-w2)/w1 < 0.25:
                        wmean = (w1 + w2) / 2.0
                        if abs(c1 - c2) < wmean/4.0:
                            same_slitlet_above = True
                        else:
                            same_slitlet_above = False
                    else:
                        same_slitlet_above = False
                else:
                    same_slitlet_above = False
            else:
                same_slitlet_above = False

            for j in range(EMIR_NAXIS1):
                xchannel = j + 1
                y0_lower = slt.list_frontiers[0](xchannel)
                y0_upper = slt.list_frontiers[1](xchannel)
                n1, n2 = nscan_minmax_frontiers(y0_frontier_lower=y0_lower,
                                                y0_frontier_upper=y0_upper,
                                                resize=True)
                # note that n1 and n2 are scans (ranging from 1 to NAXIS2)
                nn1 = n1 - slt.bb_ns1_orig + 1
                nn2 = n2 - slt.bb_ns1_orig + 1
                image2d_flatfielded[(n1 - 1):n2, j] = \
                    slitlet2d_norm[(nn1 - 1):nn2, j]

                # force to 1.0 region around frontiers
                if not same_slitlet_below:
                    image2d_flatfielded[(n1 - 1):(n1 + 2), j] = 1
                if not same_slitlet_above:
                    image2d_flatfielded[(n2 - 5):n2, j] = 1
        else:
            if args.debugplot == 0:
                islitlet_progress(islitlet, EMIR_NBARS, ignore=True)

    if args.debugplot == 0:
        print('OK!')

    # restore global offsets
    image2d_flatfielded = apply_integer_offsets(
        image2d=image2d_flatfielded ,
        offx=-rectwv_coeff.global_integer_offset_x_pix,
        offy=-rectwv_coeff.global_integer_offset_y_pix
    )

    # set pixels below minimum value to 1.0
    filtered = np.where(image2d_flatfielded < args.minimum_value_in_output)
    image2d_flatfielded[filtered] = 1.0

    # set pixels above maximum value to 1.0
    filtered = np.where(image2d_flatfielded > args.maximum_value_in_output)
    image2d_flatfielded[filtered] = 1.0

    # save output file
    save_ndarray_to_fits(
        array=image2d_flatfielded,
        file_name=args.outfile,
        main_header=header,
        overwrite=True
    )
    print('>>> Saving file ' + args.outfile.name)