Esempio n. 1
0
def clean_defects(image2d, debugplot=0):
    """Interpolate problematic image region"""

    # define output image
    image2d_clean = image2d.copy()
    # define region containing the "curved piece of hair"
    j1, j2 = 980, 1050  # channel region (X direction)
    i1, i2 = 750, 860  # scan region (Y direction)
    subimage = image2d_clean[i1:(i2 + 1), j1:(j2 + 1)].copy()
    if abs(debugplot) % 10 != 0:
        ximshow(subimage,
                title='original image',
                first_pixel=(j1, i1),
                debugplot=debugplot)
    # median filter in Y
    subimage_filtered = ndimage.median_filter(subimage, size=(5, 1))
    if abs(debugplot) % 10 != 0:
        ximshow(subimage_filtered,
                title='median-filtered image',
                first_pixel=(j1, i1),
                debugplot=debugplot)
    # fit image
    xfit = np.arange(j1, j2 + 1, 1, dtype=float)
    yfit = subimage_filtered.transpose()
    coef = fit_theil_sen(xfit, yfit)
    subimage_fitted = np.zeros_like(subimage)
    nscans = i2 - i1 + 1
    for i in range(nscans):
        subimage_fitted[i, :] = coef[0, i] + coef[1, i] * xfit
    if abs(debugplot) % 10 != 0:
        ximshow(subimage_fitted,
                title='fitted image',
                first_pixel=(j1, i1),
                debugplot=debugplot)
    # filtered_image/fitted_image
    subimage_ratio = subimage_filtered / subimage_fitted
    if abs(debugplot) % 10 != 0:
        ximshow(subimage_ratio,
                title='image ratio',
                first_pixel=(j1, i1),
                debugplot=debugplot)
    # mark bad pixels
    badpix = np.where(subimage_ratio < 0.90)
    # replace bad pixels by fitted image
    subimage[badpix] = subimage_fitted[badpix]
    if abs(debugplot) % 10 != 0:
        ximshow(subimage,
                title='cleaned image',
                first_pixel=(j1, i1),
                debugplot=debugplot)
    # replace interpolated region in original image
    image2d_clean[i1:(i2 + 1), j1:(j2 + 1)] = subimage

    return image2d_clean
Esempio n. 2
0
    def ximshow_rectified(self, slitlet2d_rect, subtitle=None):
        """Display rectified image with spectrails and frontiers.

        Parameters
        ----------
        slitlet2d_rect : numpy array
            Array containing the rectified slitlet image.
        subtitle : string, optional
            Subtitle for plot.

        """

        title = "Slitlet#" + str(self.islitlet)
        if subtitle is not None:
            title += ' (' + subtitle + ')'
        ax = ximshow(slitlet2d_rect, title=title,
                     first_pixel=(self.bb_nc1_orig, self.bb_ns1_orig),
                     show=False)
        # grid with fitted transformation: spectrum trails
        xx = np.arange(0, self.bb_nc2_orig - self.bb_nc1_orig + 1,
                       dtype=np.float)
        for spectrail in self.list_spectrails:
            yy0 = self.corr_yrect_a + \
                  self.corr_yrect_b * spectrail(self.x0_reference)
            yy = np.tile([yy0 - self.bb_ns1_orig], xx.size)
            ax.plot(xx + self.bb_nc1_orig, yy + self.bb_ns1_orig, "b")
        for spectrail in self.list_frontiers:
            yy0 = self.corr_yrect_a +\
                  self.corr_yrect_b * spectrail(self.x0_reference)
            yy = np.tile([yy0 - self.bb_ns1_orig], xx.size)
            ax.plot(xx + self.bb_nc1_orig, yy + self.bb_ns1_orig, "b:")
        # show plot
        pause_debugplot(self.debugplot, pltshow=True)
Esempio n. 3
0
    def ximshow_unrectified(self, slitlet2d, subtitle=None):
        """Display unrectified image with spectrails and frontiers.

        Parameters
        ----------
        slitlet2d : numpy array
            Array containing the unrectified slitlet image.
        subtitle : string, optional
            Subtitle for plot.

        """

        title = "Slitlet#" + str(self.islitlet)
        if subtitle is not None:
            title += ' (' + subtitle + ')'
        ax = ximshow(slitlet2d, title=title,
                     first_pixel=(self.bb_nc1_orig, self.bb_ns1_orig),
                     show=False)
        xdum = np.linspace(1, EMIR_NAXIS1, num=EMIR_NAXIS1)
        ylower = self.list_spectrails[0](xdum)
        ax.plot(xdum, ylower, 'b-')
        ymiddle = self.list_spectrails[1](xdum)
        ax.plot(xdum, ymiddle, 'b--')
        yupper = self.list_spectrails[2](xdum)
        ax.plot(xdum, yupper, 'b-')
        ylower_frontier = self.list_frontiers[0](xdum)
        ax.plot(xdum, ylower_frontier, 'b:')
        yupper_frontier = self.list_frontiers[1](xdum)
        ax.plot(xdum, yupper_frontier, 'b:')
        if title is not None:
            ax.set_title(title)
        pause_debugplot(debugplot=self.debugplot, pltshow=True)
Esempio n. 4
0
    def ximshow_unrectified(self, slitlet2d):
        """Display unrectified image with spectrails and frontiers.

        Parameters
        ----------
        slitlet2d : numpy array
            Array containing the unrectified slitlet image.

        """

        title = "Slitlet#" + str(self.islitlet)
        ax = ximshow(slitlet2d, title=title,
                     first_pixel=(self.bb_nc1_orig, self.bb_ns1_orig),
                     show=False)
        xdum = np.linspace(1, EMIR_NAXIS1, num=EMIR_NAXIS1)
        ylower = self.list_spectrails[0](xdum)
        ax.plot(xdum, ylower, 'b-')
        ymiddle = self.list_spectrails[1](xdum)
        ax.plot(xdum, ymiddle, 'b--')
        yupper = self.list_spectrails[2](xdum)
        ax.plot(xdum, yupper, 'b-')
        ylower_frontier = self.list_frontiers[0](xdum)
        ax.plot(xdum, ylower_frontier, 'b:')
        yupper_frontier = self.list_frontiers[1](xdum)
        ax.plot(xdum, yupper_frontier, 'b:')
        pause_debugplot(debugplot=self.debugplot, pltshow=True)
Esempio n. 5
0
    def ximshow_unrectified(self, slitlet2d):
        """Display unrectified image with spectrails and frontiers.

        Parameters
        ----------
        slitlet2d : numpy array
            Array containing the unrectified slitlet image.

        """

        title = "Slitlet#" + str(self.islitlet)
        ax = ximshow(slitlet2d,
                     title=title,
                     first_pixel=(self.bb_nc1_orig, self.bb_ns1_orig),
                     show=False)
        xdum = np.linspace(1, EMIR_NAXIS1, num=EMIR_NAXIS1)
        ylower = self.list_spectrails[0](xdum)
        ax.plot(xdum, ylower, 'b-')
        ymiddle = self.list_spectrails[1](xdum)
        ax.plot(xdum, ymiddle, 'b--')
        yupper = self.list_spectrails[2](xdum)
        ax.plot(xdum, yupper, 'b-')
        ylower_frontier = self.list_frontiers[0](xdum)
        ax.plot(xdum, ylower_frontier, 'b:')
        yupper_frontier = self.list_frontiers[1](xdum)
        ax.plot(xdum, yupper_frontier, 'b:')
        pause_debugplot(debugplot=self.debugplot, pltshow=True)
Esempio n. 6
0
def main(args=None):

    # parse command-line options
    parser = argparse.ArgumentParser()

    # positional arguments
    parser.add_argument("fitsfile",
                        help="FITS file name to be displayed",
                        type=argparse.FileType('rb'))
    parser.add_argument("--bounddict",
                        required=True,
                        help="bounddict file name",
                        type=argparse.FileType('rt'))
    parser.add_argument("--tuple_slit_numbers",
                        required=True,
                        help="Tuple n1[,n2[,step]] to define slitlet numbers")

    # optional arguments
    parser.add_argument("--echo",
                        help="Display full command line",
                        action="store_true")

    args = parser.parse_args()

    if args.echo:
        print('\033[1m\033[31mExecuting: ' + ' '.join(sys.argv) + '\033[0m\n')

    # read slitlet numbers to be computed
    tmp_str = args.tuple_slit_numbers.split(",")
    if len(tmp_str) == 3:
        if int(tmp_str[0]) < 1:
            raise ValueError("Invalid slitlet number < 1")
        if int(tmp_str[1]) > EMIR_NBARS:
            raise ValueError("Invalid slitlet number > EMIR_NBARS")
        list_slitlets = range(int(tmp_str[0]),
                              int(tmp_str[1]) + 1, int(tmp_str[2]))
    elif len(tmp_str) == 2:
        if int(tmp_str[0]) < 1:
            raise ValueError("Invalid slitlet number < 1")
        if int(tmp_str[1]) > EMIR_NBARS:
            raise ValueError("Invalid slitlet number > EMIR_NBARS")
        list_slitlets = range(int(tmp_str[0]), int(tmp_str[1]) + 1, 1)
    elif len(tmp_str) == 1:
        if int(tmp_str[0]) < 1:
            raise ValueError("Invalid slitlet number < 1")
        if int(tmp_str[0]) > EMIR_NBARS:
            raise ValueError("Invalid slitlet number > EMIR_NBARS")
        list_slitlets = [int(tmp_str[0])]
    else:
        raise ValueError("Invalid tuple for slitlet numbers")

    # read input FITS file
    hdulist = fits.open(args.fitsfile.name)
    image_header = hdulist[0].header
    image2d = hdulist[0].data
    hdulist.close()

    naxis1 = image_header['naxis1']
    naxis2 = image_header['naxis2']

    if image2d.shape != (naxis2, naxis1):
        raise ValueError("Unexpected error with NAXIS1, NAXIS2")

    # remove path from fitsfile
    sfitsfile = os.path.basename(args.fitsfile.name)

    # check that the FITS file has been obtained with EMIR
    instrument = image_header['instrume']
    if instrument != 'EMIR':
        raise ValueError("INSTRUME keyword is not 'EMIR'!")

    # read GRISM, FILTER and ROTANG from FITS header
    grism = image_header['grism']
    spfilter = image_header['filter']
    rotang = image_header['rotang']

    # display full image
    ax = ximshow(image2d=image2d,
                 title=sfitsfile + "\ngrism=" + grism + ", filter=" +
                 spfilter + ", rotang=" + str(round(rotang, 2)),
                 image_bbox=(1, naxis1, 1, naxis2),
                 show=False)

    # overplot boundaries for each slitlet
    for slitlet_number in list_slitlets:
        pol_lower_boundary, pol_upper_boundary, \
        xmin_lower, xmax_lower, xmin_upper, xmax_upper,  \
        csu_bar_slit_center = \
            get_boundaries(args.bounddict, slitlet_number)
        if (pol_lower_boundary is not None) and \
                (pol_upper_boundary is not None):
            xp = np.linspace(start=xmin_lower, stop=xmax_lower, num=1000)
            yp = pol_lower_boundary(xp)
            ax.plot(xp, yp, 'g-')
            xp = np.linspace(start=xmin_upper, stop=xmax_upper, num=1000)
            yp = pol_upper_boundary(xp)
            ax.plot(xp, yp, 'b-')
            # slitlet label
            yc_lower = pol_lower_boundary(EMIR_NAXIS1 / 2 + 0.5)
            yc_upper = pol_upper_boundary(EMIR_NAXIS1 / 2 + 0.5)
            tmpcolor = ['r', 'b'][slitlet_number % 2]
            xcsu = EMIR_NAXIS1 * csu_bar_slit_center / 341.5
            ax.text(xcsu, (yc_lower + yc_upper) / 2,
                    str(slitlet_number),
                    fontsize=10,
                    va='center',
                    ha='center',
                    bbox=dict(boxstyle="round,pad=0.1", fc="white", ec="grey"),
                    color=tmpcolor,
                    fontweight='bold',
                    backgroundcolor='white')

    # show plot
    pause_debugplot(12, pltshow=True)
Esempio n. 7
0
def median_slitlets_rectified(
        input_image,
        mode=0,
        minimum_slitlet_width_mm=EMIR_MINIMUM_SLITLET_WIDTH_MM,
        maximum_slitlet_width_mm=EMIR_MAXIMUM_SLITLET_WIDTH_MM,
        debugplot=0):
    """Compute median spectrum for each slitlet.

    Parameters
    ----------
    input_image : HDUList object
        Input 2D image.
    mode : int
        Indicate desired result:
        0 : image with the same size as the input image, with the
            median spectrum of each slitlet spanning all the spectra
            of the corresponding slitlet
        1 : image with 55 spectra, containing the median spectra of
            each slitlet
        2 : single collapsed median spectrum, using exclusively the
            useful slitlets from the input image
    minimum_slitlet_width_mm : float
        Minimum slitlet width (mm) for a valid slitlet.
    maximum_slitlet_width_mm : float
        Maximum slitlet width (mm) for a valid slitlet.
    debugplot : int
        Determines whether intermediate computations and/or plots
        are displayed. The valid codes are defined in
        numina.array.display.pause_debugplot.

    Returns
    -------
    image_median : HDUList object
        Output image.

    """

    image_header = input_image[0].header
    image2d = input_image[0].data

    # check image dimensions
    naxis2_expected = EMIR_NBARS * EMIR_NPIXPERSLIT_RECTIFIED

    naxis2, naxis1 = image2d.shape
    if naxis2 != naxis2_expected:
        raise ValueError("NAXIS2={0} should be {1}".format(
            naxis2, naxis2_expected))

    # check that the FITS file has been obtained with EMIR
    instrument = image_header['instrume']
    if instrument != 'EMIR':
        raise ValueError("INSTRUME keyword is not 'EMIR'!")

    # initialize output image
    if mode == 0:
        image2d_median = np.zeros((naxis2, naxis1))
    else:
        image2d_median = np.zeros((EMIR_NBARS, naxis1))

    # main loop
    for i in range(EMIR_NBARS):
        ns1 = i * EMIR_NPIXPERSLIT_RECTIFIED + 1
        ns2 = ns1 + EMIR_NPIXPERSLIT_RECTIFIED - 1
        sp_median = np.median(image2d[(ns1 - 1):ns2, :], axis=0)

        if mode == 0:
            image2d_median[(ns1 - 1):ns2, :] = np.tile(
                sp_median, (EMIR_NPIXPERSLIT_RECTIFIED, 1))
        else:
            image2d_median[i] = np.copy(sp_median)

    if mode == 2:
        # get CSU configuration from FITS header
        csu_config = CsuConfiguration.define_from_header(image_header)

        # define wavelength calibration parameters
        crpix1 = image_header['crpix1']
        crval1 = image_header['crval1']
        cdelt1 = image_header['cdelt1']

        # segregate slitlets
        list_useful_slitlets = csu_config.widths_in_range_mm(
            minwidth=minimum_slitlet_width_mm,
            maxwidth=maximum_slitlet_width_mm)
        list_not_useful_slitlets = [
            i for i in list(range(1, EMIR_NBARS + 1))
            if i not in list_useful_slitlets
        ]
        if abs(debugplot) != 0:
            print('>>> list_useful_slitlets....:', list_useful_slitlets)
            print('>>> list_not_useful_slitlets:', list_not_useful_slitlets)

        # define mask from array data
        mask2d, borders = define_mask_borders(image2d_median, sought_value=0)
        if abs(debugplot) % 10 != 0:
            ximshow(mask2d.astype(int),
                    z1z2=(-.2, 1.2),
                    crpix1=crpix1,
                    crval1=crval1,
                    cdelt1=cdelt1,
                    debugplot=debugplot)

        # update mask with unused slitlets
        for islitlet in list_not_useful_slitlets:
            mask2d[islitlet - 1, :] = np.array([True] * naxis1)
        if abs(debugplot) % 10 != 0:
            ximshow(mask2d.astype(int),
                    z1z2=(-.2, 1.2),
                    crpix1=crpix1,
                    crval1=crval1,
                    cdelt1=cdelt1,
                    debugplot=debugplot)

        # useful image pixels
        image2d_masked = image2d_median * (1 - mask2d.astype(int))
        if abs(debugplot) % 10 != 0:
            ximshow(image2d_masked,
                    crpix1=crpix1,
                    crval1=crval1,
                    cdelt1=cdelt1,
                    debugplot=debugplot)

        # masked image
        image2d_masked = np.ma.masked_array(image2d_median, mask=mask2d)
        # median spectrum
        image1d_median = np.ma.median(image2d_masked, axis=0).data

        image_median = fits.PrimaryHDU(data=image1d_median,
                                       header=image_header)

    else:
        image_median = fits.PrimaryHDU(data=image2d_median,
                                       header=image_header)

    return fits.HDUList([image_median])
Esempio n. 8
0
def useful_mos_xpixels(reduced_mos_data,
                       base_header,
                       vpix_region,
                       npix_removed_near_ohlines=0,
                       list_valid_wvregions=None,
                       debugplot=0):
    """Useful X-axis pixels removing +/- npixaround pixels around each OH line
    """

    # get wavelength calibration from image header
    naxis1 = base_header['naxis1']
    naxis2 = base_header['naxis2']
    crpix1 = base_header['crpix1']
    crval1 = base_header['crval1']
    cdelt1 = base_header['cdelt1']

    # check vertical region
    nsmin = int(vpix_region[0] + 0.5)
    nsmax = int(vpix_region[1] + 0.5)
    if nsmin > nsmax:
        raise ValueError('vpix_region values in wrong order')
    elif nsmin < 1 or nsmax > naxis2:
        raise ValueError('vpix_region outside valid range')

    # minimum and maximum pixels in the wavelength direction
    islitlet_min = get_islitlet(nsmin)
    ncmin = base_header['jmnslt{:02d}'.format(islitlet_min)]
    ncmax = base_header['jmxslt{:02d}'.format(islitlet_min)]
    islitlet_max = get_islitlet(nsmax)
    if islitlet_max > islitlet_min:
        for islitlet in range(islitlet_min + 1, islitlet_max + 1):
            ncmin_ = base_header['jmnslt{:02d}'.format(islitlet)]
            ncmax_ = base_header['jmnslt{:02d}'.format(islitlet)]
            ncmin = min(ncmin, ncmin_)
            ncmax = max(ncmax, ncmax_)

    # pixels within valid regions
    xisok_wvreg = np.zeros(naxis1, dtype='bool')
    if list_valid_wvregions is None:
        for ipix in range(ncmin, ncmax + 1):
            xisok_wvreg[ipix - 1] = True
    else:
        for wvregion in list_valid_wvregions:
            wvmin = float(wvregion[0])
            wvmax = float(wvregion[1])
            if wvmin > wvmax:
                raise ValueError('wvregion values in wrong order:'
                                 ' {}, {}'.format(wvmin, wvmax))
            minpix = int((wvmin - crval1) / cdelt1 + crpix1 + 0.5)
            maxpix = int((wvmax - crval1) / cdelt1 + crpix1 + 0.5)
            for ipix in range(minpix, maxpix + 1):
                if 1 <= ipix <= naxis1:
                    xisok_wvreg[ipix - 1] = True
        if np.sum(xisok_wvreg) < 1:
            raise ValueError('no valid wavelength ranges provided')

    # pixels affected by OH lines
    xisok_oh = np.ones(naxis1, dtype='bool')
    if int(npix_removed_near_ohlines) > 0:
        dumdata = pkgutil.get_data(
            'emirdrp.instrument.configs',
            'Oliva_etal_2013.dat'
        )
        oh_lines_tmpfile = StringIO(dumdata.decode('utf8'))
        catlines = np.genfromtxt(oh_lines_tmpfile)
        catlines_all_wave = np.concatenate(
            (catlines[:, 1], catlines[:, 0]))
        for waveline in catlines_all_wave:
            expected_pixel = int(
                (waveline - crval1) / cdelt1 + crpix1 + 0.5)
            minpix = expected_pixel - int(npix_removed_near_ohlines)
            maxpix = expected_pixel + int(npix_removed_near_ohlines)
            for ipix in range(minpix, maxpix + 1):
                if 1 <= ipix <= naxis1:
                    xisok_oh[ipix - 1] = False

    # pixels in valid regions not affected by OH lines
    xisok = np.logical_and(xisok_wvreg, xisok_oh)
    naxis1_effective = np.sum(xisok)
    if naxis1_effective < 1:
        raise ValueError('no valid wavelength range available after '
                         'removing OH lines')

    if abs(debugplot) in [21, 22]:
        slitlet2d = reduced_mos_data[(nsmin - 1):nsmax, :].copy()
        ximshow(slitlet2d,
                title='Rectified region',
                first_pixel=(1, nsmin),
                crval1=crval1, cdelt1=cdelt1, debugplot=debugplot)
        ax = ximshow(slitlet2d,
                     title='Rectified region\nafter blocking '
                           'removed wavelength ranges',
                     first_pixel=(1, nsmin),
                     crval1=crval1, cdelt1=cdelt1, show=False)
        for idum in range(1, naxis1 + 1):
            if not xisok_wvreg[idum - 1]:
                ax.plot([idum, idum], [nsmin, nsmax], 'g-')
        pause_debugplot(debugplot, pltshow=True)
        ax = ximshow(slitlet2d,
                     title='Rectified slitlet\nuseful regions after '
                           'removing OH lines',
                     first_pixel=(1, nsmin),
                     crval1=crval1, cdelt1=cdelt1, show=False)
        for idum in range(1, naxis1 + 1):
            if not xisok[idum - 1]:
                ax.plot([idum, idum], [nsmin, nsmax], 'm-')
        pause_debugplot(debugplot, pltshow=True)

    return xisok
def main(args=None):

    # parse command-line options
    parser = argparse.ArgumentParser(
        description='description: overplot boundary model over FITS image')

    # positional arguments
    parser.add_argument("fitsfile",
                        help="FITS file name to be displayed",
                        type=argparse.FileType('rb'))
    parser.add_argument("--rect_wpoly_MOSlibrary",
                        required=True,
                        help="Input JSON file with library of rectification "
                        "and wavelength calibration coefficients",
                        type=argparse.FileType('rt'))

    # optional arguments
    parser.add_argument("--global_integer_offset_x_pix",
                        help="Global integer offset in the X direction "
                        "(default=0)",
                        default=0,
                        type=int)
    parser.add_argument("--global_integer_offset_y_pix",
                        help="Global integer offset in the Y direction "
                        "(default=0)",
                        default=0,
                        type=int)
    parser.add_argument("--arc_lines",
                        help="Overplot arc lines",
                        action="store_true")
    parser.add_argument("--oh_lines",
                        help="Overplot OH lines",
                        action="store_true")
    parser.add_argument("--ds9_frontiers",
                        help="Output ds9 region file with slitlet frontiers",
                        type=lambda x: arg_file_is_new(parser, x))
    parser.add_argument("--ds9_boundaries",
                        help="Output ds9 region file with slitlet boundaries",
                        type=lambda x: arg_file_is_new(parser, x))
    parser.add_argument("--ds9_lines",
                        help="Output ds9 region file with arc/oh lines",
                        type=lambda x: arg_file_is_new(parser, x))
    parser.add_argument("--debugplot",
                        help="Integer indicating plotting/debugging" +
                        " (default=12)",
                        type=int,
                        default=12,
                        choices=DEBUGPLOT_CODES)
    parser.add_argument("--echo",
                        help="Display full command line",
                        action="store_true")

    args = parser.parse_args()

    if args.echo:
        print('\033[1m\033[31mExecuting: ' + ' '.join(sys.argv) + '\033[0m\n')

    # ---

    # avoid incompatible options
    if args.arc_lines and args.oh_lines:
        raise ValueError("--arc_lines and --oh_lines cannot be used "
                         "simultaneously")

    # --ds9_lines requires --arc_lines or --oh_lines
    if args.ds9_lines:
        if not (args.arc_lines or args.oh_lines):
            raise ValueError("--ds9_lines requires the use of either "
                             "--arc_lines or --oh_lines")

    # read input FITS file
    hdulist = fits.open(args.fitsfile)
    image_header = hdulist[0].header
    image2d = hdulist[0].data
    hdulist.close()

    naxis1 = image_header['naxis1']
    naxis2 = image_header['naxis2']

    if image2d.shape != (naxis2, naxis1):
        raise ValueError("Unexpected error with NAXIS1, NAXIS2")
    if image2d.shape != (EMIR_NAXIS2, EMIR_NAXIS1):
        raise ValueError("Unexpected values for NAXIS1, NAXIS2")

    # remove path from fitsfile
    sfitsfile = os.path.basename(args.fitsfile.name)

    # check that the FITS file has been obtained with EMIR
    instrument = image_header['instrume']
    if instrument != 'EMIR':
        raise ValueError("INSTRUME keyword is not 'EMIR'!")

    # read GRISM, FILTER and ROTANG from FITS header
    grism = image_header['grism']
    spfilter = image_header['filter']
    rotang = image_header['rotang']

    # ---

    # generate MasterRectWave object
    master_rectwv = MasterRectWave._datatype_load(
        args.rect_wpoly_MOSlibrary.name)

    # check that grism and filter are the expected ones
    grism_ = master_rectwv.tags['grism']
    if grism_ != grism:
        raise ValueError('Unexpected grism: ' + str(grism_))
    spfilter_ = master_rectwv.tags['filter']
    if spfilter_ != spfilter:
        raise ValueError('Unexpected filter ' + str(spfilter_))

    # valid slitlet numbers
    list_valid_islitlets = list(range(1, EMIR_NBARS + 1))
    for idel in master_rectwv.missing_slitlets:
        list_valid_islitlets.remove(idel)

    # read CsuConfiguration object from FITS file
    csu_config = CsuConfiguration.define_from_fits(args.fitsfile)

    # list with csu_bar_slit_center for valid slitlets
    list_csu_bar_slit_center = []
    for islitlet in list_valid_islitlets:
        list_csu_bar_slit_center.append(
            csu_config.csu_bar_slit_center(islitlet))

    # define parmodel and params
    fitted_bound_param_json = {
        'contents': master_rectwv.meta_info['refined_boundary_model']
    }
    parmodel = fitted_bound_param_json['contents']['parmodel']
    fitted_bound_param_json.update({'meta_info': {'parmodel': parmodel}})
    params = bound_params_from_dict(fitted_bound_param_json)
    if parmodel != "multislit":
        raise ValueError('parmodel = "multislit" not found')

    # ---

    # define lines to be overplotted
    if args.arc_lines or args.oh_lines:

        rectwv_coeff = rectwv_coeff_from_mos_library(hdulist, master_rectwv)
        rectwv_coeff.global_integer_offset_x_pix = \
            args.global_integer_offset_x_pix
        rectwv_coeff.global_integer_offset_y_pix = \
            args.global_integer_offset_y_pix
        # rectwv_coeff.writeto('xxx.json')

        if args.arc_lines:
            if grism == 'LR':
                catlines_file = 'lines_argon_neon_xenon_empirical_LR.dat'
            else:
                catlines_file = 'lines_argon_neon_xenon_empirical.dat'
            dumdata = pkgutil.get_data('emirdrp.instrument.configs',
                                       catlines_file)
            arc_lines_tmpfile = StringIO(dumdata.decode('utf8'))
            catlines = np.genfromtxt(arc_lines_tmpfile)
            # define wavelength and flux as separate arrays
            catlines_all_wave = catlines[:, 0]
            catlines_all_flux = catlines[:, 1]
        elif args.oh_lines:
            dumdata = pkgutil.get_data('emirdrp.instrument.configs',
                                       'Oliva_etal_2013.dat')
            oh_lines_tmpfile = StringIO(dumdata.decode('utf8'))
            catlines = np.genfromtxt(oh_lines_tmpfile)
            # define wavelength and flux as separate arrays
            catlines_all_wave = np.concatenate((catlines[:, 1], catlines[:,
                                                                         0]))
            catlines_all_flux = np.concatenate((catlines[:, 2], catlines[:,
                                                                         2]))
        else:
            raise ValueError("This should not happen!")

    else:
        rectwv_coeff = None
        catlines_all_wave = None
        catlines_all_flux = None

    # ---

    # generate output ds9 region file with slitlet boundaries
    if args.ds9_boundaries is not None:
        save_boundaries_from_params_ds9(
            params=params,
            parmodel=parmodel,
            list_islitlet=list_valid_islitlets,
            list_csu_bar_slit_center=list_csu_bar_slit_center,
            uuid=master_rectwv.uuid,
            grism=grism,
            spfilter=spfilter,
            ds9_filename=args.ds9_boundaries.name,
            global_offset_x_pix=-args.global_integer_offset_x_pix,
            global_offset_y_pix=-args.global_integer_offset_y_pix)

    # generate output ds9 region file with slitlet frontiers
    if args.ds9_frontiers is not None:
        save_frontiers_from_params_ds9(
            params=params,
            parmodel=parmodel,
            list_islitlet=list_valid_islitlets,
            list_csu_bar_slit_center=list_csu_bar_slit_center,
            uuid=master_rectwv.uuid,
            grism=grism,
            spfilter=spfilter,
            ds9_filename=args.ds9_frontiers.name,
            global_offset_x_pix=-args.global_integer_offset_x_pix,
            global_offset_y_pix=-args.global_integer_offset_y_pix)

    # ---

    # display full image
    if abs(args.debugplot) % 10 != 0:
        ax = ximshow(image2d=image2d,
                     title=sfitsfile + "\ngrism=" + grism + ", filter=" +
                     spfilter + ", rotang=" + str(round(rotang, 2)),
                     image_bbox=(1, naxis1, 1, naxis2),
                     show=False)

        # overplot boundaries
        overplot_boundaries_from_params(
            ax=ax,
            params=params,
            parmodel=parmodel,
            list_islitlet=list_valid_islitlets,
            list_csu_bar_slit_center=list_csu_bar_slit_center,
            global_offset_x_pix=-args.global_integer_offset_x_pix,
            global_offset_y_pix=-args.global_integer_offset_y_pix)

        # overplot frontiers
        overplot_frontiers_from_params(
            ax=ax,
            params=params,
            parmodel=parmodel,
            list_islitlet=list_valid_islitlets,
            list_csu_bar_slit_center=list_csu_bar_slit_center,
            micolors=('b', 'b'),
            linetype='-',
            labels=False,  # already displayed with the boundaries
            global_offset_x_pix=-args.global_integer_offset_x_pix,
            global_offset_y_pix=-args.global_integer_offset_y_pix)

    else:
        ax = None

    # overplot lines
    if catlines_all_wave is not None:

        if args.ds9_lines is None:
            ds9_file = None
        else:
            ds9_file = open(args.ds9_lines.name, 'w')
            ds9_file.write('# Region file format: DS9 version 4.1\n')
            ds9_file.write('global color=#00ffff dashlist=0 0 width=2 '
                           'font="helvetica 10 normal roman" select=1 '
                           'highlite=1 dash=0 fixed=0 edit=1 '
                           'move=1 delete=1 include=1 source=1\n')
            ds9_file.write('physical\n#\n')

            ds9_file.write('#\n# uuid..: {0}\n'.format(master_rectwv.uuid))
            ds9_file.write('# filter: {0}\n'.format(spfilter))
            ds9_file.write('# grism.: {0}\n'.format(grism))
            ds9_file.write('#\n# global_offset_x_pix: {0}\n'.format(
                args.global_integer_offset_x_pix))
            ds9_file.write('# global_offset_y_pix: {0}\n#\n'.format(
                args.global_integer_offset_y_pix))
            if parmodel == "longslit":
                for dumpar in EXPECTED_PARAMETER_LIST:
                    parvalue = params[dumpar].value
                    ds9_file.write('# {0}: {1}\n'.format(dumpar, parvalue))
            else:
                for dumpar in EXPECTED_PARAMETER_LIST_EXTENDED:
                    parvalue = params[dumpar].value
                    ds9_file.write('# {0}: {1}\n'.format(dumpar, parvalue))

        overplot_lines(ax, catlines_all_wave, list_valid_islitlets,
                       rectwv_coeff, args.global_integer_offset_x_pix,
                       args.global_integer_offset_y_pix, ds9_file,
                       args.debugplot)

        if ds9_file is not None:
            ds9_file.close()

    if ax is not None:
        # show plot
        pause_debugplot(12, pltshow=True)
Esempio n. 10
0
def main(args=None):

    # parse command-line options
    parser = argparse.ArgumentParser()

    # positional arguments
    parser.add_argument("fitsfile",
                        help="FITS file name to be displayed",
                        type=argparse.FileType('rb'))
    parser.add_argument("--fitted_bound_param", required=True,
                        help="JSON file with fitted boundary coefficients "
                             "corresponding to the multislit model",
                        type=argparse.FileType('rt'))
    parser.add_argument("--slitlets", required=True,
                        help="Slitlet selection: string between double "
                             "quotes providing tuples of the form "
                             "n1[,n2[,step]]",
                        type=str)

    # optional arguments
    parser.add_argument("--outfile",
                        help="Output FITS file name",
                        type=lambda x: arg_file_is_new(parser, x, mode='wb'))
    parser.add_argument("--maskonly",
                        help="Generate mask for the indicated slitlets",
                        action="store_true")
    parser.add_argument("--debugplot",
                        help="Integer indicating plotting/debugging" +
                             " (default=0)",
                        type=int, default=0,
                        choices=DEBUGPLOT_CODES)
    parser.add_argument("--echo",
                        help="Display full command line",
                        action="store_true")

    args = parser.parse_args()

    if args.echo:
        print('\033[1m\033[31mExecuting: ' + ' '.join(sys.argv) + '\033[0m\n')

    # read input FITS file
    hdulist_image = fits.open(args.fitsfile.name)
    image_header = hdulist_image[0].header
    image2d = hdulist_image[0].data

    naxis1 = image_header['naxis1']
    naxis2 = image_header['naxis2']

    if image2d.shape != (naxis2, naxis1):
        raise ValueError("Unexpected error with NAXIS1, NAXIS2")

    if image2d.shape != (EMIR_NAXIS2, EMIR_NAXIS1):
        raise ValueError("NAXIS1, NAXIS2 unexpected for EMIR detector")

    # remove path from fitsfile
    if args.outfile is None:
        sfitsfile = os.path.basename(args.fitsfile.name)
    else:
        sfitsfile = os.path.basename(args.outfile.name)

    # check that the FITS file has been obtained with EMIR
    instrument = image_header['instrume']
    if instrument != 'EMIR':
        raise ValueError("INSTRUME keyword is not 'EMIR'!")

    # read GRISM, FILTER and ROTANG from FITS header
    grism = image_header['grism']
    spfilter = image_header['filter']
    rotang = image_header['rotang']

    # read fitted_bound_param JSON file
    fittedpar_dict = json.loads(open(args.fitted_bound_param.name).read())
    params = bound_params_from_dict(fittedpar_dict)
    if abs(args.debugplot) in [21, 22]:
        params.pretty_print()

    parmodel = fittedpar_dict['meta_info']['parmodel']
    if parmodel != 'multislit':
        raise ValueError("Unexpected parameter model: ", parmodel)

    # define slitlet range
    islitlet_min = fittedpar_dict['tags']['islitlet_min']
    islitlet_max = fittedpar_dict['tags']['islitlet_max']
    list_islitlet = list_slitlets_from_string(
        s=args.slitlets,
        islitlet_min=islitlet_min,
        islitlet_max=islitlet_max
    )

    # read CsuConfiguration object from FITS file
    csu_config = CsuConfiguration.define_from_fits(args.fitsfile)

    # define csu_bar_slit_center associated to each slitlet
    list_csu_bar_slit_center = []
    for islitlet in list_islitlet:
        list_csu_bar_slit_center.append(
            csu_config.csu_bar_slit_center(islitlet))

    # initialize output data array
    image2d_output = np.zeros((naxis2, naxis1))

    # main loop
    for islitlet, csu_bar_slit_center in \
            zip(list_islitlet, list_csu_bar_slit_center):
        image2d_tmp = select_unrectified_slitlet(
            image2d=image2d,
            islitlet=islitlet,
            csu_bar_slit_center=csu_bar_slit_center,
            params=params,
            parmodel=parmodel,
            maskonly=args.maskonly
        )
        image2d_output += image2d_tmp

    # update the array of the output file
    hdulist_image[0].data = image2d_output

    # save output FITS file
    hdulist_image.writeto(args.outfile)

    # close original image
    hdulist_image.close()

    # display full image
    if abs(args.debugplot) % 10 != 0:
        ax = ximshow(image2d=image2d_output,
                     title=sfitsfile + "\n" + args.slitlets,
                     image_bbox=(1, naxis1, 1, naxis2), show=False)

        # overplot boundaries
        overplot_boundaries_from_params(
            ax=ax,
            params=params,
            parmodel=parmodel,
            list_islitlet=list_islitlet,
            list_csu_bar_slit_center=list_csu_bar_slit_center
        )

        # overplot frontiers
        overplot_frontiers_from_params(
            ax=ax,
            params=params,
            parmodel=parmodel,
            list_islitlet=list_islitlet,
            list_csu_bar_slit_center=list_csu_bar_slit_center,
            micolors=('b', 'b'), linetype='-',
            labels=False    # already displayed with the boundaries
        )

        # show plot
        pause_debugplot(12, pltshow=True)
Esempio n. 11
0
def main(args=None):

    # parse command-line options
    parser = argparse.ArgumentParser()

    # positional arguments
    parser.add_argument("fitsfile",
                        help="FITS file name to be displayed",
                        type=argparse.FileType('rb'))
    parser.add_argument("--bounddict", required=True,
                        help="bounddict file name",
                        type=argparse.FileType('rt'))
    parser.add_argument("--tuple_slit_numbers", required=True,
                        help="Tuple n1[,n2[,step]] to define slitlet numbers")

    # optional arguments
    parser.add_argument("--echo",
                        help="Display full command line",
                        action="store_true")

    args = parser.parse_args()

    if args.echo:
        print('\033[1m\033[31mExecuting: ' + ' '.join(sys.argv) + '\033[0m\n')

    # read slitlet numbers to be computed
    tmp_str = args.tuple_slit_numbers.split(",")
    if len(tmp_str) == 3:
        if int(tmp_str[0]) < 1:
            raise ValueError("Invalid slitlet number < 1")
        if int(tmp_str[1]) > EMIR_NBARS:
            raise ValueError("Invalid slitlet number > EMIR_NBARS")
        list_slitlets = range(int(tmp_str[0]),
                              int(tmp_str[1])+1,
                              int(tmp_str[2]))
    elif len(tmp_str) == 2:
        if int(tmp_str[0]) < 1:
            raise ValueError("Invalid slitlet number < 1")
        if int(tmp_str[1]) > EMIR_NBARS:
            raise ValueError("Invalid slitlet number > EMIR_NBARS")
        list_slitlets = range(int(tmp_str[0]),
                              int(tmp_str[1])+1,
                              1)
    elif len(tmp_str) == 1:
        if int(tmp_str[0]) < 1:
            raise ValueError("Invalid slitlet number < 1")
        if int(tmp_str[0]) > EMIR_NBARS:
            raise ValueError("Invalid slitlet number > EMIR_NBARS")
        list_slitlets = [int(tmp_str[0])]
    else:
        raise ValueError("Invalid tuple for slitlet numbers")

    # read input FITS file
    hdulist = fits.open(args.fitsfile.name)
    image_header = hdulist[0].header
    image2d = hdulist[0].data
    hdulist.close()

    naxis1 = image_header['naxis1']
    naxis2 = image_header['naxis2']

    if image2d.shape != (naxis2, naxis1):
        raise ValueError("Unexpected error with NAXIS1, NAXIS2")

    # remove path from fitsfile
    sfitsfile = os.path.basename(args.fitsfile.name)

    # check that the FITS file has been obtained with EMIR
    instrument = image_header['instrume']
    if instrument != 'EMIR':
        raise ValueError("INSTRUME keyword is not 'EMIR'!")

    # read GRISM, FILTER and ROTANG from FITS header
    grism = image_header['grism']
    spfilter = image_header['filter']
    rotang = image_header['rotang']

    # display full image
    ax = ximshow(image2d=image2d,
                 title=sfitsfile + "\ngrism=" + grism +
                       ", filter=" + spfilter +
                       ", rotang=" + str(round(rotang, 2)),
                 image_bbox=(1, naxis1, 1, naxis2), show=False)

    # overplot boundaries for each slitlet
    for slitlet_number in list_slitlets:
        pol_lower_boundary, pol_upper_boundary, \
        xmin_lower, xmax_lower, xmin_upper, xmax_upper,  \
        csu_bar_slit_center = \
            get_boundaries(args.bounddict, slitlet_number)
        if (pol_lower_boundary is not None) and \
                (pol_upper_boundary is not None):
            xp = np.linspace(start=xmin_lower, stop=xmax_lower, num=1000)
            yp = pol_lower_boundary(xp)
            ax.plot(xp, yp, 'g-')
            xp = np.linspace(start=xmin_upper, stop=xmax_upper, num=1000)
            yp = pol_upper_boundary(xp)
            ax.plot(xp, yp, 'b-')
            # slitlet label
            yc_lower = pol_lower_boundary(EMIR_NAXIS1 / 2 + 0.5)
            yc_upper = pol_upper_boundary(EMIR_NAXIS1 / 2 + 0.5)
            tmpcolor = ['r', 'b'][slitlet_number % 2]
            xcsu = EMIR_NAXIS1 * csu_bar_slit_center/341.5
            ax.text(xcsu, (yc_lower + yc_upper) / 2,
                    str(slitlet_number),
                    fontsize=10, va='center', ha='center',
                    bbox=dict(boxstyle="round,pad=0.1",
                              fc="white", ec="grey"),
                    color=tmpcolor, fontweight='bold',
                    backgroundcolor='white')

    # show plot
    pause_debugplot(12, pltshow=True)
def compute_slitlet_boundaries(
        filename, grism, spfilter, list_slitlets,
        size_x_medfilt, size_y_savgol,
        times_sigma_threshold,
        bounddict, debugplot=0):
    """Compute slitlet boundaries using continuum lamp images.

    Parameters
    ----------
    filename : string
        Input continumm lamp image.
    grism : string
        Grism name. It must be one in EMIR_VALID_GRISMS.
    spfilter : string
        Filter name. It must be one in EMIR_VALID_FILTERS.
    list_slitlets : list of integers
        Number of slitlets to be updated.
    size_x_medfilt : int
        Window in the X (spectral) direction, in pixels, to apply
        the 1d median filter in order to remove bad pixels.
    size_y_savgol : int
        Window in the Y (spatial) direction to be used when using the
        1d Savitzky-Golay filter.
    times_sigma_threshold : float
        Times sigma to detect peaks in derivatives.
    bounddict : dictionary of dictionaries
        Structure to store the boundaries.
    debugplot : int
        Determines whether intermediate computations and/or plots are
        displayed.

    """

    # read 2D image
    hdulist = fits.open(filename)
    image_header = hdulist[0].header
    image2d = hdulist[0].data
    naxis2, naxis1 = image2d.shape
    hdulist.close()
    if debugplot >= 10:
        print('>>> NAXIS1:', naxis1)
        print('>>> NAXIS2:', naxis2)


    # ToDo: replace this by application of cosmetic defect mask!
    for j in range(1024):
        image2d[1024, j] = (image2d[1023, j] + image2d[1025, j]) / 2
        image2d[1023, j + 1024] = (image2d[1022, j + 1024] +
                                   image2d[1024, j + 1024]) / 2
        
    # remove path from filename
    sfilename = os.path.basename(filename)

    # check that the FITS file has been obtained with EMIR
    instrument = image_header['instrume']
    if instrument != 'EMIR':
        raise ValueError("INSTRUME keyword is not 'EMIR'!")

    # read CSU configuration from FITS header
    csu_config = CsuConfiguration.define_from_fits(filename)

    # read DTU configuration from FITS header
    dtu_config = DtuConfiguration.define_from_fits(filename)

    # read grism
    grism_in_header = image_header['grism']
    if grism != grism_in_header:
        raise ValueError("GRISM keyword=" + grism_in_header +
                         " is not the expected value=" + grism)

    # read filter
    spfilter_in_header = image_header['filter']
    if spfilter != spfilter_in_header:
        raise ValueError("FILTER keyword=" + spfilter_in_header +
                         " is not the expected value=" + spfilter)

    # read rotator position angle
    rotang = image_header['rotang']

    # read date-obs
    date_obs = image_header['date-obs']

    for islitlet in list_slitlets:
        if debugplot < 10:
            sys.stdout.write('.')
            sys.stdout.flush()

        sltlim = SlitletLimits(grism, spfilter, islitlet)
        # extract slitlet2d
        slitlet2d = extract_slitlet2d(image2d, sltlim)
        if debugplot % 10 != 0:
            ximshow(slitlet2d,
                    title=sfilename + " [original]"
                          "\nslitlet=" + str(islitlet) +
                          ", grism=" + grism +
                          ", filter=" + spfilter +
                          ", rotang=" + str(round(rotang, 2)),
                    first_pixel=(sltlim.bb_nc1_orig, sltlim.bb_ns1_orig),
                    debugplot=debugplot)

        # apply 1d median filtering (along the spectral direction)
        # to remove bad pixels
        size_x = size_x_medfilt
        size_y = 1
        slitlet2d_smooth = ndimage.filters.median_filter(
            slitlet2d, size=(size_y, size_x))

        if debugplot % 10 != 0:
            ximshow(slitlet2d_smooth,
                    title=sfilename + " [smoothed]"
                          "\nslitlet=" + str(islitlet) +
                          ", grism=" + grism +
                          ", filter=" + spfilter +
                          ", rotang=" + str(round(rotang, 2)),
                    first_pixel=(sltlim.bb_nc1_orig, sltlim.bb_ns1_orig),
                    debugplot=debugplot)

        # apply 1d Savitzky-Golay filter (along the spatial direction)
        # to compute first derivative
        slitlet2d_savgol = savgol_filter(
            slitlet2d_smooth, window_length=size_y_savgol, polyorder=2,
            deriv=1, axis=0)

        # compute basic statistics
        q25, q50, q75 = np.percentile(slitlet2d_savgol, q=[25.0, 50.0, 75.0])
        sigmag = 0.7413 * (q75 - q25)  # robust standard deviation
        if debugplot >= 10:
            print("q50, sigmag:", q50, sigmag)

        if debugplot % 10 != 0:
            ximshow(slitlet2d_savgol,
                    title=sfilename + " [S.-G.filt.]"
                          "\nslitlet=" + str(islitlet) +
                          ", grism=" + grism +
                          ", filter=" + spfilter +
                          ", rotang=" + str(round(rotang, 2)),
                    first_pixel=(sltlim.bb_nc1_orig, sltlim.bb_ns1_orig),
                    z1z2=(q50-times_sigma_threshold*sigmag,
                          q50+times_sigma_threshold*sigmag),
                    debugplot=debugplot)

        # identify objects in slitlet2d_savgol: pixels with positive
        # derivatives are identify independently from pixels with
        # negative derivaties; then the two set of potential features
        # are merged; this approach avoids some problems when, in
        # nearby regions, there are pixels with positive and negative
        # derivatives (in those circumstances a single search as
        # np.logical_or(
        #     slitlet2d_savgol < q50 - times_sigma_threshold * sigmag,
        #     slitlet2d_savgol > q50 + times_sigma_threshold * sigmag)
        # led to erroneous detections!)
        #
        # search for positive derivatives
        labels2d_objects_pos, no_objects_pos = ndimage.label(
            slitlet2d_savgol > q50 + times_sigma_threshold * sigmag)
        # search for negative derivatives
        labels2d_objects_neg, no_objects_neg = ndimage.label(
            slitlet2d_savgol < q50 - times_sigma_threshold * sigmag)
        # merge both sets
        non_zero_neg = np.where(labels2d_objects_neg > 0)
        labels2d_objects = np.copy(labels2d_objects_pos)
        labels2d_objects[non_zero_neg] += \
            labels2d_objects_neg[non_zero_neg] + no_objects_pos
        no_objects = no_objects_pos + no_objects_neg

        if debugplot >= 10:
            print("Number of objects with positive derivative:",
                  no_objects_pos)
            print("Number of objects with negative derivative:",
                  no_objects_neg)
            print("Total number of objects initially found...:", no_objects)

        if debugplot % 10 != 0:
            ximshow(labels2d_objects,
                    z1z2=(0, no_objects),
                    title=sfilename + " [objects]"
                                     "\nslitlet=" + str(islitlet) +
                          ", grism=" + grism +
                          ", filter=" + spfilter +
                          ", rotang=" + str(round(rotang, 2)),
                    first_pixel=(sltlim.bb_nc1_orig, sltlim.bb_ns1_orig),
                    cbar_label="Object number",
                    debugplot=debugplot)

        # select boundaries as the largest objects found with
        # positive and negative derivatives
        n_der_pos = 0  # number of pixels covered by the object with deriv > 0
        i_der_pos = 0  # id of the object with deriv > 0
        n_der_neg = 0  # number of pixels covered by the object with deriv < 0
        i_der_neg = 0  # id of the object with deriv < 0
        for i in range(1, no_objects+1):
            xy_tmp = np.where(labels2d_objects == i)
            n_pix = len(xy_tmp[0])
            if i <= no_objects_pos:
                if n_pix > n_der_pos:
                    i_der_pos = i
                    n_der_pos = n_pix
            else:
                if n_pix > n_der_neg:
                    i_der_neg = i
                    n_der_neg = n_pix

        # determine which boundary is lower and which is upper
        y_center_mass_der_pos = ndimage.center_of_mass(
            slitlet2d_savgol, labels2d_objects, [i_der_pos])[0][0]
        y_center_mass_der_neg = ndimage.center_of_mass(
            slitlet2d_savgol, labels2d_objects, [i_der_neg])[0][0]
        if y_center_mass_der_pos < y_center_mass_der_neg:
            i_lower = i_der_pos
            i_upper = i_der_neg
            if debugplot >= 10:
                print("-> lower boundary has positive derivatives")
        else:
            i_lower = i_der_neg
            i_upper = i_der_pos
            if debugplot >= 10:
                print("-> lower boundary has negative derivatives")
        list_slices_ok = [i_lower, i_upper]

        # adjust individual boundaries passing the selection:
        # - select points in the image belonging to a given boundary
        # - compute weighted mean of the pixels of the boundary, column
        #   by column (this reduces dramatically the number of points
        #   to be fitted to determine the boundary)
        list_boundaries = []
        for k in range(2):  # k=0 lower boundary, k=1 upper boundary
            # select points to be fitted for a particular boundary
            # (note: be careful with array indices and pixel
            # coordinates)
            xy_tmp = np.where(labels2d_objects == list_slices_ok[k])
            xmin = xy_tmp[1].min()  # array indices (integers)
            xmax = xy_tmp[1].max()  # array indices (integers)
            xfit = []
            yfit = []
            # fix range for fit
            if k == 0:
                xmineff = max(sltlim.xmin_lower_boundary_fit, xmin)
                xmaxeff = min(sltlim.xmax_lower_boundary_fit, xmax)
            else:
                xmineff = max(sltlim.xmin_upper_boundary_fit, xmin)
                xmaxeff = min(sltlim.xmax_upper_boundary_fit, xmax)
            # loop in columns of the image belonging to the boundary
            for xdum in range(xmineff, xmaxeff + 1):  # array indices (integer)
                iok = np.where(xy_tmp[1] == xdum)
                y_tmp = xy_tmp[0][iok] + sltlim.bb_ns1_orig  # image pixel
                weight = slitlet2d_savgol[xy_tmp[0][iok], xy_tmp[1][iok]]
                y_wmean = sum(y_tmp * weight) / sum(weight)
                xfit.append(xdum + sltlim.bb_nc1_orig)
                yfit.append(y_wmean)
            xfit = np.array(xfit)
            yfit = np.array(yfit)
            # declare new SpectrumTrail instance
            boundary = SpectrumTrail()
            # define new boundary
            boundary.fit(x=xfit, y=yfit, deg=sltlim.deg_boundary,
                         times_sigma_reject=10,
                         title="slit:" + str(sltlim.islitlet) +
                               ", deg=" + str(sltlim.deg_boundary),
                         debugplot=0)
            list_boundaries.append(boundary)

        if debugplot % 10 != 0:
            for tmp_img, tmp_label in zip(
                [slitlet2d_savgol, slitlet2d],
                [' [S.-G.filt.]', ' [original]']
            ):
                ax = ximshow(tmp_img,
                             title=sfilename + tmp_label +
                                   "\nslitlet=" + str(islitlet) +
                                   ", grism=" + grism +
                                   ", filter=" + spfilter +
                                   ", rotang=" + str(round(rotang, 2)),
                             first_pixel=(sltlim.bb_nc1_orig,
                                          sltlim.bb_ns1_orig),
                             show=False,
                             debugplot=debugplot)
                for k in range(2):
                    xpol, ypol = list_boundaries[k].linspace_pix(
                        start=1, stop=EMIR_NAXIS1)
                    ax.plot(xpol, ypol, 'b--', linewidth=1)
                for k in range(2):
                    xpol, ypol = list_boundaries[k].linspace_pix()
                    ax.plot(xpol, ypol, 'g--', linewidth=4)
                # show plot
                pause_debugplot(debugplot, pltshow=True)

        # update bounddict
        tmp_dict = {
            'boundary_coef_lower':
                list_boundaries[0].poly_funct.coef.tolist(),
            'boundary_xmin_lower': list_boundaries[0].xlower_line,
            'boundary_xmax_lower': list_boundaries[0].xupper_line,
            'boundary_coef_upper':
                list_boundaries[1].poly_funct.coef.tolist(),
            'boundary_xmin_upper': list_boundaries[1].xlower_line,
            'boundary_xmax_upper': list_boundaries[1].xupper_line,
            'csu_bar_left': csu_config.csu_bar_left(islitlet),
            'csu_bar_right': csu_config.csu_bar_right(islitlet),
            'csu_bar_slit_center': csu_config.csu_bar_slit_center(islitlet),
            'csu_bar_slit_width': csu_config.csu_bar_slit_width(islitlet),
            'rotang': rotang,
            'xdtu': dtu_config.xdtu,
            'ydtu': dtu_config.ydtu,
            'zdtu': dtu_config.zdtu,
            'xdtu_0': dtu_config.xdtu_0,
            'ydtu_0': dtu_config.ydtu_0,
            'zdtu_0': dtu_config.zdtu_0,
            'zzz_info1': os.getlogin() + '@' + socket.gethostname(),
            'zzz_info2': datetime.now().isoformat()
        }
        slitlet_label = "slitlet" + str(islitlet).zfill(2)
        if slitlet_label not in bounddict['contents']:
            bounddict['contents'][slitlet_label] = {}
        bounddict['contents'][slitlet_label][date_obs] = tmp_dict

    if debugplot < 10:
        print("")
def compute_slitlet_boundaries(filename,
                               grism,
                               spfilter,
                               list_slitlets,
                               size_x_medfilt,
                               size_y_savgol,
                               times_sigma_threshold,
                               bounddict,
                               debugplot=0):
    """Compute slitlet boundaries using continuum lamp images.

    Parameters
    ----------
    filename : string
        Input continumm lamp image.
    grism : string
        Grism name. It must be one in EMIR_VALID_GRISMS.
    spfilter : string
        Filter name. It must be one in EMIR_VALID_FILTERS.
    list_slitlets : list of integers
        Number of slitlets to be updated.
    size_x_medfilt : int
        Window in the X (spectral) direction, in pixels, to apply
        the 1d median filter in order to remove bad pixels.
    size_y_savgol : int
        Window in the Y (spatial) direction to be used when using the
        1d Savitzky-Golay filter.
    times_sigma_threshold : float
        Times sigma to detect peaks in derivatives.
    bounddict : dictionary of dictionaries
        Structure to store the boundaries.
    debugplot : int
        Determines whether intermediate computations and/or plots are
        displayed.

    """

    # read 2D image
    hdulist = fits.open(filename)
    image_header = hdulist[0].header
    image2d = hdulist[0].data
    naxis2, naxis1 = image2d.shape
    hdulist.close()
    if debugplot >= 10:
        print('>>> NAXIS1:', naxis1)
        print('>>> NAXIS2:', naxis2)

    # ToDo: replace this by application of cosmetic defect mask!
    for j in range(1024):
        image2d[1024, j] = (image2d[1023, j] + image2d[1025, j]) / 2
        image2d[1023, j +
                1024] = (image2d[1022, j + 1024] + image2d[1024, j + 1024]) / 2

    # remove path from filename
    sfilename = os.path.basename(filename)

    # check that the FITS file has been obtained with EMIR
    instrument = image_header['instrume']
    if instrument != 'EMIR':
        raise ValueError("INSTRUME keyword is not 'EMIR'!")

    # read CSU configuration from FITS header
    csu_config = CsuConfiguration.define_from_fits(filename)

    # read DTU configuration from FITS header
    dtu_config = DtuConfiguration.define_from_fits(filename)

    # read grism
    grism_in_header = image_header['grism']
    if grism != grism_in_header:
        raise ValueError("GRISM keyword=" + grism_in_header +
                         " is not the expected value=" + grism)

    # read filter
    spfilter_in_header = image_header['filter']
    if spfilter != spfilter_in_header:
        raise ValueError("FILTER keyword=" + spfilter_in_header +
                         " is not the expected value=" + spfilter)

    # read rotator position angle
    rotang = image_header['rotang']

    # read date-obs
    date_obs = image_header['date-obs']

    for islitlet in list_slitlets:
        if debugplot < 10:
            sys.stdout.write('.')
            sys.stdout.flush()

        sltlim = SlitletLimits(grism, spfilter, islitlet)
        # extract slitlet2d
        slitlet2d = extract_slitlet2d(image2d, sltlim)
        if debugplot % 10 != 0:
            ximshow(slitlet2d,
                    title=sfilename + " [original]"
                    "\nslitlet=" + str(islitlet) + ", grism=" + grism +
                    ", filter=" + spfilter + ", rotang=" +
                    str(round(rotang, 2)),
                    first_pixel=(sltlim.bb_nc1_orig, sltlim.bb_ns1_orig),
                    debugplot=debugplot)

        # apply 1d median filtering (along the spectral direction)
        # to remove bad pixels
        size_x = size_x_medfilt
        size_y = 1
        slitlet2d_smooth = ndimage.filters.median_filter(slitlet2d,
                                                         size=(size_y, size_x))

        if debugplot % 10 != 0:
            ximshow(slitlet2d_smooth,
                    title=sfilename + " [smoothed]"
                    "\nslitlet=" + str(islitlet) + ", grism=" + grism +
                    ", filter=" + spfilter + ", rotang=" +
                    str(round(rotang, 2)),
                    first_pixel=(sltlim.bb_nc1_orig, sltlim.bb_ns1_orig),
                    debugplot=debugplot)

        # apply 1d Savitzky-Golay filter (along the spatial direction)
        # to compute first derivative
        slitlet2d_savgol = savgol_filter(slitlet2d_smooth,
                                         window_length=size_y_savgol,
                                         polyorder=2,
                                         deriv=1,
                                         axis=0)

        # compute basic statistics
        q25, q50, q75 = np.percentile(slitlet2d_savgol, q=[25.0, 50.0, 75.0])
        sigmag = 0.7413 * (q75 - q25)  # robust standard deviation
        if debugplot >= 10:
            print("q50, sigmag:", q50, sigmag)

        if debugplot % 10 != 0:
            ximshow(slitlet2d_savgol,
                    title=sfilename + " [S.-G.filt.]"
                    "\nslitlet=" + str(islitlet) + ", grism=" + grism +
                    ", filter=" + spfilter + ", rotang=" +
                    str(round(rotang, 2)),
                    first_pixel=(sltlim.bb_nc1_orig, sltlim.bb_ns1_orig),
                    z1z2=(q50 - times_sigma_threshold * sigmag,
                          q50 + times_sigma_threshold * sigmag),
                    debugplot=debugplot)

        # identify objects in slitlet2d_savgol: pixels with positive
        # derivatives are identify independently from pixels with
        # negative derivaties; then the two set of potential features
        # are merged; this approach avoids some problems when, in
        # nearby regions, there are pixels with positive and negative
        # derivatives (in those circumstances a single search as
        # np.logical_or(
        #     slitlet2d_savgol < q50 - times_sigma_threshold * sigmag,
        #     slitlet2d_savgol > q50 + times_sigma_threshold * sigmag)
        # led to erroneous detections!)
        #
        # search for positive derivatives
        labels2d_objects_pos, no_objects_pos = ndimage.label(
            slitlet2d_savgol > q50 + times_sigma_threshold * sigmag)
        # search for negative derivatives
        labels2d_objects_neg, no_objects_neg = ndimage.label(
            slitlet2d_savgol < q50 - times_sigma_threshold * sigmag)
        # merge both sets
        non_zero_neg = np.where(labels2d_objects_neg > 0)
        labels2d_objects = np.copy(labels2d_objects_pos)
        labels2d_objects[non_zero_neg] += \
            labels2d_objects_neg[non_zero_neg] + no_objects_pos
        no_objects = no_objects_pos + no_objects_neg

        if debugplot >= 10:
            print("Number of objects with positive derivative:",
                  no_objects_pos)
            print("Number of objects with negative derivative:",
                  no_objects_neg)
            print("Total number of objects initially found...:", no_objects)

        if debugplot % 10 != 0:
            ximshow(labels2d_objects,
                    z1z2=(0, no_objects),
                    title=sfilename + " [objects]"
                    "\nslitlet=" + str(islitlet) + ", grism=" + grism +
                    ", filter=" + spfilter + ", rotang=" +
                    str(round(rotang, 2)),
                    first_pixel=(sltlim.bb_nc1_orig, sltlim.bb_ns1_orig),
                    cbar_label="Object number",
                    debugplot=debugplot)

        # select boundaries as the largest objects found with
        # positive and negative derivatives
        n_der_pos = 0  # number of pixels covered by the object with deriv > 0
        i_der_pos = 0  # id of the object with deriv > 0
        n_der_neg = 0  # number of pixels covered by the object with deriv < 0
        i_der_neg = 0  # id of the object with deriv < 0
        for i in range(1, no_objects + 1):
            xy_tmp = np.where(labels2d_objects == i)
            n_pix = len(xy_tmp[0])
            if i <= no_objects_pos:
                if n_pix > n_der_pos:
                    i_der_pos = i
                    n_der_pos = n_pix
            else:
                if n_pix > n_der_neg:
                    i_der_neg = i
                    n_der_neg = n_pix

        # determine which boundary is lower and which is upper
        y_center_mass_der_pos = ndimage.center_of_mass(slitlet2d_savgol,
                                                       labels2d_objects,
                                                       [i_der_pos])[0][0]
        y_center_mass_der_neg = ndimage.center_of_mass(slitlet2d_savgol,
                                                       labels2d_objects,
                                                       [i_der_neg])[0][0]
        if y_center_mass_der_pos < y_center_mass_der_neg:
            i_lower = i_der_pos
            i_upper = i_der_neg
            if debugplot >= 10:
                print("-> lower boundary has positive derivatives")
        else:
            i_lower = i_der_neg
            i_upper = i_der_pos
            if debugplot >= 10:
                print("-> lower boundary has negative derivatives")
        list_slices_ok = [i_lower, i_upper]

        # adjust individual boundaries passing the selection:
        # - select points in the image belonging to a given boundary
        # - compute weighted mean of the pixels of the boundary, column
        #   by column (this reduces dramatically the number of points
        #   to be fitted to determine the boundary)
        list_boundaries = []
        for k in range(2):  # k=0 lower boundary, k=1 upper boundary
            # select points to be fitted for a particular boundary
            # (note: be careful with array indices and pixel
            # coordinates)
            xy_tmp = np.where(labels2d_objects == list_slices_ok[k])
            xmin = xy_tmp[1].min()  # array indices (integers)
            xmax = xy_tmp[1].max()  # array indices (integers)
            xfit = []
            yfit = []
            # fix range for fit
            if k == 0:
                xmineff = max(sltlim.xmin_lower_boundary_fit, xmin)
                xmaxeff = min(sltlim.xmax_lower_boundary_fit, xmax)
            else:
                xmineff = max(sltlim.xmin_upper_boundary_fit, xmin)
                xmaxeff = min(sltlim.xmax_upper_boundary_fit, xmax)
            # loop in columns of the image belonging to the boundary
            for xdum in range(xmineff, xmaxeff + 1):  # array indices (integer)
                iok = np.where(xy_tmp[1] == xdum)
                y_tmp = xy_tmp[0][iok] + sltlim.bb_ns1_orig  # image pixel
                weight = slitlet2d_savgol[xy_tmp[0][iok], xy_tmp[1][iok]]
                y_wmean = sum(y_tmp * weight) / sum(weight)
                xfit.append(xdum + sltlim.bb_nc1_orig)
                yfit.append(y_wmean)
            xfit = np.array(xfit)
            yfit = np.array(yfit)
            # declare new SpectrumTrail instance
            boundary = SpectrumTrail()
            # define new boundary
            boundary.fit(x=xfit,
                         y=yfit,
                         deg=sltlim.deg_boundary,
                         times_sigma_reject=10,
                         title="slit:" + str(sltlim.islitlet) + ", deg=" +
                         str(sltlim.deg_boundary),
                         debugplot=0)
            list_boundaries.append(boundary)

        if debugplot % 10 != 0:
            for tmp_img, tmp_label in zip([slitlet2d_savgol, slitlet2d],
                                          [' [S.-G.filt.]', ' [original]']):
                ax = ximshow(tmp_img,
                             title=sfilename + tmp_label + "\nslitlet=" +
                             str(islitlet) + ", grism=" + grism + ", filter=" +
                             spfilter + ", rotang=" + str(round(rotang, 2)),
                             first_pixel=(sltlim.bb_nc1_orig,
                                          sltlim.bb_ns1_orig),
                             show=False,
                             debugplot=debugplot)
                for k in range(2):
                    xpol, ypol = list_boundaries[k].linspace_pix(
                        start=1, stop=EMIR_NAXIS1)
                    ax.plot(xpol, ypol, 'b--', linewidth=1)
                for k in range(2):
                    xpol, ypol = list_boundaries[k].linspace_pix()
                    ax.plot(xpol, ypol, 'g--', linewidth=4)
                # show plot
                pause_debugplot(debugplot, pltshow=True)

        # update bounddict
        tmp_dict = {
            'boundary_coef_lower': list_boundaries[0].poly_funct.coef.tolist(),
            'boundary_xmin_lower': list_boundaries[0].xlower_line,
            'boundary_xmax_lower': list_boundaries[0].xupper_line,
            'boundary_coef_upper': list_boundaries[1].poly_funct.coef.tolist(),
            'boundary_xmin_upper': list_boundaries[1].xlower_line,
            'boundary_xmax_upper': list_boundaries[1].xupper_line,
            'csu_bar_left': csu_config.csu_bar_left(islitlet),
            'csu_bar_right': csu_config.csu_bar_right(islitlet),
            'csu_bar_slit_center': csu_config.csu_bar_slit_center(islitlet),
            'csu_bar_slit_width': csu_config.csu_bar_slit_width(islitlet),
            'rotang': rotang,
            'xdtu': dtu_config.xdtu,
            'ydtu': dtu_config.ydtu,
            'zdtu': dtu_config.zdtu,
            'xdtu_0': dtu_config.xdtu_0,
            'ydtu_0': dtu_config.ydtu_0,
            'zdtu_0': dtu_config.zdtu_0,
            'zzz_info1': os.getlogin() + '@' + socket.gethostname(),
            'zzz_info2': datetime.now().isoformat()
        }
        slitlet_label = "slitlet" + str(islitlet).zfill(2)
        if slitlet_label not in bounddict['contents']:
            bounddict['contents'][slitlet_label] = {}
        bounddict['contents'][slitlet_label][date_obs] = tmp_dict

    if debugplot < 10:
        print("")
Esempio n. 14
0
def main(args=None):
    # parse command-line options
    parser = argparse.ArgumentParser(
        description='description: compute pixel-to-pixel flatfield'
    )

    # required arguments
    parser.add_argument("fitsfile",
                        help="Input FITS file (flat ON-OFF)",
                        type=argparse.FileType('rb'))
    parser.add_argument("--rectwv_coeff", required=True,
                        help="Input JSON file with rectification and "
                             "wavelength calibration coefficients",
                        type=argparse.FileType('rt'))
    parser.add_argument("--minimum_slitlet_width_mm", required=True,
                        help="Minimum slitlet width in mm",
                        type=float)
    parser.add_argument("--maximum_slitlet_width_mm", required=True,
                        help="Maximum slitlet width in mm",
                        type=float)
    parser.add_argument("--minimum_fraction", required=True,
                        help="Minimum allowed flatfielding value",
                        type=float, default=0.01)
    parser.add_argument("--minimum_value_in_output",
                        help="Minimum value allowed in output file: pixels "
                             "below this value are set to 1.0 (default=0.01)",
                        type=float, default=0.01)
    parser.add_argument("--maximum_value_in_output",
                        help="Maximum value allowed in output file: pixels "
                             "above this value are set to 1.0 (default=10.0)",
                        type=float, default=10.0)
    parser.add_argument("--nwindow_median",
                        help="Window size to smooth median spectrum in the "
                             "spectral direction",
                        type=int)
    parser.add_argument("--outfile", required=True,
                        help="Output FITS file",
                        type=lambda x: arg_file_is_new(parser, x, mode='wb'))

    # optional arguments
    parser.add_argument("--delta_global_integer_offset_x_pix",
                        help="Delta global integer offset in the X direction "
                             "(default=0)",
                        default=0, type=int)
    parser.add_argument("--delta_global_integer_offset_y_pix",
                        help="Delta global integer offset in the Y direction "
                             "(default=0)",
                        default=0, type=int)
    parser.add_argument("--resampling",
                        help="Resampling method: 1 -> nearest neighbor, "
                             "2 -> linear interpolation (default)",
                        default=2, type=int,
                        choices=(1, 2))
    parser.add_argument("--ignore_DTUconf",
                        help="Ignore DTU configurations differences between "
                             "model and input image",
                        action="store_true")
    parser.add_argument("--debugplot",
                        help="Integer indicating plotting & debugging options"
                             " (default=0)",
                        default=0, type=int,
                        choices=DEBUGPLOT_CODES)
    parser.add_argument("--echo",
                        help="Display full command line",
                        action="store_true")
    args = parser.parse_args(args)

    if args.echo:
        print('\033[1m\033[31m% ' + ' '.join(sys.argv) + '\033[0m\n')

    # This code is obsolete
    raise ValueError('This code is obsolete: use recipe in '
                     'emirdrp/recipes/spec/flatpix2pix.py')

    # read calibration structure from JSON file
    rectwv_coeff = RectWaveCoeff._datatype_load(args.rectwv_coeff.name)

    # modify (when requested) global offsets
    rectwv_coeff.global_integer_offset_x_pix += \
        args.delta_global_integer_offset_x_pix
    rectwv_coeff.global_integer_offset_y_pix += \
        args.delta_global_integer_offset_y_pix

    # read FITS image and its corresponding header
    hdulist = fits.open(args.fitsfile)
    header = hdulist[0].header
    image2d = hdulist[0].data
    hdulist.close()

    # apply global offsets
    image2d = apply_integer_offsets(
        image2d=image2d,
        offx=rectwv_coeff.global_integer_offset_x_pix,
        offy=rectwv_coeff.global_integer_offset_y_pix
    )

    # protections
    naxis2, naxis1 = image2d.shape
    if naxis1 != header['naxis1'] or naxis2 != header['naxis2']:
        print('>>> NAXIS1:', naxis1)
        print('>>> NAXIS2:', naxis2)
        raise ValueError('Something is wrong with NAXIS1 and/or NAXIS2')
    if abs(args.debugplot) >= 10:
        print('>>> NAXIS1:', naxis1)
        print('>>> NAXIS2:', naxis2)

    # check that the input FITS file grism and filter match
    filter_name = header['filter']
    if filter_name != rectwv_coeff.tags['filter']:
        raise ValueError("Filter name does not match!")
    grism_name = header['grism']
    if grism_name != rectwv_coeff.tags['grism']:
        raise ValueError("Filter name does not match!")
    if abs(args.debugplot) >= 10:
        print('>>> grism.......:', grism_name)
        print('>>> filter......:', filter_name)

    # check that the DTU configurations are compatible
    dtu_conf_fitsfile = DtuConfiguration.define_from_fits(args.fitsfile)
    dtu_conf_jsonfile = DtuConfiguration.define_from_dictionary(
        rectwv_coeff.meta_info['dtu_configuration'])
    if dtu_conf_fitsfile != dtu_conf_jsonfile:
        print('DTU configuration (FITS file):\n\t', dtu_conf_fitsfile)
        print('DTU configuration (JSON file):\n\t', dtu_conf_jsonfile)
        if args.ignore_DTUconf:
            print('WARNING: DTU configuration differences found!')
        else:
            raise ValueError('DTU configurations do not match')
    else:
        if abs(args.debugplot) >= 10:
            print('>>> DTU Configuration match!')
            print(dtu_conf_fitsfile)

    # load CSU configuration
    csu_conf_fitsfile = CsuConfiguration.define_from_fits(args.fitsfile)
    if abs(args.debugplot) >= 10:
        print(csu_conf_fitsfile)

    # valid slitlet numbers
    list_valid_islitlets = list(range(1, EMIR_NBARS + 1))
    for idel in rectwv_coeff.missing_slitlets:
        print('-> Removing slitlet (not defined):', idel)
        list_valid_islitlets.remove(idel)
    # filter out slitlets with widths outside valid range
    list_outside_valid_width = []
    for islitlet in list_valid_islitlets:
        slitwidth = csu_conf_fitsfile.csu_bar_slit_width(islitlet)
        if (slitwidth < args.minimum_slitlet_width_mm) or \
                (slitwidth > args.maximum_slitlet_width_mm):
            list_outside_valid_width.append(islitlet)
            print('-> Removing slitlet (invalid width):', islitlet)
    if len(list_outside_valid_width) > 0:
        for idel in list_outside_valid_width:
            list_valid_islitlets.remove(idel)
    print('>>> valid slitlet numbers:\n', list_valid_islitlets)

    # ---

    # compute and store median spectrum (and masked region) for each
    # individual slitlet
    image2d_sp_median = np.zeros((EMIR_NBARS, EMIR_NAXIS1))
    image2d_sp_mask = np.zeros((EMIR_NBARS, EMIR_NAXIS1), dtype=bool)
    for islitlet in list(range(1, EMIR_NBARS + 1)):
        if islitlet in list_valid_islitlets:
            if args.debugplot == 0:
                islitlet_progress(islitlet, EMIR_NBARS, ignore=False)
            # define Slitlet2D object
            slt = Slitlet2D(islitlet=islitlet,
                            rectwv_coeff=rectwv_coeff,
                            debugplot=args.debugplot)

            if abs(args.debugplot) >= 10:
                print(slt)

            # extract (distorted) slitlet from the initial image
            slitlet2d = slt.extract_slitlet2d(
                image_2k2k=image2d,
                subtitle='original image'
            )

            # rectify slitlet
            slitlet2d_rect = slt.rectify(
                slitlet2d=slitlet2d,
                resampling=args.resampling,
                subtitle='original rectified'
            )
            naxis2_slitlet2d, naxis1_slitlet2d = slitlet2d_rect.shape

            if naxis1_slitlet2d != EMIR_NAXIS1:
                print('naxis1_slitlet2d: ', naxis1_slitlet2d)
                print('EMIR_NAXIS1.....: ', EMIR_NAXIS1)
                raise ValueError("Unexpected naxis1_slitlet2d")

            sp_mask = np.zeros(naxis1_slitlet2d, dtype=bool)

            # for grism LR set to zero data beyond useful wavelength range
            if grism_name == 'LR':
                wv_parameters = set_wv_parameters(filter_name, grism_name)
                x_pix = np.arange(1, naxis1_slitlet2d + 1)
                wl_pix = polyval(x_pix, slt.wpoly)
                lremove = wl_pix < wv_parameters['wvmin_useful']
                sp_mask[lremove] = True
                slitlet2d_rect[:, lremove] = 0.0
                lremove = wl_pix > wv_parameters['wvmax_useful']
                slitlet2d_rect[:, lremove] = 0.0
                sp_mask[lremove] = True

            # get useful slitlet region (use boundaries instead of frontiers;
            # note that the nscan_minmax_frontiers() works well independently
            # of using frontiers of boundaries as arguments)
            nscan_min, nscan_max = nscan_minmax_frontiers(
                slt.y0_reference_lower,
                slt.y0_reference_upper,
                resize=False
            )
            ii1 = nscan_min - slt.bb_ns1_orig
            ii2 = nscan_max - slt.bb_ns1_orig + 1

            # median spectrum
            sp_collapsed = np.median(slitlet2d_rect[ii1:(ii2 + 1), :], axis=0)

            # smooth median spectrum along the spectral direction
            sp_median = ndimage.median_filter(
                sp_collapsed,
                args.nwindow_median,
                mode='nearest'
            )

            """
                nremove = 5
                spl = AdaptiveLSQUnivariateSpline(
                    x=xaxis1[nremove:-nremove],
                    y=sp_collapsed[nremove:-nremove],
                    t=11,
                    adaptive=True
                )
                xknots = spl.get_knots()
                yknots = spl(xknots)
                sp_median = spl(xaxis1)

                # compute rms within each knot interval
                nknots = len(xknots)
                rms_array = np.zeros(nknots - 1, dtype=float)
                for iknot in range(nknots - 1):
                    residuals = []
                    for xdum, ydum, yydum in \
                            zip(xaxis1, sp_collapsed, sp_median):
                        if xknots[iknot] <= xdum <= xknots[iknot + 1]:
                            residuals.append(abs(ydum - yydum))
                    if len(residuals) > 5:
                        rms_array[iknot] = np.std(residuals)
                    else:
                        rms_array[iknot] = 0

                # determine in which knot interval falls each pixel
                iknot_array = np.zeros(len(xaxis1), dtype=int)
                for idum, xdum in enumerate(xaxis1):
                    for iknot in range(nknots - 1):
                        if xknots[iknot] <= xdum <= xknots[iknot + 1]:
                            iknot_array[idum] = iknot

                # compute new fit removing deviant points (with fixed knots)
                xnewfit = []
                ynewfit = []
                for idum in range(len(xaxis1)):
                    delta_sp = abs(sp_collapsed[idum] - sp_median[idum])
                    rms_tmp = rms_array[iknot_array[idum]]
                    if idum == 0 or idum == (len(xaxis1) - 1):
                        lok = True
                    elif rms_tmp > 0:
                        if delta_sp < 3.0 * rms_tmp:
                            lok = True
                        else:
                            lok = False
                    else:
                        lok = True
                    if lok:
                        xnewfit.append(xaxis1[idum])
                        ynewfit.append(sp_collapsed[idum])
                nremove = 5
                splnew = AdaptiveLSQUnivariateSpline(
                    x=xnewfit[nremove:-nremove],
                    y=ynewfit[nremove:-nremove],
                    t=xknots[1:-1],
                    adaptive=False
                )
                sp_median = splnew(xaxis1)
            """

            ymax_spmedian = sp_median.max()
            y_threshold = ymax_spmedian * args.minimum_fraction
            lremove = np.where(sp_median < y_threshold)
            sp_median[lremove] = 0.0
            sp_mask[lremove] = True

            image2d_sp_median[islitlet - 1, :] = sp_median
            image2d_sp_mask[islitlet - 1, :] = sp_mask

            if abs(args.debugplot) % 10 != 0:
                xaxis1 = np.arange(1, naxis1_slitlet2d + 1)
                title = 'Slitlet#' + str(islitlet) + ' (median spectrum)'
                ax = ximplotxy(xaxis1, sp_collapsed,
                               title=title,
                               show=False, **{'label' : 'collapsed spectrum'})
                ax.plot(xaxis1, sp_median, label='fitted spectrum')
                ax.plot([1, naxis1_slitlet2d], 2*[y_threshold],
                        label='threshold')
                # ax.plot(xknots, yknots, 'o', label='knots')
                ax.legend()
                ax.set_ylim(-0.05*ymax_spmedian, 1.05*ymax_spmedian)
                pause_debugplot(args.debugplot,
                                pltshow=True, tight_layout=True)
        else:
            if args.debugplot == 0:
                islitlet_progress(islitlet, EMIR_NBARS, ignore=True)

    # ToDo: compute "average" spectrum for each pseudo-longslit, scaling
    #       with the median signal in each slitlet; derive a particular
    #       spectrum for each slitlet (scaling properly)

    image2d_sp_median_masked = np.ma.masked_array(
        image2d_sp_median,
        mask=image2d_sp_mask
    )
    ycut_median = np.ma.median(image2d_sp_median_masked, axis=1).data
    ycut_median_2d = np.repeat(ycut_median, EMIR_NAXIS1).reshape(
        EMIR_NBARS, EMIR_NAXIS1)
    image2d_sp_median_eq = image2d_sp_median_masked / ycut_median_2d
    image2d_sp_median_eq = image2d_sp_median_eq.data

    if True:
        ximshow(image2d_sp_median, title='sp_median', debugplot=12)
        ximplotxy(np.arange(1, EMIR_NBARS + 1), ycut_median, 'ro',
                  title='median value of each spectrum', debugplot=12)
        ximshow(image2d_sp_median_eq, title='sp_median_eq', debugplot=12)

    csu_conf_fitsfile.display_pseudo_longslits(
        list_valid_slitlets=list_valid_islitlets)
    dict_longslits = csu_conf_fitsfile.pseudo_longslits()

    # compute median spectrum for each longslit and insert (properly
    # scaled) that spectrum in each slitlet belonging to that longslit
    image2d_sp_median_longslit = np.zeros((EMIR_NBARS, EMIR_NAXIS1))
    islitlet = 1
    loop = True
    while loop:
        if islitlet in list_valid_islitlets:
            imin = dict_longslits[islitlet].imin()
            imax = dict_longslits[islitlet].imax()
            print('--> imin, imax: ', imin, imax)
            sp_median_longslit = np.median(
                image2d_sp_median_eq[(imin - 1):imax, :], axis=0)
            for i in range(imin, imax+1):
                print('----> i: ', i)
                image2d_sp_median_longslit[(i - 1), :] = \
                    sp_median_longslit * ycut_median[i - 1]
            islitlet = imax
        else:
            print('--> ignoring: ', islitlet)
        if islitlet == EMIR_NBARS:
            loop = False
        else:
            islitlet += 1
    if True:
        ximshow(image2d_sp_median_longslit, debugplot=12)

    # initialize rectified image
    image2d_flatfielded = np.zeros((EMIR_NAXIS2, EMIR_NAXIS1))

    # main loop
    for islitlet in list(range(1, EMIR_NBARS + 1)):
        if islitlet in list_valid_islitlets:
            if args.debugplot == 0:
                islitlet_progress(islitlet, EMIR_NBARS, ignore=False)
            # define Slitlet2D object
            slt = Slitlet2D(islitlet=islitlet,
                            rectwv_coeff=rectwv_coeff,
                            debugplot=args.debugplot)

            # extract (distorted) slitlet from the initial image
            slitlet2d = slt.extract_slitlet2d(
                image_2k2k=image2d,
                subtitle='original image'
            )

            # rectify slitlet
            slitlet2d_rect = slt.rectify(
                slitlet2d=slitlet2d,
                resampling=args.resampling,
                subtitle='original rectified'
            )
            naxis2_slitlet2d, naxis1_slitlet2d = slitlet2d_rect.shape

            sp_median = image2d_sp_median_longslit[islitlet - 1, :]

            # generate rectified slitlet region filled with the median spectrum
            slitlet2d_rect_spmedian = np.tile(sp_median, (naxis2_slitlet2d, 1))
            if abs(args.debugplot) > 10:
                slt.ximshow_rectified(
                    slitlet2d_rect=slitlet2d_rect_spmedian,
                    subtitle='rectified, filled with median spectrum'
                )

            # unrectified image
            slitlet2d_unrect_spmedian = slt.rectify(
                slitlet2d=slitlet2d_rect_spmedian,
                resampling=args.resampling,
                inverse=True,
                subtitle='unrectified, filled with median spectrum'
            )

            # normalize initial slitlet image (avoid division by zero)
            slitlet2d_norm = np.zeros_like(slitlet2d)
            for j in range(naxis1_slitlet2d):
                for i in range(naxis2_slitlet2d):
                    den = slitlet2d_unrect_spmedian[i, j]
                    if den == 0:
                        slitlet2d_norm[i, j] = 1.0
                    else:
                        slitlet2d_norm[i, j] = slitlet2d[i, j] / den

            if abs(args.debugplot) > 10:
                slt.ximshow_unrectified(
                    slitlet2d=slitlet2d_norm,
                    subtitle='unrectified, pixel-to-pixel'
                )

            # check for pseudo-longslit with previous slitlet
            if islitlet > 1:
                if (islitlet - 1) in list_valid_islitlets:
                    c1 = csu_conf_fitsfile.csu_bar_slit_center(islitlet - 1)
                    w1 = csu_conf_fitsfile.csu_bar_slit_width(islitlet - 1)
                    c2 = csu_conf_fitsfile.csu_bar_slit_center(islitlet)
                    w2 = csu_conf_fitsfile.csu_bar_slit_width(islitlet)
                    if abs(w1-w2)/w1 < 0.25:
                        wmean = (w1 + w2) / 2.0
                        if abs(c1 - c2) < wmean/4.0:
                            same_slitlet_below = True
                        else:
                            same_slitlet_below = False
                    else:
                        same_slitlet_below = False
                else:
                    same_slitlet_below = False
            else:
                same_slitlet_below = False

            # check for pseudo-longslit with next slitlet
            if islitlet < EMIR_NBARS:
                if (islitlet + 1) in list_valid_islitlets:
                    c1 = csu_conf_fitsfile.csu_bar_slit_center(islitlet)
                    w1 = csu_conf_fitsfile.csu_bar_slit_width(islitlet)
                    c2 = csu_conf_fitsfile.csu_bar_slit_center(islitlet + 1)
                    w2 = csu_conf_fitsfile.csu_bar_slit_width(islitlet + 1)
                    if abs(w1-w2)/w1 < 0.25:
                        wmean = (w1 + w2) / 2.0
                        if abs(c1 - c2) < wmean/4.0:
                            same_slitlet_above = True
                        else:
                            same_slitlet_above = False
                    else:
                        same_slitlet_above = False
                else:
                    same_slitlet_above = False
            else:
                same_slitlet_above = False

            for j in range(EMIR_NAXIS1):
                xchannel = j + 1
                y0_lower = slt.list_frontiers[0](xchannel)
                y0_upper = slt.list_frontiers[1](xchannel)
                n1, n2 = nscan_minmax_frontiers(y0_frontier_lower=y0_lower,
                                                y0_frontier_upper=y0_upper,
                                                resize=True)
                # note that n1 and n2 are scans (ranging from 1 to NAXIS2)
                nn1 = n1 - slt.bb_ns1_orig + 1
                nn2 = n2 - slt.bb_ns1_orig + 1
                image2d_flatfielded[(n1 - 1):n2, j] = \
                    slitlet2d_norm[(nn1 - 1):nn2, j]

                # force to 1.0 region around frontiers
                if not same_slitlet_below:
                    image2d_flatfielded[(n1 - 1):(n1 + 2), j] = 1
                if not same_slitlet_above:
                    image2d_flatfielded[(n2 - 5):n2, j] = 1
        else:
            if args.debugplot == 0:
                islitlet_progress(islitlet, EMIR_NBARS, ignore=True)

    if args.debugplot == 0:
        print('OK!')

    # restore global offsets
    image2d_flatfielded = apply_integer_offsets(
        image2d=image2d_flatfielded ,
        offx=-rectwv_coeff.global_integer_offset_x_pix,
        offy=-rectwv_coeff.global_integer_offset_y_pix
    )

    # set pixels below minimum value to 1.0
    filtered = np.where(image2d_flatfielded < args.minimum_value_in_output)
    image2d_flatfielded[filtered] = 1.0

    # set pixels above maximum value to 1.0
    filtered = np.where(image2d_flatfielded > args.maximum_value_in_output)
    image2d_flatfielded[filtered] = 1.0

    # save output file
    save_ndarray_to_fits(
        array=image2d_flatfielded,
        file_name=args.outfile,
        main_header=header,
        overwrite=True
    )
    print('>>> Saving file ' + args.outfile.name)
def median_slitlets_rectified(
        input_image,
        mode=0,
        minimum_slitlet_width_mm=EMIR_MINIMUM_SLITLET_WIDTH_MM,
        maximum_slitlet_width_mm=EMIR_MAXIMUM_SLITLET_WIDTH_MM,
        debugplot=0
    ):
    """Compute median spectrum for each slitlet.

    Parameters
    ----------
    input_image : HDUList object
        Input 2D image.
    mode : int
        Indicate desired result:
        0 : image with the same size as the input image, with the
            median spectrum of each slitlet spanning all the spectra
            of the corresponding slitlet
        1 : image with 55 spectra, containing the median spectra of
            each slitlet
        2 : single collapsed median spectrum, using exclusively the
            useful slitlets from the input image
    minimum_slitlet_width_mm : float
        Minimum slitlet width (mm) for a valid slitlet.
    maximum_slitlet_width_mm : float
        Maximum slitlet width (mm) for a valid slitlet.
    debugplot : int
        Determines whether intermediate computations and/or plots
        are displayed. The valid codes are defined in
        numina.array.display.pause_debugplot.

    Returns
    -------
    image_median : HDUList object
        Output image.

    """

    image_header = input_image[0].header
    image2d = input_image[0].data

    # check image dimensions
    naxis2_expected = EMIR_NBARS * EMIR_NPIXPERSLIT_RECTIFIED

    naxis2, naxis1 = image2d.shape
    if naxis2 != naxis2_expected:
        raise ValueError("NAXIS2={0} should be {1}".format(
            naxis2, naxis2_expected
        ))

    # check that the FITS file has been obtained with EMIR
    instrument = image_header['instrume']
    if instrument != 'EMIR':
        raise ValueError("INSTRUME keyword is not 'EMIR'!")

    # initialize output image
    if mode == 0:
        image2d_median = np.zeros((naxis2, naxis1))
    else:
        image2d_median = np.zeros((EMIR_NBARS, naxis1))

    # main loop
    for i in range(EMIR_NBARS):
        ns1 = i * EMIR_NPIXPERSLIT_RECTIFIED + 1
        ns2 = ns1 + EMIR_NPIXPERSLIT_RECTIFIED - 1
        sp_median = np.median(image2d[(ns1-1):ns2, :], axis=0)

        if mode == 0:
            image2d_median[(ns1-1):ns2, :] = np.tile(
                sp_median, (EMIR_NPIXPERSLIT_RECTIFIED, 1)
            )
        else:
            image2d_median[i] = np.copy(sp_median)

    if mode == 2:
        # get CSU configuration from FITS header
        csu_config = CsuConfiguration.define_from_header(image_header)

        # define wavelength calibration parameters
        crpix1 = image_header['crpix1']
        crval1 = image_header['crval1']
        cdelt1 = image_header['cdelt1']

        # segregate slitlets
        list_useful_slitlets = csu_config.widths_in_range_mm(
            minwidth=minimum_slitlet_width_mm,
            maxwidth=maximum_slitlet_width_mm
        )
        list_not_useful_slitlets = [i for i in list(range(1, EMIR_NBARS + 1))
                                    if i not in list_useful_slitlets]
        if abs(debugplot) != 0:
            print('>>> list_useful_slitlets....:', list_useful_slitlets)
            print('>>> list_not_useful_slitlets:', list_not_useful_slitlets)

        # define mask from array data
        mask2d, borders = define_mask_borders(image2d_median, sought_value=0)
        if abs(debugplot) % 10 != 0:
            ximshow(mask2d.astype(int), z1z2=(-.2, 1.2), crpix1=crpix1,
                    crval1=crval1, cdelt1=cdelt1, debugplot=debugplot)

        # update mask with unused slitlets
        for islitlet in list_not_useful_slitlets:
            mask2d[islitlet - 1, :] = np.array([True] * naxis1)
        if abs(debugplot) % 10 != 0:
            ximshow(mask2d.astype(int), z1z2=(-.2, 1.2), crpix1=crpix1,
                    crval1=crval1, cdelt1=cdelt1, debugplot=debugplot)

        # useful image pixels
        image2d_masked = image2d_median * (1 - mask2d.astype(int))
        if abs(debugplot) % 10 != 0:
            ximshow(image2d_masked, crpix1=crpix1, crval1=crval1,
                    cdelt1=cdelt1, debugplot=debugplot)

        # masked image
        image2d_masked = np.ma.masked_array(image2d_median, mask=mask2d)
        # median spectrum
        image1d_median = np.ma.median(image2d_masked, axis=0).data

        image_median = fits.PrimaryHDU(data=image1d_median,
                                       header=image_header)

    else:
        image_median = fits.PrimaryHDU(data=image2d_median,
                                       header=image_header)

    return fits.HDUList([image_median])