Example #1
0
    def ximshow_rectified(self, slitlet2d_rect, subtitle=None):
        """Display rectified image with spectrails and frontiers.

        Parameters
        ----------
        slitlet2d_rect : numpy array
            Array containing the rectified slitlet image.
        subtitle : string, optional
            Subtitle for plot.

        """

        title = "Slitlet#" + str(self.islitlet)
        if subtitle is not None:
            title += ' (' + subtitle + ')'
        ax = ximshow(slitlet2d_rect, title=title,
                     first_pixel=(self.bb_nc1_orig, self.bb_ns1_orig),
                     show=False)
        # grid with fitted transformation: spectrum trails
        xx = np.arange(0, self.bb_nc2_orig - self.bb_nc1_orig + 1,
                       dtype=np.float)
        for spectrail in self.list_spectrails:
            yy0 = self.corr_yrect_a + \
                  self.corr_yrect_b * spectrail(self.x0_reference)
            yy = np.tile([yy0 - self.bb_ns1_orig], xx.size)
            ax.plot(xx + self.bb_nc1_orig, yy + self.bb_ns1_orig, "b")
        for spectrail in self.list_frontiers:
            yy0 = self.corr_yrect_a +\
                  self.corr_yrect_b * spectrail(self.x0_reference)
            yy = np.tile([yy0 - self.bb_ns1_orig], xx.size)
            ax.plot(xx + self.bb_nc1_orig, yy + self.bb_ns1_orig, "b:")
        # show plot
        pause_debugplot(self.debugplot, pltshow=True)
Example #2
0
    def ximshow_unrectified(self, slitlet2d):
        """Display unrectified image with spectrails and frontiers.

        Parameters
        ----------
        slitlet2d : numpy array
            Array containing the unrectified slitlet image.

        """

        title = "Slitlet#" + str(self.islitlet)
        ax = ximshow(slitlet2d,
                     title=title,
                     first_pixel=(self.bb_nc1_orig, self.bb_ns1_orig),
                     show=False)
        xdum = np.linspace(1, EMIR_NAXIS1, num=EMIR_NAXIS1)
        ylower = self.list_spectrails[0](xdum)
        ax.plot(xdum, ylower, 'b-')
        ymiddle = self.list_spectrails[1](xdum)
        ax.plot(xdum, ymiddle, 'b--')
        yupper = self.list_spectrails[2](xdum)
        ax.plot(xdum, yupper, 'b-')
        ylower_frontier = self.list_frontiers[0](xdum)
        ax.plot(xdum, ylower_frontier, 'b:')
        yupper_frontier = self.list_frontiers[1](xdum)
        ax.plot(xdum, yupper_frontier, 'b:')
        pause_debugplot(debugplot=self.debugplot, pltshow=True)
Example #3
0
    def ximshow_unrectified(self, slitlet2d, subtitle=None):
        """Display unrectified image with spectrails and frontiers.

        Parameters
        ----------
        slitlet2d : numpy array
            Array containing the unrectified slitlet image.
        subtitle : string, optional
            Subtitle for plot.

        """

        title = "Slitlet#" + str(self.islitlet)
        if subtitle is not None:
            title += ' (' + subtitle + ')'
        ax = ximshow(slitlet2d, title=title,
                     first_pixel=(self.bb_nc1_orig, self.bb_ns1_orig),
                     show=False)
        xdum = np.linspace(1, EMIR_NAXIS1, num=EMIR_NAXIS1)
        ylower = self.list_spectrails[0](xdum)
        ax.plot(xdum, ylower, 'b-')
        ymiddle = self.list_spectrails[1](xdum)
        ax.plot(xdum, ymiddle, 'b--')
        yupper = self.list_spectrails[2](xdum)
        ax.plot(xdum, yupper, 'b-')
        ylower_frontier = self.list_frontiers[0](xdum)
        ax.plot(xdum, ylower_frontier, 'b:')
        yupper_frontier = self.list_frontiers[1](xdum)
        ax.plot(xdum, yupper_frontier, 'b:')
        if title is not None:
            ax.set_title(title)
        pause_debugplot(debugplot=self.debugplot, pltshow=True)
Example #4
0
    def ximshow_unrectified(self, slitlet2d):
        """Display unrectified image with spectrails and frontiers.

        Parameters
        ----------
        slitlet2d : numpy array
            Array containing the unrectified slitlet image.

        """

        title = "Slitlet#" + str(self.islitlet)
        ax = ximshow(slitlet2d, title=title,
                     first_pixel=(self.bb_nc1_orig, self.bb_ns1_orig),
                     show=False)
        xdum = np.linspace(1, EMIR_NAXIS1, num=EMIR_NAXIS1)
        ylower = self.list_spectrails[0](xdum)
        ax.plot(xdum, ylower, 'b-')
        ymiddle = self.list_spectrails[1](xdum)
        ax.plot(xdum, ymiddle, 'b--')
        yupper = self.list_spectrails[2](xdum)
        ax.plot(xdum, yupper, 'b-')
        ylower_frontier = self.list_frontiers[0](xdum)
        ax.plot(xdum, ylower_frontier, 'b:')
        yupper_frontier = self.list_frontiers[1](xdum)
        ax.plot(xdum, yupper_frontier, 'b:')
        pause_debugplot(debugplot=self.debugplot, pltshow=True)
Example #5
0
def display_slitlet_histogram(csu_bar_slit_width,
                              n_clusters=2,
                              geometry=None,
                              debugplot=0):
    """

    Find separations between groups of slitlet widths.

    Parameters
    ----------
    csu_bar_slit_width : numpy array
        Array containing the csu_bar_slit_center values.
    n_clusters : int
        Number of slitlet groups to be sought.
    geometry : tuple (4 integers) or None
        x, y, dx, dy values employed to set the Qt backend geometry.
    debugplot : int
        Determines whether intermediate computations and/or plots
        are displayed. The valid codes are defined in
        numina.array.display.pause_debugplot

    Returns
    -------
    TBD

    """

    # protections
    if n_clusters < 2:
        raise ValueError("n_clusters must be >= 2")

    # determine useful slitlets from their widths
    km = KMeans(n_clusters=n_clusters)
    fit = km.fit(np.array(csu_bar_slit_width).reshape(-1, 1))
    separator_list = []
    for i in range(n_clusters - 1):
        peak1 = fit.cluster_centers_[i][0]
        peak2 = fit.cluster_centers_[i + 1][0]
        separator = (peak1 + peak2) / 2
        separator_list.append(separator)
        print('--->  separator: {0:7.3f}'.format(separator))

    # display histogram
    if abs(debugplot) % 10 != 0:
        fig = plt.figure()
        set_window_geometry(geometry)
        ax = fig.add_subplot(111)
        ax.hist(csu_bar_slit_width, bins=100)
        ax.set_xlabel('slit width (mm)')
        ax.set_ylabel('number of slitlets')
        for separator in separator_list:
            ax.axvline(separator, color='C1', linestyle='dashed')
        pause_debugplot(debugplot, pltshow=True)
def display_slitlet_histogram(csu_bar_slit_width,
                              n_clusters=2,
                              geometry=None,
                              debugplot=0):
    """

    Find separations between groups of slitlet widths.

    Parameters
    ----------
    csu_bar_slit_width : numpy array
        Array containing the csu_bar_slit_center values.
    n_clusters : int
        Number of slitlet groups to be sought.
    geometry : tuple (4 integers) or None
        x, y, dx, dy values employed to set the Qt backend geometry.
    debugplot : int
        Determines whether intermediate computations and/or plots
        are displayed. The valid codes are defined in
        numina.array.display.pause_debugplot

    Returns
    -------
    TBD

    """

    # protections
    if n_clusters < 2:
        raise ValueError("n_clusters must be >= 2")

    # determine useful slitlets from their widths
    km = KMeans(n_clusters=n_clusters)
    fit = km.fit(np.array(csu_bar_slit_width).reshape(-1, 1))
    separator_list = []
    for i in range(n_clusters - 1):
        peak1 = fit.cluster_centers_[i][0]
        peak2 = fit.cluster_centers_[i+1][0]
        separator = (peak1 + peak2) / 2
        separator_list.append(separator)
        print('--->  separator: {0:7.3f}'.format(separator))

    # display histogram
    if abs(debugplot) % 10 != 0:
        fig = plt.figure()
        set_window_geometry(geometry)
        ax = fig.add_subplot(111)
        ax.hist(csu_bar_slit_width, bins=100)
        ax.set_xlabel('slit width (mm)')
        ax.set_ylabel('number of slitlets')
        for separator in separator_list:
            ax.axvline(separator, color='C1', linestyle='dashed')
        pause_debugplot(debugplot, pltshow=True)
Example #7
0
def main(args=None):
    # parse command-line options
    parser = argparse.ArgumentParser(
        description='description: compute pixel-to-pixel flatfield'
    )

    # required arguments
    parser.add_argument("fitsfile",
                        help="Input FITS file (flat ON-OFF)",
                        type=argparse.FileType('rb'))
    parser.add_argument("--rectwv_coeff", required=True,
                        help="Input JSON file with rectification and "
                             "wavelength calibration coefficients",
                        type=argparse.FileType('rt'))
    parser.add_argument("--minimum_slitlet_width_mm", required=True,
                        help="Minimum slitlet width in mm",
                        type=float)
    parser.add_argument("--maximum_slitlet_width_mm", required=True,
                        help="Maximum slitlet width in mm",
                        type=float)
    parser.add_argument("--minimum_fraction", required=True,
                        help="Minimum allowed flatfielding value",
                        type=float, default=0.01)
    parser.add_argument("--minimum_value_in_output",
                        help="Minimum value allowed in output file: pixels "
                             "below this value are set to 1.0 (default=0.01)",
                        type=float, default=0.01)
    parser.add_argument("--maximum_value_in_output",
                        help="Maximum value allowed in output file: pixels "
                             "above this value are set to 1.0 (default=10.0)",
                        type=float, default=10.0)
    # parser.add_argument("--nwindow_median", required=True,
    #                     help="Window size to smooth median spectrum in the "
    #                          "spectral direction",
    #                     type=int)
    parser.add_argument("--outfile", required=True,
                        help="Output FITS file",
                        type=lambda x: arg_file_is_new(parser, x, mode='wb'))

    # optional arguments
    parser.add_argument("--delta_global_integer_offset_x_pix",
                        help="Delta global integer offset in the X direction "
                             "(default=0)",
                        default=0, type=int)
    parser.add_argument("--delta_global_integer_offset_y_pix",
                        help="Delta global integer offset in the Y direction "
                             "(default=0)",
                        default=0, type=int)
    parser.add_argument("--resampling",
                        help="Resampling method: 1 -> nearest neighbor, "
                             "2 -> linear interpolation (default)",
                        default=2, type=int,
                        choices=(1, 2))
    parser.add_argument("--ignore_DTUconf",
                        help="Ignore DTU configurations differences between "
                             "model and input image",
                        action="store_true")
    parser.add_argument("--debugplot",
                        help="Integer indicating plotting & debugging options"
                             " (default=0)",
                        default=0, type=int,
                        choices=DEBUGPLOT_CODES)
    parser.add_argument("--echo",
                        help="Display full command line",
                        action="store_true")
    args = parser.parse_args(args)

    if args.echo:
        print('\033[1m\033[31m% ' + ' '.join(sys.argv) + '\033[0m\n')

    # read calibration structure from JSON file
    rectwv_coeff = RectWaveCoeff._datatype_load(args.rectwv_coeff.name)

    # modify (when requested) global offsets
    rectwv_coeff.global_integer_offset_x_pix += \
        args.delta_global_integer_offset_x_pix
    rectwv_coeff.global_integer_offset_y_pix += \
        args.delta_global_integer_offset_y_pix

    # read FITS image and its corresponding header
    hdulist = fits.open(args.fitsfile)
    header = hdulist[0].header
    image2d = hdulist[0].data
    hdulist.close()

    # apply global offsets
    image2d = apply_integer_offsets(
        image2d=image2d,
        offx=rectwv_coeff.global_integer_offset_x_pix,
        offy=rectwv_coeff.global_integer_offset_y_pix
    )

    # protections
    naxis2, naxis1 = image2d.shape
    if naxis1 != header['naxis1'] or naxis2 != header['naxis2']:
        print('>>> NAXIS1:', naxis1)
        print('>>> NAXIS2:', naxis2)
        raise ValueError('Something is wrong with NAXIS1 and/or NAXIS2')
    if abs(args.debugplot) >= 10:
        print('>>> NAXIS1:', naxis1)
        print('>>> NAXIS2:', naxis2)

    # check that the input FITS file grism and filter match
    filter_name = header['filter']
    if filter_name != rectwv_coeff.tags['filter']:
        raise ValueError("Filter name does not match!")
    grism_name = header['grism']
    if grism_name != rectwv_coeff.tags['grism']:
        raise ValueError("Filter name does not match!")
    if abs(args.debugplot) >= 10:
        print('>>> grism.......:', grism_name)
        print('>>> filter......:', filter_name)

    # check that the DTU configurations are compatible
    dtu_conf_fitsfile = DtuConfiguration.define_from_fits(args.fitsfile)
    dtu_conf_jsonfile = DtuConfiguration.define_from_dictionary(
        rectwv_coeff.meta_info['dtu_configuration'])
    if dtu_conf_fitsfile != dtu_conf_jsonfile:
        print('DTU configuration (FITS file):\n\t', dtu_conf_fitsfile)
        print('DTU configuration (JSON file):\n\t', dtu_conf_jsonfile)
        if args.ignore_DTUconf:
            print('WARNING: DTU configuration differences found!')
        else:
            raise ValueError('DTU configurations do not match')
    else:
        if abs(args.debugplot) >= 10:
            print('>>> DTU Configuration match!')
            print(dtu_conf_fitsfile)

    # load CSU configuration
    csu_conf_fitsfile = CsuConfiguration.define_from_fits(args.fitsfile)
    if abs(args.debugplot) >= 10:
        print(csu_conf_fitsfile)

    # valid slitlet numbers
    list_valid_islitlets = list(range(1, EMIR_NBARS + 1))
    for idel in rectwv_coeff.missing_slitlets:
        print('-> Removing slitlet (not defined):', idel)
        list_valid_islitlets.remove(idel)
    # filter out slitlets with widths outside valid range
    list_outside_valid_width = []
    for islitlet in list_valid_islitlets:
        slitwidth = csu_conf_fitsfile.csu_bar_slit_width(islitlet)
        if (slitwidth < args.minimum_slitlet_width_mm) or \
                (slitwidth > args.maximum_slitlet_width_mm):
            list_outside_valid_width.append(islitlet)
            print('-> Removing slitlet (invalid width):', islitlet)
    if len(list_outside_valid_width) > 0:
        for idel in list_outside_valid_width:
            list_valid_islitlets.remove(idel)
    print('>>> valid slitlet numbers:\n', list_valid_islitlets)

    # ---

    # initialize rectified image
    image2d_flatfielded = np.zeros((EMIR_NAXIS2, EMIR_NAXIS1))

    # main loop
    for islitlet in list(range(1, EMIR_NBARS + 1)):
        if islitlet in list_valid_islitlets:
            if args.debugplot == 0:
                islitlet_progress(islitlet, EMIR_NBARS, ignore=False)
            # define Slitlet2D object
            slt = Slitlet2D(islitlet=islitlet,
                            rectwv_coeff=rectwv_coeff,
                            debugplot=args.debugplot)

            if abs(args.debugplot) >= 10:
                print(slt)

            # extract (distorted) slitlet from the initial image
            slitlet2d = slt.extract_slitlet2d(
                image_2k2k=image2d,
                subtitle='original image'
            )

            # rectify slitlet
            slitlet2d_rect = slt.rectify(
                slitlet2d=slitlet2d,
                resampling=args.resampling,
                subtitle='original rectified'
            )
            naxis2_slitlet2d, naxis1_slitlet2d = slitlet2d_rect.shape

            if naxis1_slitlet2d != EMIR_NAXIS1:
                print('naxis1_slitlet2d: ', naxis1_slitlet2d)
                print('EMIR_NAXIS1.....: ', EMIR_NAXIS1)
                raise ValueError("Unexpected naxis1_slitlet2d")

            # get useful slitlet region (use boundaries instead of frontiers;
            # note that the nscan_minmax_frontiers() works well independently
            # of using frontiers of boundaries as arguments)
            nscan_min, nscan_max = nscan_minmax_frontiers(
                slt.y0_reference_lower,
                slt.y0_reference_upper,
                resize=False
            )
            ii1 = nscan_min - slt.bb_ns1_orig
            ii2 = nscan_max - slt.bb_ns1_orig + 1

            # median spectrum
            sp_collapsed = np.median(slitlet2d_rect[ii1:(ii2 + 1), :], axis=0)

            # smooth median spectrum along the spectral direction
            # sp_median = ndimage.median_filter(
            #     sp_collapsed,
            #     args.nwindow_median,
            #     mode='nearest'
            # )
            xaxis1 = np.arange(1, naxis1_slitlet2d + 1)
            nremove = 5
            spl = AdaptiveLSQUnivariateSpline(
                x=xaxis1[nremove:-nremove],
                y=sp_collapsed[nremove:-nremove],
                t=11,
                adaptive=True
            )
            xknots = spl.get_knots()
            yknots = spl(xknots)
            sp_median = spl(xaxis1)

            # compute rms within each knot interval
            nknots = len(xknots)
            rms_array = np.zeros(nknots - 1, dtype=float)
            for iknot in range(nknots - 1):
                residuals = []
                for xdum, ydum, yydum in \
                        zip(xaxis1, sp_collapsed, sp_median):
                    if xknots[iknot] <= xdum <= xknots[iknot + 1]:
                        residuals.append(abs(ydum - yydum))
                if len(residuals) > 5:
                    rms_array[iknot] = np.std(residuals)
                else:
                    rms_array[iknot] = 0

            # determine in which knot interval falls each pixel
            iknot_array = np.zeros(len(xaxis1), dtype=int)
            for idum, xdum in enumerate(xaxis1):
                for iknot in range(nknots - 1):
                    if xknots[iknot] <= xdum <= xknots[iknot + 1]:
                        iknot_array[idum] = iknot

            # compute new fit removing deviant points (with fixed knots)
            xnewfit = []
            ynewfit = []
            for idum in range(len(xaxis1)):
                delta_sp = abs(sp_collapsed[idum] - sp_median[idum])
                rms_tmp = rms_array[iknot_array[idum]]
                if idum == 0 or idum == (len(xaxis1) - 1):
                    lok = True
                elif rms_tmp > 0:
                    if delta_sp < 3.0 * rms_tmp:
                        lok = True
                    else:
                        lok = False
                else:
                    lok = True
                if lok:
                    xnewfit.append(xaxis1[idum])
                    ynewfit.append(sp_collapsed[idum])
            nremove = 5
            splnew = AdaptiveLSQUnivariateSpline(
                x=xnewfit[nremove:-nremove],
                y=ynewfit[nremove:-nremove],
                t=xknots[1:-1],
                adaptive=False
            )
            sp_median = splnew(xaxis1)

            ymax_spmedian = sp_median.max()
            y_threshold = ymax_spmedian * args.minimum_fraction
            sp_median[np.where(sp_median < y_threshold)] = 0.0

            if abs(args.debugplot) > 10:
                title = 'Slitlet#' + str(islitlet) + ' (median spectrum)'
                ax = ximplotxy(xaxis1, sp_collapsed,
                               title=title,
                               show=False, **{'label' : 'collapsed spectrum'})
                ax.plot(xaxis1, sp_median, label='fitted spectrum')
                ax.plot([1, naxis1_slitlet2d], 2*[y_threshold],
                        label='threshold')
                ax.plot(xknots, yknots, 'o', label='knots')
                ax.legend()
                ax.set_ylim(-0.05*ymax_spmedian, 1.05*ymax_spmedian)
                pause_debugplot(args.debugplot,
                                pltshow=True, tight_layout=True)

            # generate rectified slitlet region filled with the median spectrum
            slitlet2d_rect_spmedian = np.tile(sp_median, (naxis2_slitlet2d, 1))
            if abs(args.debugplot) > 10:
                slt.ximshow_rectified(
                    slitlet2d_rect=slitlet2d_rect_spmedian,
                    subtitle='rectified, filled with median spectrum'
                )

            # unrectified image
            slitlet2d_unrect_spmedian = slt.rectify(
                slitlet2d=slitlet2d_rect_spmedian,
                resampling=args.resampling,
                inverse=True,
                subtitle='unrectified, filled with median spectrum'
            )

            # normalize initial slitlet image (avoid division by zero)
            slitlet2d_norm = np.zeros_like(slitlet2d)
            for j in range(naxis1_slitlet2d):
                for i in range(naxis2_slitlet2d):
                    den = slitlet2d_unrect_spmedian[i, j]
                    if den == 0:
                        slitlet2d_norm[i, j] = 1.0
                    else:
                        slitlet2d_norm[i, j] = slitlet2d[i, j] / den

            if abs(args.debugplot) > 10:
                slt.ximshow_unrectified(
                    slitlet2d=slitlet2d_norm,
                    subtitle='unrectified, pixel-to-pixel'
                )

            # check for pseudo-longslit with previous slitlet
            if islitlet > 1:
                if (islitlet - 1) in list_valid_islitlets:
                    c1 = csu_conf_fitsfile.csu_bar_slit_center(islitlet - 1)
                    w1 = csu_conf_fitsfile.csu_bar_slit_width(islitlet - 1)
                    c2 = csu_conf_fitsfile.csu_bar_slit_center(islitlet)
                    w2 = csu_conf_fitsfile.csu_bar_slit_width(islitlet)
                    if abs(w1-w2)/w1 < 0.25:
                        wmean = (w1 + w2) / 2.0
                        if abs(c1 - c2) < wmean/4.0:
                            same_slitlet_below = True
                        else:
                            same_slitlet_below = False
                    else:
                        same_slitlet_below = False
                else:
                    same_slitlet_below = False
            else:
                same_slitlet_below = False

            # check for pseudo-longslit with previous slitlet
            if islitlet < EMIR_NBARS:
                if (islitlet + 1) in list_valid_islitlets:
                    c1 = csu_conf_fitsfile.csu_bar_slit_center(islitlet)
                    w1 = csu_conf_fitsfile.csu_bar_slit_width(islitlet)
                    c2 = csu_conf_fitsfile.csu_bar_slit_center(islitlet + 1)
                    w2 = csu_conf_fitsfile.csu_bar_slit_width(islitlet + 1)
                    if abs(w1-w2)/w1 < 0.25:
                        wmean = (w1 + w2) / 2.0
                        if abs(c1 - c2) < wmean/4.0:
                            same_slitlet_above = True
                        else:
                            same_slitlet_above = False
                    else:
                        same_slitlet_above = False
                else:
                    same_slitlet_above = False
            else:
                same_slitlet_above = False

            for j in range(EMIR_NAXIS1):
                xchannel = j + 1
                y0_lower = slt.list_frontiers[0](xchannel)
                y0_upper = slt.list_frontiers[1](xchannel)
                n1, n2 = nscan_minmax_frontiers(y0_frontier_lower=y0_lower,
                                                y0_frontier_upper=y0_upper,
                                                resize=True)
                # note that n1 and n2 are scans (ranging from 1 to NAXIS2)
                nn1 = n1 - slt.bb_ns1_orig + 1
                nn2 = n2 - slt.bb_ns1_orig + 1
                image2d_flatfielded[(n1 - 1):n2, j] = \
                    slitlet2d_norm[(nn1 - 1):nn2, j]

                # force to 1.0 region around frontiers
                if not same_slitlet_below:
                    image2d_flatfielded[(n1 - 1):(n1 + 2), j] = 1
                if not same_slitlet_above:
                    image2d_flatfielded[(n2 - 5):n2, j] = 1
        else:
            if args.debugplot == 0:
                islitlet_progress(islitlet, EMIR_NBARS, ignore=True)

    if args.debugplot == 0:
        print('OK!')

    # restore global offsets
    image2d_flatfielded = apply_integer_offsets(
        image2d=image2d_flatfielded ,
        offx=-rectwv_coeff.global_integer_offset_x_pix,
        offy=-rectwv_coeff.global_integer_offset_y_pix
    )

    # set pixels below minimum value to 1.0
    filtered = np.where(image2d_flatfielded < args.minimum_value_in_output)
    image2d_flatfielded[filtered] = 1.0

    # set pixels above maximum value to 1.0
    filtered = np.where(image2d_flatfielded > args.maximum_value_in_output)
    image2d_flatfielded[filtered] = 1.0

    # save output file
    save_ndarray_to_fits(
        array=image2d_flatfielded,
        file_name=args.outfile,
        main_header=header,
        overwrite=True
    )
    print('>>> Saving file ' + args.outfile.name)
def display_slitlet_arrangement(fileobj,
                                grism=None,
                                spfilter=None,
                                bbox=None,
                                adjust=None,
                                geometry=None,
                                debugplot=0):
    """Display slitlet arrangment from CSUP keywords in FITS header.

    Parameters
    ----------
    fileobj : file object
        FITS or TXT file object.
    grism : str
        Grism.
    grism : str
        Filter.
    bbox : tuple of 4 floats
        If not None, values for xmin, xmax, ymin and ymax.
    adjust : bool
        Adjust X range according to minimum and maximum csu_bar_left
        and csu_bar_right (note that this option is overriden by 'bbox')
    geometry : tuple (4 integers) or None
        x, y, dx, dy values employed to set the Qt backend geometry.
    debugplot : int
        Determines whether intermediate computations and/or plots
        are displayed. The valid codes are defined in
        numina.array.display.pause_debugplot

    Returns
    -------
    csu_bar_left : list of floats
        Location (mm) of the left bar for each slitlet.
    csu_bar_right : list of floats
        Location (mm) of the right bar for each slitlet, using the
        same origin employed for csu_bar_left (which is not the
        value stored in the FITS keywords.
    csu_bar_slit_center : list of floats
        Middle point (mm) in between the two bars defining a slitlet.
    csu_bar_slit_width : list of floats
        Slitlet width (mm), computed as the distance between the two
        bars defining the slitlet.

    """

    if fileobj.name[-4:] == ".txt":
        if grism is None:
            raise ValueError("Undefined grism!")
        if spfilter is None:
            raise ValueError("Undefined filter!")
        # define CsuConfiguration object
        csu_config = CsuConfiguration()
        csu_config._csu_bar_left = []
        csu_config._csu_bar_right = []
        csu_config._csu_bar_slit_center = []
        csu_config._csu_bar_slit_width = []

        # since the input filename has been opened with argparse in binary
        # mode, it is necessary to close it and open it in text mode
        fileobj.close()
        # read TXT file
        with open(fileobj.name, mode='rt') as f:
            file_content = f.read().splitlines()
        next_id_bar = 1
        for line in file_content:
            if len(line) > 0:
                if line[0] not in ['#']:
                    line_contents = line.split()
                    id_bar = int(line_contents[0])
                    position = float(line_contents[1])
                    if id_bar == next_id_bar:
                        if id_bar <= EMIR_NBARS:
                            csu_config._csu_bar_left.append(position)
                            next_id_bar = id_bar + EMIR_NBARS
                        else:
                            csu_config._csu_bar_right.append(341.5 - position)
                            next_id_bar = id_bar - EMIR_NBARS + 1
                    else:
                        raise ValueError("Unexpected id_bar:" + str(id_bar))

        # compute slit width and center
        for i in range(EMIR_NBARS):
            csu_config._csu_bar_slit_center.append(
                (csu_config._csu_bar_left[i] + csu_config._csu_bar_right[i])/2
            )
            csu_config._csu_bar_slit_width.append(
                csu_config._csu_bar_right[i] - csu_config._csu_bar_left[i]
            )

    else:
        # read input FITS file
        hdulist = fits.open(fileobj.name)
        image_header = hdulist[0].header
        hdulist.close()

        # additional info from header
        grism = image_header['grism']
        spfilter = image_header['filter']

        # define slitlet arrangement
        csu_config = CsuConfiguration.define_from_fits(fileobj)

    # determine calibration
    if grism in ["J", "OPEN"] and spfilter == "J":
        wv_parameters = set_wv_parameters("J", "J")
    elif grism in ["H", "OPEN"] and spfilter == "H":
        wv_parameters = set_wv_parameters("H", "H")
    elif grism in ["K", "OPEN"] and spfilter == "Ksp":
        wv_parameters = set_wv_parameters("Ksp", "K")
    elif grism in ["LR", "OPEN"] and spfilter == "YJ":
        wv_parameters = set_wv_parameters("YJ", "LR")
    elif grism in ["LR", "OPEN"] and spfilter == "HK":
        wv_parameters = set_wv_parameters("HK", "LR")
    else:
        raise ValueError("Invalid grism + filter configuration")

    crval1 = wv_parameters['poly_crval1_linear']
    cdelt1 = wv_parameters['poly_cdelt1_linear']

    wvmin_useful = wv_parameters['wvmin_useful']
    wvmax_useful = wv_parameters['wvmax_useful']

    # display arrangement
    if abs(debugplot) >= 10:
        print("slit     left    right   center   width   min.wave   max.wave")
        print("====  =======  =======  =======   =====   ========   ========")
        for i in range(EMIR_NBARS):
            ibar = i + 1
            csu_crval1 = crval1(csu_config.csu_bar_slit_center(ibar))
            csu_cdelt1 = cdelt1(csu_config.csu_bar_slit_center(ibar))
            csu_crvaln = csu_crval1 + (EMIR_NAXIS1 - 1) * csu_cdelt1
            if wvmin_useful is not None:
                csu_crval1 = np.amax([csu_crval1, wvmin_useful])
            if wvmax_useful is not None:
                csu_crvaln = np.amin([csu_crvaln, wvmax_useful])
            print("{0:4d} {1:8.3f} {2:8.3f} {3:8.3f} {4:7.3f}   "
                  "{5:8.2f}   {6:8.2f}".format(
                ibar, csu_config.csu_bar_left(ibar),
                csu_config.csu_bar_right(ibar),
                csu_config.csu_bar_slit_center(ibar),
                csu_config.csu_bar_slit_width(ibar),
                csu_crval1, csu_crvaln)
            )
        print(
            "---> {0:8.3f} {1:8.3f} {2:8.3f} {3:7.3f} <- mean (all)".format(
                np.mean(csu_config._csu_bar_left),
                np.mean(csu_config._csu_bar_right),
                np.mean(csu_config._csu_bar_slit_center),
                np.mean(csu_config._csu_bar_slit_width)
            )
        )
        print(
            "---> {0:8.3f} {1:8.3f} {2:8.3f} {3:7.3f} <- mean (odd)".format(
                np.mean(csu_config._csu_bar_left[::2]),
                np.mean(csu_config._csu_bar_right[::2]),
                np.mean(csu_config._csu_bar_slit_center[::2]),
                np.mean(csu_config._csu_bar_slit_width[::2])
            )
        )
        print(
            "---> {0:8.3f} {1:8.3f} {2:8.3f} {3:7.3f} <- mean (even)".format(
                np.mean(csu_config._csu_bar_left[1::2]),
                np.mean(csu_config._csu_bar_right[1::2]),
                np.mean(csu_config._csu_bar_slit_center[1::2]),
                np.mean(csu_config._csu_bar_slit_width[1::2])
            )
        )

    # display slit arrangement
    if abs(debugplot) % 10 != 0:
        fig = plt.figure()
        set_window_geometry(geometry)
        ax = fig.add_subplot(111)
        if bbox is None:
            if adjust:
                xmin = min(csu_config._csu_bar_left)
                xmax = max(csu_config._csu_bar_right)
                dx = xmax - xmin
                if dx == 0:
                    dx = 1
                xmin -= dx/20
                xmax += dx/20
                ax.set_xlim(xmin, xmax)
            else:
                ax.set_xlim(0., 341.5)
            ax.set_ylim(0, 56)
        else:
            ax.set_xlim(bbox[0], bbox[1])
            ax.set_ylim(bbox[2], bbox[3])
        ax.set_xlabel('csu_bar_position (mm)')
        ax.set_ylabel('slit number')
        for i in range(EMIR_NBARS):
            ibar = i + 1
            ax.add_patch(patches.Rectangle(
                (csu_config.csu_bar_left(ibar), ibar-0.5),
                csu_config.csu_bar_slit_width(ibar), 1.0))
            ax.plot([0., csu_config.csu_bar_left(ibar)], [ibar, ibar],
                    '-', color='gray')
            ax.plot([csu_config.csu_bar_right(ibar), 341.5],
                    [ibar, ibar], '-', color='gray')
        plt.title("File: " + fileobj.name + "\ngrism=" + grism +
                  ", filter=" + spfilter)
        pause_debugplot(debugplot, pltshow=True)

    # return results
    return csu_config._csu_bar_left, csu_config._csu_bar_right, \
           csu_config._csu_bar_slit_center, csu_config._csu_bar_slit_width
def main(args=None):

    # parse command-line options
    parser = argparse.ArgumentParser(
        description='description: display arrangement of EMIR CSU bars'
    )

    # positional arguments
    parser.add_argument("filename",
                        help="FITS files (wildcards accepted) or single TXT "
                             "file with CSU configuration from OSP",
                        type=argparse.FileType('rb'),
                        nargs='+')

    # optional arguments
    parser.add_argument("--grism",
                        help="Grism (J, H, K, LR)",
                        choices=["J", "H", "K", "LR"])
    parser.add_argument("--filter",
                        help="Filter (J, H, Ksp, YJ, HK)",
                        choices=["J", "H", "Ksp", "YJ", "HK"])
    parser.add_argument("--n_clusters",
                        help="Display histogram of slitlet widths",
                        default=0, type=int)
    parser.add_argument("--bbox",
                        help="Bounding box tuple xmin,xmax,ymin,ymax "
                             "indicating plot limits")
    parser.add_argument("--adjust",
                        help="Adjust X range according to minimum and maximum"
                             " csu_bar_left and csu_bar_right (note that this "
                             "option is overriden by --bbox",
                        action='store_true')
    parser.add_argument("--geometry",
                        help="Tuple x,y,dx,dy indicating window geometry",
                        default="0,0,640,480")
    parser.add_argument("--debugplot",
                        help="Integer indicating plotting & debugging options"
                             " (default=12)",
                        default=12, type=int,
                        choices=DEBUGPLOT_CODES)
    parser.add_argument("--echo",
                        help="Display full command line",
                        action="store_true")
    args = parser.parse_args(args)

    if args.echo:
        print('\033[1m\033[31mExecuting: ' + ' '.join(sys.argv) + '\033[0m\n')

    # geometry
    if args.geometry is None:
        geometry = None
    else:
        tmp_str = args.geometry.split(",")
        x_geom = int(tmp_str[0])
        y_geom = int(tmp_str[1])
        dx_geom = int(tmp_str[2])
        dy_geom = int(tmp_str[3])
        geometry = x_geom, y_geom, dx_geom, dy_geom

    # read bounding box
    if args.bbox is None:
        bbox = None
    else:
        str_bbox = args.bbox.split(",")
        xmin, xmax, ymin, ymax = [int(str_bbox[i]) for i in range(4)]
        bbox = xmin, xmax, ymin, ymax

    list_fits_file_objects = []
    # if input file is a single txt file, assume it is a list of FITS files
    if len(args.filename) == 1:
        list_fits_file_objects = [args.filename[0]]
    else:
        list_fits_file_objects = args.filename

    # total number of files to be examined
    nfiles = len(list_fits_file_objects)

    # declare arrays to store CSU values
    csu_bar_left = np.zeros((nfiles, EMIR_NBARS))
    csu_bar_right = np.zeros((nfiles, EMIR_NBARS))
    csu_bar_slit_center = np.zeros((nfiles, EMIR_NBARS))
    csu_bar_slit_width = np.zeros((nfiles, EMIR_NBARS))

    # display CSU bar arrangement
    for ifile, fileobj in enumerate(list_fits_file_objects):
        print("\nFile " + str(ifile+1) + "/" + str(nfiles) + ": " +
              fileobj.name)
        csu_bar_left[ifile, :], csu_bar_right[ifile, :], \
        csu_bar_slit_center[ifile, :], csu_bar_slit_width[ifile, :] = \
            display_slitlet_arrangement(
                fileobj,
                grism=args.grism,
                spfilter=args.filter,
                bbox=bbox,
                adjust=args.adjust,
                geometry=geometry,
                debugplot=args.debugplot
            )
        if args.n_clusters >= 2:
            display_slitlet_histogram(
                csu_bar_slit_width[ifile, :],
                n_clusters=args.n_clusters,
                geometry=geometry,
                debugplot=args.debugplot
            )

    # print summary of comparison between files
    if nfiles > 1:
        std_csu_bar_left = np.zeros(EMIR_NBARS)
        std_csu_bar_right = np.zeros(EMIR_NBARS)
        std_csu_bar_slit_center = np.zeros(EMIR_NBARS)
        std_csu_bar_slit_width = np.zeros(EMIR_NBARS)
        if args.debugplot >= 10:
            print("\n   STANDARD DEVIATION BETWEEN IMAGES")
            print("slit     left    right   center   width")
            print("====  =======  =======  =======   =====")
            for i in range(EMIR_NBARS):
                ibar = i + 1
                std_csu_bar_left[i] = np.std(csu_bar_left[:, i])
                std_csu_bar_right[i] = np.std(csu_bar_right[:, i])
                std_csu_bar_slit_center[i] = np.std(csu_bar_slit_center[:, i])
                std_csu_bar_slit_width[i] = np.std(csu_bar_slit_width[:, i])
                print("{0:4d} {1:8.3f} {2:8.3f} {3:8.3f} {4:7.3f}".format(
                    ibar,
                    std_csu_bar_left[i],
                    std_csu_bar_right[i],
                    std_csu_bar_slit_center[i],
                    std_csu_bar_slit_width[i]))
            print("====  =======  =======  =======   =====")
            print("MIN: {0:8.3f} {1:8.3f} {2:8.3f} {3:7.3f}".format(
                std_csu_bar_left.min(),
                std_csu_bar_right.min(),
                std_csu_bar_slit_center.min(),
                std_csu_bar_slit_width.min()))
            print("MAX: {0:8.3f} {1:8.3f} {2:8.3f} {3:7.3f}".format(
                std_csu_bar_left.max(),
                std_csu_bar_right.max(),
                std_csu_bar_slit_center.max(),
                std_csu_bar_slit_width.max()))
            print("====  =======  =======  =======   =====")
            print("Total number of files examined:", nfiles)

    # stop program execution
    if len(list_fits_file_objects) > 1:
        pause_debugplot(12, optional_prompt="Press RETURN to STOP")
Example #10
0
def main(args=None):
    # parse command-line options
    parser = argparse.ArgumentParser(
        description='description: compute pixel-to-pixel flatfield'
    )

    # required arguments
    parser.add_argument("fitsfile",
                        help="Input FITS file (flat ON-OFF)",
                        type=argparse.FileType('rb'))
    parser.add_argument("--rectwv_coeff", required=True,
                        help="Input JSON file with rectification and "
                             "wavelength calibration coefficients",
                        type=argparse.FileType('rt'))
    parser.add_argument("--minimum_slitlet_width_mm", required=True,
                        help="Minimum slitlet width in mm",
                        type=float)
    parser.add_argument("--maximum_slitlet_width_mm", required=True,
                        help="Maximum slitlet width in mm",
                        type=float)
    parser.add_argument("--minimum_fraction", required=True,
                        help="Minimum allowed flatfielding value",
                        type=float, default=0.01)
    parser.add_argument("--minimum_value_in_output",
                        help="Minimum value allowed in output file: pixels "
                             "below this value are set to 1.0 (default=0.01)",
                        type=float, default=0.01)
    parser.add_argument("--maximum_value_in_output",
                        help="Maximum value allowed in output file: pixels "
                             "above this value are set to 1.0 (default=10.0)",
                        type=float, default=10.0)
    parser.add_argument("--nwindow_median",
                        help="Window size to smooth median spectrum in the "
                             "spectral direction",
                        type=int)
    parser.add_argument("--outfile", required=True,
                        help="Output FITS file",
                        type=lambda x: arg_file_is_new(parser, x, mode='wb'))

    # optional arguments
    parser.add_argument("--delta_global_integer_offset_x_pix",
                        help="Delta global integer offset in the X direction "
                             "(default=0)",
                        default=0, type=int)
    parser.add_argument("--delta_global_integer_offset_y_pix",
                        help="Delta global integer offset in the Y direction "
                             "(default=0)",
                        default=0, type=int)
    parser.add_argument("--resampling",
                        help="Resampling method: 1 -> nearest neighbor, "
                             "2 -> linear interpolation (default)",
                        default=2, type=int,
                        choices=(1, 2))
    parser.add_argument("--ignore_DTUconf",
                        help="Ignore DTU configurations differences between "
                             "model and input image",
                        action="store_true")
    parser.add_argument("--debugplot",
                        help="Integer indicating plotting & debugging options"
                             " (default=0)",
                        default=0, type=int,
                        choices=DEBUGPLOT_CODES)
    parser.add_argument("--echo",
                        help="Display full command line",
                        action="store_true")
    args = parser.parse_args(args)

    if args.echo:
        print('\033[1m\033[31m% ' + ' '.join(sys.argv) + '\033[0m\n')

    # This code is obsolete
    raise ValueError('This code is obsolete: use recipe in '
                     'emirdrp/recipes/spec/flatpix2pix.py')

    # read calibration structure from JSON file
    rectwv_coeff = RectWaveCoeff._datatype_load(args.rectwv_coeff.name)

    # modify (when requested) global offsets
    rectwv_coeff.global_integer_offset_x_pix += \
        args.delta_global_integer_offset_x_pix
    rectwv_coeff.global_integer_offset_y_pix += \
        args.delta_global_integer_offset_y_pix

    # read FITS image and its corresponding header
    hdulist = fits.open(args.fitsfile)
    header = hdulist[0].header
    image2d = hdulist[0].data
    hdulist.close()

    # apply global offsets
    image2d = apply_integer_offsets(
        image2d=image2d,
        offx=rectwv_coeff.global_integer_offset_x_pix,
        offy=rectwv_coeff.global_integer_offset_y_pix
    )

    # protections
    naxis2, naxis1 = image2d.shape
    if naxis1 != header['naxis1'] or naxis2 != header['naxis2']:
        print('>>> NAXIS1:', naxis1)
        print('>>> NAXIS2:', naxis2)
        raise ValueError('Something is wrong with NAXIS1 and/or NAXIS2')
    if abs(args.debugplot) >= 10:
        print('>>> NAXIS1:', naxis1)
        print('>>> NAXIS2:', naxis2)

    # check that the input FITS file grism and filter match
    filter_name = header['filter']
    if filter_name != rectwv_coeff.tags['filter']:
        raise ValueError("Filter name does not match!")
    grism_name = header['grism']
    if grism_name != rectwv_coeff.tags['grism']:
        raise ValueError("Filter name does not match!")
    if abs(args.debugplot) >= 10:
        print('>>> grism.......:', grism_name)
        print('>>> filter......:', filter_name)

    # check that the DTU configurations are compatible
    dtu_conf_fitsfile = DtuConfiguration.define_from_fits(args.fitsfile)
    dtu_conf_jsonfile = DtuConfiguration.define_from_dictionary(
        rectwv_coeff.meta_info['dtu_configuration'])
    if dtu_conf_fitsfile != dtu_conf_jsonfile:
        print('DTU configuration (FITS file):\n\t', dtu_conf_fitsfile)
        print('DTU configuration (JSON file):\n\t', dtu_conf_jsonfile)
        if args.ignore_DTUconf:
            print('WARNING: DTU configuration differences found!')
        else:
            raise ValueError('DTU configurations do not match')
    else:
        if abs(args.debugplot) >= 10:
            print('>>> DTU Configuration match!')
            print(dtu_conf_fitsfile)

    # load CSU configuration
    csu_conf_fitsfile = CsuConfiguration.define_from_fits(args.fitsfile)
    if abs(args.debugplot) >= 10:
        print(csu_conf_fitsfile)

    # valid slitlet numbers
    list_valid_islitlets = list(range(1, EMIR_NBARS + 1))
    for idel in rectwv_coeff.missing_slitlets:
        print('-> Removing slitlet (not defined):', idel)
        list_valid_islitlets.remove(idel)
    # filter out slitlets with widths outside valid range
    list_outside_valid_width = []
    for islitlet in list_valid_islitlets:
        slitwidth = csu_conf_fitsfile.csu_bar_slit_width(islitlet)
        if (slitwidth < args.minimum_slitlet_width_mm) or \
                (slitwidth > args.maximum_slitlet_width_mm):
            list_outside_valid_width.append(islitlet)
            print('-> Removing slitlet (invalid width):', islitlet)
    if len(list_outside_valid_width) > 0:
        for idel in list_outside_valid_width:
            list_valid_islitlets.remove(idel)
    print('>>> valid slitlet numbers:\n', list_valid_islitlets)

    # ---

    # compute and store median spectrum (and masked region) for each
    # individual slitlet
    image2d_sp_median = np.zeros((EMIR_NBARS, EMIR_NAXIS1))
    image2d_sp_mask = np.zeros((EMIR_NBARS, EMIR_NAXIS1), dtype=bool)
    for islitlet in list(range(1, EMIR_NBARS + 1)):
        if islitlet in list_valid_islitlets:
            if args.debugplot == 0:
                islitlet_progress(islitlet, EMIR_NBARS, ignore=False)
            # define Slitlet2D object
            slt = Slitlet2D(islitlet=islitlet,
                            rectwv_coeff=rectwv_coeff,
                            debugplot=args.debugplot)

            if abs(args.debugplot) >= 10:
                print(slt)

            # extract (distorted) slitlet from the initial image
            slitlet2d = slt.extract_slitlet2d(
                image_2k2k=image2d,
                subtitle='original image'
            )

            # rectify slitlet
            slitlet2d_rect = slt.rectify(
                slitlet2d=slitlet2d,
                resampling=args.resampling,
                subtitle='original rectified'
            )
            naxis2_slitlet2d, naxis1_slitlet2d = slitlet2d_rect.shape

            if naxis1_slitlet2d != EMIR_NAXIS1:
                print('naxis1_slitlet2d: ', naxis1_slitlet2d)
                print('EMIR_NAXIS1.....: ', EMIR_NAXIS1)
                raise ValueError("Unexpected naxis1_slitlet2d")

            sp_mask = np.zeros(naxis1_slitlet2d, dtype=bool)

            # for grism LR set to zero data beyond useful wavelength range
            if grism_name == 'LR':
                wv_parameters = set_wv_parameters(filter_name, grism_name)
                x_pix = np.arange(1, naxis1_slitlet2d + 1)
                wl_pix = polyval(x_pix, slt.wpoly)
                lremove = wl_pix < wv_parameters['wvmin_useful']
                sp_mask[lremove] = True
                slitlet2d_rect[:, lremove] = 0.0
                lremove = wl_pix > wv_parameters['wvmax_useful']
                slitlet2d_rect[:, lremove] = 0.0
                sp_mask[lremove] = True

            # get useful slitlet region (use boundaries instead of frontiers;
            # note that the nscan_minmax_frontiers() works well independently
            # of using frontiers of boundaries as arguments)
            nscan_min, nscan_max = nscan_minmax_frontiers(
                slt.y0_reference_lower,
                slt.y0_reference_upper,
                resize=False
            )
            ii1 = nscan_min - slt.bb_ns1_orig
            ii2 = nscan_max - slt.bb_ns1_orig + 1

            # median spectrum
            sp_collapsed = np.median(slitlet2d_rect[ii1:(ii2 + 1), :], axis=0)

            # smooth median spectrum along the spectral direction
            sp_median = ndimage.median_filter(
                sp_collapsed,
                args.nwindow_median,
                mode='nearest'
            )

            """
                nremove = 5
                spl = AdaptiveLSQUnivariateSpline(
                    x=xaxis1[nremove:-nremove],
                    y=sp_collapsed[nremove:-nremove],
                    t=11,
                    adaptive=True
                )
                xknots = spl.get_knots()
                yknots = spl(xknots)
                sp_median = spl(xaxis1)

                # compute rms within each knot interval
                nknots = len(xknots)
                rms_array = np.zeros(nknots - 1, dtype=float)
                for iknot in range(nknots - 1):
                    residuals = []
                    for xdum, ydum, yydum in \
                            zip(xaxis1, sp_collapsed, sp_median):
                        if xknots[iknot] <= xdum <= xknots[iknot + 1]:
                            residuals.append(abs(ydum - yydum))
                    if len(residuals) > 5:
                        rms_array[iknot] = np.std(residuals)
                    else:
                        rms_array[iknot] = 0

                # determine in which knot interval falls each pixel
                iknot_array = np.zeros(len(xaxis1), dtype=int)
                for idum, xdum in enumerate(xaxis1):
                    for iknot in range(nknots - 1):
                        if xknots[iknot] <= xdum <= xknots[iknot + 1]:
                            iknot_array[idum] = iknot

                # compute new fit removing deviant points (with fixed knots)
                xnewfit = []
                ynewfit = []
                for idum in range(len(xaxis1)):
                    delta_sp = abs(sp_collapsed[idum] - sp_median[idum])
                    rms_tmp = rms_array[iknot_array[idum]]
                    if idum == 0 or idum == (len(xaxis1) - 1):
                        lok = True
                    elif rms_tmp > 0:
                        if delta_sp < 3.0 * rms_tmp:
                            lok = True
                        else:
                            lok = False
                    else:
                        lok = True
                    if lok:
                        xnewfit.append(xaxis1[idum])
                        ynewfit.append(sp_collapsed[idum])
                nremove = 5
                splnew = AdaptiveLSQUnivariateSpline(
                    x=xnewfit[nremove:-nremove],
                    y=ynewfit[nremove:-nremove],
                    t=xknots[1:-1],
                    adaptive=False
                )
                sp_median = splnew(xaxis1)
            """

            ymax_spmedian = sp_median.max()
            y_threshold = ymax_spmedian * args.minimum_fraction
            lremove = np.where(sp_median < y_threshold)
            sp_median[lremove] = 0.0
            sp_mask[lremove] = True

            image2d_sp_median[islitlet - 1, :] = sp_median
            image2d_sp_mask[islitlet - 1, :] = sp_mask

            if abs(args.debugplot) % 10 != 0:
                xaxis1 = np.arange(1, naxis1_slitlet2d + 1)
                title = 'Slitlet#' + str(islitlet) + ' (median spectrum)'
                ax = ximplotxy(xaxis1, sp_collapsed,
                               title=title,
                               show=False, **{'label' : 'collapsed spectrum'})
                ax.plot(xaxis1, sp_median, label='fitted spectrum')
                ax.plot([1, naxis1_slitlet2d], 2*[y_threshold],
                        label='threshold')
                # ax.plot(xknots, yknots, 'o', label='knots')
                ax.legend()
                ax.set_ylim(-0.05*ymax_spmedian, 1.05*ymax_spmedian)
                pause_debugplot(args.debugplot,
                                pltshow=True, tight_layout=True)
        else:
            if args.debugplot == 0:
                islitlet_progress(islitlet, EMIR_NBARS, ignore=True)

    # ToDo: compute "average" spectrum for each pseudo-longslit, scaling
    #       with the median signal in each slitlet; derive a particular
    #       spectrum for each slitlet (scaling properly)

    image2d_sp_median_masked = np.ma.masked_array(
        image2d_sp_median,
        mask=image2d_sp_mask
    )
    ycut_median = np.ma.median(image2d_sp_median_masked, axis=1).data
    ycut_median_2d = np.repeat(ycut_median, EMIR_NAXIS1).reshape(
        EMIR_NBARS, EMIR_NAXIS1)
    image2d_sp_median_eq = image2d_sp_median_masked / ycut_median_2d
    image2d_sp_median_eq = image2d_sp_median_eq.data

    if True:
        ximshow(image2d_sp_median, title='sp_median', debugplot=12)
        ximplotxy(np.arange(1, EMIR_NBARS + 1), ycut_median, 'ro',
                  title='median value of each spectrum', debugplot=12)
        ximshow(image2d_sp_median_eq, title='sp_median_eq', debugplot=12)

    csu_conf_fitsfile.display_pseudo_longslits(
        list_valid_slitlets=list_valid_islitlets)
    dict_longslits = csu_conf_fitsfile.pseudo_longslits()

    # compute median spectrum for each longslit and insert (properly
    # scaled) that spectrum in each slitlet belonging to that longslit
    image2d_sp_median_longslit = np.zeros((EMIR_NBARS, EMIR_NAXIS1))
    islitlet = 1
    loop = True
    while loop:
        if islitlet in list_valid_islitlets:
            imin = dict_longslits[islitlet].imin()
            imax = dict_longslits[islitlet].imax()
            print('--> imin, imax: ', imin, imax)
            sp_median_longslit = np.median(
                image2d_sp_median_eq[(imin - 1):imax, :], axis=0)
            for i in range(imin, imax+1):
                print('----> i: ', i)
                image2d_sp_median_longslit[(i - 1), :] = \
                    sp_median_longslit * ycut_median[i - 1]
            islitlet = imax
        else:
            print('--> ignoring: ', islitlet)
        if islitlet == EMIR_NBARS:
            loop = False
        else:
            islitlet += 1
    if True:
        ximshow(image2d_sp_median_longslit, debugplot=12)

    # initialize rectified image
    image2d_flatfielded = np.zeros((EMIR_NAXIS2, EMIR_NAXIS1))

    # main loop
    for islitlet in list(range(1, EMIR_NBARS + 1)):
        if islitlet in list_valid_islitlets:
            if args.debugplot == 0:
                islitlet_progress(islitlet, EMIR_NBARS, ignore=False)
            # define Slitlet2D object
            slt = Slitlet2D(islitlet=islitlet,
                            rectwv_coeff=rectwv_coeff,
                            debugplot=args.debugplot)

            # extract (distorted) slitlet from the initial image
            slitlet2d = slt.extract_slitlet2d(
                image_2k2k=image2d,
                subtitle='original image'
            )

            # rectify slitlet
            slitlet2d_rect = slt.rectify(
                slitlet2d=slitlet2d,
                resampling=args.resampling,
                subtitle='original rectified'
            )
            naxis2_slitlet2d, naxis1_slitlet2d = slitlet2d_rect.shape

            sp_median = image2d_sp_median_longslit[islitlet - 1, :]

            # generate rectified slitlet region filled with the median spectrum
            slitlet2d_rect_spmedian = np.tile(sp_median, (naxis2_slitlet2d, 1))
            if abs(args.debugplot) > 10:
                slt.ximshow_rectified(
                    slitlet2d_rect=slitlet2d_rect_spmedian,
                    subtitle='rectified, filled with median spectrum'
                )

            # unrectified image
            slitlet2d_unrect_spmedian = slt.rectify(
                slitlet2d=slitlet2d_rect_spmedian,
                resampling=args.resampling,
                inverse=True,
                subtitle='unrectified, filled with median spectrum'
            )

            # normalize initial slitlet image (avoid division by zero)
            slitlet2d_norm = np.zeros_like(slitlet2d)
            for j in range(naxis1_slitlet2d):
                for i in range(naxis2_slitlet2d):
                    den = slitlet2d_unrect_spmedian[i, j]
                    if den == 0:
                        slitlet2d_norm[i, j] = 1.0
                    else:
                        slitlet2d_norm[i, j] = slitlet2d[i, j] / den

            if abs(args.debugplot) > 10:
                slt.ximshow_unrectified(
                    slitlet2d=slitlet2d_norm,
                    subtitle='unrectified, pixel-to-pixel'
                )

            # check for pseudo-longslit with previous slitlet
            if islitlet > 1:
                if (islitlet - 1) in list_valid_islitlets:
                    c1 = csu_conf_fitsfile.csu_bar_slit_center(islitlet - 1)
                    w1 = csu_conf_fitsfile.csu_bar_slit_width(islitlet - 1)
                    c2 = csu_conf_fitsfile.csu_bar_slit_center(islitlet)
                    w2 = csu_conf_fitsfile.csu_bar_slit_width(islitlet)
                    if abs(w1-w2)/w1 < 0.25:
                        wmean = (w1 + w2) / 2.0
                        if abs(c1 - c2) < wmean/4.0:
                            same_slitlet_below = True
                        else:
                            same_slitlet_below = False
                    else:
                        same_slitlet_below = False
                else:
                    same_slitlet_below = False
            else:
                same_slitlet_below = False

            # check for pseudo-longslit with next slitlet
            if islitlet < EMIR_NBARS:
                if (islitlet + 1) in list_valid_islitlets:
                    c1 = csu_conf_fitsfile.csu_bar_slit_center(islitlet)
                    w1 = csu_conf_fitsfile.csu_bar_slit_width(islitlet)
                    c2 = csu_conf_fitsfile.csu_bar_slit_center(islitlet + 1)
                    w2 = csu_conf_fitsfile.csu_bar_slit_width(islitlet + 1)
                    if abs(w1-w2)/w1 < 0.25:
                        wmean = (w1 + w2) / 2.0
                        if abs(c1 - c2) < wmean/4.0:
                            same_slitlet_above = True
                        else:
                            same_slitlet_above = False
                    else:
                        same_slitlet_above = False
                else:
                    same_slitlet_above = False
            else:
                same_slitlet_above = False

            for j in range(EMIR_NAXIS1):
                xchannel = j + 1
                y0_lower = slt.list_frontiers[0](xchannel)
                y0_upper = slt.list_frontiers[1](xchannel)
                n1, n2 = nscan_minmax_frontiers(y0_frontier_lower=y0_lower,
                                                y0_frontier_upper=y0_upper,
                                                resize=True)
                # note that n1 and n2 are scans (ranging from 1 to NAXIS2)
                nn1 = n1 - slt.bb_ns1_orig + 1
                nn2 = n2 - slt.bb_ns1_orig + 1
                image2d_flatfielded[(n1 - 1):n2, j] = \
                    slitlet2d_norm[(nn1 - 1):nn2, j]

                # force to 1.0 region around frontiers
                if not same_slitlet_below:
                    image2d_flatfielded[(n1 - 1):(n1 + 2), j] = 1
                if not same_slitlet_above:
                    image2d_flatfielded[(n2 - 5):n2, j] = 1
        else:
            if args.debugplot == 0:
                islitlet_progress(islitlet, EMIR_NBARS, ignore=True)

    if args.debugplot == 0:
        print('OK!')

    # restore global offsets
    image2d_flatfielded = apply_integer_offsets(
        image2d=image2d_flatfielded ,
        offx=-rectwv_coeff.global_integer_offset_x_pix,
        offy=-rectwv_coeff.global_integer_offset_y_pix
    )

    # set pixels below minimum value to 1.0
    filtered = np.where(image2d_flatfielded < args.minimum_value_in_output)
    image2d_flatfielded[filtered] = 1.0

    # set pixels above maximum value to 1.0
    filtered = np.where(image2d_flatfielded > args.maximum_value_in_output)
    image2d_flatfielded[filtered] = 1.0

    # save output file
    save_ndarray_to_fits(
        array=image2d_flatfielded,
        file_name=args.outfile,
        main_header=header,
        overwrite=True
    )
    print('>>> Saving file ' + args.outfile.name)
Example #11
0
def main(args=None):
    # parse command-line options
    parser = argparse.ArgumentParser(
        description='description: compute pixel-to-pixel flatfield')

    # required arguments
    parser.add_argument("fitsfile",
                        help="Input FITS file (flat ON-OFF)",
                        type=argparse.FileType('rb'))
    parser.add_argument("--rectwv_coeff",
                        required=True,
                        help="Input JSON file with rectification and "
                        "wavelength calibration coefficients",
                        type=argparse.FileType('rt'))
    parser.add_argument("--minimum_fraction",
                        required=True,
                        help="Minimum allowed flatfielding value",
                        type=float,
                        default=0.01)
    parser.add_argument("--minimum_value_in_output",
                        help="Minimum value allowed in output file: pixels "
                        "below this value are set to 1.0 (default=0.01)",
                        type=float,
                        default=0.01)
    parser.add_argument("--nwindow_median",
                        required=True,
                        help="Window size to smooth median spectrum in the "
                        "spectral direction",
                        type=int)
    parser.add_argument("--outfile",
                        required=True,
                        help="Output FITS file",
                        type=lambda x: arg_file_is_new(parser, x, mode='wb'))

    # optional arguments
    parser.add_argument("--delta_global_integer_offset_x_pix",
                        help="Delta global integer offset in the X direction "
                        "(default=0)",
                        default=0,
                        type=int)
    parser.add_argument("--delta_global_integer_offset_y_pix",
                        help="Delta global integer offset in the Y direction "
                        "(default=0)",
                        default=0,
                        type=int)
    parser.add_argument("--resampling",
                        help="Resampling method: 1 -> nearest neighbor, "
                        "2 -> linear interpolation (default)",
                        default=2,
                        type=int,
                        choices=(1, 2))
    parser.add_argument("--ignore_DTUconf",
                        help="Ignore DTU configurations differences between "
                        "model and input image",
                        action="store_true")
    parser.add_argument("--debugplot",
                        help="Integer indicating plotting & debugging options"
                        " (default=0)",
                        default=0,
                        type=int,
                        choices=DEBUGPLOT_CODES)
    parser.add_argument("--echo",
                        help="Display full command line",
                        action="store_true")
    args = parser.parse_args(args)

    if args.echo:
        print('\033[1m\033[31m% ' + ' '.join(sys.argv) + '\033[0m\n')

    # read calibration structure from JSON file
    rectwv_coeff = RectWaveCoeff._datatype_load(args.rectwv_coeff.name)

    # modify (when requested) global offsets
    rectwv_coeff.global_integer_offset_x_pix += \
        args.delta_global_integer_offset_x_pix
    rectwv_coeff.global_integer_offset_y_pix += \
        args.delta_global_integer_offset_y_pix

    # read FITS image and its corresponding header
    hdulist = fits.open(args.fitsfile)
    header = hdulist[0].header
    image2d = hdulist[0].data
    hdulist.close()

    # apply global offsets
    image2d = apply_integer_offsets(
        image2d=image2d,
        offx=rectwv_coeff.global_integer_offset_x_pix,
        offy=rectwv_coeff.global_integer_offset_y_pix)

    # protections
    naxis2, naxis1 = image2d.shape
    if naxis1 != header['naxis1'] or naxis2 != header['naxis2']:
        print('>>> NAXIS1:', naxis1)
        print('>>> NAXIS2:', naxis2)
        raise ValueError('Something is wrong with NAXIS1 and/or NAXIS2')
    if abs(args.debugplot) >= 10:
        print('>>> NAXIS1:', naxis1)
        print('>>> NAXIS2:', naxis2)

    # check that the input FITS file grism and filter match
    filter_name = header['filter']
    if filter_name != rectwv_coeff.tags['filter']:
        raise ValueError("Filter name does not match!")
    grism_name = header['grism']
    if grism_name != rectwv_coeff.tags['grism']:
        raise ValueError("Filter name does not match!")
    if abs(args.debugplot) >= 10:
        print('>>> grism.......:', grism_name)
        print('>>> filter......:', filter_name)

    # check that the DTU configurations are compatible
    dtu_conf_fitsfile = DtuConfiguration.define_from_fits(args.fitsfile)
    dtu_conf_jsonfile = DtuConfiguration.define_from_dictionary(
        rectwv_coeff.meta_info['dtu_configuration'])
    if dtu_conf_fitsfile != dtu_conf_jsonfile:
        print('DTU configuration (FITS file):\n\t', dtu_conf_fitsfile)
        print('DTU configuration (JSON file):\n\t', dtu_conf_jsonfile)
        if args.ignore_DTUconf:
            print('WARNING: DTU configuration differences found!')
        else:
            raise ValueError('DTU configurations do not match')
    else:
        if abs(args.debugplot) >= 10:
            print('>>> DTU Configuration match!')
            print(dtu_conf_fitsfile)

    # valid slitlet numbers
    list_valid_islitlets = list(range(1, EMIR_NBARS + 1))
    for idel in rectwv_coeff.missing_slitlets:
        list_valid_islitlets.remove(idel)
    if abs(args.debugplot) >= 10:
        print('>>> valid slitlet numbers:\n', list_valid_islitlets)

    # ---

    # initialize rectified image
    image2d_flatfielded = np.zeros((EMIR_NAXIS2, EMIR_NAXIS1))

    # main loop
    for islitlet in list_valid_islitlets:
        if args.debugplot == 0:
            islitlet_progress(islitlet, EMIR_NBARS)

        # define Slitlet2D object
        slt = Slitlet2D(islitlet=islitlet,
                        rectwv_coeff=rectwv_coeff,
                        debugplot=args.debugplot)

        if abs(args.debugplot) >= 10:
            print(slt)

        # extract (distorted) slitlet from the initial image
        slitlet2d = slt.extract_slitlet2d(image2d)

        # rectify slitlet
        slitlet2d_rect = slt.rectify(slitlet2d, resampling=args.resampling)
        naxis2_slitlet2d, naxis1_slitlet2d = slitlet2d_rect.shape

        if naxis1_slitlet2d != EMIR_NAXIS1:
            print('naxis1_slitlet2d: ', naxis1_slitlet2d)
            print('EMIR_NAXIS1.....: ', EMIR_NAXIS1)
            raise ValueError("Unexpected naxis1_slitlet2d")

        # get useful slitlet region (use boundaires instead of frontiers;
        # note that the nscan_minmax_frontiers() works well independently
        # of using frontiers of boundaries as arguments)
        nscan_min, nscan_max = nscan_minmax_frontiers(slt.y0_reference_lower,
                                                      slt.y0_reference_upper,
                                                      resize=False)
        ii1 = nscan_min - slt.bb_ns1_orig
        ii2 = nscan_max - slt.bb_ns1_orig + 1

        # median spectrum
        sp_collapsed = np.median(slitlet2d_rect[ii1:(ii2 + 1), :], axis=0)

        # smooth median spectrum along the spectral direction
        sp_median = ndimage.median_filter(sp_collapsed,
                                          args.nwindow_median,
                                          mode='nearest')
        ymax_spmedian = sp_median.max()
        y_threshold = ymax_spmedian * args.minimum_fraction
        sp_median[np.where(sp_median < y_threshold)] = 0.0

        if abs(args.debugplot) > 10:
            title = 'Slitlet#' + str(islitlet) + '(median spectrum)'
            xdum = np.arange(1, naxis1_slitlet2d + 1)
            ax = ximplotxy(xdum,
                           sp_collapsed,
                           title=title,
                           show=False,
                           **{'label': 'collapsed spectrum'})
            ax.plot(xdum, sp_median, label='filtered spectrum')
            ax.plot([1, naxis1_slitlet2d],
                    2 * [y_threshold],
                    label='threshold')
            ax.legend()
            ax.set_ylim(-0.05 * ymax_spmedian, 1.05 * ymax_spmedian)
            pause_debugplot(args.debugplot, pltshow=True, tight_layout=True)

        # generate rectified slitlet region filled with the median spectrum
        slitlet2d_rect_spmedian = np.tile(sp_median, (naxis2_slitlet2d, 1))
        if abs(args.debugplot) > 10:
            slt.ximshow_rectified(slitlet2d_rect_spmedian)

        # unrectified image
        slitlet2d_unrect_spmedian = slt.rectify(slitlet2d_rect_spmedian,
                                                resampling=args.resampling,
                                                inverse=True)

        # normalize initial slitlet image (avoid division by zero)
        slitlet2d_norm = np.zeros_like(slitlet2d)
        for j in range(naxis1_slitlet2d):
            for i in range(naxis2_slitlet2d):
                den = slitlet2d_unrect_spmedian[i, j]
                if den == 0:
                    slitlet2d_norm[i, j] = 1.0
                else:
                    slitlet2d_norm[i, j] = slitlet2d[i, j] / den

        if abs(args.debugplot) > 10:
            slt.ximshow_unrectified(slitlet2d_norm)

        for j in range(EMIR_NAXIS1):
            xchannel = j + 1
            y0_lower = slt.list_frontiers[0](xchannel)
            y0_upper = slt.list_frontiers[1](xchannel)
            n1, n2 = nscan_minmax_frontiers(y0_frontier_lower=y0_lower,
                                            y0_frontier_upper=y0_upper,
                                            resize=True)
            # note that n1 and n2 are scans (ranging from 1 to NAXIS2)
            nn1 = n1 - slt.bb_ns1_orig + 1
            nn2 = n2 - slt.bb_ns1_orig + 1
            image2d_flatfielded[(n1 - 1):n2, j] = \
                slitlet2d_norm[(nn1 - 1):nn2, j]

            # force to 1.0 region around frontiers
            image2d_flatfielded[(n1 - 1):(n1 + 2), j] = 1
            image2d_flatfielded[(n2 - 5):n2, j] = 1
    if args.debugplot == 0:
        print('OK!')

    # set pixels below minimum value to 1.0
    filtered = np.where(image2d_flatfielded < args.minimum_value_in_output)
    image2d_flatfielded[filtered] = 1.0

    # restore global offsets
    image2d_flatfielded = apply_integer_offsets(
        image2d=image2d_flatfielded,
        offx=-rectwv_coeff.global_integer_offset_x_pix,
        offy=-rectwv_coeff.global_integer_offset_y_pix)

    # save output file
    save_ndarray_to_fits(array=image2d_flatfielded,
                         file_name=args.outfile,
                         main_header=header,
                         overwrite=True)
    print('>>> Saving file ' + args.outfile.name)
def compute_slitlet_boundaries(filename,
                               grism,
                               spfilter,
                               list_slitlets,
                               size_x_medfilt,
                               size_y_savgol,
                               times_sigma_threshold,
                               bounddict,
                               debugplot=0):
    """Compute slitlet boundaries using continuum lamp images.

    Parameters
    ----------
    filename : string
        Input continumm lamp image.
    grism : string
        Grism name. It must be one in EMIR_VALID_GRISMS.
    spfilter : string
        Filter name. It must be one in EMIR_VALID_FILTERS.
    list_slitlets : list of integers
        Number of slitlets to be updated.
    size_x_medfilt : int
        Window in the X (spectral) direction, in pixels, to apply
        the 1d median filter in order to remove bad pixels.
    size_y_savgol : int
        Window in the Y (spatial) direction to be used when using the
        1d Savitzky-Golay filter.
    times_sigma_threshold : float
        Times sigma to detect peaks in derivatives.
    bounddict : dictionary of dictionaries
        Structure to store the boundaries.
    debugplot : int
        Determines whether intermediate computations and/or plots are
        displayed.

    """

    # read 2D image
    hdulist = fits.open(filename)
    image_header = hdulist[0].header
    image2d = hdulist[0].data
    naxis2, naxis1 = image2d.shape
    hdulist.close()
    if debugplot >= 10:
        print('>>> NAXIS1:', naxis1)
        print('>>> NAXIS2:', naxis2)

    # ToDo: replace this by application of cosmetic defect mask!
    for j in range(1024):
        image2d[1024, j] = (image2d[1023, j] + image2d[1025, j]) / 2
        image2d[1023, j +
                1024] = (image2d[1022, j + 1024] + image2d[1024, j + 1024]) / 2

    # remove path from filename
    sfilename = os.path.basename(filename)

    # check that the FITS file has been obtained with EMIR
    instrument = image_header['instrume']
    if instrument != 'EMIR':
        raise ValueError("INSTRUME keyword is not 'EMIR'!")

    # read CSU configuration from FITS header
    csu_config = CsuConfiguration.define_from_fits(filename)

    # read DTU configuration from FITS header
    dtu_config = DtuConfiguration.define_from_fits(filename)

    # read grism
    grism_in_header = image_header['grism']
    if grism != grism_in_header:
        raise ValueError("GRISM keyword=" + grism_in_header +
                         " is not the expected value=" + grism)

    # read filter
    spfilter_in_header = image_header['filter']
    if spfilter != spfilter_in_header:
        raise ValueError("FILTER keyword=" + spfilter_in_header +
                         " is not the expected value=" + spfilter)

    # read rotator position angle
    rotang = image_header['rotang']

    # read date-obs
    date_obs = image_header['date-obs']

    for islitlet in list_slitlets:
        if debugplot < 10:
            sys.stdout.write('.')
            sys.stdout.flush()

        sltlim = SlitletLimits(grism, spfilter, islitlet)
        # extract slitlet2d
        slitlet2d = extract_slitlet2d(image2d, sltlim)
        if debugplot % 10 != 0:
            ximshow(slitlet2d,
                    title=sfilename + " [original]"
                    "\nslitlet=" + str(islitlet) + ", grism=" + grism +
                    ", filter=" + spfilter + ", rotang=" +
                    str(round(rotang, 2)),
                    first_pixel=(sltlim.bb_nc1_orig, sltlim.bb_ns1_orig),
                    debugplot=debugplot)

        # apply 1d median filtering (along the spectral direction)
        # to remove bad pixels
        size_x = size_x_medfilt
        size_y = 1
        slitlet2d_smooth = ndimage.filters.median_filter(slitlet2d,
                                                         size=(size_y, size_x))

        if debugplot % 10 != 0:
            ximshow(slitlet2d_smooth,
                    title=sfilename + " [smoothed]"
                    "\nslitlet=" + str(islitlet) + ", grism=" + grism +
                    ", filter=" + spfilter + ", rotang=" +
                    str(round(rotang, 2)),
                    first_pixel=(sltlim.bb_nc1_orig, sltlim.bb_ns1_orig),
                    debugplot=debugplot)

        # apply 1d Savitzky-Golay filter (along the spatial direction)
        # to compute first derivative
        slitlet2d_savgol = savgol_filter(slitlet2d_smooth,
                                         window_length=size_y_savgol,
                                         polyorder=2,
                                         deriv=1,
                                         axis=0)

        # compute basic statistics
        q25, q50, q75 = np.percentile(slitlet2d_savgol, q=[25.0, 50.0, 75.0])
        sigmag = 0.7413 * (q75 - q25)  # robust standard deviation
        if debugplot >= 10:
            print("q50, sigmag:", q50, sigmag)

        if debugplot % 10 != 0:
            ximshow(slitlet2d_savgol,
                    title=sfilename + " [S.-G.filt.]"
                    "\nslitlet=" + str(islitlet) + ", grism=" + grism +
                    ", filter=" + spfilter + ", rotang=" +
                    str(round(rotang, 2)),
                    first_pixel=(sltlim.bb_nc1_orig, sltlim.bb_ns1_orig),
                    z1z2=(q50 - times_sigma_threshold * sigmag,
                          q50 + times_sigma_threshold * sigmag),
                    debugplot=debugplot)

        # identify objects in slitlet2d_savgol: pixels with positive
        # derivatives are identify independently from pixels with
        # negative derivaties; then the two set of potential features
        # are merged; this approach avoids some problems when, in
        # nearby regions, there are pixels with positive and negative
        # derivatives (in those circumstances a single search as
        # np.logical_or(
        #     slitlet2d_savgol < q50 - times_sigma_threshold * sigmag,
        #     slitlet2d_savgol > q50 + times_sigma_threshold * sigmag)
        # led to erroneous detections!)
        #
        # search for positive derivatives
        labels2d_objects_pos, no_objects_pos = ndimage.label(
            slitlet2d_savgol > q50 + times_sigma_threshold * sigmag)
        # search for negative derivatives
        labels2d_objects_neg, no_objects_neg = ndimage.label(
            slitlet2d_savgol < q50 - times_sigma_threshold * sigmag)
        # merge both sets
        non_zero_neg = np.where(labels2d_objects_neg > 0)
        labels2d_objects = np.copy(labels2d_objects_pos)
        labels2d_objects[non_zero_neg] += \
            labels2d_objects_neg[non_zero_neg] + no_objects_pos
        no_objects = no_objects_pos + no_objects_neg

        if debugplot >= 10:
            print("Number of objects with positive derivative:",
                  no_objects_pos)
            print("Number of objects with negative derivative:",
                  no_objects_neg)
            print("Total number of objects initially found...:", no_objects)

        if debugplot % 10 != 0:
            ximshow(labels2d_objects,
                    z1z2=(0, no_objects),
                    title=sfilename + " [objects]"
                    "\nslitlet=" + str(islitlet) + ", grism=" + grism +
                    ", filter=" + spfilter + ", rotang=" +
                    str(round(rotang, 2)),
                    first_pixel=(sltlim.bb_nc1_orig, sltlim.bb_ns1_orig),
                    cbar_label="Object number",
                    debugplot=debugplot)

        # select boundaries as the largest objects found with
        # positive and negative derivatives
        n_der_pos = 0  # number of pixels covered by the object with deriv > 0
        i_der_pos = 0  # id of the object with deriv > 0
        n_der_neg = 0  # number of pixels covered by the object with deriv < 0
        i_der_neg = 0  # id of the object with deriv < 0
        for i in range(1, no_objects + 1):
            xy_tmp = np.where(labels2d_objects == i)
            n_pix = len(xy_tmp[0])
            if i <= no_objects_pos:
                if n_pix > n_der_pos:
                    i_der_pos = i
                    n_der_pos = n_pix
            else:
                if n_pix > n_der_neg:
                    i_der_neg = i
                    n_der_neg = n_pix

        # determine which boundary is lower and which is upper
        y_center_mass_der_pos = ndimage.center_of_mass(slitlet2d_savgol,
                                                       labels2d_objects,
                                                       [i_der_pos])[0][0]
        y_center_mass_der_neg = ndimage.center_of_mass(slitlet2d_savgol,
                                                       labels2d_objects,
                                                       [i_der_neg])[0][0]
        if y_center_mass_der_pos < y_center_mass_der_neg:
            i_lower = i_der_pos
            i_upper = i_der_neg
            if debugplot >= 10:
                print("-> lower boundary has positive derivatives")
        else:
            i_lower = i_der_neg
            i_upper = i_der_pos
            if debugplot >= 10:
                print("-> lower boundary has negative derivatives")
        list_slices_ok = [i_lower, i_upper]

        # adjust individual boundaries passing the selection:
        # - select points in the image belonging to a given boundary
        # - compute weighted mean of the pixels of the boundary, column
        #   by column (this reduces dramatically the number of points
        #   to be fitted to determine the boundary)
        list_boundaries = []
        for k in range(2):  # k=0 lower boundary, k=1 upper boundary
            # select points to be fitted for a particular boundary
            # (note: be careful with array indices and pixel
            # coordinates)
            xy_tmp = np.where(labels2d_objects == list_slices_ok[k])
            xmin = xy_tmp[1].min()  # array indices (integers)
            xmax = xy_tmp[1].max()  # array indices (integers)
            xfit = []
            yfit = []
            # fix range for fit
            if k == 0:
                xmineff = max(sltlim.xmin_lower_boundary_fit, xmin)
                xmaxeff = min(sltlim.xmax_lower_boundary_fit, xmax)
            else:
                xmineff = max(sltlim.xmin_upper_boundary_fit, xmin)
                xmaxeff = min(sltlim.xmax_upper_boundary_fit, xmax)
            # loop in columns of the image belonging to the boundary
            for xdum in range(xmineff, xmaxeff + 1):  # array indices (integer)
                iok = np.where(xy_tmp[1] == xdum)
                y_tmp = xy_tmp[0][iok] + sltlim.bb_ns1_orig  # image pixel
                weight = slitlet2d_savgol[xy_tmp[0][iok], xy_tmp[1][iok]]
                y_wmean = sum(y_tmp * weight) / sum(weight)
                xfit.append(xdum + sltlim.bb_nc1_orig)
                yfit.append(y_wmean)
            xfit = np.array(xfit)
            yfit = np.array(yfit)
            # declare new SpectrumTrail instance
            boundary = SpectrumTrail()
            # define new boundary
            boundary.fit(x=xfit,
                         y=yfit,
                         deg=sltlim.deg_boundary,
                         times_sigma_reject=10,
                         title="slit:" + str(sltlim.islitlet) + ", deg=" +
                         str(sltlim.deg_boundary),
                         debugplot=0)
            list_boundaries.append(boundary)

        if debugplot % 10 != 0:
            for tmp_img, tmp_label in zip([slitlet2d_savgol, slitlet2d],
                                          [' [S.-G.filt.]', ' [original]']):
                ax = ximshow(tmp_img,
                             title=sfilename + tmp_label + "\nslitlet=" +
                             str(islitlet) + ", grism=" + grism + ", filter=" +
                             spfilter + ", rotang=" + str(round(rotang, 2)),
                             first_pixel=(sltlim.bb_nc1_orig,
                                          sltlim.bb_ns1_orig),
                             show=False,
                             debugplot=debugplot)
                for k in range(2):
                    xpol, ypol = list_boundaries[k].linspace_pix(
                        start=1, stop=EMIR_NAXIS1)
                    ax.plot(xpol, ypol, 'b--', linewidth=1)
                for k in range(2):
                    xpol, ypol = list_boundaries[k].linspace_pix()
                    ax.plot(xpol, ypol, 'g--', linewidth=4)
                # show plot
                pause_debugplot(debugplot, pltshow=True)

        # update bounddict
        tmp_dict = {
            'boundary_coef_lower': list_boundaries[0].poly_funct.coef.tolist(),
            'boundary_xmin_lower': list_boundaries[0].xlower_line,
            'boundary_xmax_lower': list_boundaries[0].xupper_line,
            'boundary_coef_upper': list_boundaries[1].poly_funct.coef.tolist(),
            'boundary_xmin_upper': list_boundaries[1].xlower_line,
            'boundary_xmax_upper': list_boundaries[1].xupper_line,
            'csu_bar_left': csu_config.csu_bar_left(islitlet),
            'csu_bar_right': csu_config.csu_bar_right(islitlet),
            'csu_bar_slit_center': csu_config.csu_bar_slit_center(islitlet),
            'csu_bar_slit_width': csu_config.csu_bar_slit_width(islitlet),
            'rotang': rotang,
            'xdtu': dtu_config.xdtu,
            'ydtu': dtu_config.ydtu,
            'zdtu': dtu_config.zdtu,
            'xdtu_0': dtu_config.xdtu_0,
            'ydtu_0': dtu_config.ydtu_0,
            'zdtu_0': dtu_config.zdtu_0,
            'zzz_info1': os.getlogin() + '@' + socket.gethostname(),
            'zzz_info2': datetime.now().isoformat()
        }
        slitlet_label = "slitlet" + str(islitlet).zfill(2)
        if slitlet_label not in bounddict['contents']:
            bounddict['contents'][slitlet_label] = {}
        bounddict['contents'][slitlet_label][date_obs] = tmp_dict

    if debugplot < 10:
        print("")
Example #13
0
def main(args=None):

    # parse command-line options
    parser = argparse.ArgumentParser(
        description='description: display arrangement of EMIR CSU bars')

    # positional arguments
    parser.add_argument("filename",
                        help="FITS files (wildcards accepted) or single TXT "
                        "file with CSU configuration from OSP",
                        type=argparse.FileType('rb'),
                        nargs='+')

    # optional arguments
    parser.add_argument("--grism",
                        help="Grism (J, H, K, LR)",
                        choices=["J", "H", "K", "LR"])
    parser.add_argument("--filter",
                        help="Filter (J, H, Ksp, YJ, HK)",
                        choices=["J", "H", "Ksp", "YJ", "HK"])
    parser.add_argument("--n_clusters",
                        help="Display histogram of slitlet widths",
                        default=0,
                        type=int)
    parser.add_argument("--bbox",
                        help="Bounding box tuple xmin,xmax,ymin,ymax "
                        "indicating plot limits")
    parser.add_argument("--adjust",
                        help="Adjust X range according to minimum and maximum"
                        " csu_bar_left and csu_bar_right (note that this "
                        "option is overriden by --bbox",
                        action='store_true')
    parser.add_argument("--geometry",
                        help="Tuple x,y,dx,dy indicating window geometry",
                        default="0,0,640,480")
    parser.add_argument("--debugplot",
                        help="Integer indicating plotting & debugging options"
                        " (default=12)",
                        default=12,
                        type=int,
                        choices=DEBUGPLOT_CODES)
    parser.add_argument("--echo",
                        help="Display full command line",
                        action="store_true")
    args = parser.parse_args(args)

    if args.echo:
        print('\033[1m\033[31mExecuting: ' + ' '.join(sys.argv) + '\033[0m\n')

    # geometry
    if args.geometry is None:
        geometry = None
    else:
        tmp_str = args.geometry.split(",")
        x_geom = int(tmp_str[0])
        y_geom = int(tmp_str[1])
        dx_geom = int(tmp_str[2])
        dy_geom = int(tmp_str[3])
        geometry = x_geom, y_geom, dx_geom, dy_geom

    # read bounding box
    if args.bbox is None:
        bbox = None
    else:
        str_bbox = args.bbox.split(",")
        xmin, xmax, ymin, ymax = [int(str_bbox[i]) for i in range(4)]
        bbox = xmin, xmax, ymin, ymax

    list_fits_file_objects = []
    # if input file is a single txt file, assume it is a list of FITS files
    if len(args.filename) == 1:
        list_fits_file_objects = [args.filename[0]]
    else:
        list_fits_file_objects = args.filename

    # total number of files to be examined
    nfiles = len(list_fits_file_objects)

    # declare arrays to store CSU values
    csu_bar_left = np.zeros((nfiles, EMIR_NBARS))
    csu_bar_right = np.zeros((nfiles, EMIR_NBARS))
    csu_bar_slit_center = np.zeros((nfiles, EMIR_NBARS))
    csu_bar_slit_width = np.zeros((nfiles, EMIR_NBARS))

    # display CSU bar arrangement
    for ifile, fileobj in enumerate(list_fits_file_objects):
        print("\nFile " + str(ifile + 1) + "/" + str(nfiles) + ": " +
              fileobj.name)
        csu_bar_left[ifile, :], csu_bar_right[ifile, :], \
        csu_bar_slit_center[ifile, :], csu_bar_slit_width[ifile, :] = \
            display_slitlet_arrangement(
                fileobj,
                grism=args.grism,
                spfilter=args.filter,
                bbox=bbox,
                adjust=args.adjust,
                geometry=geometry,
                debugplot=args.debugplot
            )
        if args.n_clusters >= 2:
            display_slitlet_histogram(csu_bar_slit_width[ifile, :],
                                      n_clusters=args.n_clusters,
                                      geometry=geometry,
                                      debugplot=args.debugplot)

    # print summary of comparison between files
    if nfiles > 1:
        std_csu_bar_left = np.zeros(EMIR_NBARS)
        std_csu_bar_right = np.zeros(EMIR_NBARS)
        std_csu_bar_slit_center = np.zeros(EMIR_NBARS)
        std_csu_bar_slit_width = np.zeros(EMIR_NBARS)
        if args.debugplot >= 10:
            print("\n   STANDARD DEVIATION BETWEEN IMAGES")
            print("slit     left    right   center   width")
            print("====  =======  =======  =======   =====")
            for i in range(EMIR_NBARS):
                ibar = i + 1
                std_csu_bar_left[i] = np.std(csu_bar_left[:, i])
                std_csu_bar_right[i] = np.std(csu_bar_right[:, i])
                std_csu_bar_slit_center[i] = np.std(csu_bar_slit_center[:, i])
                std_csu_bar_slit_width[i] = np.std(csu_bar_slit_width[:, i])
                print("{0:4d} {1:8.3f} {2:8.3f} {3:8.3f} {4:7.3f}".format(
                    ibar, std_csu_bar_left[i], std_csu_bar_right[i],
                    std_csu_bar_slit_center[i], std_csu_bar_slit_width[i]))
            print("====  =======  =======  =======   =====")
            print("MIN: {0:8.3f} {1:8.3f} {2:8.3f} {3:7.3f}".format(
                std_csu_bar_left.min(), std_csu_bar_right.min(),
                std_csu_bar_slit_center.min(), std_csu_bar_slit_width.min()))
            print("MAX: {0:8.3f} {1:8.3f} {2:8.3f} {3:7.3f}".format(
                std_csu_bar_left.max(), std_csu_bar_right.max(),
                std_csu_bar_slit_center.max(), std_csu_bar_slit_width.max()))
            print("====  =======  =======  =======   =====")
            print("Total number of files examined:", nfiles)

    # stop program execution
    if len(list_fits_file_objects) > 1:
        pause_debugplot(12, optional_prompt="Press RETURN to STOP")
Example #14
0
def refine_rectwv_coeff(input_image,
                        rectwv_coeff,
                        refine_wavecalib_mode,
                        minimum_slitlet_width_mm,
                        maximum_slitlet_width_mm,
                        save_intermediate_results=False,
                        debugplot=0):
    """Refine RectWaveCoeff object using a catalogue of lines

    One and only one among refine_with_oh_lines_mode and
    refine_with_arc_lines must be different from zero.

    Parameters
    ----------
    input_image : HDUList object
        Input 2D image.
    rectwv_coeff : RectWaveCoeff instance
        Rectification and wavelength calibration coefficients for the
        particular CSU configuration.
    refine_wavecalib_mode : int
        Integer, indicating the type of refinement:
        0 : no refinement
        1 : apply the same global offset to all the slitlets (using ARC lines)
        2 : apply individual offset to each slitlet (using ARC lines)
        11 : apply the same global offset to all the slitlets (using OH lines)
        12 : apply individual offset to each slitlet (using OH lines)
    minimum_slitlet_width_mm : float
        Minimum slitlet width (mm) for a valid slitlet.
    maximum_slitlet_width_mm : float
        Maximum slitlet width (mm) for a valid slitlet.
    save_intermediate_results : bool
        If True, save plots in PDF files
    debugplot : int
        Determines whether intermediate computations and/or plots
        are displayed. The valid codes are defined in
        numina.array.display.pause_debugplot.

    Returns
    -------
    refined_rectwv_coeff : RectWaveCoeff instance
        Refined rectification and wavelength calibration coefficients
        for the particular CSU configuration.
    expected_cat_image : HDUList object
        Output 2D image with the expected catalogued lines.

    """

    logger = logging.getLogger(__name__)

    if save_intermediate_results:
        from matplotlib.backends.backend_pdf import PdfPages
        pdf = PdfPages('crosscorrelation.pdf')
    else:
        pdf = None

    # image header
    main_header = input_image[0].header
    filter_name = main_header['filter']
    grism_name = main_header['grism']

    # protections
    if refine_wavecalib_mode not in [1, 2, 11, 12]:
        logger.error('Wavelength calibration refinemente mode={}'.format(
            refine_wavecalib_mode))
        raise ValueError("Invalid wavelength calibration refinement mode")

    # read tabulated lines
    if refine_wavecalib_mode in [1, 2]:  # ARC lines
        if grism_name == 'LR':
            catlines_file = 'lines_argon_neon_xenon_empirical_LR.dat'
        else:
            catlines_file = 'lines_argon_neon_xenon_empirical.dat'
        dumdata = pkgutil.get_data('emirdrp.instrument.configs', catlines_file)
        arc_lines_tmpfile = StringIO(dumdata.decode('utf8'))
        catlines = np.genfromtxt(arc_lines_tmpfile)
        # define wavelength and flux as separate arrays
        catlines_all_wave = catlines[:, 0]
        catlines_all_flux = catlines[:, 1]
        mode = refine_wavecalib_mode
    elif refine_wavecalib_mode in [11, 12]:  # OH lines
        dumdata = pkgutil.get_data('emirdrp.instrument.configs',
                                   'Oliva_etal_2013.dat')
        oh_lines_tmpfile = StringIO(dumdata.decode('utf8'))
        catlines = np.genfromtxt(oh_lines_tmpfile)
        # define wavelength and flux as separate arrays
        catlines_all_wave = np.concatenate((catlines[:, 1], catlines[:, 0]))
        catlines_all_flux = np.concatenate((catlines[:, 2], catlines[:, 2]))
        mode = refine_wavecalib_mode - 10
    else:
        raise ValueError('Unexpected mode={}'.format(refine_wavecalib_mode))

    # initialize output
    refined_rectwv_coeff = deepcopy(rectwv_coeff)

    logger.info('Computing median spectrum')
    # compute median spectrum and normalize it
    sp_median = median_slitlets_rectified(
        input_image,
        mode=2,
        minimum_slitlet_width_mm=minimum_slitlet_width_mm,
        maximum_slitlet_width_mm=maximum_slitlet_width_mm)[0].data
    sp_median /= sp_median.max()

    # determine minimum and maximum useful wavelength
    jmin, jmax = find_pix_borders(sp_median, 0)
    naxis1 = main_header['naxis1']
    naxis2 = main_header['naxis2']
    crpix1 = main_header['crpix1']
    crval1 = main_header['crval1']
    cdelt1 = main_header['cdelt1']
    xwave = crval1 + (np.arange(naxis1) + 1.0 - crpix1) * cdelt1
    if grism_name == 'LR':
        wv_parameters = set_wv_parameters(filter_name, grism_name)
        wave_min = wv_parameters['wvmin_useful']
        wave_max = wv_parameters['wvmax_useful']
    else:
        wave_min = crval1 + (jmin + 1 - crpix1) * cdelt1
        wave_max = crval1 + (jmax + 1 - crpix1) * cdelt1
    logger.info('Setting wave_min to {}'.format(wave_min))
    logger.info('Setting wave_max to {}'.format(wave_max))

    # extract subset of catalogue lines within current wavelength range
    lok1 = catlines_all_wave >= wave_min
    lok2 = catlines_all_wave <= wave_max
    catlines_reference_wave = catlines_all_wave[lok1 * lok2]
    catlines_reference_flux = catlines_all_flux[lok1 * lok2]
    catlines_reference_flux /= catlines_reference_flux.max()

    # estimate sigma to broaden catalogue lines
    csu_config = CsuConfiguration.define_from_header(main_header)
    # segregate slitlets
    list_useful_slitlets = csu_config.widths_in_range_mm(
        minwidth=minimum_slitlet_width_mm, maxwidth=maximum_slitlet_width_mm)
    list_not_useful_slitlets = [
        i for i in list(range(1, EMIR_NBARS + 1))
        if i not in list_useful_slitlets
    ]
    logger.info('list of useful slitlets: {}'.format(list_useful_slitlets))
    logger.info(
        'list of not useful slitlets: {}'.format(list_not_useful_slitlets))
    tempwidths = np.array([
        csu_config.csu_bar_slit_width(islitlet)
        for islitlet in list_useful_slitlets
    ])
    widths_summary = summary(tempwidths)
    logger.info('Statistics of useful slitlet widths (mm):')
    logger.info('- npoints....: {0:d}'.format(widths_summary['npoints']))
    logger.info('- mean.......: {0:7.3f}'.format(widths_summary['mean']))
    logger.info('- median.....: {0:7.3f}'.format(widths_summary['median']))
    logger.info('- std........: {0:7.3f}'.format(widths_summary['std']))
    logger.info('- robust_std.: {0:7.3f}'.format(widths_summary['robust_std']))
    # empirical transformation of slit width (mm) to pixels
    sigma_broadening = cdelt1 * widths_summary['median']

    # convolve location of catalogue lines to generate expected spectrum
    xwave_reference, sp_reference = convolve_comb_lines(
        catlines_reference_wave, catlines_reference_flux, sigma_broadening,
        crpix1, crval1, cdelt1, naxis1)
    sp_reference /= sp_reference.max()

    # generate image2d with expected lines
    image2d_expected_lines = np.tile(sp_reference, (naxis2, 1))
    hdu = fits.PrimaryHDU(data=image2d_expected_lines, header=main_header)
    expected_cat_image = fits.HDUList([hdu])

    if (abs(debugplot) % 10 != 0) or (pdf is not None):
        ax = ximplotxy(xwave,
                       sp_median,
                       'C1-',
                       xlabel='Wavelength (Angstroms, in vacuum)',
                       ylabel='Normalized number of counts',
                       title='Median spectrum',
                       label='observed spectrum',
                       show=False)
        # overplot reference catalogue lines
        ax.stem(catlines_reference_wave,
                catlines_reference_flux,
                'C4-',
                markerfmt=' ',
                basefmt='C4-',
                label='tabulated lines')
        # overplot convolved reference lines
        ax.plot(xwave_reference,
                sp_reference,
                'C0-',
                label='expected spectrum')
        ax.legend()
        if pdf is not None:
            pdf.savefig()
        else:
            pause_debugplot(debugplot=debugplot, pltshow=True)

    # compute baseline signal in sp_median
    baseline = np.percentile(sp_median[sp_median > 0], q=10)
    if (abs(debugplot) % 10 != 0) or (pdf is not None):
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.hist(sp_median, bins=1000, log=True)
        ax.set_xlabel('Normalized number of counts')
        ax.set_ylabel('Number of pixels')
        ax.set_title('Median spectrum')
        ax.axvline(float(baseline), linestyle='--', color='grey')
        if pdf is not None:
            pdf.savefig()
        else:
            geometry = (0, 0, 640, 480)
            set_window_geometry(geometry)
            plt.show()
    # subtract baseline to sp_median (only pixels with signal above zero)
    lok = np.where(sp_median > 0)
    sp_median[lok] -= baseline

    # compute global offset through periodic correlation
    logger.info('Computing global offset')
    global_offset, fpeak = periodic_corr1d(
        sp_reference=sp_reference,
        sp_offset=sp_median,
        fminmax=None,
        naround_zero=50,
        plottitle='Median spectrum (cross-correlation)',
        pdf=pdf,
        debugplot=debugplot)
    logger.info('Global offset: {} pixels'.format(-global_offset))

    missing_slitlets = rectwv_coeff.missing_slitlets

    if mode == 1:
        # apply computed offset to obtain refined_rectwv_coeff_global
        for islitlet in range(1, EMIR_NBARS + 1):
            if islitlet not in missing_slitlets:
                i = islitlet - 1
                dumdict = refined_rectwv_coeff.contents[i]
                dumdict['wpoly_coeff'][0] -= global_offset * cdelt1

    elif mode == 2:
        # compute individual offset for each slitlet
        logger.info('Computing individual offsets')
        median_55sp = median_slitlets_rectified(input_image, mode=1)
        offset_array = np.zeros(EMIR_NBARS)
        xplot = []
        yplot = []
        xplot_skipped = []
        yplot_skipped = []
        cout = '0'
        for islitlet in range(1, EMIR_NBARS + 1):
            if islitlet in list_useful_slitlets:
                i = islitlet - 1
                sp_median = median_55sp[0].data[i, :]
                lok = np.where(sp_median > 0)
                baseline = np.percentile(sp_median[lok], q=10)
                sp_median[lok] -= baseline
                sp_median /= sp_median.max()
                offset_array[i], fpeak = periodic_corr1d(
                    sp_reference=sp_reference,
                    sp_offset=median_55sp[0].data[i, :],
                    fminmax=None,
                    naround_zero=50,
                    plottitle='slitlet #{0} (cross-correlation)'.format(
                        islitlet),
                    pdf=pdf,
                    debugplot=debugplot)
                dumdict = refined_rectwv_coeff.contents[i]
                dumdict['wpoly_coeff'][0] -= offset_array[i] * cdelt1
                xplot.append(islitlet)
                yplot.append(-offset_array[i])
                # second correction
                wpoly_coeff_refined = check_wlcalib_sp(
                    sp=median_55sp[0].data[i, :],
                    crpix1=crpix1,
                    crval1=crval1 - offset_array[i] * cdelt1,
                    cdelt1=cdelt1,
                    wv_master=catlines_reference_wave,
                    coeff_ini=dumdict['wpoly_coeff'],
                    naxis1_ini=EMIR_NAXIS1,
                    title='slitlet #{0} (after applying offset)'.format(
                        islitlet),
                    ylogscale=False,
                    pdf=pdf,
                    debugplot=debugplot)
                dumdict['wpoly_coeff'] = wpoly_coeff_refined
                cout += '.'

            else:
                xplot_skipped.append(islitlet)
                yplot_skipped.append(0)
                cout += 'i'

            if islitlet % 10 == 0:
                if cout != 'i':
                    cout = str(islitlet // 10)

            logger.info(cout)

        # show offsets with opposite sign
        stat_summary = summary(np.array(yplot))
        logger.info('Statistics of individual slitlet offsets (pixels):')
        logger.info('- npoints....: {0:d}'.format(stat_summary['npoints']))
        logger.info('- mean.......: {0:7.3f}'.format(stat_summary['mean']))
        logger.info('- median.....: {0:7.3f}'.format(stat_summary['median']))
        logger.info('- std........: {0:7.3f}'.format(stat_summary['std']))
        logger.info('- robust_std.: {0:7.3f}'.format(
            stat_summary['robust_std']))
        if (abs(debugplot) % 10 != 0) or (pdf is not None):
            ax = ximplotxy(xplot,
                           yplot,
                           linestyle='',
                           marker='o',
                           color='C0',
                           xlabel='slitlet number',
                           ylabel='-offset (pixels) = offset to be applied',
                           title='cross-correlation result',
                           show=False,
                           **{'label': 'individual slitlets'})
            if len(xplot_skipped) > 0:
                ax.plot(xplot_skipped, yplot_skipped, 'mx')
            ax.axhline(-global_offset,
                       linestyle='--',
                       color='C1',
                       label='global offset')
            ax.legend()
            if pdf is not None:
                pdf.savefig()
            else:
                pause_debugplot(debugplot=debugplot, pltshow=True)
    else:
        raise ValueError('Unexpected mode={}'.format(mode))

    # close output PDF file
    if pdf is not None:
        pdf.close()

    # return result
    return refined_rectwv_coeff, expected_cat_image
Example #15
0
def main(args=None):

    # parse command-line options
    parser = argparse.ArgumentParser()

    # positional arguments
    parser.add_argument("fitsfile",
                        help="FITS file name to be displayed",
                        type=argparse.FileType('rb'))
    parser.add_argument("--bounddict", required=True,
                        help="bounddict file name",
                        type=argparse.FileType('rt'))
    parser.add_argument("--tuple_slit_numbers", required=True,
                        help="Tuple n1[,n2[,step]] to define slitlet numbers")

    # optional arguments
    parser.add_argument("--echo",
                        help="Display full command line",
                        action="store_true")

    args = parser.parse_args()

    if args.echo:
        print('\033[1m\033[31mExecuting: ' + ' '.join(sys.argv) + '\033[0m\n')

    # read slitlet numbers to be computed
    tmp_str = args.tuple_slit_numbers.split(",")
    if len(tmp_str) == 3:
        if int(tmp_str[0]) < 1:
            raise ValueError("Invalid slitlet number < 1")
        if int(tmp_str[1]) > EMIR_NBARS:
            raise ValueError("Invalid slitlet number > EMIR_NBARS")
        list_slitlets = range(int(tmp_str[0]),
                              int(tmp_str[1])+1,
                              int(tmp_str[2]))
    elif len(tmp_str) == 2:
        if int(tmp_str[0]) < 1:
            raise ValueError("Invalid slitlet number < 1")
        if int(tmp_str[1]) > EMIR_NBARS:
            raise ValueError("Invalid slitlet number > EMIR_NBARS")
        list_slitlets = range(int(tmp_str[0]),
                              int(tmp_str[1])+1,
                              1)
    elif len(tmp_str) == 1:
        if int(tmp_str[0]) < 1:
            raise ValueError("Invalid slitlet number < 1")
        if int(tmp_str[0]) > EMIR_NBARS:
            raise ValueError("Invalid slitlet number > EMIR_NBARS")
        list_slitlets = [int(tmp_str[0])]
    else:
        raise ValueError("Invalid tuple for slitlet numbers")

    # read input FITS file
    hdulist = fits.open(args.fitsfile.name)
    image_header = hdulist[0].header
    image2d = hdulist[0].data
    hdulist.close()

    naxis1 = image_header['naxis1']
    naxis2 = image_header['naxis2']

    if image2d.shape != (naxis2, naxis1):
        raise ValueError("Unexpected error with NAXIS1, NAXIS2")

    # remove path from fitsfile
    sfitsfile = os.path.basename(args.fitsfile.name)

    # check that the FITS file has been obtained with EMIR
    instrument = image_header['instrume']
    if instrument != 'EMIR':
        raise ValueError("INSTRUME keyword is not 'EMIR'!")

    # read GRISM, FILTER and ROTANG from FITS header
    grism = image_header['grism']
    spfilter = image_header['filter']
    rotang = image_header['rotang']

    # display full image
    ax = ximshow(image2d=image2d,
                 title=sfitsfile + "\ngrism=" + grism +
                       ", filter=" + spfilter +
                       ", rotang=" + str(round(rotang, 2)),
                 image_bbox=(1, naxis1, 1, naxis2), show=False)

    # overplot boundaries for each slitlet
    for slitlet_number in list_slitlets:
        pol_lower_boundary, pol_upper_boundary, \
        xmin_lower, xmax_lower, xmin_upper, xmax_upper,  \
        csu_bar_slit_center = \
            get_boundaries(args.bounddict, slitlet_number)
        if (pol_lower_boundary is not None) and \
                (pol_upper_boundary is not None):
            xp = np.linspace(start=xmin_lower, stop=xmax_lower, num=1000)
            yp = pol_lower_boundary(xp)
            ax.plot(xp, yp, 'g-')
            xp = np.linspace(start=xmin_upper, stop=xmax_upper, num=1000)
            yp = pol_upper_boundary(xp)
            ax.plot(xp, yp, 'b-')
            # slitlet label
            yc_lower = pol_lower_boundary(EMIR_NAXIS1 / 2 + 0.5)
            yc_upper = pol_upper_boundary(EMIR_NAXIS1 / 2 + 0.5)
            tmpcolor = ['r', 'b'][slitlet_number % 2]
            xcsu = EMIR_NAXIS1 * csu_bar_slit_center/341.5
            ax.text(xcsu, (yc_lower + yc_upper) / 2,
                    str(slitlet_number),
                    fontsize=10, va='center', ha='center',
                    bbox=dict(boxstyle="round,pad=0.1",
                              fc="white", ec="grey"),
                    color=tmpcolor, fontweight='bold',
                    backgroundcolor='white')

    # show plot
    pause_debugplot(12, pltshow=True)
def rectwv_coeff_from_arc_image(reduced_image,
                                bound_param,
                                lines_catalog,
                                args_nbrightlines=None,
                                args_ymargin_bb=2,
                                args_remove_sp_background=True,
                                args_times_sigma_threshold=10,
                                args_order_fmap=2,
                                args_sigma_gaussian_filtering=2,
                                args_margin_npix=50,
                                args_poldeg_initial=3,
                                args_poldeg_refined=5,
                                args_interactive=False,
                                args_threshold_wv=0,
                                args_ylogscale=False,
                                args_pdf=None,
                                args_geometry=(0,0,640,480),
                                debugplot=0):
    """Evaluate rect.+wavecal. coefficients from arc image

    Parameters
    ----------
    reduced_image : HDUList object
        Image with preliminary basic reduction: bpm, bias, dark and
        flatfield.
    bound_param : RefinedBoundaryModelParam instance
        Refined boundary model.
    lines_catalog : Numpy array
        2D numpy array with the contents of the master file with the
        expected arc line wavelengths.
    args_nbrightlines : int
        TBD
    args_ymargin_bb : int
        TBD
    args_remove_sp_background : bool
        TBD
    args_times_sigma_threshold : float
        TBD
    args_order_fmap : int
        TBD
    args_sigma_gaussian_filtering : float
        TBD
    args_margin_npix : int
        TBD
    args_poldeg_initial : int
        TBD
    args_poldeg_refined : int
        TBD
    args_interactive : bool
        TBD
    args_threshold_wv : float
        TBD
    args_ylogscale : bool
        TBD
    args_pdf : TBD
    args_geometry : TBD
    debugplot : int
            Debugging level for messages and plots. For details see
            'numina.array.display.pause_debugplot.py'.

    Returns
    -------
    rectwv_coeff : RectWaveCoeff instance
        Rectification and wavelength calibration coefficients for the
        particular CSU configuration of the input arc image.
    reduced_55sp : HDUList object
        Image with 55 spectra corresponding to the median spectrum for
        each slitlet, employed to derived the wavelength calibration
        polynomial.

    """

    logger = logging.getLogger(__name__)

    # protections
    if args_interactive and args_pdf is not None:
        logger.error('--interactive and --pdf are incompatible options')
        raise ValueError('--interactive and --pdf are incompatible options')

    # header and data array
    header = reduced_image[0].header
    image2d = reduced_image[0].data

    # check grism and filter
    filter_name = header['filter']
    logger.info('Filter: ' + filter_name)
    if filter_name != bound_param.tags['filter']:
        raise ValueError('Filter name does not match!')
    grism_name = header['grism']
    logger.info('Grism: ' + grism_name)
    if grism_name != bound_param.tags['grism']:
        raise ValueError('Grism name does not match!')

    # read the CSU configuration from the image header
    csu_conf = CsuConfiguration.define_from_header(header)
    logger.debug(csu_conf)

    # read the DTU configuration from the image header
    dtu_conf = DtuConfiguration.define_from_header(header)
    logger.debug(dtu_conf)

    # set boundary parameters
    parmodel = bound_param.meta_info['parmodel']
    params = bound_params_from_dict(bound_param.__getstate__())
    if abs(debugplot) >= 10:
        print('-' * 83)
        print('* FITTED BOUND PARAMETERS')
        params.pretty_print()
        pause_debugplot(debugplot)

    # determine parameters according to grism+filter combination
    wv_parameters = set_wv_parameters(filter_name, grism_name)
    islitlet_min = wv_parameters['islitlet_min']
    islitlet_max = wv_parameters['islitlet_max']
    if args_nbrightlines is None:
        nbrightlines = wv_parameters['nbrightlines']
    else:
        nbrightlines = [int(idum) for idum in args_nbrightlines.split(',')]
    poly_crval1_linear = wv_parameters['poly_crval1_linear']
    poly_cdelt1_linear = wv_parameters['poly_cdelt1_linear']
    wvmin_expected = wv_parameters['wvmin_expected']
    wvmax_expected = wv_parameters['wvmax_expected']
    wvmin_useful = wv_parameters['wvmin_useful']
    wvmax_useful = wv_parameters['wvmax_useful']

    # list of slitlets to be computed
    logger.info('list_slitlets: [' + str(islitlet_min) + ',... ' +
                str(islitlet_max) + ']')

    # read master arc line wavelengths (only brightest lines)
    wv_master = read_wv_master_from_array(
        master_table=lines_catalog, lines='brightest', debugplot=debugplot
    )

    # read master arc line wavelengths (whole data set)
    wv_master_all = read_wv_master_from_array(
        master_table=lines_catalog, lines='all', debugplot=debugplot
    )

    # check that the arc lines in the master file are properly sorted
    # in ascending order
    for i in range(len(wv_master_all) - 1):
        if wv_master_all[i] >= wv_master_all[i + 1]:
            logger.error('>>> wavelengths: ' +
                         str(wv_master_all[i]) + '  ' +
                         str(wv_master_all[i+1]))
            raise ValueError('Arc lines are not sorted in master file')

    # ---

    image2d_55sp = np.zeros((EMIR_NBARS, EMIR_NAXIS1))

    # compute rectification transformation and wavelength calibration
    # polynomials

    measured_slitlets = []

    cout = '0'
    for islitlet in range(1, EMIR_NBARS + 1):

        if islitlet_min <= islitlet <= islitlet_max:

            # define Slitlet2dArc object
            slt = Slitlet2dArc(
                islitlet=islitlet,
                csu_conf=csu_conf,
                ymargin_bb=args_ymargin_bb,
                params=params,
                parmodel=parmodel,
                debugplot=debugplot
            )

            # extract 2D image corresponding to the selected slitlet, clipping
            # the image beyond the unrectified slitlet (in order to isolate
            # the arc lines of the current slitlet; otherwise there are
            # problems with arc lines from neighbour slitlets)
            image2d_tmp = select_unrectified_slitlet(
                image2d=image2d,
                islitlet=islitlet,
                csu_bar_slit_center=csu_conf.csu_bar_slit_center(islitlet),
                params=params,
                parmodel=parmodel,
                maskonly=False
            )
            slitlet2d = slt.extract_slitlet2d(image2d_tmp)

            # subtract smooth background computed as follows:
            # - median collapsed spectrum of the whole slitlet2d
            # - independent median filtering of the previous spectrum in the
            #   two halves in the spectral direction
            if args_remove_sp_background:
                spmedian = np.median(slitlet2d, axis=0)
                naxis1_tmp = spmedian.shape[0]
                jmidpoint = naxis1_tmp // 2
                sp1 = medfilt(spmedian[:jmidpoint], [201])
                sp2 = medfilt(spmedian[jmidpoint:], [201])
                spbackground = np.concatenate((sp1, sp2))
                slitlet2d -= spbackground

            # locate unknown arc lines
            slt.locate_unknown_arc_lines(
                slitlet2d=slitlet2d,
                times_sigma_threshold=args_times_sigma_threshold)

            # continue working with current slitlet only if arc lines have
            # been detected
            if slt.list_arc_lines is not None:

                # compute intersections between spectrum trails and arc lines
                slt.xy_spectrail_arc_intersections(slitlet2d=slitlet2d)

                # compute rectification transformation
                slt.estimate_tt_to_rectify(order=args_order_fmap,
                                           slitlet2d=slitlet2d)

                # rectify image
                slitlet2d_rect = slt.rectify(slitlet2d,
                                             resampling=2,
                                             transformation=1)

                # median spectrum and line peaks from rectified image
                sp_median, fxpeaks = slt.median_spectrum_from_rectified_image(
                    slitlet2d_rect,
                    sigma_gaussian_filtering=args_sigma_gaussian_filtering,
                    nwinwidth_initial=5,
                    nwinwidth_refined=5,
                    times_sigma_threshold=5,
                    npix_avoid_border=6,
                    nbrightlines=nbrightlines
                )

                image2d_55sp[islitlet - 1, :] = sp_median

                # determine expected wavelength limits prior to the wavelength
                # calibration
                csu_bar_slit_center = csu_conf.csu_bar_slit_center(islitlet)
                crval1_linear = poly_crval1_linear(csu_bar_slit_center)
                cdelt1_linear = poly_cdelt1_linear(csu_bar_slit_center)
                expected_wvmin = crval1_linear - \
                                 args_margin_npix * cdelt1_linear
                naxis1_linear = sp_median.shape[0]
                crvaln_linear = crval1_linear + \
                                (naxis1_linear - 1) * cdelt1_linear
                expected_wvmax = crvaln_linear + \
                                 args_margin_npix * cdelt1_linear
                # override previous estimates when necessary
                if wvmin_expected is not None:
                    expected_wvmin = wvmin_expected
                if wvmax_expected is not None:
                    expected_wvmax = wvmax_expected

                # clip initial master arc line list with bright lines to
                # the expected wavelength range
                lok1 = expected_wvmin <= wv_master
                lok2 = wv_master <= expected_wvmax
                lok = lok1 * lok2
                wv_master_eff = wv_master[lok]

                # perform initial wavelength calibration
                solution_wv = wvcal_spectrum(
                    sp=sp_median,
                    fxpeaks=fxpeaks,
                    poly_degree_wfit=args_poldeg_initial,
                    wv_master=wv_master_eff,
                    wv_ini_search=expected_wvmin,
                    wv_end_search=expected_wvmax,
                    wvmin_useful=wvmin_useful,
                    wvmax_useful=wvmax_useful,
                    geometry=args_geometry,
                    debugplot=slt.debugplot
                )
                # store initial wavelength calibration polynomial in current
                # slitlet instance
                slt.wpoly = np.polynomial.Polynomial(solution_wv.coeff)
                pause_debugplot(debugplot)

                # clip initial master arc line list with all the lines to
                # the expected wavelength range
                lok1 = expected_wvmin <= wv_master_all
                lok2 = wv_master_all <= expected_wvmax
                lok = lok1 * lok2
                wv_master_all_eff = wv_master_all[lok]

                # clip master arc line list to useful region
                if wvmin_useful is not None:
                    lok = wvmin_useful <= wv_master_all_eff
                    wv_master_all_eff  = wv_master_all_eff[lok]
                if wvmax_useful is not None:
                    lok = wv_master_all_eff <= wvmax_useful
                    wv_master_all_eff  = wv_master_all_eff[lok]

                # refine wavelength calibration
                if args_poldeg_refined > 0:
                    plottitle = '[slitlet#{}, refined]'.format(islitlet)
                    poly_refined, yres_summary = refine_arccalibration(
                        sp=sp_median,
                        poly_initial=slt.wpoly,
                        wv_master=wv_master_all_eff,
                        poldeg=args_poldeg_refined,
                        ntimes_match_wv=1,
                        interactive=args_interactive,
                        threshold=args_threshold_wv,
                        plottitle=plottitle,
                        ylogscale=args_ylogscale,
                        geometry=args_geometry,
                        pdf=args_pdf,
                        debugplot=slt.debugplot
                    )
                    # store refined wavelength calibration polynomial in
                    # current slitlet instance
                    slt.wpoly = poly_refined

                # compute approximate linear values for CRVAL1 and CDELT1
                naxis1_linear = sp_median.shape[0]
                crmin1_linear = slt.wpoly(1)
                crmax1_linear = slt.wpoly(naxis1_linear)
                slt.crval1_linear = crmin1_linear
                slt.cdelt1_linear = \
                    (crmax1_linear - crmin1_linear) / (naxis1_linear - 1)

                # check that the trimming of wv_master and wv_master_all has
                # preserved the wavelength range [crmin1_linear, crmax1_linear]
                if crmin1_linear < expected_wvmin:
                    logger.warning(">>> islitlet: " +str(islitlet))
                    logger.warning("expected_wvmin: " + str(expected_wvmin))
                    logger.warning("crmin1_linear.: " + str(crmin1_linear))
                    logger.warning("WARNING: Unexpected crmin1_linear < "
                                   "expected_wvmin")
                if crmax1_linear > expected_wvmax:
                    logger.warning(">>> islitlet: " +str(islitlet))
                    logger.warning("expected_wvmax: " + str(expected_wvmax))
                    logger.warning("crmax1_linear.: " + str(crmax1_linear))
                    logger.warning("WARNING: Unexpected crmax1_linear > "
                                   "expected_wvmax")

                cout += '.'

            else:

                cout += 'x'

            if islitlet % 10 == 0:
                if cout != 'x':
                    cout = str(islitlet // 10)

            if debugplot != 0:
                pause_debugplot(debugplot)

        else:

            # define Slitlet2dArc object
            slt = Slitlet2dArc(
                islitlet=islitlet,
                csu_conf=csu_conf,
                ymargin_bb=args_ymargin_bb,
                params=None,
                parmodel=None,
                debugplot=debugplot
            )

            cout += 'i'

        # store current slitlet in list of measured slitlets
        measured_slitlets.append(slt)

        logger.info(cout)

    # ---

    # generate FITS file structure with 55 spectra corresponding to the
    # median spectrum for each slitlet
    reduced_55sp = fits.PrimaryHDU(data=image2d_55sp)
    reduced_55sp.header['crpix1'] = (0.0, 'reference pixel')
    reduced_55sp.header['crval1'] = (0.0, 'central value at crpix2')
    reduced_55sp.header['cdelt1'] = (1.0, 'increment')
    reduced_55sp.header['ctype1'] = 'PIXEL'
    reduced_55sp.header['cunit1'] = ('Pixel', 'units along axis2')
    reduced_55sp.header['crpix2'] = (0.0, 'reference pixel')
    reduced_55sp.header['crval2'] = (0.0, 'central value at crpix2')
    reduced_55sp.header['cdelt2'] = (1.0, 'increment')
    reduced_55sp.header['ctype2'] = 'PIXEL'
    reduced_55sp.header['cunit2'] = ('Pixel', 'units along axis2')

    # ---

    # Generate structure to store intermediate results
    outdict = {}
    outdict['instrument'] = 'EMIR'
    outdict['meta_info'] = {}
    outdict['meta_info']['creation_date'] = datetime.now().isoformat()
    outdict['meta_info']['description'] = \
        'computation of rectification and wavelength calibration polynomial ' \
        'coefficients for a particular CSU configuration'
    outdict['meta_info']['recipe_name'] = 'undefined'
    outdict['meta_info']['origin'] = {}
    outdict['meta_info']['origin']['bound_param_uuid'] = \
        bound_param.uuid
    outdict['meta_info']['origin']['arc_image_uuid'] = 'undefined'
    outdict['tags'] = {}
    outdict['tags']['grism'] = grism_name
    outdict['tags']['filter'] = filter_name
    outdict['tags']['islitlet_min'] = islitlet_min
    outdict['tags']['islitlet_max'] = islitlet_max
    outdict['dtu_configuration'] = dtu_conf.outdict()
    outdict['uuid'] = str(uuid4())
    outdict['contents'] = {}

    missing_slitlets = []
    for slt in measured_slitlets:

        islitlet = slt.islitlet

        if islitlet_min <= islitlet <= islitlet_max:

            # avoid error when creating a python list of coefficients from
            # numpy polynomials when the polynomials do not exist (note that
            # the JSON format doesn't handle numpy arrays and such arrays must
            # be transformed into native python lists)
            if slt.wpoly is None:
                wpoly_coeff = None
            else:
                wpoly_coeff = slt.wpoly.coef.tolist()
            if slt.wpoly_longslit_model is None:
                wpoly_coeff_longslit_model = None
            else:
                wpoly_coeff_longslit_model = \
                    slt.wpoly_longslit_model.coef.tolist()

            # avoid similar error when creating a python list of coefficients
            # when the numpy array does not exist; note that this problem
            # does not happen with tt?_aij_longslit_model and
            # tt?_bij_longslit_model because the latter have already been
            # created as native python lists
            if slt.ttd_aij is None:
                ttd_aij = None
            else:
                ttd_aij = slt.ttd_aij.tolist()
            if slt.ttd_bij is None:
                ttd_bij = None
            else:
                ttd_bij = slt.ttd_bij.tolist()
            if slt.tti_aij is None:
                tti_aij = None
            else:
                tti_aij = slt.tti_aij.tolist()
            if slt.tti_bij is None:
                tti_bij = None
            else:
                tti_bij = slt.tti_bij.tolist()

            # creating temporary dictionary with the information corresponding
            # to the current slitlett that will be saved in the JSON file
            tmp_dict = {
                'csu_bar_left': slt.csu_bar_left,
                'csu_bar_right': slt.csu_bar_right,
                'csu_bar_slit_center': slt.csu_bar_slit_center,
                'csu_bar_slit_width': slt.csu_bar_slit_width,
                'x0_reference': slt.x0_reference,
                'y0_reference_lower': slt.y0_reference_lower,
                'y0_reference_middle': slt.y0_reference_middle,
                'y0_reference_upper': slt.y0_reference_upper,
                'y0_reference_lower_expected':
                    slt.y0_reference_lower_expected,
                'y0_reference_middle_expected':
                    slt.y0_reference_middle_expected,
                'y0_reference_upper_expected':
                    slt.y0_reference_upper_expected,
                'y0_frontier_lower': slt.y0_frontier_lower,
                'y0_frontier_upper': slt.y0_frontier_upper,
                'y0_frontier_lower_expected': slt.y0_frontier_lower_expected,
                'y0_frontier_upper_expected': slt.y0_frontier_upper_expected,
                'corr_yrect_a': slt.corr_yrect_a,
                'corr_yrect_b': slt.corr_yrect_b,
                'min_row_rectified': slt.min_row_rectified,
                'max_row_rectified': slt.max_row_rectified,
                'ymargin_bb': slt.ymargin_bb,
                'bb_nc1_orig': slt.bb_nc1_orig,
                'bb_nc2_orig': slt.bb_nc2_orig,
                'bb_ns1_orig': slt.bb_ns1_orig,
                'bb_ns2_orig': slt.bb_ns2_orig,
                'spectrail': {
                    'poly_coef_lower':
                        slt.list_spectrails[
                            slt.i_lower_spectrail].poly_funct.coef.tolist(),
                    'poly_coef_middle':
                        slt.list_spectrails[
                            slt.i_middle_spectrail].poly_funct.coef.tolist(),
                    'poly_coef_upper':
                        slt.list_spectrails[
                            slt.i_upper_spectrail].poly_funct.coef.tolist(),
                },
                'frontier': {
                    'poly_coef_lower':
                        slt.list_frontiers[0].poly_funct.coef.tolist(),
                    'poly_coef_upper':
                        slt.list_frontiers[1].poly_funct.coef.tolist(),
                },
                'ttd_order': slt.ttd_order,
                'ttd_aij': ttd_aij,
                'ttd_bij': ttd_bij,
                'tti_aij': tti_aij,
                'tti_bij': tti_bij,
                'ttd_order_longslit_model': slt.ttd_order_longslit_model,
                'ttd_aij_longslit_model': slt.ttd_aij_longslit_model,
                'ttd_bij_longslit_model': slt.ttd_bij_longslit_model,
                'tti_aij_longslit_model': slt.tti_aij_longslit_model,
                'tti_bij_longslit_model': slt.tti_bij_longslit_model,
                'wpoly_coeff': wpoly_coeff,
                'wpoly_coeff_longslit_model': wpoly_coeff_longslit_model,
                'crval1_linear': slt.crval1_linear,
                'cdelt1_linear': slt.cdelt1_linear
            }
        else:
            missing_slitlets.append(islitlet)
            tmp_dict = {
                'csu_bar_left': slt.csu_bar_left,
                'csu_bar_right': slt.csu_bar_right,
                'csu_bar_slit_center': slt.csu_bar_slit_center,
                'csu_bar_slit_width': slt.csu_bar_slit_width,
                'x0_reference': slt.x0_reference,
                'y0_frontier_lower_expected': slt.y0_frontier_lower_expected,
                'y0_frontier_upper_expected': slt.y0_frontier_upper_expected
            }
        slitlet_label = "slitlet" + str(islitlet).zfill(2)
        outdict['contents'][slitlet_label] = tmp_dict

    # ---

    # OBSOLETE
    '''
    # save JSON file needed to compute the MOS model
    with open(args.out_json.name, 'w') as fstream:
        json.dump(outdict, fstream, indent=2, sort_keys=True)
        print('>>> Saving file ' + args.out_json.name)
    '''

    # ---

    # Create object of type RectWaveCoeff with coefficients for
    # rectification and wavelength calibration
    rectwv_coeff = RectWaveCoeff(instrument='EMIR')
    rectwv_coeff.quality_control = numina.types.qc.QC.GOOD
    rectwv_coeff.tags['grism'] = grism_name
    rectwv_coeff.tags['filter'] = filter_name
    rectwv_coeff.meta_info['origin']['bound_param'] = \
        'uuid' + bound_param.uuid
    rectwv_coeff.meta_info['dtu_configuration'] = outdict['dtu_configuration']
    rectwv_coeff.total_slitlets = EMIR_NBARS
    rectwv_coeff.missing_slitlets = missing_slitlets
    for i in range(EMIR_NBARS):
        islitlet = i + 1
        dumdict = {'islitlet': islitlet}
        cslitlet = 'slitlet' + str(islitlet).zfill(2)
        if cslitlet in outdict['contents']:
            dumdict.update(outdict['contents'][cslitlet])
        else:
            raise ValueError("Unexpected error")
        rectwv_coeff.contents.append(dumdict)
    # debugging __getstate__ and __setstate__
    # rectwv_coeff.writeto(args.out_json.name)
    # print('>>> Saving file ' + args.out_json.name)
    # check_setstate_getstate(rectwv_coeff, args.out_json.name)
    logger.info('Generating RectWaveCoeff object with uuid=' +
                rectwv_coeff.uuid)

    return rectwv_coeff, reduced_55sp
Example #17
0
def refine_rectwv_coeff(input_image, rectwv_coeff,
                        refine_wavecalib_mode,
                        minimum_slitlet_width_mm,
                        maximum_slitlet_width_mm,
                        save_intermediate_results=False,
                        debugplot=0):
    """Refine RectWaveCoeff object using a catalogue of lines

    One and only one among refine_with_oh_lines_mode and
    refine_with_arc_lines must be different from zero.

    Parameters
    ----------
    input_image : HDUList object
        Input 2D image.
    rectwv_coeff : RectWaveCoeff instance
        Rectification and wavelength calibration coefficients for the
        particular CSU configuration.
    refine_wavecalib_mode : int
        Integer, indicating the type of refinement:
        0 : no refinement
        1 : apply the same global offset to all the slitlets (using ARC lines)
        2 : apply individual offset to each slitlet (using ARC lines)
        11 : apply the same global offset to all the slitlets (using OH lines)
        12 : apply individual offset to each slitlet (using OH lines)
    minimum_slitlet_width_mm : float
        Minimum slitlet width (mm) for a valid slitlet.
    maximum_slitlet_width_mm : float
        Maximum slitlet width (mm) for a valid slitlet.
    save_intermediate_results : bool
        If True, save plots in PDF files
    debugplot : int
        Determines whether intermediate computations and/or plots
        are displayed. The valid codes are defined in
        numina.array.display.pause_debugplot.

    Returns
    -------
    refined_rectwv_coeff : RectWaveCoeff instance
        Refined rectification and wavelength calibration coefficients
        for the particular CSU configuration.
    expected_cat_image : HDUList object
        Output 2D image with the expected catalogued lines.

    """

    logger = logging.getLogger(__name__)

    if save_intermediate_results:
        from matplotlib.backends.backend_pdf import PdfPages
        pdf = PdfPages('crosscorrelation.pdf')
    else:
        pdf = None

    # image header
    main_header = input_image[0].header
    filter_name = main_header['filter']
    grism_name = main_header['grism']

    # protections
    if refine_wavecalib_mode not in [1, 2, 11, 12]:
        logger.error('Wavelength calibration refinemente mode={}'. format(
            refine_wavecalib_mode
        ))
        raise ValueError("Invalid wavelength calibration refinement mode")

    # read tabulated lines
    if refine_wavecalib_mode in [1, 2]:        # ARC lines
        if grism_name == 'LR':
            catlines_file = 'lines_argon_neon_xenon_empirical_LR.dat'
        else:
            catlines_file = 'lines_argon_neon_xenon_empirical.dat'
        dumdata = pkgutil.get_data('emirdrp.instrument.configs', catlines_file)
        arc_lines_tmpfile = StringIO(dumdata.decode('utf8'))
        catlines = np.genfromtxt(arc_lines_tmpfile)
        # define wavelength and flux as separate arrays
        catlines_all_wave = catlines[:, 0]
        catlines_all_flux = catlines[:, 1]
        mode = refine_wavecalib_mode
    elif refine_wavecalib_mode in [11, 12]:    # OH lines
        dumdata = pkgutil.get_data(
            'emirdrp.instrument.configs',
            'Oliva_etal_2013.dat'
        )
        oh_lines_tmpfile = StringIO(dumdata.decode('utf8'))
        catlines = np.genfromtxt(oh_lines_tmpfile)
        # define wavelength and flux as separate arrays
        catlines_all_wave = np.concatenate((catlines[:, 1], catlines[:, 0]))
        catlines_all_flux = np.concatenate((catlines[:, 2], catlines[:, 2]))
        mode = refine_wavecalib_mode - 10
    else:
        raise ValueError('Unexpected mode={}'.format(refine_wavecalib_mode))

    # initialize output
    refined_rectwv_coeff = deepcopy(rectwv_coeff)

    logger.info('Computing median spectrum')
    # compute median spectrum and normalize it
    sp_median = median_slitlets_rectified(
        input_image,
        mode=2,
        minimum_slitlet_width_mm=minimum_slitlet_width_mm,
        maximum_slitlet_width_mm=maximum_slitlet_width_mm
    )[0].data
    sp_median /= sp_median.max()

    # determine minimum and maximum useful wavelength
    jmin, jmax = find_pix_borders(sp_median, 0)
    naxis1 = main_header['naxis1']
    naxis2 = main_header['naxis2']
    crpix1 = main_header['crpix1']
    crval1 = main_header['crval1']
    cdelt1 = main_header['cdelt1']
    xwave = crval1 + (np.arange(naxis1) + 1.0 - crpix1) * cdelt1
    if grism_name == 'LR':
        wv_parameters = set_wv_parameters(filter_name, grism_name)
        wave_min = wv_parameters['wvmin_useful']
        wave_max = wv_parameters['wvmax_useful']
    else:
        wave_min = crval1 + (jmin + 1 - crpix1) * cdelt1
        wave_max = crval1 + (jmax + 1 - crpix1) * cdelt1
    logger.info('Setting wave_min to {}'.format(wave_min))
    logger.info('Setting wave_max to {}'.format(wave_max))

    # extract subset of catalogue lines within current wavelength range
    lok1 = catlines_all_wave >= wave_min
    lok2 = catlines_all_wave <= wave_max
    catlines_reference_wave = catlines_all_wave[lok1*lok2]
    catlines_reference_flux = catlines_all_flux[lok1*lok2]
    catlines_reference_flux /= catlines_reference_flux.max()

    # estimate sigma to broaden catalogue lines
    csu_config = CsuConfiguration.define_from_header(main_header)
    # segregate slitlets
    list_useful_slitlets = csu_config.widths_in_range_mm(
        minwidth=minimum_slitlet_width_mm,
        maxwidth=maximum_slitlet_width_mm
    )
    # remove missing slitlets
    if len(refined_rectwv_coeff.missing_slitlets) > 0:
        for iremove in refined_rectwv_coeff.missing_slitlets:
            if iremove in list_useful_slitlets:
                list_useful_slitlets.remove(iremove)

    list_not_useful_slitlets = [i for i in list(range(1, EMIR_NBARS + 1))
                                if i not in list_useful_slitlets]
    logger.info('list of useful slitlets: {}'.format(
        list_useful_slitlets))
    logger.info('list of not useful slitlets: {}'.format(
        list_not_useful_slitlets))
    tempwidths = np.array([csu_config.csu_bar_slit_width(islitlet)
                           for islitlet in list_useful_slitlets])
    widths_summary = summary(tempwidths)
    logger.info('Statistics of useful slitlet widths (mm):')
    logger.info('- npoints....: {0:d}'.format(widths_summary['npoints']))
    logger.info('- mean.......: {0:7.3f}'.format(widths_summary['mean']))
    logger.info('- median.....: {0:7.3f}'.format(widths_summary['median']))
    logger.info('- std........: {0:7.3f}'.format(widths_summary['std']))
    logger.info('- robust_std.: {0:7.3f}'.format(widths_summary['robust_std']))
    # empirical transformation of slit width (mm) to pixels
    sigma_broadening = cdelt1 * widths_summary['median']

    # convolve location of catalogue lines to generate expected spectrum
    xwave_reference, sp_reference = convolve_comb_lines(
        catlines_reference_wave, catlines_reference_flux, sigma_broadening,
        crpix1, crval1, cdelt1, naxis1
    )
    sp_reference /= sp_reference.max()

    # generate image2d with expected lines
    image2d_expected_lines = np.tile(sp_reference, (naxis2, 1))
    hdu = fits.PrimaryHDU(data=image2d_expected_lines, header=main_header)
    expected_cat_image = fits.HDUList([hdu])

    if (abs(debugplot) % 10 != 0) or (pdf is not None):
        ax = ximplotxy(xwave, sp_median, 'C1-',
                       xlabel='Wavelength (Angstroms, in vacuum)',
                       ylabel='Normalized number of counts',
                       title='Median spectrum',
                       label='observed spectrum', show=False)
        # overplot reference catalogue lines
        ax.stem(catlines_reference_wave, catlines_reference_flux, 'C4-',
                markerfmt=' ', basefmt='C4-', label='tabulated lines')
        # overplot convolved reference lines
        ax.plot(xwave_reference, sp_reference, 'C0-',
                label='expected spectrum')
        ax.legend()
        if pdf is not None:
            pdf.savefig()
        else:
            pause_debugplot(debugplot=debugplot, pltshow=True)

    # compute baseline signal in sp_median
    baseline = np.percentile(sp_median[sp_median > 0], q=10)
    if (abs(debugplot) % 10 != 0) or (pdf is not None):
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.hist(sp_median, bins=1000, log=True)
        ax.set_xlabel('Normalized number of counts')
        ax.set_ylabel('Number of pixels')
        ax.set_title('Median spectrum')
        ax.axvline(float(baseline), linestyle='--', color='grey')
        if pdf is not None:
            pdf.savefig()
        else:
            geometry = (0, 0, 640, 480)
            set_window_geometry(geometry)
            plt.show()
    # subtract baseline to sp_median (only pixels with signal above zero)
    lok = np.where(sp_median > 0)
    sp_median[lok] -= baseline

    # compute global offset through periodic correlation
    logger.info('Computing global offset')
    global_offset, fpeak = periodic_corr1d(
        sp_reference=sp_reference,
        sp_offset=sp_median,
        fminmax=None,
        naround_zero=50,
        plottitle='Median spectrum (cross-correlation)',
        pdf=pdf,
        debugplot=debugplot
    )
    logger.info('Global offset: {} pixels'.format(-global_offset))

    missing_slitlets = rectwv_coeff.missing_slitlets

    if mode == 1:
        # apply computed offset to obtain refined_rectwv_coeff_global
        for islitlet in range(1, EMIR_NBARS + 1):
            if islitlet not in missing_slitlets:
                i = islitlet - 1
                dumdict = refined_rectwv_coeff.contents[i]
                dumdict['wpoly_coeff'][0] -= global_offset*cdelt1

    elif mode == 2:
        # compute individual offset for each slitlet
        logger.info('Computing individual offsets')
        median_55sp = median_slitlets_rectified(input_image, mode=1)
        offset_array = np.zeros(EMIR_NBARS)
        xplot = []
        yplot = []
        xplot_skipped = []
        yplot_skipped = []
        cout = '0'
        for islitlet in range(1, EMIR_NBARS + 1):
            if islitlet in list_useful_slitlets:
                i = islitlet - 1
                sp_median = median_55sp[0].data[i, :]
                lok = np.where(sp_median > 0)
                if np.any(lok):
                    baseline = np.percentile(sp_median[lok], q=10)
                    sp_median[lok] -= baseline
                    sp_median /= sp_median.max()
                    offset_array[i], fpeak = periodic_corr1d(
                        sp_reference=sp_reference,
                        sp_offset=median_55sp[0].data[i, :],
                        fminmax=None,
                        naround_zero=50,
                        plottitle='slitlet #{0} (cross-correlation)'.format(
                            islitlet),
                        pdf=pdf,
                        debugplot=debugplot
                    )
                else:
                    offset_array[i] = 0.0
                dumdict = refined_rectwv_coeff.contents[i]
                dumdict['wpoly_coeff'][0] -= offset_array[i]*cdelt1
                xplot.append(islitlet)
                yplot.append(-offset_array[i])
                # second correction
                wpoly_coeff_refined = check_wlcalib_sp(
                    sp=median_55sp[0].data[i, :],
                    crpix1=crpix1,
                    crval1=crval1-offset_array[i]*cdelt1,
                    cdelt1=cdelt1,
                    wv_master=catlines_reference_wave,
                    coeff_ini=dumdict['wpoly_coeff'],
                    naxis1_ini=EMIR_NAXIS1,
                    title='slitlet #{0} (after applying offset)'.format(
                        islitlet),
                    ylogscale=False,
                    pdf=pdf,
                    debugplot=debugplot
                )
                dumdict['wpoly_coeff'] = wpoly_coeff_refined
                cout += '.'

            else:
                xplot_skipped.append(islitlet)
                yplot_skipped.append(0)
                cout += 'i'

            if islitlet % 10 == 0:
                if cout != 'i':
                    cout = str(islitlet // 10)

            logger.info(cout)

        # show offsets with opposite sign
        stat_summary = summary(np.array(yplot))
        logger.info('Statistics of individual slitlet offsets (pixels):')
        logger.info('- npoints....: {0:d}'.format(stat_summary['npoints']))
        logger.info('- mean.......: {0:7.3f}'.format(stat_summary['mean']))
        logger.info('- median.....: {0:7.3f}'.format(stat_summary['median']))
        logger.info('- std........: {0:7.3f}'.format(stat_summary['std']))
        logger.info('- robust_std.: {0:7.3f}'.format(stat_summary[
                                                        'robust_std']))
        if (abs(debugplot) % 10 != 0) or (pdf is not None):
            ax = ximplotxy(xplot, yplot,
                           linestyle='', marker='o', color='C0',
                           xlabel='slitlet number',
                           ylabel='-offset (pixels) = offset to be applied',
                           title='cross-correlation result',
                           show=False, **{'label': 'individual slitlets'})
            if len(xplot_skipped) > 0:
                ax.plot(xplot_skipped, yplot_skipped, 'mx')
            ax.axhline(-global_offset, linestyle='--', color='C1',
                       label='global offset')
            ax.legend()
            if pdf is not None:
                pdf.savefig()
            else:
                pause_debugplot(debugplot=debugplot, pltshow=True)
    else:
        raise ValueError('Unexpected mode={}'.format(mode))

    # close output PDF file
    if pdf is not None:
        pdf.close()

    # return result
    return refined_rectwv_coeff, expected_cat_image
def rectwv_coeff_add_longslit_model(rectwv_coeff, geometry, debugplot=0):
    """Compute longslit_model coefficients for RectWaveCoeff object.

    Parameters
    ----------
    rectwv_coeff : RectWaveCoeff instance
        Rectification and wavelength calibration coefficients for a
        particular CSU configuration corresponding to a longslit
        observation.
    geometry : TBD
    debugplot : int
        Debugging level for messages and plots. For details see
        'numina.array.display.pause_debugplot.py'.

    Returns
    -------
    rectwv_coeff : RectWaveCoeff instance
        Updated object with longslit_model coefficients computed.

    """

    logger = logging.getLogger(__name__)

    # check grism and filter
    grism_name = rectwv_coeff.tags['grism']
    logger.info('Grism: ' + grism_name)
    filter_name = rectwv_coeff.tags['filter']
    logger.info('Filter: ' + filter_name)

    # list of slitlets to be computed
    list_valid_islitlets = list(range(1, EMIR_NBARS + 1))
    for idel in rectwv_coeff.missing_slitlets:
        list_valid_islitlets.remove(idel)
    if abs(debugplot) >= 10:
        print('>>> valid slitlet numbers:\n', list_valid_islitlets)

    # ---

    # check that the CSU configuration corresponds to longslit
    csu_bar_slit_center_list = []
    for islitlet in list_valid_islitlets:
        csu_bar_slit_center_list.append(
            rectwv_coeff.contents[islitlet - 1]['csu_bar_slit_center']
        )
    if abs(debugplot) >= 10:
        logger.debug('Checking csu_bar_slit_center values:')
        summary(np.array(csu_bar_slit_center_list), debug=True)
        pause_debugplot(debugplot)

    # ---

    # polynomial coefficients corresponding to the wavelength calibration

    # step 0: determine poldeg_refined, checking that it is the same for
    # all the slitlets
    poldeg_refined_list = []
    for islitlet in list_valid_islitlets:
        poldeg_refined_list.append(
            len(rectwv_coeff.contents[islitlet - 1]['wpoly_coeff']) - 1
        )
    # remove duplicates
    poldeg_refined_list = list(set(poldeg_refined_list))
    if len(poldeg_refined_list) != 1:
        raise ValueError('Unexpected different poldeg_refined found: ' +
                         str(poldeg_refined_list))
    poldeg_refined = poldeg_refined_list[0]

    # step 1: compute variation of each coefficient as a function of
    # y0_reference_middle of each slitlet
    list_poly = []
    for i in range(poldeg_refined + 1):
        xp = []
        yp = []
        for islitlet in list_valid_islitlets:
            tmp_dict = rectwv_coeff.contents[islitlet - 1]
            wpoly_coeff = tmp_dict['wpoly_coeff']
            if wpoly_coeff is not None:
                xp.append(tmp_dict['y0_reference_middle'])
                yp.append(wpoly_coeff[i])
        poly, yres, reject = polfit_residuals_with_sigma_rejection(
            x=np.array(xp),
            y=np.array(yp),
            deg=2,
            times_sigma_reject=5,
            xlabel='y0_rectified',
            ylabel='coeff[' + str(i) + ']',
            title="Fit to refined wavelength calibration coefficients",
            geometry=geometry,
            debugplot=debugplot
        )
        list_poly.append(poly)

    # step 2: use the variation of each polynomial coefficient with
    # y0_reference_middle to infer the expected wavelength calibration
    # polynomial for each rectifified slitlet
    for islitlet in list_valid_islitlets:
        tmp_dict = rectwv_coeff.contents[islitlet - 1]
        y0_reference_middle = tmp_dict['y0_reference_middle']
        list_new_coeff = []
        for i in range(poldeg_refined + 1):
            new_coeff = list_poly[i](y0_reference_middle)
            list_new_coeff.append(new_coeff)
        tmp_dict['wpoly_coeff_longslit_model'] = list_new_coeff

    # ---

    # rectification transformation coefficients aij and bij

    # step 0: determine order_fmap, checking that it is the same for
    # all the slitlets
    order_fmap_list = []
    for islitlet in list_valid_islitlets:
        order_fmap_list.append(
            rectwv_coeff.contents[islitlet - 1]['ttd_order']
        )
    # remove duplicates
    order_fmap_list = list(set(order_fmap_list))
    if len(order_fmap_list) != 1:
        raise ValueError('Unexpected different order_fmap found')
    order_fmap = order_fmap_list[0]

    # step 1: compute variation of each coefficient as a function of
    # y0_reference_middle of each slitlet
    list_poly_ttd_aij = []
    list_poly_ttd_bij = []
    list_poly_tti_aij = []
    list_poly_tti_bij = []
    ncoef_ttd = ncoef_fmap(order_fmap)
    for i in range(ncoef_ttd):
        xp = []
        yp_ttd_aij = []
        yp_ttd_bij = []
        yp_tti_aij = []
        yp_tti_bij = []
        for islitlet in list_valid_islitlets:
            tmp_dict = rectwv_coeff.contents[islitlet - 1]
            ttd_aij = tmp_dict['ttd_aij']
            ttd_bij = tmp_dict['ttd_bij']
            tti_aij = tmp_dict['tti_aij']
            tti_bij = tmp_dict['tti_bij']
            if ttd_aij is not None:
                xp.append(tmp_dict['y0_reference_middle'])
                yp_ttd_aij.append(ttd_aij[i])
                yp_ttd_bij.append(ttd_bij[i])
                yp_tti_aij.append(tti_aij[i])
                yp_tti_bij.append(tti_bij[i])
        poly, yres, reject = polfit_residuals_with_sigma_rejection(
            x=np.array(xp),
            y=np.array(yp_ttd_aij),
            deg=5,
            times_sigma_reject=5,
            xlabel='y0_rectified',
            ylabel='ttd_aij[' + str(i) + ']',
            geometry=geometry,
            debugplot=debugplot
        )
        list_poly_ttd_aij.append(poly)
        poly, yres, reject = polfit_residuals_with_sigma_rejection(
            x=np.array(xp),
            y=np.array(yp_ttd_bij),
            deg=5,
            times_sigma_reject=5,
            xlabel='y0_rectified',
            ylabel='ttd_bij[' + str(i) + ']',
            geometry=geometry,
            debugplot=debugplot
        )
        list_poly_ttd_bij.append(poly)
        poly, yres, reject = polfit_residuals_with_sigma_rejection(
            x=np.array(xp),
            y=np.array(yp_tti_aij),
            deg=5,
            times_sigma_reject=5,
            xlabel='y0_rectified',
            ylabel='tti_aij[' + str(i) + ']',
            geometry=geometry,
            debugplot=debugplot
        )
        list_poly_tti_aij.append(poly)
        poly, yres, reject = polfit_residuals_with_sigma_rejection(
            x=np.array(xp),
            y=np.array(yp_tti_bij),
            deg=5,
            times_sigma_reject=5,
            xlabel='y0_rectified',
            ylabel='tti_bij[' + str(i) + ']',
            geometry=geometry,
            debugplot=debugplot
        )
        list_poly_tti_bij.append(poly)

    # step 2: use the variation of each coefficient with y0_reference_middle
    # to infer the expected rectification transformation for each slitlet
    for islitlet in list_valid_islitlets:
        tmp_dict = rectwv_coeff.contents[islitlet - 1]
        y0_reference_middle = tmp_dict['y0_reference_middle']
        tmp_dict['ttd_order_longslit_model'] = order_fmap
        ttd_aij_longslit_model = []
        ttd_bij_longslit_model = []
        tti_aij_longslit_model = []
        tti_bij_longslit_model = []
        for i in range(ncoef_ttd):
            new_coeff = list_poly_ttd_aij[i](y0_reference_middle)
            ttd_aij_longslit_model.append(new_coeff)
            new_coeff = list_poly_ttd_bij[i](y0_reference_middle)
            ttd_bij_longslit_model.append(new_coeff)
            new_coeff = list_poly_tti_aij[i](y0_reference_middle)
            tti_aij_longslit_model.append(new_coeff)
            new_coeff = list_poly_tti_bij[i](y0_reference_middle)
            tti_bij_longslit_model.append(new_coeff)
        tmp_dict['ttd_aij_longslit_model'] = ttd_aij_longslit_model
        tmp_dict['ttd_bij_longslit_model'] = ttd_bij_longslit_model
        tmp_dict['tti_aij_longslit_model'] = tti_aij_longslit_model
        tmp_dict['tti_bij_longslit_model'] = tti_bij_longslit_model

    # ---

    # update uuid and meta_info in output JSON structure
    rectwv_coeff.uuid = str(uuid4())
    rectwv_coeff.meta_info['creation_date'] = datetime.now().isoformat()

    # return updated object
    return rectwv_coeff
def compute_slitlet_boundaries(
        filename, grism, spfilter, list_slitlets,
        size_x_medfilt, size_y_savgol,
        times_sigma_threshold,
        bounddict, debugplot=0):
    """Compute slitlet boundaries using continuum lamp images.

    Parameters
    ----------
    filename : string
        Input continumm lamp image.
    grism : string
        Grism name. It must be one in EMIR_VALID_GRISMS.
    spfilter : string
        Filter name. It must be one in EMIR_VALID_FILTERS.
    list_slitlets : list of integers
        Number of slitlets to be updated.
    size_x_medfilt : int
        Window in the X (spectral) direction, in pixels, to apply
        the 1d median filter in order to remove bad pixels.
    size_y_savgol : int
        Window in the Y (spatial) direction to be used when using the
        1d Savitzky-Golay filter.
    times_sigma_threshold : float
        Times sigma to detect peaks in derivatives.
    bounddict : dictionary of dictionaries
        Structure to store the boundaries.
    debugplot : int
        Determines whether intermediate computations and/or plots are
        displayed.

    """

    # read 2D image
    hdulist = fits.open(filename)
    image_header = hdulist[0].header
    image2d = hdulist[0].data
    naxis2, naxis1 = image2d.shape
    hdulist.close()
    if debugplot >= 10:
        print('>>> NAXIS1:', naxis1)
        print('>>> NAXIS2:', naxis2)


    # ToDo: replace this by application of cosmetic defect mask!
    for j in range(1024):
        image2d[1024, j] = (image2d[1023, j] + image2d[1025, j]) / 2
        image2d[1023, j + 1024] = (image2d[1022, j + 1024] +
                                   image2d[1024, j + 1024]) / 2
        
    # remove path from filename
    sfilename = os.path.basename(filename)

    # check that the FITS file has been obtained with EMIR
    instrument = image_header['instrume']
    if instrument != 'EMIR':
        raise ValueError("INSTRUME keyword is not 'EMIR'!")

    # read CSU configuration from FITS header
    csu_config = CsuConfiguration.define_from_fits(filename)

    # read DTU configuration from FITS header
    dtu_config = DtuConfiguration.define_from_fits(filename)

    # read grism
    grism_in_header = image_header['grism']
    if grism != grism_in_header:
        raise ValueError("GRISM keyword=" + grism_in_header +
                         " is not the expected value=" + grism)

    # read filter
    spfilter_in_header = image_header['filter']
    if spfilter != spfilter_in_header:
        raise ValueError("FILTER keyword=" + spfilter_in_header +
                         " is not the expected value=" + spfilter)

    # read rotator position angle
    rotang = image_header['rotang']

    # read date-obs
    date_obs = image_header['date-obs']

    for islitlet in list_slitlets:
        if debugplot < 10:
            sys.stdout.write('.')
            sys.stdout.flush()

        sltlim = SlitletLimits(grism, spfilter, islitlet)
        # extract slitlet2d
        slitlet2d = extract_slitlet2d(image2d, sltlim)
        if debugplot % 10 != 0:
            ximshow(slitlet2d,
                    title=sfilename + " [original]"
                          "\nslitlet=" + str(islitlet) +
                          ", grism=" + grism +
                          ", filter=" + spfilter +
                          ", rotang=" + str(round(rotang, 2)),
                    first_pixel=(sltlim.bb_nc1_orig, sltlim.bb_ns1_orig),
                    debugplot=debugplot)

        # apply 1d median filtering (along the spectral direction)
        # to remove bad pixels
        size_x = size_x_medfilt
        size_y = 1
        slitlet2d_smooth = ndimage.filters.median_filter(
            slitlet2d, size=(size_y, size_x))

        if debugplot % 10 != 0:
            ximshow(slitlet2d_smooth,
                    title=sfilename + " [smoothed]"
                          "\nslitlet=" + str(islitlet) +
                          ", grism=" + grism +
                          ", filter=" + spfilter +
                          ", rotang=" + str(round(rotang, 2)),
                    first_pixel=(sltlim.bb_nc1_orig, sltlim.bb_ns1_orig),
                    debugplot=debugplot)

        # apply 1d Savitzky-Golay filter (along the spatial direction)
        # to compute first derivative
        slitlet2d_savgol = savgol_filter(
            slitlet2d_smooth, window_length=size_y_savgol, polyorder=2,
            deriv=1, axis=0)

        # compute basic statistics
        q25, q50, q75 = np.percentile(slitlet2d_savgol, q=[25.0, 50.0, 75.0])
        sigmag = 0.7413 * (q75 - q25)  # robust standard deviation
        if debugplot >= 10:
            print("q50, sigmag:", q50, sigmag)

        if debugplot % 10 != 0:
            ximshow(slitlet2d_savgol,
                    title=sfilename + " [S.-G.filt.]"
                          "\nslitlet=" + str(islitlet) +
                          ", grism=" + grism +
                          ", filter=" + spfilter +
                          ", rotang=" + str(round(rotang, 2)),
                    first_pixel=(sltlim.bb_nc1_orig, sltlim.bb_ns1_orig),
                    z1z2=(q50-times_sigma_threshold*sigmag,
                          q50+times_sigma_threshold*sigmag),
                    debugplot=debugplot)

        # identify objects in slitlet2d_savgol: pixels with positive
        # derivatives are identify independently from pixels with
        # negative derivaties; then the two set of potential features
        # are merged; this approach avoids some problems when, in
        # nearby regions, there are pixels with positive and negative
        # derivatives (in those circumstances a single search as
        # np.logical_or(
        #     slitlet2d_savgol < q50 - times_sigma_threshold * sigmag,
        #     slitlet2d_savgol > q50 + times_sigma_threshold * sigmag)
        # led to erroneous detections!)
        #
        # search for positive derivatives
        labels2d_objects_pos, no_objects_pos = ndimage.label(
            slitlet2d_savgol > q50 + times_sigma_threshold * sigmag)
        # search for negative derivatives
        labels2d_objects_neg, no_objects_neg = ndimage.label(
            slitlet2d_savgol < q50 - times_sigma_threshold * sigmag)
        # merge both sets
        non_zero_neg = np.where(labels2d_objects_neg > 0)
        labels2d_objects = np.copy(labels2d_objects_pos)
        labels2d_objects[non_zero_neg] += \
            labels2d_objects_neg[non_zero_neg] + no_objects_pos
        no_objects = no_objects_pos + no_objects_neg

        if debugplot >= 10:
            print("Number of objects with positive derivative:",
                  no_objects_pos)
            print("Number of objects with negative derivative:",
                  no_objects_neg)
            print("Total number of objects initially found...:", no_objects)

        if debugplot % 10 != 0:
            ximshow(labels2d_objects,
                    z1z2=(0, no_objects),
                    title=sfilename + " [objects]"
                                     "\nslitlet=" + str(islitlet) +
                          ", grism=" + grism +
                          ", filter=" + spfilter +
                          ", rotang=" + str(round(rotang, 2)),
                    first_pixel=(sltlim.bb_nc1_orig, sltlim.bb_ns1_orig),
                    cbar_label="Object number",
                    debugplot=debugplot)

        # select boundaries as the largest objects found with
        # positive and negative derivatives
        n_der_pos = 0  # number of pixels covered by the object with deriv > 0
        i_der_pos = 0  # id of the object with deriv > 0
        n_der_neg = 0  # number of pixels covered by the object with deriv < 0
        i_der_neg = 0  # id of the object with deriv < 0
        for i in range(1, no_objects+1):
            xy_tmp = np.where(labels2d_objects == i)
            n_pix = len(xy_tmp[0])
            if i <= no_objects_pos:
                if n_pix > n_der_pos:
                    i_der_pos = i
                    n_der_pos = n_pix
            else:
                if n_pix > n_der_neg:
                    i_der_neg = i
                    n_der_neg = n_pix

        # determine which boundary is lower and which is upper
        y_center_mass_der_pos = ndimage.center_of_mass(
            slitlet2d_savgol, labels2d_objects, [i_der_pos])[0][0]
        y_center_mass_der_neg = ndimage.center_of_mass(
            slitlet2d_savgol, labels2d_objects, [i_der_neg])[0][0]
        if y_center_mass_der_pos < y_center_mass_der_neg:
            i_lower = i_der_pos
            i_upper = i_der_neg
            if debugplot >= 10:
                print("-> lower boundary has positive derivatives")
        else:
            i_lower = i_der_neg
            i_upper = i_der_pos
            if debugplot >= 10:
                print("-> lower boundary has negative derivatives")
        list_slices_ok = [i_lower, i_upper]

        # adjust individual boundaries passing the selection:
        # - select points in the image belonging to a given boundary
        # - compute weighted mean of the pixels of the boundary, column
        #   by column (this reduces dramatically the number of points
        #   to be fitted to determine the boundary)
        list_boundaries = []
        for k in range(2):  # k=0 lower boundary, k=1 upper boundary
            # select points to be fitted for a particular boundary
            # (note: be careful with array indices and pixel
            # coordinates)
            xy_tmp = np.where(labels2d_objects == list_slices_ok[k])
            xmin = xy_tmp[1].min()  # array indices (integers)
            xmax = xy_tmp[1].max()  # array indices (integers)
            xfit = []
            yfit = []
            # fix range for fit
            if k == 0:
                xmineff = max(sltlim.xmin_lower_boundary_fit, xmin)
                xmaxeff = min(sltlim.xmax_lower_boundary_fit, xmax)
            else:
                xmineff = max(sltlim.xmin_upper_boundary_fit, xmin)
                xmaxeff = min(sltlim.xmax_upper_boundary_fit, xmax)
            # loop in columns of the image belonging to the boundary
            for xdum in range(xmineff, xmaxeff + 1):  # array indices (integer)
                iok = np.where(xy_tmp[1] == xdum)
                y_tmp = xy_tmp[0][iok] + sltlim.bb_ns1_orig  # image pixel
                weight = slitlet2d_savgol[xy_tmp[0][iok], xy_tmp[1][iok]]
                y_wmean = sum(y_tmp * weight) / sum(weight)
                xfit.append(xdum + sltlim.bb_nc1_orig)
                yfit.append(y_wmean)
            xfit = np.array(xfit)
            yfit = np.array(yfit)
            # declare new SpectrumTrail instance
            boundary = SpectrumTrail()
            # define new boundary
            boundary.fit(x=xfit, y=yfit, deg=sltlim.deg_boundary,
                         times_sigma_reject=10,
                         title="slit:" + str(sltlim.islitlet) +
                               ", deg=" + str(sltlim.deg_boundary),
                         debugplot=0)
            list_boundaries.append(boundary)

        if debugplot % 10 != 0:
            for tmp_img, tmp_label in zip(
                [slitlet2d_savgol, slitlet2d],
                [' [S.-G.filt.]', ' [original]']
            ):
                ax = ximshow(tmp_img,
                             title=sfilename + tmp_label +
                                   "\nslitlet=" + str(islitlet) +
                                   ", grism=" + grism +
                                   ", filter=" + spfilter +
                                   ", rotang=" + str(round(rotang, 2)),
                             first_pixel=(sltlim.bb_nc1_orig,
                                          sltlim.bb_ns1_orig),
                             show=False,
                             debugplot=debugplot)
                for k in range(2):
                    xpol, ypol = list_boundaries[k].linspace_pix(
                        start=1, stop=EMIR_NAXIS1)
                    ax.plot(xpol, ypol, 'b--', linewidth=1)
                for k in range(2):
                    xpol, ypol = list_boundaries[k].linspace_pix()
                    ax.plot(xpol, ypol, 'g--', linewidth=4)
                # show plot
                pause_debugplot(debugplot, pltshow=True)

        # update bounddict
        tmp_dict = {
            'boundary_coef_lower':
                list_boundaries[0].poly_funct.coef.tolist(),
            'boundary_xmin_lower': list_boundaries[0].xlower_line,
            'boundary_xmax_lower': list_boundaries[0].xupper_line,
            'boundary_coef_upper':
                list_boundaries[1].poly_funct.coef.tolist(),
            'boundary_xmin_upper': list_boundaries[1].xlower_line,
            'boundary_xmax_upper': list_boundaries[1].xupper_line,
            'csu_bar_left': csu_config.csu_bar_left(islitlet),
            'csu_bar_right': csu_config.csu_bar_right(islitlet),
            'csu_bar_slit_center': csu_config.csu_bar_slit_center(islitlet),
            'csu_bar_slit_width': csu_config.csu_bar_slit_width(islitlet),
            'rotang': rotang,
            'xdtu': dtu_config.xdtu,
            'ydtu': dtu_config.ydtu,
            'zdtu': dtu_config.zdtu,
            'xdtu_0': dtu_config.xdtu_0,
            'ydtu_0': dtu_config.ydtu_0,
            'zdtu_0': dtu_config.zdtu_0,
            'zzz_info1': os.getlogin() + '@' + socket.gethostname(),
            'zzz_info2': datetime.now().isoformat()
        }
        slitlet_label = "slitlet" + str(islitlet).zfill(2)
        if slitlet_label not in bounddict['contents']:
            bounddict['contents'][slitlet_label] = {}
        bounddict['contents'][slitlet_label][date_obs] = tmp_dict

    if debugplot < 10:
        print("")
Example #20
0
def display_slitlet_arrangement(fileobj,
                                grism=None,
                                spfilter=None,
                                bbox=None,
                                adjust=None,
                                geometry=None,
                                debugplot=0):
    """Display slitlet arrangment from CSUP keywords in FITS header.

    Parameters
    ----------
    fileobj : file object
        FITS or TXT file object.
    grism : str
        Grism.
    grism : str
        Filter.
    bbox : tuple of 4 floats
        If not None, values for xmin, xmax, ymin and ymax.
    adjust : bool
        Adjust X range according to minimum and maximum csu_bar_left
        and csu_bar_right (note that this option is overriden by 'bbox')
    geometry : tuple (4 integers) or None
        x, y, dx, dy values employed to set the Qt backend geometry.
    debugplot : int
        Determines whether intermediate computations and/or plots
        are displayed. The valid codes are defined in
        numina.array.display.pause_debugplot

    Returns
    -------
    csu_bar_left : list of floats
        Location (mm) of the left bar for each slitlet.
    csu_bar_right : list of floats
        Location (mm) of the right bar for each slitlet, using the
        same origin employed for csu_bar_left (which is not the
        value stored in the FITS keywords.
    csu_bar_slit_center : list of floats
        Middle point (mm) in between the two bars defining a slitlet.
    csu_bar_slit_width : list of floats
        Slitlet width (mm), computed as the distance between the two
        bars defining the slitlet.

    """

    if fileobj.name[-4:] == ".txt":
        if grism is None:
            raise ValueError("Undefined grism!")
        if spfilter is None:
            raise ValueError("Undefined filter!")
        # define CsuConfiguration object
        csu_config = CsuConfiguration()
        csu_config._csu_bar_left = []
        csu_config._csu_bar_right = []
        csu_config._csu_bar_slit_center = []
        csu_config._csu_bar_slit_width = []

        # since the input filename has been opened with argparse in binary
        # mode, it is necessary to close it and open it in text mode
        fileobj.close()
        # read TXT file
        with open(fileobj.name, mode='rt') as f:
            file_content = f.read().splitlines()
        next_id_bar = 1
        for line in file_content:
            if len(line) > 0:
                if line[0] not in ['#']:
                    line_contents = line.split()
                    id_bar = int(line_contents[0])
                    position = float(line_contents[1])
                    if id_bar == next_id_bar:
                        if id_bar <= EMIR_NBARS:
                            csu_config._csu_bar_left.append(position)
                            next_id_bar = id_bar + EMIR_NBARS
                        else:
                            csu_config._csu_bar_right.append(341.5 - position)
                            next_id_bar = id_bar - EMIR_NBARS + 1
                    else:
                        raise ValueError("Unexpected id_bar:" + str(id_bar))

        # compute slit width and center
        for i in range(EMIR_NBARS):
            csu_config._csu_bar_slit_center.append(
                (csu_config._csu_bar_left[i] + csu_config._csu_bar_right[i]) /
                2)
            csu_config._csu_bar_slit_width.append(
                csu_config._csu_bar_right[i] - csu_config._csu_bar_left[i])

    else:
        # read input FITS file
        hdulist = fits.open(fileobj.name)
        image_header = hdulist[0].header
        hdulist.close()

        # additional info from header
        grism = image_header['grism']
        spfilter = image_header['filter']

        # define slitlet arrangement
        csu_config = CsuConfiguration.define_from_fits(fileobj)

    # determine calibration
    if grism in ["J", "OPEN"] and spfilter == "J":
        wv_parameters = set_wv_parameters("J", "J")
    elif grism in ["H", "OPEN"] and spfilter == "H":
        wv_parameters = set_wv_parameters("H", "H")
    elif grism in ["K", "OPEN"] and spfilter == "Ksp":
        wv_parameters = set_wv_parameters("Ksp", "K")
    elif grism in ["LR", "OPEN"] and spfilter == "YJ":
        wv_parameters = set_wv_parameters("YJ", "LR")
    elif grism in ["LR", "OPEN"] and spfilter == "HK":
        wv_parameters = set_wv_parameters("HK", "LR")
    else:
        raise ValueError("Invalid grism + filter configuration")

    crval1 = wv_parameters['poly_crval1_linear']
    cdelt1 = wv_parameters['poly_cdelt1_linear']

    wvmin_useful = wv_parameters['wvmin_useful']
    wvmax_useful = wv_parameters['wvmax_useful']

    # display arrangement
    if abs(debugplot) >= 10:
        print("slit     left    right   center   width   min.wave   max.wave")
        print("====  =======  =======  =======   =====   ========   ========")
        for i in range(EMIR_NBARS):
            ibar = i + 1
            csu_crval1 = crval1(csu_config.csu_bar_slit_center(ibar))
            csu_cdelt1 = cdelt1(csu_config.csu_bar_slit_center(ibar))
            csu_crvaln = csu_crval1 + (EMIR_NAXIS1 - 1) * csu_cdelt1
            if wvmin_useful is not None:
                csu_crval1 = np.amax([csu_crval1, wvmin_useful])
            if wvmax_useful is not None:
                csu_crvaln = np.amin([csu_crvaln, wvmax_useful])
            print("{0:4d} {1:8.3f} {2:8.3f} {3:8.3f} {4:7.3f}   "
                  "{5:8.2f}   {6:8.2f}".format(
                      ibar, csu_config.csu_bar_left(ibar),
                      csu_config.csu_bar_right(ibar),
                      csu_config.csu_bar_slit_center(ibar),
                      csu_config.csu_bar_slit_width(ibar), csu_crval1,
                      csu_crvaln))
        print("---> {0:8.3f} {1:8.3f} {2:8.3f} {3:7.3f} <- mean (all)".format(
            np.mean(csu_config._csu_bar_left),
            np.mean(csu_config._csu_bar_right),
            np.mean(csu_config._csu_bar_slit_center),
            np.mean(csu_config._csu_bar_slit_width)))
        print("---> {0:8.3f} {1:8.3f} {2:8.3f} {3:7.3f} <- mean (odd)".format(
            np.mean(csu_config._csu_bar_left[::2]),
            np.mean(csu_config._csu_bar_right[::2]),
            np.mean(csu_config._csu_bar_slit_center[::2]),
            np.mean(csu_config._csu_bar_slit_width[::2])))
        print("---> {0:8.3f} {1:8.3f} {2:8.3f} {3:7.3f} <- mean (even)".format(
            np.mean(csu_config._csu_bar_left[1::2]),
            np.mean(csu_config._csu_bar_right[1::2]),
            np.mean(csu_config._csu_bar_slit_center[1::2]),
            np.mean(csu_config._csu_bar_slit_width[1::2])))

    # display slit arrangement
    if abs(debugplot) % 10 != 0:
        fig = plt.figure()
        set_window_geometry(geometry)
        ax = fig.add_subplot(111)
        if bbox is None:
            if adjust:
                xmin = min(csu_config._csu_bar_left)
                xmax = max(csu_config._csu_bar_right)
                dx = xmax - xmin
                if dx == 0:
                    dx = 1
                xmin -= dx / 20
                xmax += dx / 20
                ax.set_xlim(xmin, xmax)
            else:
                ax.set_xlim(0., 341.5)
            ax.set_ylim(0, 56)
        else:
            ax.set_xlim(bbox[0], bbox[1])
            ax.set_ylim(bbox[2], bbox[3])
        ax.set_xlabel('csu_bar_position (mm)')
        ax.set_ylabel('slit number')
        for i in range(EMIR_NBARS):
            ibar = i + 1
            ax.add_patch(
                patches.Rectangle((csu_config.csu_bar_left(ibar), ibar - 0.5),
                                  csu_config.csu_bar_slit_width(ibar), 1.0))
            ax.plot([0., csu_config.csu_bar_left(ibar)], [ibar, ibar],
                    '-',
                    color='gray')
            ax.plot([csu_config.csu_bar_right(ibar), 341.5], [ibar, ibar],
                    '-',
                    color='gray')
        plt.title("File: " + fileobj.name + "\ngrism=" + grism + ", filter=" +
                  spfilter)
        pause_debugplot(debugplot, pltshow=True)

    # return results
    return csu_config._csu_bar_left, csu_config._csu_bar_right, \
           csu_config._csu_bar_slit_center, csu_config._csu_bar_slit_width
Example #21
0
def main(args=None):

    # parse command-line options
    parser = argparse.ArgumentParser()

    # positional arguments
    parser.add_argument("fitsfile",
                        help="FITS file name to be displayed",
                        type=argparse.FileType('rb'))
    parser.add_argument("--fitted_bound_param", required=True,
                        help="JSON file with fitted boundary coefficients "
                             "corresponding to the multislit model",
                        type=argparse.FileType('rt'))
    parser.add_argument("--slitlets", required=True,
                        help="Slitlet selection: string between double "
                             "quotes providing tuples of the form "
                             "n1[,n2[,step]]",
                        type=str)

    # optional arguments
    parser.add_argument("--outfile",
                        help="Output FITS file name",
                        type=lambda x: arg_file_is_new(parser, x, mode='wb'))
    parser.add_argument("--maskonly",
                        help="Generate mask for the indicated slitlets",
                        action="store_true")
    parser.add_argument("--debugplot",
                        help="Integer indicating plotting/debugging" +
                             " (default=0)",
                        type=int, default=0,
                        choices=DEBUGPLOT_CODES)
    parser.add_argument("--echo",
                        help="Display full command line",
                        action="store_true")

    args = parser.parse_args()

    if args.echo:
        print('\033[1m\033[31mExecuting: ' + ' '.join(sys.argv) + '\033[0m\n')

    # read input FITS file
    hdulist_image = fits.open(args.fitsfile.name)
    image_header = hdulist_image[0].header
    image2d = hdulist_image[0].data

    naxis1 = image_header['naxis1']
    naxis2 = image_header['naxis2']

    if image2d.shape != (naxis2, naxis1):
        raise ValueError("Unexpected error with NAXIS1, NAXIS2")

    if image2d.shape != (EMIR_NAXIS2, EMIR_NAXIS1):
        raise ValueError("NAXIS1, NAXIS2 unexpected for EMIR detector")

    # remove path from fitsfile
    if args.outfile is None:
        sfitsfile = os.path.basename(args.fitsfile.name)
    else:
        sfitsfile = os.path.basename(args.outfile.name)

    # check that the FITS file has been obtained with EMIR
    instrument = image_header['instrume']
    if instrument != 'EMIR':
        raise ValueError("INSTRUME keyword is not 'EMIR'!")

    # read GRISM, FILTER and ROTANG from FITS header
    grism = image_header['grism']
    spfilter = image_header['filter']
    rotang = image_header['rotang']

    # read fitted_bound_param JSON file
    fittedpar_dict = json.loads(open(args.fitted_bound_param.name).read())
    params = bound_params_from_dict(fittedpar_dict)
    if abs(args.debugplot) in [21, 22]:
        params.pretty_print()

    parmodel = fittedpar_dict['meta_info']['parmodel']
    if parmodel != 'multislit':
        raise ValueError("Unexpected parameter model: ", parmodel)

    # define slitlet range
    islitlet_min = fittedpar_dict['tags']['islitlet_min']
    islitlet_max = fittedpar_dict['tags']['islitlet_max']
    list_islitlet = list_slitlets_from_string(
        s=args.slitlets,
        islitlet_min=islitlet_min,
        islitlet_max=islitlet_max
    )

    # read CsuConfiguration object from FITS file
    csu_config = CsuConfiguration.define_from_fits(args.fitsfile)

    # define csu_bar_slit_center associated to each slitlet
    list_csu_bar_slit_center = []
    for islitlet in list_islitlet:
        list_csu_bar_slit_center.append(
            csu_config.csu_bar_slit_center(islitlet))

    # initialize output data array
    image2d_output = np.zeros((naxis2, naxis1))

    # main loop
    for islitlet, csu_bar_slit_center in \
            zip(list_islitlet, list_csu_bar_slit_center):
        image2d_tmp = select_unrectified_slitlet(
            image2d=image2d,
            islitlet=islitlet,
            csu_bar_slit_center=csu_bar_slit_center,
            params=params,
            parmodel=parmodel,
            maskonly=args.maskonly
        )
        image2d_output += image2d_tmp

    # update the array of the output file
    hdulist_image[0].data = image2d_output

    # save output FITS file
    hdulist_image.writeto(args.outfile)

    # close original image
    hdulist_image.close()

    # display full image
    if abs(args.debugplot) % 10 != 0:
        ax = ximshow(image2d=image2d_output,
                     title=sfitsfile + "\n" + args.slitlets,
                     image_bbox=(1, naxis1, 1, naxis2), show=False)

        # overplot boundaries
        overplot_boundaries_from_params(
            ax=ax,
            params=params,
            parmodel=parmodel,
            list_islitlet=list_islitlet,
            list_csu_bar_slit_center=list_csu_bar_slit_center
        )

        # overplot frontiers
        overplot_frontiers_from_params(
            ax=ax,
            params=params,
            parmodel=parmodel,
            list_islitlet=list_islitlet,
            list_csu_bar_slit_center=list_csu_bar_slit_center,
            micolors=('b', 'b'), linetype='-',
            labels=False    # already displayed with the boundaries
        )

        # show plot
        pause_debugplot(12, pltshow=True)
Example #22
0
    def run(self, rinput):
        self.logger.info('starting generation of flatlowfreq')

        self.logger.info('rectwv_coeff..........................: {}'.format(
            rinput.rectwv_coeff))
        self.logger.info('master_rectwv.........................: {}'.format(
            rinput.master_rectwv))
        self.logger.info('Minimum slitlet width (mm)............: {}'.format(
            rinput.minimum_slitlet_width_mm))
        self.logger.info('Maximum slitlet width (mm)............: {}'.format(
            rinput.maximum_slitlet_width_mm))
        self.logger.info('Global offset X direction (pixels)....: {}'.format(
            rinput.global_integer_offset_x_pix))
        self.logger.info('Global offset Y direction (pixels)....: {}'.format(
            rinput.global_integer_offset_y_pix))
        self.logger.info('nwindow_x_median......................: {}'.format(
            rinput.nwindow_x_median))
        self.logger.info('nwindow_y_median......................: {}'.format(
            rinput.nwindow_y_median))
        self.logger.info('Minimum fraction......................: {}'.format(
            rinput.minimum_fraction))
        self.logger.info('Minimum value in output...............: {}'.format(
            rinput.minimum_value_in_output))
        self.logger.info('Maximum value in output...............: {}'.format(
            rinput.maximum_value_in_output))

        # check rectification and wavelength calibration information
        if rinput.master_rectwv is None and rinput.rectwv_coeff is None:
            raise ValueError('No master_rectwv nor rectwv_coeff data have '
                             'been provided')
        elif rinput.master_rectwv is not None and \
                rinput.rectwv_coeff is not None:
            self.logger.warning('rectwv_coeff will be used instead of '
                                'master_rectwv')
        if rinput.rectwv_coeff is not None and \
                (rinput.global_integer_offset_x_pix != 0 or
                 rinput.global_integer_offset_y_pix != 0):
            raise ValueError('global_integer_offsets cannot be used '
                             'simultaneously with rectwv_coeff')

        # check headers to detect lamp status (on/off)
        list_lampincd = []
        for fname in rinput.obresult.frames:
            with fname.open() as f:
                list_lampincd.append(f[0].header['lampincd'])

        # check number of images
        nimages = len(rinput.obresult.frames)
        n_on = list_lampincd.count(1)
        n_off = list_lampincd.count(0)
        self.logger.info(
            'Number of images with lamp ON.........: {}'.format(n_on))
        self.logger.info(
            'Number of images with lamp OFF........: {}'.format(n_off))
        self.logger.info(
            'Total number of images................: {}'.format(nimages))
        if n_on == 0:
            raise ValueError('Insufficient number of images with lamp ON')
        if n_on + n_off != nimages:
            raise ValueError('Number of images does not match!')

        # check combination method
        if rinput.method_kwargs == {}:
            method_kwargs = None
        else:
            if rinput.method == 'sigmaclip':
                method_kwargs = rinput.method_kwargs
            else:
                raise ValueError('Unexpected method_kwargs={}'.format(
                    rinput.method_kwargs))

        # build object to proceed with bpm, bias, and dark (not flat)
        flow = self.init_filters(rinput)

        # available combination methods
        method = getattr(combine, rinput.method)

        # basic reduction of images with lamp ON or OFF
        lampmode = {0: 'off', 1: 'on'}
        reduced_image_on = None
        reduced_image_off = None
        for imode in lampmode.keys():
            self.logger.info('starting basic reduction of images with'
                             ' lamp {}'.format(lampmode[imode]))
            tmplist = [
                rinput.obresult.frames[i]
                for i, lampincd in enumerate(list_lampincd)
                if lampincd == imode
            ]
            if len(tmplist) > 0:
                with contextlib.ExitStack() as stack:
                    hduls = [
                        stack.enter_context(fname.open()) for fname in tmplist
                    ]
                    reduced_image = combine_imgs(hduls,
                                                 method=method,
                                                 method_kwargs=method_kwargs,
                                                 errors=False,
                                                 prolog=None)
                if imode == 0:
                    reduced_image_off = flow(reduced_image)
                    hdr = reduced_image_off[0].header
                    self.set_base_headers(hdr)
                    self.save_intermediate_img(reduced_image_off,
                                               'reduced_image_off.fits')
                elif imode == 1:
                    reduced_image_on = flow(reduced_image)
                    hdr = reduced_image_on[0].header
                    self.set_base_headers(hdr)
                    self.save_intermediate_img(reduced_image_on,
                                               'reduced_image_on.fits')
                else:
                    raise ValueError('Unexpected imode={}'.format(imode))

        # computation of ON-OFF
        header_on = reduced_image_on[0].header
        data_on = reduced_image_on[0].data.astype('float32')
        if n_off > 0:
            header_off = reduced_image_off[0].header
            data_off = reduced_image_off[0].data.astype('float32')
        else:
            header_off = None
            data_off = np.zeros_like(data_on)
        reduced_data = data_on - data_off

        # update reduced image header
        reduced_image = self.create_reduced_image(rinput, reduced_data,
                                                  header_on, header_off,
                                                  list_lampincd)

        # save intermediate image in work directory
        self.save_intermediate_img(reduced_image, 'reduced_image.fits')

        # define rectification and wavelength calibration coefficients
        if rinput.rectwv_coeff is None:
            rectwv_coeff = rectwv_coeff_from_mos_library(
                reduced_image, rinput.master_rectwv)
            # set global offsets
            rectwv_coeff.global_integer_offset_x_pix = \
                rinput.global_integer_offset_x_pix
            rectwv_coeff.global_integer_offset_y_pix = \
                rinput.global_integer_offset_y_pix
        else:
            rectwv_coeff = rinput.rectwv_coeff
        # save as JSON in work directory
        self.save_structured_as_json(rectwv_coeff, 'rectwv_coeff.json')
        # ds9 region files (to be saved in the work directory)
        if self.intermediate_results:
            save_four_ds9(rectwv_coeff)
            save_spectral_lines_ds9(rectwv_coeff)

        # apply global offsets (to both, the original and the cleaned version)
        image2d = apply_integer_offsets(
            image2d=reduced_data,
            offx=rectwv_coeff.global_integer_offset_x_pix,
            offy=rectwv_coeff.global_integer_offset_y_pix)

        # load CSU configuration
        csu_conf = CsuConfiguration.define_from_header(reduced_image[0].header)
        # determine (pseudo) longslits
        dict_longslits = csu_conf.pseudo_longslits()

        # valid slitlet numbers
        list_valid_islitlets = list(range(1, EMIR_NBARS + 1))
        for idel in rectwv_coeff.missing_slitlets:
            self.logger.info('-> Removing slitlet (not defined): ' + str(idel))
            list_valid_islitlets.remove(idel)
        # filter out slitlets with widths outside valid range
        list_outside_valid_width = []
        for islitlet in list_valid_islitlets:
            slitwidth = csu_conf.csu_bar_slit_width(islitlet)
            if (slitwidth < rinput.minimum_slitlet_width_mm) or \
                    (slitwidth > rinput.maximum_slitlet_width_mm):
                list_outside_valid_width.append(islitlet)
                self.logger.info('-> Removing slitlet (width out of range): ' +
                                 str(islitlet))
        if len(list_outside_valid_width) > 0:
            for idel in list_outside_valid_width:
                list_valid_islitlets.remove(idel)

        # initialize rectified image
        image2d_flatfielded = np.zeros((EMIR_NAXIS2, EMIR_NAXIS1))

        # main loop
        grism_name = rectwv_coeff.tags['grism']
        filter_name = rectwv_coeff.tags['filter']
        cout = '0'
        debugplot = rinput.debugplot
        for islitlet in list(range(1, EMIR_NBARS + 1)):
            if islitlet in list_valid_islitlets:
                # define Slitlet2D object
                slt = Slitlet2D(islitlet=islitlet,
                                rectwv_coeff=rectwv_coeff,
                                debugplot=debugplot)
                if abs(slt.debugplot) > 10:
                    print(slt)

                # extract (distorted) slitlet from the initial image
                slitlet2d = slt.extract_slitlet2d(image_2k2k=image2d,
                                                  subtitle='original image')

                # rectify slitlet
                slitlet2d_rect = slt.rectify(
                    slitlet2d=slitlet2d,
                    resampling=2,
                    subtitle='original (cleaned) rectified')
                naxis2_slitlet2d, naxis1_slitlet2d = slitlet2d_rect.shape

                if naxis1_slitlet2d != EMIR_NAXIS1:
                    print('naxis1_slitlet2d: ', naxis1_slitlet2d)
                    print('EMIR_NAXIS1.....: ', EMIR_NAXIS1)
                    raise ValueError("Unexpected naxis1_slitlet2d")

                # get useful slitlet region (use boundaries)
                spectrail = slt.list_spectrails[0]
                yy0 = slt.corr_yrect_a + \
                      slt.corr_yrect_b * spectrail(slt.x0_reference)
                ii1 = int(yy0 + 0.5) - slt.bb_ns1_orig
                spectrail = slt.list_spectrails[2]
                yy0 = slt.corr_yrect_a + \
                      slt.corr_yrect_b * spectrail(slt.x0_reference)
                ii2 = int(yy0 + 0.5) - slt.bb_ns1_orig

                # median spectrum
                sp_collapsed = np.median(slitlet2d_rect[ii1:(ii2 + 1), :],
                                         axis=0)

                # smooth median spectrum along the spectral direction
                sp_median = ndimage.median_filter(sp_collapsed,
                                                  5,
                                                  mode='nearest')

                ymax_spmedian = sp_median.max()
                y_threshold = ymax_spmedian * rinput.minimum_fraction
                lremove = np.where(sp_median < y_threshold)
                sp_median[lremove] = 0.0

                if abs(slt.debugplot) % 10 != 0:
                    xaxis1 = np.arange(1, naxis1_slitlet2d + 1)
                    title = 'Slitlet#' + str(islitlet) + ' (median spectrum)'
                    ax = ximplotxy(xaxis1,
                                   sp_collapsed,
                                   title=title,
                                   show=False,
                                   **{'label': 'collapsed spectrum'})
                    ax.plot(xaxis1, sp_median, label='fitted spectrum')
                    ax.plot([1, naxis1_slitlet2d],
                            2 * [y_threshold],
                            label='threshold')
                    # ax.plot(xknots, yknots, 'o', label='knots')
                    ax.legend()
                    ax.set_ylim(-0.05 * ymax_spmedian, 1.05 * ymax_spmedian)
                    pause_debugplot(slt.debugplot,
                                    pltshow=True,
                                    tight_layout=True)

                # generate rectified slitlet region filled with the
                # median spectrum
                slitlet2d_rect_spmedian = np.tile(sp_median,
                                                  (naxis2_slitlet2d, 1))
                if abs(slt.debugplot) % 10 != 0:
                    slt.ximshow_rectified(
                        slitlet2d_rect=slitlet2d_rect_spmedian,
                        subtitle='rectified, filled with median spectrum')

                # compute smooth surface
                # clipped region
                slitlet2d_rect_clipped = slitlet2d_rect_spmedian.copy()
                slitlet2d_rect_clipped[:(ii1 - 1), :] = 0.0
                slitlet2d_rect_clipped[(ii2 + 2):, :] = 0.0
                # unrectified clipped image
                slitlet2d_unrect_clipped = slt.rectify(
                    slitlet2d=slitlet2d_rect_clipped,
                    resampling=2,
                    inverse=True,
                    subtitle='unrectified, filled with median spectrum '
                    '(clipped)')
                # normalize initial slitlet image (avoid division by zero)
                slitlet2d_norm_clipped = np.zeros_like(slitlet2d)
                for j in range(naxis1_slitlet2d):
                    for i in range(naxis2_slitlet2d):
                        den = slitlet2d_unrect_clipped[i, j]
                        if den == 0:
                            slitlet2d_norm_clipped[i, j] = 1.0
                        else:
                            slitlet2d_norm_clipped[i, j] = \
                                slitlet2d[i, j] / den
                # set to 1.0 one additional pixel at each side (since
                # 'den' above is small at the borders and generates wrong
                # bright pixels)
                slitlet2d_norm_clipped = fix_pix_borders(
                    image2d=slitlet2d_norm_clipped,
                    nreplace=1,
                    sought_value=1.0,
                    replacement_value=1.0)
                slitlet2d_norm_clipped = slitlet2d_norm_clipped.transpose()
                slitlet2d_norm_clipped = fix_pix_borders(
                    image2d=slitlet2d_norm_clipped,
                    nreplace=1,
                    sought_value=1.0,
                    replacement_value=1.0)
                slitlet2d_norm_clipped = slitlet2d_norm_clipped.transpose()
                slitlet2d_norm_smooth = ndimage.median_filter(
                    slitlet2d_norm_clipped,
                    size=(rinput.nwindow_y_median, rinput.nwindow_x_median),
                    mode='nearest')

                if abs(slt.debugplot) % 10 != 0:
                    slt.ximshow_unrectified(
                        slitlet2d=slitlet2d_norm_clipped,
                        subtitle='unrectified, pixel-to-pixel (clipped)')
                    slt.ximshow_unrectified(
                        slitlet2d=slitlet2d_norm_smooth,
                        subtitle='unrectified, pixel-to-pixel (smoothed)')

                # ---

                # check for (pseudo) longslit with previous and next slitlet
                imin = dict_longslits[islitlet].imin()
                imax = dict_longslits[islitlet].imax()
                if islitlet > 1:
                    same_slitlet_below = (islitlet - 1) >= imin
                else:
                    same_slitlet_below = False
                if islitlet < EMIR_NBARS:
                    same_slitlet_above = (islitlet + 1) <= imax
                else:
                    same_slitlet_above = False

                for j in range(EMIR_NAXIS1):
                    xchannel = j + 1
                    y0_lower = slt.list_frontiers[0](xchannel)
                    y0_upper = slt.list_frontiers[1](xchannel)
                    n1, n2 = nscan_minmax_frontiers(y0_frontier_lower=y0_lower,
                                                    y0_frontier_upper=y0_upper,
                                                    resize=True)
                    # note that n1 and n2 are scans (ranging from 1 to NAXIS2)
                    nn1 = n1 - slt.bb_ns1_orig + 1
                    nn2 = n2 - slt.bb_ns1_orig + 1
                    image2d_flatfielded[(n1 - 1):n2, j] = \
                        slitlet2d_norm_smooth[(nn1 - 1):nn2, j]

                    # force to 1.0 region around frontiers
                    if not same_slitlet_below:
                        image2d_flatfielded[(n1 - 1):(n1 + 2), j] = 1
                    if not same_slitlet_above:
                        image2d_flatfielded[(n2 - 5):n2, j] = 1
                cout += '.'
            else:
                cout += 'i'

            if islitlet % 10 == 0:
                if cout != 'i':
                    cout = str(islitlet // 10)

            self.logger.info(cout)

        # restore global offsets
        image2d_flatfielded = apply_integer_offsets(
            image2d=image2d_flatfielded,
            offx=-rectwv_coeff.global_integer_offset_x_pix,
            offy=-rectwv_coeff.global_integer_offset_y_pix)

        # set pixels below minimum value to 1.0
        filtered = np.where(
            image2d_flatfielded < rinput.minimum_value_in_output)
        image2d_flatfielded[filtered] = 1.0

        # set pixels above maximum value to 1.0
        filtered = np.where(
            image2d_flatfielded > rinput.maximum_value_in_output)
        image2d_flatfielded[filtered] = 1.0

        # update image header
        reduced_flatlowfreq = self.create_reduced_image(
            rinput, image2d_flatfielded, header_on, header_off, list_lampincd)

        # ds9 region files (to be saved in the work directory)
        if self.intermediate_results:
            save_four_ds9(rectwv_coeff)
            save_spectral_lines_ds9(rectwv_coeff)

        # save results in results directory
        self.logger.info('end of flatlowfreq generation')
        result = self.create_result(reduced_flatlowfreq=reduced_flatlowfreq)
        return result
def rectwv_coeff_from_arc_image(reduced_image,
                                bound_param,
                                lines_catalog,
                                args_nbrightlines=None,
                                args_ymargin_bb=2,
                                args_remove_sp_background=True,
                                args_times_sigma_threshold=10,
                                args_order_fmap=2,
                                args_sigma_gaussian_filtering=2,
                                args_margin_npix=50,
                                args_poldeg_initial=3,
                                args_poldeg_refined=5,
                                args_interactive=False,
                                args_threshold_wv=0,
                                args_ylogscale=False,
                                args_pdf=None,
                                args_geometry=(0, 0, 640, 480),
                                debugplot=0):
    """Evaluate rect.+wavecal. coefficients from arc image

    Parameters
    ----------
    reduced_image : HDUList object
        Image with preliminary basic reduction: bpm, bias, dark and
        flatfield.
    bound_param : RefinedBoundaryModelParam instance
        Refined boundary model.
    lines_catalog : Numpy array
        2D numpy array with the contents of the master file with the
        expected arc line wavelengths.
    args_nbrightlines : int
        TBD
    args_ymargin_bb : int
        TBD
    args_remove_sp_background : bool
        TBD
    args_times_sigma_threshold : float
        TBD
    args_order_fmap : int
        TBD
    args_sigma_gaussian_filtering : float
        TBD
    args_margin_npix : int
        TBD
    args_poldeg_initial : int
        TBD
    args_poldeg_refined : int
        TBD
    args_interactive : bool
        TBD
    args_threshold_wv : float
        TBD
    args_ylogscale : bool
        TBD
    args_pdf : TBD
    args_geometry : TBD
    debugplot : int
            Debugging level for messages and plots. For details see
            'numina.array.display.pause_debugplot.py'.

    Returns
    -------
    rectwv_coeff : RectWaveCoeff instance
        Rectification and wavelength calibration coefficients for the
        particular CSU configuration of the input arc image.
    reduced_55sp : HDUList object
        Image with 55 spectra corresponding to the median spectrum for
        each slitlet, employed to derived the wavelength calibration
        polynomial.

    """

    logger = logging.getLogger(__name__)

    # protections
    if args_interactive and args_pdf is not None:
        logger.error('--interactive and --pdf are incompatible options')
        raise ValueError('--interactive and --pdf are incompatible options')

    # header and data array
    header = reduced_image[0].header
    image2d = reduced_image[0].data

    # check grism and filter
    filter_name = header['filter']
    logger.info('Filter: ' + filter_name)
    if filter_name != bound_param.tags['filter']:
        raise ValueError('Filter name does not match!')
    grism_name = header['grism']
    logger.info('Grism: ' + grism_name)
    if grism_name != bound_param.tags['grism']:
        raise ValueError('Grism name does not match!')

    # read the CSU configuration from the image header
    csu_conf = CsuConfiguration.define_from_header(header)
    logger.debug(csu_conf)

    # read the DTU configuration from the image header
    dtu_conf = DtuConfiguration.define_from_header(header)
    logger.debug(dtu_conf)

    # set boundary parameters
    parmodel = bound_param.meta_info['parmodel']
    params = bound_params_from_dict(bound_param.__getstate__())
    if abs(debugplot) >= 10:
        print('-' * 83)
        print('* FITTED BOUND PARAMETERS')
        params.pretty_print()
        pause_debugplot(debugplot)

    # determine parameters according to grism+filter combination
    wv_parameters = set_wv_parameters(filter_name, grism_name)
    islitlet_min = wv_parameters['islitlet_min']
    islitlet_max = wv_parameters['islitlet_max']
    if args_nbrightlines is None:
        nbrightlines = wv_parameters['nbrightlines']
    else:
        nbrightlines = [int(idum) for idum in args_nbrightlines.split(',')]
    poly_crval1_linear = wv_parameters['poly_crval1_linear']
    poly_cdelt1_linear = wv_parameters['poly_cdelt1_linear']
    wvmin_expected = wv_parameters['wvmin_expected']
    wvmax_expected = wv_parameters['wvmax_expected']
    wvmin_useful = wv_parameters['wvmin_useful']
    wvmax_useful = wv_parameters['wvmax_useful']

    # list of slitlets to be computed
    logger.info('list_slitlets: [' + str(islitlet_min) + ',... ' +
                str(islitlet_max) + ']')

    # read master arc line wavelengths (only brightest lines)
    wv_master = read_wv_master_from_array(master_table=lines_catalog,
                                          lines='brightest',
                                          debugplot=debugplot)

    # read master arc line wavelengths (whole data set)
    wv_master_all = read_wv_master_from_array(master_table=lines_catalog,
                                              lines='all',
                                              debugplot=debugplot)

    # check that the arc lines in the master file are properly sorted
    # in ascending order
    for i in range(len(wv_master_all) - 1):
        if wv_master_all[i] >= wv_master_all[i + 1]:
            logger.error('>>> wavelengths: ' + str(wv_master_all[i]) + '  ' +
                         str(wv_master_all[i + 1]))
            raise ValueError('Arc lines are not sorted in master file')

    # ---

    image2d_55sp = np.zeros((EMIR_NBARS, EMIR_NAXIS1))

    # compute rectification transformation and wavelength calibration
    # polynomials

    measured_slitlets = []

    cout = '0'
    for islitlet in range(1, EMIR_NBARS + 1):

        if islitlet_min <= islitlet <= islitlet_max:

            # define Slitlet2dArc object
            slt = Slitlet2dArc(islitlet=islitlet,
                               csu_conf=csu_conf,
                               ymargin_bb=args_ymargin_bb,
                               params=params,
                               parmodel=parmodel,
                               debugplot=debugplot)

            # extract 2D image corresponding to the selected slitlet, clipping
            # the image beyond the unrectified slitlet (in order to isolate
            # the arc lines of the current slitlet; otherwise there are
            # problems with arc lines from neighbour slitlets)
            image2d_tmp = select_unrectified_slitlet(
                image2d=image2d,
                islitlet=islitlet,
                csu_bar_slit_center=csu_conf.csu_bar_slit_center(islitlet),
                params=params,
                parmodel=parmodel,
                maskonly=False)
            slitlet2d = slt.extract_slitlet2d(image2d_tmp)

            # subtract smooth background computed as follows:
            # - median collapsed spectrum of the whole slitlet2d
            # - independent median filtering of the previous spectrum in the
            #   two halves in the spectral direction
            if args_remove_sp_background:
                spmedian = np.median(slitlet2d, axis=0)
                naxis1_tmp = spmedian.shape[0]
                jmidpoint = naxis1_tmp // 2
                sp1 = medfilt(spmedian[:jmidpoint], [201])
                sp2 = medfilt(spmedian[jmidpoint:], [201])
                spbackground = np.concatenate((sp1, sp2))
                slitlet2d -= spbackground

            # locate unknown arc lines
            slt.locate_unknown_arc_lines(
                slitlet2d=slitlet2d,
                times_sigma_threshold=args_times_sigma_threshold)

            # continue working with current slitlet only if arc lines have
            # been detected
            if slt.list_arc_lines is not None:

                # compute intersections between spectrum trails and arc lines
                slt.xy_spectrail_arc_intersections(slitlet2d=slitlet2d)

                # compute rectification transformation
                slt.estimate_tt_to_rectify(order=args_order_fmap,
                                           slitlet2d=slitlet2d)

                # rectify image
                slitlet2d_rect = slt.rectify(slitlet2d,
                                             resampling=2,
                                             transformation=1)

                # median spectrum and line peaks from rectified image
                sp_median, fxpeaks = slt.median_spectrum_from_rectified_image(
                    slitlet2d_rect,
                    sigma_gaussian_filtering=args_sigma_gaussian_filtering,
                    nwinwidth_initial=5,
                    nwinwidth_refined=5,
                    times_sigma_threshold=5,
                    npix_avoid_border=6,
                    nbrightlines=nbrightlines)

                image2d_55sp[islitlet - 1, :] = sp_median

                # determine expected wavelength limits prior to the wavelength
                # calibration
                csu_bar_slit_center = csu_conf.csu_bar_slit_center(islitlet)
                crval1_linear = poly_crval1_linear(csu_bar_slit_center)
                cdelt1_linear = poly_cdelt1_linear(csu_bar_slit_center)
                expected_wvmin = crval1_linear - \
                                 args_margin_npix * cdelt1_linear
                naxis1_linear = sp_median.shape[0]
                crvaln_linear = crval1_linear + \
                                (naxis1_linear - 1) * cdelt1_linear
                expected_wvmax = crvaln_linear + \
                                 args_margin_npix * cdelt1_linear
                # override previous estimates when necessary
                if wvmin_expected is not None:
                    expected_wvmin = wvmin_expected
                if wvmax_expected is not None:
                    expected_wvmax = wvmax_expected

                # clip initial master arc line list with bright lines to
                # the expected wavelength range
                lok1 = expected_wvmin <= wv_master
                lok2 = wv_master <= expected_wvmax
                lok = lok1 * lok2
                wv_master_eff = wv_master[lok]

                # perform initial wavelength calibration
                solution_wv = wvcal_spectrum(
                    sp=sp_median,
                    fxpeaks=fxpeaks,
                    poly_degree_wfit=args_poldeg_initial,
                    wv_master=wv_master_eff,
                    wv_ini_search=expected_wvmin,
                    wv_end_search=expected_wvmax,
                    wvmin_useful=wvmin_useful,
                    wvmax_useful=wvmax_useful,
                    geometry=args_geometry,
                    debugplot=slt.debugplot)
                # store initial wavelength calibration polynomial in current
                # slitlet instance
                slt.wpoly = np.polynomial.Polynomial(solution_wv.coeff)
                pause_debugplot(debugplot)

                # clip initial master arc line list with all the lines to
                # the expected wavelength range
                lok1 = expected_wvmin <= wv_master_all
                lok2 = wv_master_all <= expected_wvmax
                lok = lok1 * lok2
                wv_master_all_eff = wv_master_all[lok]

                # clip master arc line list to useful region
                if wvmin_useful is not None:
                    lok = wvmin_useful <= wv_master_all_eff
                    wv_master_all_eff = wv_master_all_eff[lok]
                if wvmax_useful is not None:
                    lok = wv_master_all_eff <= wvmax_useful
                    wv_master_all_eff = wv_master_all_eff[lok]

                # refine wavelength calibration
                if args_poldeg_refined > 0:
                    plottitle = '[slitlet#{}, refined]'.format(islitlet)
                    poly_refined, yres_summary = refine_arccalibration(
                        sp=sp_median,
                        poly_initial=slt.wpoly,
                        wv_master=wv_master_all_eff,
                        poldeg=args_poldeg_refined,
                        ntimes_match_wv=1,
                        interactive=args_interactive,
                        threshold=args_threshold_wv,
                        plottitle=plottitle,
                        ylogscale=args_ylogscale,
                        geometry=args_geometry,
                        pdf=args_pdf,
                        debugplot=slt.debugplot)
                    # store refined wavelength calibration polynomial in
                    # current slitlet instance
                    slt.wpoly = poly_refined

                # compute approximate linear values for CRVAL1 and CDELT1
                naxis1_linear = sp_median.shape[0]
                crmin1_linear = slt.wpoly(1)
                crmax1_linear = slt.wpoly(naxis1_linear)
                slt.crval1_linear = crmin1_linear
                slt.cdelt1_linear = \
                    (crmax1_linear - crmin1_linear) / (naxis1_linear - 1)

                # check that the trimming of wv_master and wv_master_all has
                # preserved the wavelength range [crmin1_linear, crmax1_linear]
                if crmin1_linear < expected_wvmin:
                    logger.warning(">>> islitlet: " + str(islitlet))
                    logger.warning("expected_wvmin: " + str(expected_wvmin))
                    logger.warning("crmin1_linear.: " + str(crmin1_linear))
                    logger.warning("WARNING: Unexpected crmin1_linear < "
                                   "expected_wvmin")
                if crmax1_linear > expected_wvmax:
                    logger.warning(">>> islitlet: " + str(islitlet))
                    logger.warning("expected_wvmax: " + str(expected_wvmax))
                    logger.warning("crmax1_linear.: " + str(crmax1_linear))
                    logger.warning("WARNING: Unexpected crmax1_linear > "
                                   "expected_wvmax")

                cout += '.'

            else:

                cout += 'x'

            if islitlet % 10 == 0:
                if cout != 'x':
                    cout = str(islitlet // 10)

            if debugplot != 0:
                pause_debugplot(debugplot)

        else:

            # define Slitlet2dArc object
            slt = Slitlet2dArc(islitlet=islitlet,
                               csu_conf=csu_conf,
                               ymargin_bb=args_ymargin_bb,
                               params=None,
                               parmodel=None,
                               debugplot=debugplot)

            cout += 'i'

        # store current slitlet in list of measured slitlets
        measured_slitlets.append(slt)

        logger.info(cout)

    # ---

    # generate FITS file structure with 55 spectra corresponding to the
    # median spectrum for each slitlet
    reduced_55sp = fits.PrimaryHDU(data=image2d_55sp)
    reduced_55sp.header['crpix1'] = (0.0, 'reference pixel')
    reduced_55sp.header['crval1'] = (0.0, 'central value at crpix2')
    reduced_55sp.header['cdelt1'] = (1.0, 'increment')
    reduced_55sp.header['ctype1'] = 'PIXEL'
    reduced_55sp.header['cunit1'] = ('Pixel', 'units along axis2')
    reduced_55sp.header['crpix2'] = (0.0, 'reference pixel')
    reduced_55sp.header['crval2'] = (0.0, 'central value at crpix2')
    reduced_55sp.header['cdelt2'] = (1.0, 'increment')
    reduced_55sp.header['ctype2'] = 'PIXEL'
    reduced_55sp.header['cunit2'] = ('Pixel', 'units along axis2')

    # ---

    # Generate structure to store intermediate results
    outdict = {}
    outdict['instrument'] = 'EMIR'
    outdict['meta_info'] = {}
    outdict['meta_info']['creation_date'] = datetime.now().isoformat()
    outdict['meta_info']['description'] = \
        'computation of rectification and wavelength calibration polynomial ' \
        'coefficients for a particular CSU configuration'
    outdict['meta_info']['recipe_name'] = 'undefined'
    outdict['meta_info']['origin'] = {}
    outdict['meta_info']['origin']['bound_param_uuid'] = \
        bound_param.uuid
    outdict['meta_info']['origin']['arc_image_uuid'] = 'undefined'
    outdict['tags'] = {}
    outdict['tags']['grism'] = grism_name
    outdict['tags']['filter'] = filter_name
    outdict['tags']['islitlet_min'] = islitlet_min
    outdict['tags']['islitlet_max'] = islitlet_max
    outdict['dtu_configuration'] = dtu_conf.outdict()
    outdict['uuid'] = str(uuid4())
    outdict['contents'] = {}

    missing_slitlets = []
    for slt in measured_slitlets:

        islitlet = slt.islitlet

        if islitlet_min <= islitlet <= islitlet_max:

            # avoid error when creating a python list of coefficients from
            # numpy polynomials when the polynomials do not exist (note that
            # the JSON format doesn't handle numpy arrays and such arrays must
            # be transformed into native python lists)
            if slt.wpoly is None:
                wpoly_coeff = None
            else:
                wpoly_coeff = slt.wpoly.coef.tolist()
            if slt.wpoly_longslit_model is None:
                wpoly_coeff_longslit_model = None
            else:
                wpoly_coeff_longslit_model = \
                    slt.wpoly_longslit_model.coef.tolist()

            # avoid similar error when creating a python list of coefficients
            # when the numpy array does not exist; note that this problem
            # does not happen with tt?_aij_longslit_model and
            # tt?_bij_longslit_model because the latter have already been
            # created as native python lists
            if slt.ttd_aij is None:
                ttd_aij = None
            else:
                ttd_aij = slt.ttd_aij.tolist()
            if slt.ttd_bij is None:
                ttd_bij = None
            else:
                ttd_bij = slt.ttd_bij.tolist()
            if slt.tti_aij is None:
                tti_aij = None
            else:
                tti_aij = slt.tti_aij.tolist()
            if slt.tti_bij is None:
                tti_bij = None
            else:
                tti_bij = slt.tti_bij.tolist()

            # creating temporary dictionary with the information corresponding
            # to the current slitlett that will be saved in the JSON file
            tmp_dict = {
                'csu_bar_left': slt.csu_bar_left,
                'csu_bar_right': slt.csu_bar_right,
                'csu_bar_slit_center': slt.csu_bar_slit_center,
                'csu_bar_slit_width': slt.csu_bar_slit_width,
                'x0_reference': slt.x0_reference,
                'y0_reference_lower': slt.y0_reference_lower,
                'y0_reference_middle': slt.y0_reference_middle,
                'y0_reference_upper': slt.y0_reference_upper,
                'y0_reference_lower_expected': slt.y0_reference_lower_expected,
                'y0_reference_middle_expected':
                slt.y0_reference_middle_expected,
                'y0_reference_upper_expected': slt.y0_reference_upper_expected,
                'y0_frontier_lower': slt.y0_frontier_lower,
                'y0_frontier_upper': slt.y0_frontier_upper,
                'y0_frontier_lower_expected': slt.y0_frontier_lower_expected,
                'y0_frontier_upper_expected': slt.y0_frontier_upper_expected,
                'corr_yrect_a': slt.corr_yrect_a,
                'corr_yrect_b': slt.corr_yrect_b,
                'min_row_rectified': slt.min_row_rectified,
                'max_row_rectified': slt.max_row_rectified,
                'ymargin_bb': slt.ymargin_bb,
                'bb_nc1_orig': slt.bb_nc1_orig,
                'bb_nc2_orig': slt.bb_nc2_orig,
                'bb_ns1_orig': slt.bb_ns1_orig,
                'bb_ns2_orig': slt.bb_ns2_orig,
                'spectrail': {
                    'poly_coef_lower':
                    slt.list_spectrails[
                        slt.i_lower_spectrail].poly_funct.coef.tolist(),
                    'poly_coef_middle':
                    slt.list_spectrails[
                        slt.i_middle_spectrail].poly_funct.coef.tolist(),
                    'poly_coef_upper':
                    slt.list_spectrails[
                        slt.i_upper_spectrail].poly_funct.coef.tolist(),
                },
                'frontier': {
                    'poly_coef_lower':
                    slt.list_frontiers[0].poly_funct.coef.tolist(),
                    'poly_coef_upper':
                    slt.list_frontiers[1].poly_funct.coef.tolist(),
                },
                'ttd_order': slt.ttd_order,
                'ttd_aij': ttd_aij,
                'ttd_bij': ttd_bij,
                'tti_aij': tti_aij,
                'tti_bij': tti_bij,
                'ttd_order_longslit_model': slt.ttd_order_longslit_model,
                'ttd_aij_longslit_model': slt.ttd_aij_longslit_model,
                'ttd_bij_longslit_model': slt.ttd_bij_longslit_model,
                'tti_aij_longslit_model': slt.tti_aij_longslit_model,
                'tti_bij_longslit_model': slt.tti_bij_longslit_model,
                'wpoly_coeff': wpoly_coeff,
                'wpoly_coeff_longslit_model': wpoly_coeff_longslit_model,
                'crval1_linear': slt.crval1_linear,
                'cdelt1_linear': slt.cdelt1_linear
            }
        else:
            missing_slitlets.append(islitlet)
            tmp_dict = {
                'csu_bar_left': slt.csu_bar_left,
                'csu_bar_right': slt.csu_bar_right,
                'csu_bar_slit_center': slt.csu_bar_slit_center,
                'csu_bar_slit_width': slt.csu_bar_slit_width,
                'x0_reference': slt.x0_reference,
                'y0_frontier_lower_expected': slt.y0_frontier_lower_expected,
                'y0_frontier_upper_expected': slt.y0_frontier_upper_expected
            }
        slitlet_label = "slitlet" + str(islitlet).zfill(2)
        outdict['contents'][slitlet_label] = tmp_dict

    # ---

    # OBSOLETE
    '''
    # save JSON file needed to compute the MOS model
    with open(args.out_json.name, 'w') as fstream:
        json.dump(outdict, fstream, indent=2, sort_keys=True)
        print('>>> Saving file ' + args.out_json.name)
    '''

    # ---

    # Create object of type RectWaveCoeff with coefficients for
    # rectification and wavelength calibration
    rectwv_coeff = RectWaveCoeff(instrument='EMIR')
    rectwv_coeff.quality_control = numina.types.qc.QC.GOOD
    rectwv_coeff.tags['grism'] = grism_name
    rectwv_coeff.tags['filter'] = filter_name
    rectwv_coeff.meta_info['origin']['bound_param'] = \
        'uuid' + bound_param.uuid
    rectwv_coeff.meta_info['dtu_configuration'] = outdict['dtu_configuration']
    rectwv_coeff.total_slitlets = EMIR_NBARS
    rectwv_coeff.missing_slitlets = missing_slitlets
    for i in range(EMIR_NBARS):
        islitlet = i + 1
        dumdict = {'islitlet': islitlet}
        cslitlet = 'slitlet' + str(islitlet).zfill(2)
        if cslitlet in outdict['contents']:
            dumdict.update(outdict['contents'][cslitlet])
        else:
            raise ValueError("Unexpected error")
        rectwv_coeff.contents.append(dumdict)
    # debugging __getstate__ and __setstate__
    # rectwv_coeff.writeto(args.out_json.name)
    # print('>>> Saving file ' + args.out_json.name)
    # check_setstate_getstate(rectwv_coeff, args.out_json.name)
    logger.info('Generating RectWaveCoeff object with uuid=' +
                rectwv_coeff.uuid)

    return rectwv_coeff, reduced_55sp
Example #24
0
def rectwv_coeff_add_longslit_model(rectwv_coeff, geometry, debugplot=0):
    """Compute longslit_model coefficients for RectWaveCoeff object.

    Parameters
    ----------
    rectwv_coeff : RectWaveCoeff instance
        Rectification and wavelength calibration coefficients for a
        particular CSU configuration corresponding to a longslit
        observation.
    geometry : TBD
    debugplot : int
        Debugging level for messages and plots. For details see
        'numina.array.display.pause_debugplot.py'.

    Returns
    -------
    rectwv_coeff : RectWaveCoeff instance
        Updated object with longslit_model coefficients computed.

    """

    logger = logging.getLogger(__name__)

    # check grism and filter
    grism_name = rectwv_coeff.tags['grism']
    logger.info('Grism: ' + grism_name)
    filter_name = rectwv_coeff.tags['filter']
    logger.info('Filter: ' + filter_name)

    # list of slitlets to be computed
    list_valid_islitlets = list(range(1, EMIR_NBARS + 1))
    for idel in rectwv_coeff.missing_slitlets:
        list_valid_islitlets.remove(idel)
    if abs(debugplot) >= 10:
        print('>>> valid slitlet numbers:\n', list_valid_islitlets)

    # ---

    # check that the CSU configuration corresponds to longslit
    csu_bar_slit_center_list = []
    for islitlet in list_valid_islitlets:
        csu_bar_slit_center_list.append(
            rectwv_coeff.contents[islitlet - 1]['csu_bar_slit_center'])
    if abs(debugplot) >= 10:
        logger.debug('Checking csu_bar_slit_center values:')
        summary(np.array(csu_bar_slit_center_list), debug=True)
        pause_debugplot(debugplot)

    # ---

    # polynomial coefficients corresponding to the wavelength calibration

    # step 0: determine poldeg_refined, checking that it is the same for
    # all the slitlets
    poldeg_refined_list = []
    for islitlet in list_valid_islitlets:
        poldeg_refined_list.append(
            len(rectwv_coeff.contents[islitlet - 1]['wpoly_coeff']) - 1)
    # remove duplicates
    poldeg_refined_list = list(set(poldeg_refined_list))
    if len(poldeg_refined_list) != 1:
        raise ValueError('Unexpected different poldeg_refined found: ' +
                         str(poldeg_refined_list))
    poldeg_refined = poldeg_refined_list[0]

    # step 1: compute variation of each coefficient as a function of
    # y0_reference_middle of each slitlet
    list_poly = []
    for i in range(poldeg_refined + 1):
        xp = []
        yp = []
        for islitlet in list_valid_islitlets:
            tmp_dict = rectwv_coeff.contents[islitlet - 1]
            wpoly_coeff = tmp_dict['wpoly_coeff']
            if wpoly_coeff is not None:
                xp.append(tmp_dict['y0_reference_middle'])
                yp.append(wpoly_coeff[i])
        poly, yres, reject = polfit_residuals_with_sigma_rejection(
            x=np.array(xp),
            y=np.array(yp),
            deg=2,
            times_sigma_reject=5,
            xlabel='y0_rectified',
            ylabel='coeff[' + str(i) + ']',
            title="Fit to refined wavelength calibration coefficients",
            geometry=geometry,
            debugplot=debugplot)
        list_poly.append(poly)

    # step 2: use the variation of each polynomial coefficient with
    # y0_reference_middle to infer the expected wavelength calibration
    # polynomial for each rectifified slitlet
    for islitlet in list_valid_islitlets:
        tmp_dict = rectwv_coeff.contents[islitlet - 1]
        y0_reference_middle = tmp_dict['y0_reference_middle']
        list_new_coeff = []
        for i in range(poldeg_refined + 1):
            new_coeff = list_poly[i](y0_reference_middle)
            list_new_coeff.append(new_coeff)
        tmp_dict['wpoly_coeff_longslit_model'] = list_new_coeff

    # ---

    # rectification transformation coefficients aij and bij

    # step 0: determine order_fmap, checking that it is the same for
    # all the slitlets
    order_fmap_list = []
    for islitlet in list_valid_islitlets:
        order_fmap_list.append(rectwv_coeff.contents[islitlet -
                                                     1]['ttd_order'])
    # remove duplicates
    order_fmap_list = list(set(order_fmap_list))
    if len(order_fmap_list) != 1:
        raise ValueError('Unexpected different order_fmap found')
    order_fmap = order_fmap_list[0]

    # step 1: compute variation of each coefficient as a function of
    # y0_reference_middle of each slitlet
    list_poly_ttd_aij = []
    list_poly_ttd_bij = []
    list_poly_tti_aij = []
    list_poly_tti_bij = []
    ncoef_ttd = ncoef_fmap(order_fmap)
    for i in range(ncoef_ttd):
        xp = []
        yp_ttd_aij = []
        yp_ttd_bij = []
        yp_tti_aij = []
        yp_tti_bij = []
        for islitlet in list_valid_islitlets:
            tmp_dict = rectwv_coeff.contents[islitlet - 1]
            ttd_aij = tmp_dict['ttd_aij']
            ttd_bij = tmp_dict['ttd_bij']
            tti_aij = tmp_dict['tti_aij']
            tti_bij = tmp_dict['tti_bij']
            if ttd_aij is not None:
                xp.append(tmp_dict['y0_reference_middle'])
                yp_ttd_aij.append(ttd_aij[i])
                yp_ttd_bij.append(ttd_bij[i])
                yp_tti_aij.append(tti_aij[i])
                yp_tti_bij.append(tti_bij[i])
        poly, yres, reject = polfit_residuals_with_sigma_rejection(
            x=np.array(xp),
            y=np.array(yp_ttd_aij),
            deg=5,
            times_sigma_reject=5,
            xlabel='y0_rectified',
            ylabel='ttd_aij[' + str(i) + ']',
            geometry=geometry,
            debugplot=debugplot)
        list_poly_ttd_aij.append(poly)
        poly, yres, reject = polfit_residuals_with_sigma_rejection(
            x=np.array(xp),
            y=np.array(yp_ttd_bij),
            deg=5,
            times_sigma_reject=5,
            xlabel='y0_rectified',
            ylabel='ttd_bij[' + str(i) + ']',
            geometry=geometry,
            debugplot=debugplot)
        list_poly_ttd_bij.append(poly)
        poly, yres, reject = polfit_residuals_with_sigma_rejection(
            x=np.array(xp),
            y=np.array(yp_tti_aij),
            deg=5,
            times_sigma_reject=5,
            xlabel='y0_rectified',
            ylabel='tti_aij[' + str(i) + ']',
            geometry=geometry,
            debugplot=debugplot)
        list_poly_tti_aij.append(poly)
        poly, yres, reject = polfit_residuals_with_sigma_rejection(
            x=np.array(xp),
            y=np.array(yp_tti_bij),
            deg=5,
            times_sigma_reject=5,
            xlabel='y0_rectified',
            ylabel='tti_bij[' + str(i) + ']',
            geometry=geometry,
            debugplot=debugplot)
        list_poly_tti_bij.append(poly)

    # step 2: use the variation of each coefficient with y0_reference_middle
    # to infer the expected rectification transformation for each slitlet
    for islitlet in list_valid_islitlets:
        tmp_dict = rectwv_coeff.contents[islitlet - 1]
        y0_reference_middle = tmp_dict['y0_reference_middle']
        tmp_dict['ttd_order_longslit_model'] = order_fmap
        ttd_aij_longslit_model = []
        ttd_bij_longslit_model = []
        tti_aij_longslit_model = []
        tti_bij_longslit_model = []
        for i in range(ncoef_ttd):
            new_coeff = list_poly_ttd_aij[i](y0_reference_middle)
            ttd_aij_longslit_model.append(new_coeff)
            new_coeff = list_poly_ttd_bij[i](y0_reference_middle)
            ttd_bij_longslit_model.append(new_coeff)
            new_coeff = list_poly_tti_aij[i](y0_reference_middle)
            tti_aij_longslit_model.append(new_coeff)
            new_coeff = list_poly_tti_bij[i](y0_reference_middle)
            tti_bij_longslit_model.append(new_coeff)
        tmp_dict['ttd_aij_longslit_model'] = ttd_aij_longslit_model
        tmp_dict['ttd_bij_longslit_model'] = ttd_bij_longslit_model
        tmp_dict['tti_aij_longslit_model'] = tti_aij_longslit_model
        tmp_dict['tti_bij_longslit_model'] = tti_bij_longslit_model

    # ---

    # update uuid and meta_info in output JSON structure
    rectwv_coeff.uuid = str(uuid4())
    rectwv_coeff.meta_info['creation_date'] = datetime.now().isoformat()

    # return updated object
    return rectwv_coeff
Example #25
0
def main(args=None):
    # parse command-line options
    parser = argparse.ArgumentParser(
        description='description: compute pixel-to-pixel flatfield'
    )

    # required arguments
    parser.add_argument("fitsfile",
                        help="Input FITS file (flat ON-OFF)",
                        type=argparse.FileType('rb'))
    parser.add_argument("--rectwv_coeff", required=True,
                        help="Input JSON file with rectification and "
                             "wavelength calibration coefficients",
                        type=argparse.FileType('rt'))
    parser.add_argument("--minimum_fraction", required=True,
                        help="Minimum allowed flatfielding value",
                        type=float, default=0.01)
    parser.add_argument("--minimum_value_in_output",
                        help="Minimum value allowed in output file: pixels "
                             "below this value are set to 1.0 (default=0.01)",
                        type=float, default=0.01)
    parser.add_argument("--nwindow_median", required=True,
                        help="Window size to smooth median spectrum in the "
                             "spectral direction",
                        type=int)
    parser.add_argument("--outfile", required=True,
                        help="Output FITS file",
                        type=lambda x: arg_file_is_new(parser, x, mode='wb'))

    # optional arguments
    parser.add_argument("--delta_global_integer_offset_x_pix",
                        help="Delta global integer offset in the X direction "
                             "(default=0)",
                        default=0, type=int)
    parser.add_argument("--delta_global_integer_offset_y_pix",
                        help="Delta global integer offset in the Y direction "
                             "(default=0)",
                        default=0, type=int)
    parser.add_argument("--resampling",
                        help="Resampling method: 1 -> nearest neighbor, "
                             "2 -> linear interpolation (default)",
                        default=2, type=int,
                        choices=(1, 2))
    parser.add_argument("--ignore_DTUconf",
                        help="Ignore DTU configurations differences between "
                             "model and input image",
                        action="store_true")
    parser.add_argument("--debugplot",
                        help="Integer indicating plotting & debugging options"
                             " (default=0)",
                        default=0, type=int,
                        choices=DEBUGPLOT_CODES)
    parser.add_argument("--echo",
                        help="Display full command line",
                        action="store_true")
    args = parser.parse_args(args)

    if args.echo:
        print('\033[1m\033[31m% ' + ' '.join(sys.argv) + '\033[0m\n')

    # read calibration structure from JSON file
    rectwv_coeff = RectWaveCoeff._datatype_load(args.rectwv_coeff.name)

    # modify (when requested) global offsets
    rectwv_coeff.global_integer_offset_x_pix += \
        args.delta_global_integer_offset_x_pix
    rectwv_coeff.global_integer_offset_y_pix += \
        args.delta_global_integer_offset_y_pix

    # read FITS image and its corresponding header
    hdulist = fits.open(args.fitsfile)
    header = hdulist[0].header
    image2d = hdulist[0].data
    hdulist.close()

    # apply global offsets
    image2d = apply_integer_offsets(
        image2d=image2d,
        offx=rectwv_coeff.global_integer_offset_x_pix,
        offy=rectwv_coeff.global_integer_offset_y_pix
    )

    # protections
    naxis2, naxis1 = image2d.shape
    if naxis1 != header['naxis1'] or naxis2 != header['naxis2']:
        print('>>> NAXIS1:', naxis1)
        print('>>> NAXIS2:', naxis2)
        raise ValueError('Something is wrong with NAXIS1 and/or NAXIS2')
    if abs(args.debugplot) >= 10:
        print('>>> NAXIS1:', naxis1)
        print('>>> NAXIS2:', naxis2)

    # check that the input FITS file grism and filter match
    filter_name = header['filter']
    if filter_name != rectwv_coeff.tags['filter']:
        raise ValueError("Filter name does not match!")
    grism_name = header['grism']
    if grism_name != rectwv_coeff.tags['grism']:
        raise ValueError("Filter name does not match!")
    if abs(args.debugplot) >= 10:
        print('>>> grism.......:', grism_name)
        print('>>> filter......:', filter_name)

    # check that the DTU configurations are compatible
    dtu_conf_fitsfile = DtuConfiguration.define_from_fits(args.fitsfile)
    dtu_conf_jsonfile = DtuConfiguration.define_from_dictionary(
        rectwv_coeff.meta_info['dtu_configuration'])
    if dtu_conf_fitsfile != dtu_conf_jsonfile:
        print('DTU configuration (FITS file):\n\t', dtu_conf_fitsfile)
        print('DTU configuration (JSON file):\n\t', dtu_conf_jsonfile)
        if args.ignore_DTUconf:
            print('WARNING: DTU configuration differences found!')
        else:
            raise ValueError('DTU configurations do not match')
    else:
        if abs(args.debugplot) >= 10:
            print('>>> DTU Configuration match!')
            print(dtu_conf_fitsfile)

    # valid slitlet numbers
    list_valid_islitlets = list(range(1, EMIR_NBARS + 1))
    for idel in rectwv_coeff.missing_slitlets:
        list_valid_islitlets.remove(idel)
    if abs(args.debugplot) >= 10:
        print('>>> valid slitlet numbers:\n', list_valid_islitlets)

    # ---

    # initialize rectified image
    image2d_flatfielded = np.zeros((EMIR_NAXIS2, EMIR_NAXIS1))

    # main loop
    for islitlet in list_valid_islitlets:
        if args.debugplot == 0:
            islitlet_progress(islitlet, EMIR_NBARS)

        # define Slitlet2D object
        slt = Slitlet2D(islitlet=islitlet,
                        rectwv_coeff=rectwv_coeff,
                        debugplot=args.debugplot)

        if abs(args.debugplot) >= 10:
            print(slt)

        # extract (distorted) slitlet from the initial image
        slitlet2d = slt.extract_slitlet2d(image2d)

        # rectify slitlet
        slitlet2d_rect = slt.rectify(
            slitlet2d,
            resampling=args.resampling
        )
        naxis2_slitlet2d, naxis1_slitlet2d = slitlet2d_rect.shape

        if naxis1_slitlet2d != EMIR_NAXIS1:
            print('naxis1_slitlet2d: ', naxis1_slitlet2d)
            print('EMIR_NAXIS1.....: ', EMIR_NAXIS1)
            raise ValueError("Unexpected naxis1_slitlet2d")

        # get useful slitlet region (use boundaires instead of frontiers;
        # note that the nscan_minmax_frontiers() works well independently
        # of using frontiers of boundaries as arguments)
        nscan_min, nscan_max = nscan_minmax_frontiers(
            slt.y0_reference_lower,
            slt.y0_reference_upper,
            resize=False
        )
        ii1 = nscan_min - slt.bb_ns1_orig
        ii2 = nscan_max - slt.bb_ns1_orig + 1

        # median spectrum
        sp_collapsed = np.median(slitlet2d_rect[ii1:(ii2 + 1), :], axis=0)

        # smooth median spectrum along the spectral direction
        sp_median = ndimage.median_filter(sp_collapsed, args.nwindow_median,
                                          mode='nearest')
        ymax_spmedian = sp_median.max()
        y_threshold = ymax_spmedian * args.minimum_fraction
        sp_median[np.where(sp_median < y_threshold)] = 0.0

        if abs(args.debugplot) > 10:
            title = 'Slitlet#' + str(islitlet) + '(median spectrum)'
            xdum = np.arange(1, naxis1_slitlet2d + 1)
            ax = ximplotxy(xdum, sp_collapsed,
                           title=title,
                           show=False, **{'label' : 'collapsed spectrum'})
            ax.plot(xdum, sp_median, label='filtered spectrum')
            ax.plot([1, naxis1_slitlet2d], 2*[y_threshold],
                    label='threshold')
            ax.legend()
            ax.set_ylim(-0.05*ymax_spmedian, 1.05*ymax_spmedian)
            pause_debugplot(args.debugplot,
                            pltshow=True, tight_layout=True)

        # generate rectified slitlet region filled with the median spectrum
        slitlet2d_rect_spmedian = np.tile(sp_median, (naxis2_slitlet2d, 1))
        if abs(args.debugplot) > 10:
            slt.ximshow_rectified(slitlet2d_rect_spmedian)

        # unrectified image
        slitlet2d_unrect_spmedian = slt.rectify(
            slitlet2d_rect_spmedian,
            resampling=args.resampling,
            inverse=True
        )

        # normalize initial slitlet image (avoid division by zero)
        slitlet2d_norm = np.zeros_like(slitlet2d)
        for j in range(naxis1_slitlet2d):
            for i in range(naxis2_slitlet2d):
                den = slitlet2d_unrect_spmedian[i, j]
                if den == 0:
                    slitlet2d_norm[i, j] = 1.0
                else:
                    slitlet2d_norm[i, j] = slitlet2d[i, j] / den

        if abs(args.debugplot) > 10:
            slt.ximshow_unrectified(slitlet2d_norm)

        for j in range(EMIR_NAXIS1):
            xchannel = j + 1
            y0_lower = slt.list_frontiers[0](xchannel)
            y0_upper = slt.list_frontiers[1](xchannel)
            n1, n2 = nscan_minmax_frontiers(y0_frontier_lower=y0_lower,
                                            y0_frontier_upper=y0_upper,
                                            resize=True)
            # note that n1 and n2 are scans (ranging from 1 to NAXIS2)
            nn1 = n1 - slt.bb_ns1_orig + 1
            nn2 = n2 - slt.bb_ns1_orig + 1
            image2d_flatfielded[(n1 - 1):n2, j] = \
                slitlet2d_norm[(nn1 - 1):nn2, j]

            # force to 1.0 region around frontiers
            image2d_flatfielded[(n1 - 1):(n1 + 2), j] = 1
            image2d_flatfielded[(n2 - 5):n2, j] = 1
    if args.debugplot == 0:
        print('OK!')

    # set pixels below minimum value to 1.0
    filtered = np.where(image2d_flatfielded < args.minimum_value_in_output)
    image2d_flatfielded[filtered] = 1.0

    # restore global offsets
    image2d_flatfielded = apply_integer_offsets(
        image2d=image2d_flatfielded ,
        offx=-rectwv_coeff.global_integer_offset_x_pix,
        offy=-rectwv_coeff.global_integer_offset_y_pix
    )

    # save output file
    save_ndarray_to_fits(
        array=image2d_flatfielded,
        file_name=args.outfile,
        main_header=header,
        overwrite=True
    )
    print('>>> Saving file ' + args.outfile.name)
def main(args=None):

    # parse command-line options
    parser = argparse.ArgumentParser(
        description='description: overplot boundary model over FITS image')

    # positional arguments
    parser.add_argument("fitsfile",
                        help="FITS file name to be displayed",
                        type=argparse.FileType('rb'))
    parser.add_argument("--rect_wpoly_MOSlibrary",
                        required=True,
                        help="Input JSON file with library of rectification "
                        "and wavelength calibration coefficients",
                        type=argparse.FileType('rt'))

    # optional arguments
    parser.add_argument("--global_integer_offset_x_pix",
                        help="Global integer offset in the X direction "
                        "(default=0)",
                        default=0,
                        type=int)
    parser.add_argument("--global_integer_offset_y_pix",
                        help="Global integer offset in the Y direction "
                        "(default=0)",
                        default=0,
                        type=int)
    parser.add_argument("--arc_lines",
                        help="Overplot arc lines",
                        action="store_true")
    parser.add_argument("--oh_lines",
                        help="Overplot OH lines",
                        action="store_true")
    parser.add_argument("--ds9_frontiers",
                        help="Output ds9 region file with slitlet frontiers",
                        type=lambda x: arg_file_is_new(parser, x))
    parser.add_argument("--ds9_boundaries",
                        help="Output ds9 region file with slitlet boundaries",
                        type=lambda x: arg_file_is_new(parser, x))
    parser.add_argument("--ds9_lines",
                        help="Output ds9 region file with arc/oh lines",
                        type=lambda x: arg_file_is_new(parser, x))
    parser.add_argument("--debugplot",
                        help="Integer indicating plotting/debugging" +
                        " (default=12)",
                        type=int,
                        default=12,
                        choices=DEBUGPLOT_CODES)
    parser.add_argument("--echo",
                        help="Display full command line",
                        action="store_true")

    args = parser.parse_args()

    if args.echo:
        print('\033[1m\033[31mExecuting: ' + ' '.join(sys.argv) + '\033[0m\n')

    # ---

    # avoid incompatible options
    if args.arc_lines and args.oh_lines:
        raise ValueError("--arc_lines and --oh_lines cannot be used "
                         "simultaneously")

    # --ds9_lines requires --arc_lines or --oh_lines
    if args.ds9_lines:
        if not (args.arc_lines or args.oh_lines):
            raise ValueError("--ds9_lines requires the use of either "
                             "--arc_lines or --oh_lines")

    # read input FITS file
    hdulist = fits.open(args.fitsfile)
    image_header = hdulist[0].header
    image2d = hdulist[0].data
    hdulist.close()

    naxis1 = image_header['naxis1']
    naxis2 = image_header['naxis2']

    if image2d.shape != (naxis2, naxis1):
        raise ValueError("Unexpected error with NAXIS1, NAXIS2")
    if image2d.shape != (EMIR_NAXIS2, EMIR_NAXIS1):
        raise ValueError("Unexpected values for NAXIS1, NAXIS2")

    # remove path from fitsfile
    sfitsfile = os.path.basename(args.fitsfile.name)

    # check that the FITS file has been obtained with EMIR
    instrument = image_header['instrume']
    if instrument != 'EMIR':
        raise ValueError("INSTRUME keyword is not 'EMIR'!")

    # read GRISM, FILTER and ROTANG from FITS header
    grism = image_header['grism']
    spfilter = image_header['filter']
    rotang = image_header['rotang']

    # ---

    # generate MasterRectWave object
    master_rectwv = MasterRectWave._datatype_load(
        args.rect_wpoly_MOSlibrary.name)

    # check that grism and filter are the expected ones
    grism_ = master_rectwv.tags['grism']
    if grism_ != grism:
        raise ValueError('Unexpected grism: ' + str(grism_))
    spfilter_ = master_rectwv.tags['filter']
    if spfilter_ != spfilter:
        raise ValueError('Unexpected filter ' + str(spfilter_))

    # valid slitlet numbers
    list_valid_islitlets = list(range(1, EMIR_NBARS + 1))
    for idel in master_rectwv.missing_slitlets:
        list_valid_islitlets.remove(idel)

    # read CsuConfiguration object from FITS file
    csu_config = CsuConfiguration.define_from_fits(args.fitsfile)

    # list with csu_bar_slit_center for valid slitlets
    list_csu_bar_slit_center = []
    for islitlet in list_valid_islitlets:
        list_csu_bar_slit_center.append(
            csu_config.csu_bar_slit_center(islitlet))

    # define parmodel and params
    fitted_bound_param_json = {
        'contents': master_rectwv.meta_info['refined_boundary_model']
    }
    parmodel = fitted_bound_param_json['contents']['parmodel']
    fitted_bound_param_json.update({'meta_info': {'parmodel': parmodel}})
    params = bound_params_from_dict(fitted_bound_param_json)
    if parmodel != "multislit":
        raise ValueError('parmodel = "multislit" not found')

    # ---

    # define lines to be overplotted
    if args.arc_lines or args.oh_lines:

        rectwv_coeff = rectwv_coeff_from_mos_library(hdulist, master_rectwv)
        rectwv_coeff.global_integer_offset_x_pix = \
            args.global_integer_offset_x_pix
        rectwv_coeff.global_integer_offset_y_pix = \
            args.global_integer_offset_y_pix
        # rectwv_coeff.writeto('xxx.json')

        if args.arc_lines:
            if grism == 'LR':
                catlines_file = 'lines_argon_neon_xenon_empirical_LR.dat'
            else:
                catlines_file = 'lines_argon_neon_xenon_empirical.dat'
            dumdata = pkgutil.get_data('emirdrp.instrument.configs',
                                       catlines_file)
            arc_lines_tmpfile = StringIO(dumdata.decode('utf8'))
            catlines = np.genfromtxt(arc_lines_tmpfile)
            # define wavelength and flux as separate arrays
            catlines_all_wave = catlines[:, 0]
            catlines_all_flux = catlines[:, 1]
        elif args.oh_lines:
            dumdata = pkgutil.get_data('emirdrp.instrument.configs',
                                       'Oliva_etal_2013.dat')
            oh_lines_tmpfile = StringIO(dumdata.decode('utf8'))
            catlines = np.genfromtxt(oh_lines_tmpfile)
            # define wavelength and flux as separate arrays
            catlines_all_wave = np.concatenate((catlines[:, 1], catlines[:,
                                                                         0]))
            catlines_all_flux = np.concatenate((catlines[:, 2], catlines[:,
                                                                         2]))
        else:
            raise ValueError("This should not happen!")

    else:
        rectwv_coeff = None
        catlines_all_wave = None
        catlines_all_flux = None

    # ---

    # generate output ds9 region file with slitlet boundaries
    if args.ds9_boundaries is not None:
        save_boundaries_from_params_ds9(
            params=params,
            parmodel=parmodel,
            list_islitlet=list_valid_islitlets,
            list_csu_bar_slit_center=list_csu_bar_slit_center,
            uuid=master_rectwv.uuid,
            grism=grism,
            spfilter=spfilter,
            ds9_filename=args.ds9_boundaries.name,
            global_offset_x_pix=-args.global_integer_offset_x_pix,
            global_offset_y_pix=-args.global_integer_offset_y_pix)

    # generate output ds9 region file with slitlet frontiers
    if args.ds9_frontiers is not None:
        save_frontiers_from_params_ds9(
            params=params,
            parmodel=parmodel,
            list_islitlet=list_valid_islitlets,
            list_csu_bar_slit_center=list_csu_bar_slit_center,
            uuid=master_rectwv.uuid,
            grism=grism,
            spfilter=spfilter,
            ds9_filename=args.ds9_frontiers.name,
            global_offset_x_pix=-args.global_integer_offset_x_pix,
            global_offset_y_pix=-args.global_integer_offset_y_pix)

    # ---

    # display full image
    if abs(args.debugplot) % 10 != 0:
        ax = ximshow(image2d=image2d,
                     title=sfitsfile + "\ngrism=" + grism + ", filter=" +
                     spfilter + ", rotang=" + str(round(rotang, 2)),
                     image_bbox=(1, naxis1, 1, naxis2),
                     show=False)

        # overplot boundaries
        overplot_boundaries_from_params(
            ax=ax,
            params=params,
            parmodel=parmodel,
            list_islitlet=list_valid_islitlets,
            list_csu_bar_slit_center=list_csu_bar_slit_center,
            global_offset_x_pix=-args.global_integer_offset_x_pix,
            global_offset_y_pix=-args.global_integer_offset_y_pix)

        # overplot frontiers
        overplot_frontiers_from_params(
            ax=ax,
            params=params,
            parmodel=parmodel,
            list_islitlet=list_valid_islitlets,
            list_csu_bar_slit_center=list_csu_bar_slit_center,
            micolors=('b', 'b'),
            linetype='-',
            labels=False,  # already displayed with the boundaries
            global_offset_x_pix=-args.global_integer_offset_x_pix,
            global_offset_y_pix=-args.global_integer_offset_y_pix)

    else:
        ax = None

    # overplot lines
    if catlines_all_wave is not None:

        if args.ds9_lines is None:
            ds9_file = None
        else:
            ds9_file = open(args.ds9_lines.name, 'w')
            ds9_file.write('# Region file format: DS9 version 4.1\n')
            ds9_file.write('global color=#00ffff dashlist=0 0 width=2 '
                           'font="helvetica 10 normal roman" select=1 '
                           'highlite=1 dash=0 fixed=0 edit=1 '
                           'move=1 delete=1 include=1 source=1\n')
            ds9_file.write('physical\n#\n')

            ds9_file.write('#\n# uuid..: {0}\n'.format(master_rectwv.uuid))
            ds9_file.write('# filter: {0}\n'.format(spfilter))
            ds9_file.write('# grism.: {0}\n'.format(grism))
            ds9_file.write('#\n# global_offset_x_pix: {0}\n'.format(
                args.global_integer_offset_x_pix))
            ds9_file.write('# global_offset_y_pix: {0}\n#\n'.format(
                args.global_integer_offset_y_pix))
            if parmodel == "longslit":
                for dumpar in EXPECTED_PARAMETER_LIST:
                    parvalue = params[dumpar].value
                    ds9_file.write('# {0}: {1}\n'.format(dumpar, parvalue))
            else:
                for dumpar in EXPECTED_PARAMETER_LIST_EXTENDED:
                    parvalue = params[dumpar].value
                    ds9_file.write('# {0}: {1}\n'.format(dumpar, parvalue))

        overplot_lines(ax, catlines_all_wave, list_valid_islitlets,
                       rectwv_coeff, args.global_integer_offset_x_pix,
                       args.global_integer_offset_y_pix, ds9_file,
                       args.debugplot)

        if ds9_file is not None:
            ds9_file.close()

    if ax is not None:
        # show plot
        pause_debugplot(12, pltshow=True)
Example #27
0
def main(args=None):
    # parse command-line options
    parser = argparse.ArgumentParser(
        description="description: overplot traces"
    )
    # positional parameters
    parser.add_argument("fits_file",
                        help="FITS image containing the spectra",
                        type=argparse.FileType('r'))
    parser.add_argument("traces_file",
                        help="JSON file with fiber traces",
                        type=argparse.FileType('r'))
    # optional parameters
    parser.add_argument("--rawimage",
                        help="FITS file is a RAW image (otherwise trimmed "
                             "image is assumed)",
                        action="store_true")
    parser.add_argument("--global_offset",
                        help="Global offset polynomial coefficients "
                             "(+upwards, -downwards)")
    parser.add_argument("--fibids",
                        help="Display fiber identification number",
                        action="store_true")
    parser.add_argument("--verbose",
                        help="Enhance verbosity",
                        action="store_true")
    parser.add_argument("--healing",
                        help="JSON healing file to improve traces",
                        type=argparse.FileType('r'))
    parser.add_argument("--updated_traces",
                        help="JSON file with modified fiber traces",
                        type=argparse.FileType('w'))
    parser.add_argument("--z1z2",
                        help="tuple z1,z2, minmax or None (use zscale)")
    parser.add_argument("--bbox",
                        help="bounding box tuple: nc1,nc2,ns1,ns2")
    parser.add_argument("--keystitle",
                        help="tuple of FITS keywords.format: " +
                             "key1,key2,...keyn.'format'")
    parser.add_argument("--geometry",
                        help="tuple x,y,dx,dy",
                        default="0,0,640,480")
    parser.add_argument("--pdffile",
                        help="ouput PDF file name",
                        type=argparse.FileType('w'))
    parser.add_argument("--echo",
                        help="Display full command line",
                        action="store_true")

    args = parser.parse_args(args=args)

    if args.echo:
        print('\033[1m\033[31m% ' + ' '.join(sys.argv) + '\033[0m\n')

    # global_offset in command line
    if args.global_offset is None:
        args_global_offset = [0.0]
    else:
        args_global_offset = [float(dum) for dum in
                              str(args.global_offset).split(",")]

    # read pdffile
    if args.pdffile is not None:
        from matplotlib.backends.backend_pdf import PdfPages
        pdf = PdfPages(args.pdffile.name)
    else:
        pdf = None

    ax = ximshow_file(args.fits_file.name,
                      args_cbar_orientation='vertical',
                      args_z1z2=args.z1z2,
                      args_bbox=args.bbox,
                      args_keystitle=args.keystitle,
                      args_geometry=args.geometry,
                      pdf=pdf,
                      show=False)

    # trace offsets for RAW images
    if args.rawimage:
        ix_offset = 51
    else:
        ix_offset = 1

    # read and display traces from JSON file
    bigdict = json.loads(open(args.traces_file.name).read())

    # Load metadata from the traces
    meta_info = bigdict['meta_info']

    origin = meta_info['origin']
    insconf_uuid = origin['insconf_uuid']
    date_obs = origin['date_obs']

    tags = bigdict['tags']
    insmode = tags['insmode']

    # create instrument model
    pkg_paths = ['megaradrp.instrument.configs']
    store = asb.load_paths_store(pkg_paths)

    insmodel = asb.assembly_instrument(store, insconf_uuid, date_obs, by_key='uuid')

    pseudo_slit_config = insmodel.get_value('pseudoslit.boxes', **tags)

    fibid_with_box = assign_boxes_to_fibers(pseudo_slit_config, insmode)
    total_fibers = bigdict['total_fibers']
    if total_fibers != len(fibid_with_box):
        raise ValueError('Mismatch between number of fibers and '
                         'expected number from account from boxes')
    if 'global_offset' in bigdict.keys():
        global_offset = bigdict['global_offset']
        if args_global_offset != [0.0] and global_offset != [0.0]:
            raise ValueError('global_offset != 0 argument cannot be employed '
                             'when global_offset != 0 in JSON file')
        elif args_global_offset != [0.0]:
            global_offset = args_global_offset
    else:
        global_offset = args_global_offset
    print('>>> Using global_offset:', global_offset)
    pol_global_offset = np.polynomial.Polynomial(global_offset)
    if 'ref_column' in bigdict.keys():
        ref_column = bigdict['ref_column']
    else:
        ref_column = 2000
    for fiberdict in bigdict['contents']:
        fibid = fiberdict['fibid']
        fiblabel = fibid_with_box[fibid - 1]
        start = fiberdict['start']
        stop = fiberdict['stop']
        coeff = np.array(fiberdict['fitparms'])
        # skip fibers without trace
        if len(coeff) > 0:
            pol_trace = np.polynomial.Polynomial(coeff)
            y_at_ref_column = pol_trace(ref_column)
            correction = pol_global_offset(y_at_ref_column)
            coeff[0] += correction
            # update values in bigdict (JSON structure)
            bigdict['contents'][fibid-1]['fitparms'] = coeff.tolist()
            plot_trace(ax, coeff, start, stop, ix_offset, args.rawimage,
                       args.fibids, fiblabel, colour='blue')
        else:
            print('Warning ---> Missing fiber:', fibid_with_box[fibid - 1])

    # if present, read healing JSON file
    if args.healing is not None:
        healdict = json.loads(open(args.healing.name).read())
        list_operations = healdict['operations']
        for operation in list_operations:

            if operation['description'] == 'vertical_shift_in_pixels':
                if 'fibid_list' in operation.keys():
                    fibid_list = operation['fibid_list']
                else:
                    fibid_ini = operation['fibid_ini']
                    fibid_end = operation['fibid_end']
                    fibid_list = range(fibid_ini, fibid_end + 1)
                for fibid in fibid_list:
                    if fibid < 1 or fibid > total_fibers:
                        raise ValueError('fibid number outside valid range')
                    fiblabel = fibid_with_box[fibid - 1]
                    coeff = np.array(
                        bigdict['contents'][fibid - 1]['fitparms']
                    )
                    if len(coeff) > 0:
                        if args.verbose:
                            print('(vertical_shift_in_pixels) fibid:',
                                  fiblabel)
                        vshift = operation['vshift']
                        coeff[0] += vshift
                        bigdict['contents'][fibid - 1]['fitparms'] = \
                            coeff.tolist()
                        start = bigdict['contents'][fibid - 1]['start']
                        stop = bigdict['contents'][fibid - 1]['stop']
                        plot_trace(ax, coeff, start, stop, ix_offset,
                                   args.rawimage, True, fiblabel,
                                   colour='green')
                    else:
                        print('(vertical_shift_in_pixels SKIPPED) fibid:',
                              fiblabel)

            elif operation['description'] == 'duplicate_trace':
                fibid_original = operation['fibid_original']
                if fibid_original < 1 or fibid_original > total_fibers:
                    raise ValueError(
                        'fibid_original number outside valid range'
                    )
                fibid_duplicated = operation['fibid_duplicated']
                if fibid_duplicated < 1 or fibid_duplicated > total_fibers:
                    raise ValueError(
                        'fibid_duplicated number outside valid range'
                    )
                fiblabel_original = fibid_with_box[fibid_original - 1]
                fiblabel_duplicated = fibid_with_box[fibid_duplicated - 1]
                coeff = np.array(
                    bigdict['contents'][fibid_original - 1]['fitparms']
                )
                if len(coeff) > 0:
                    if args.verbose:
                        print('(duplicated_trace) fibids:',
                              fiblabel_original, '-->', fiblabel_duplicated)
                    vshift = operation['vshift']
                    coeff[0] += vshift
                    bigdict['contents'][fibid_duplicated - 1]['fitparms'] = \
                        coeff.tolist()
                    start = bigdict['contents'][fibid_original - 1]['start']
                    stop = bigdict['contents'][fibid_original - 1]['stop']
                    bigdict['contents'][fibid_duplicated - 1]['start'] = start
                    bigdict['contents'][fibid_duplicated - 1]['stop'] = stop
                    plot_trace(ax, coeff, start, stop, ix_offset,
                               args.rawimage, True, fiblabel_duplicated,
                               colour='green')
                else:
                    print('(duplicated_trace SKIPPED) fibids:',
                          fiblabel_original, '-->', fiblabel_duplicated)

            elif operation['description'] == 'extrapolation':
                if 'fibid_list' in operation.keys():
                    fibid_list = operation['fibid_list']
                else:
                    fibid_ini = operation['fibid_ini']
                    fibid_end = operation['fibid_end']
                    fibid_list = range(fibid_ini, fibid_end + 1)
                for fibid in fibid_list:
                    if fibid < 1 or fibid > total_fibers:
                        raise ValueError('fibid number outside valid range')
                    fiblabel = fibid_with_box[fibid - 1]
                    coeff = np.array(
                        bigdict['contents'][fibid - 1]['fitparms']
                    )
                    if len(coeff) > 0:
                        if args.verbose:
                            print('(extrapolation) fibid:', fiblabel)
                        # update values in bigdict (JSON structure)
                        start = operation['start']
                        stop = operation['stop']
                        start_orig = bigdict['contents'][fibid - 1]['start']
                        stop_orig = bigdict['contents'][fibid - 1]['stop']
                        bigdict['contents'][fibid - 1]['start'] = start
                        bigdict['contents'][fibid - 1]['stop'] = stop
                        if start < start_orig:
                            plot_trace(ax, coeff, start, start_orig,
                                       ix_offset,
                                       args.rawimage, True, fiblabel,
                                       colour='green')
                        if stop_orig < stop:
                            plot_trace(ax, coeff, stop_orig, stop,
                                       ix_offset,
                                       args.rawimage, True, fiblabel,
                                       colour='green')
                        if start_orig <= start <= stop <= stop_orig:
                            plot_trace(ax, coeff, start, stop,
                                       ix_offset,
                                       args.rawimage, True, fiblabel,
                                       colour='green')
                    else:
                        print('(extrapolation SKIPPED) fibid:', fiblabel)

            elif operation['description'] == 'fit_through_user_points':
                fibid = operation['fibid']
                fiblabel = fibid_with_box[fibid - 1]
                if args.verbose:
                    print('(fit through user points) fibid:', fiblabel)
                poldeg = operation['poldeg']
                start = operation['start']
                stop = operation['stop']
                xfit = []
                yfit = []
                for userpoint in operation['user_points']:
                    # assume x, y coordinates in JSON file are given in
                    # image coordinates, starting at (1,1) in the lower
                    # left corner
                    xdum = userpoint['x'] - 1  # use np.array coordinates
                    ydum = userpoint['y'] - 1  # use np.array coordinates
                    xfit.append(xdum)
                    yfit.append(ydum)
                xfit = np.array(xfit)
                yfit = np.array(yfit)
                if len(xfit) <= poldeg:
                    raise ValueError('Insufficient number of points to fit'
                                     ' polynomial')
                poly, residum = polfit_residuals(xfit, yfit, poldeg)
                coeff = poly.coef
                plot_trace(ax, coeff, start, stop, ix_offset,
                           args.rawimage, args.fibids, fiblabel,
                           colour='green')
                bigdict['contents'][fibid - 1]['start'] = start
                bigdict['contents'][fibid - 1]['stop'] = stop
                bigdict['contents'][fibid - 1]['fitparms'] = coeff.tolist()

            elif operation['description'] == \
                    'extrapolation_through_user_points':
                fibid = operation['fibid']
                fiblabel = fibid_with_box[fibid - 1]
                if args.verbose:
                    print('(extrapolation_through_user_points):', fiblabel)
                start_reuse = operation['start_reuse']
                stop_reuse = operation['stop_reuse']
                resampling = operation['resampling']
                poldeg = operation['poldeg']
                start = operation['start']
                stop = operation['stop']
                coeff = bigdict['contents'][fibid - 1]['fitparms']
                xfit = np.linspace(start_reuse, stop_reuse, num=resampling)
                poly = np.polynomial.Polynomial(coeff)
                yfit = poly(xfit)
                for userpoint in operation['user_points']:
                    # assume x, y coordinates in JSON file are given in
                    # image coordinates, starting at (1,1) in the lower
                    # left corner
                    xdum = userpoint['x'] - 1  # use np.array coordinates
                    ydum = userpoint['y'] - 1  # use np.array coordinates
                    xfit = np.concatenate((xfit, np.array([xdum])))
                    yfit = np.concatenate((yfit, np.array([ydum])))
                poly, residum = polfit_residuals(xfit, yfit, poldeg)
                coeff = poly.coef
                if start < start_reuse:
                    plot_trace(ax, coeff, start, start_reuse, ix_offset,
                               args.rawimage, args.fibids, fiblabel,
                               colour='green')
                if stop_reuse < stop:
                    plot_trace(ax, coeff, stop_reuse, stop, ix_offset,
                               args.rawimage, args.fibids, fiblabel,
                               colour='green')
                bigdict['contents'][fibid - 1]['start'] = start
                bigdict['contents'][fibid - 1]['stop'] = stop
                bigdict['contents'][fibid - 1]['fitparms'] = coeff.tolist()

            elif operation['description'] == 'sandwich':
                fibid = operation['fibid']
                fiblabel = fibid_with_box[fibid - 1]
                if args.verbose:
                    print('(sandwich) fibid:', fiblabel)
                fraction = operation['fraction']
                nf1, nf2 = operation['neighbours']
                start = operation['start']
                stop = operation['stop']
                tmpf1 = bigdict['contents'][nf1 - 1]
                tmpf2 = bigdict['contents'][nf2 - 1]
                if nf1 != tmpf1['fibid'] or nf2 != tmpf2['fibid']:
                    raise ValueError(
                        "Unexpected fiber numbers in neighbours"
                    )
                coefff1 = np.array(tmpf1['fitparms'])
                coefff2 = np.array(tmpf2['fitparms'])
                coeff = coefff1 + fraction * (coefff2 - coefff1)
                plot_trace(ax, coeff, start, stop, ix_offset,
                           args.rawimage, args.fibids,
                           fiblabel, colour='green')
                # update values in bigdict (JSON structure)
                bigdict['contents'][fibid - 1]['start'] = start
                bigdict['contents'][fibid - 1]['stop'] = stop
                bigdict['contents'][fibid - 1][
                    'fitparms'] = coeff.tolist()
                if fibid in bigdict['error_fitting']:
                    bigdict['error_fitting'].remove(fibid)

            elif operation['description'] == 'renumber_fibids_within_box':
                fibid_ini = operation['fibid_ini']
                fibid_end = operation['fibid_end']
                box_ini = fibid_with_box[fibid_ini - 1][4:]
                box_end = fibid_with_box[fibid_end - 1][4:]
                if box_ini != box_end:
                    print('ERROR: box_ini={}, box_end={}'.format(box_ini,
                                                                 box_end))
                    raise ValueError('fibid_ini and fibid_end correspond to '
                                     'different fiber boxes')
                fibid_shift = operation['fibid_shift']
                if fibid_shift in [-1, 1]:
                    if fibid_shift == -1:
                        i_start = fibid_ini
                        i_stop = fibid_end + 1
                        i_step = 1
                    else:
                        i_start = fibid_end
                        i_stop = fibid_ini - 1
                        i_step = -1
                    for fibid in range(i_start, i_stop, i_step):
                        fiblabel_ori = fibid_with_box[fibid - 1]
                        fiblabel_new = fibid_with_box[fibid - 1 + fibid_shift]
                        if args.verbose:
                            print('(renumber_fibids) fibid:',
                                  fiblabel_ori, '-->', fiblabel_new)
                        bigdict['contents'][fibid -1 + fibid_shift] = \
                            deepcopy(bigdict['contents'][fibid -1])
                        bigdict['contents'][fibid -1 + fibid_shift]['fibid'] += \
                            fibid_shift
                        # display updated trace
                        coeff = \
                            bigdict['contents'][fibid -1 + fibid_shift]['fitparms']
                        start = \
                            bigdict['contents'][fibid -1 + fibid_shift]['start']
                        stop = bigdict['contents'][fibid -1 + fibid_shift]['stop']
                        plot_trace(ax, coeff, start, stop, ix_offset,
                                   args.rawimage, args.fibids,
                                   fiblabel_ori + '-->' + fiblabel_new,
                                   colour='green')
                    if fibid_shift == -1:
                        bigdict['contents'][fibid_end - 1]['fitparms'] = []
                    else:
                        bigdict['contents'][fibid_ini - 1]['fitparms'] = []
                else:
                    raise ValueError('fibid_shift in operation '
                                     'renumber_fibids_within_box '
                                     'must be -1 or 1')
            else:
                raise ValueError('Unexpected healing method:',
                                 operation['description'])

# update trace map
    if args.updated_traces is not None:
        # avoid overwritting initial JSON file
        if args.updated_traces.name != args.traces_file.name:
            # new random uuid for the updated calibration
            bigdict['uuid'] = str(uuid4())
            with open(args.updated_traces.name, 'w') as outfile:
                json.dump(bigdict, outfile, indent=2)

    if pdf is not None:
        pdf.savefig()
        pdf.close()
    else:
        pause_debugplot(12, pltshow=True, tight_layout=True)
Example #28
0
def main(args=None):

    # parse command-line options
    parser = argparse.ArgumentParser()

    # positional arguments
    parser.add_argument("fitsfile",
                        help="FITS file name to be displayed",
                        type=argparse.FileType('rb'))
    parser.add_argument("--bounddict",
                        required=True,
                        help="bounddict file name",
                        type=argparse.FileType('rt'))
    parser.add_argument("--tuple_slit_numbers",
                        required=True,
                        help="Tuple n1[,n2[,step]] to define slitlet numbers")

    # optional arguments
    parser.add_argument("--echo",
                        help="Display full command line",
                        action="store_true")

    args = parser.parse_args()

    if args.echo:
        print('\033[1m\033[31mExecuting: ' + ' '.join(sys.argv) + '\033[0m\n')

    # read slitlet numbers to be computed
    tmp_str = args.tuple_slit_numbers.split(",")
    if len(tmp_str) == 3:
        if int(tmp_str[0]) < 1:
            raise ValueError("Invalid slitlet number < 1")
        if int(tmp_str[1]) > EMIR_NBARS:
            raise ValueError("Invalid slitlet number > EMIR_NBARS")
        list_slitlets = range(int(tmp_str[0]),
                              int(tmp_str[1]) + 1, int(tmp_str[2]))
    elif len(tmp_str) == 2:
        if int(tmp_str[0]) < 1:
            raise ValueError("Invalid slitlet number < 1")
        if int(tmp_str[1]) > EMIR_NBARS:
            raise ValueError("Invalid slitlet number > EMIR_NBARS")
        list_slitlets = range(int(tmp_str[0]), int(tmp_str[1]) + 1, 1)
    elif len(tmp_str) == 1:
        if int(tmp_str[0]) < 1:
            raise ValueError("Invalid slitlet number < 1")
        if int(tmp_str[0]) > EMIR_NBARS:
            raise ValueError("Invalid slitlet number > EMIR_NBARS")
        list_slitlets = [int(tmp_str[0])]
    else:
        raise ValueError("Invalid tuple for slitlet numbers")

    # read input FITS file
    hdulist = fits.open(args.fitsfile.name)
    image_header = hdulist[0].header
    image2d = hdulist[0].data
    hdulist.close()

    naxis1 = image_header['naxis1']
    naxis2 = image_header['naxis2']

    if image2d.shape != (naxis2, naxis1):
        raise ValueError("Unexpected error with NAXIS1, NAXIS2")

    # remove path from fitsfile
    sfitsfile = os.path.basename(args.fitsfile.name)

    # check that the FITS file has been obtained with EMIR
    instrument = image_header['instrume']
    if instrument != 'EMIR':
        raise ValueError("INSTRUME keyword is not 'EMIR'!")

    # read GRISM, FILTER and ROTANG from FITS header
    grism = image_header['grism']
    spfilter = image_header['filter']
    rotang = image_header['rotang']

    # display full image
    ax = ximshow(image2d=image2d,
                 title=sfitsfile + "\ngrism=" + grism + ", filter=" +
                 spfilter + ", rotang=" + str(round(rotang, 2)),
                 image_bbox=(1, naxis1, 1, naxis2),
                 show=False)

    # overplot boundaries for each slitlet
    for slitlet_number in list_slitlets:
        pol_lower_boundary, pol_upper_boundary, \
        xmin_lower, xmax_lower, xmin_upper, xmax_upper,  \
        csu_bar_slit_center = \
            get_boundaries(args.bounddict, slitlet_number)
        if (pol_lower_boundary is not None) and \
                (pol_upper_boundary is not None):
            xp = np.linspace(start=xmin_lower, stop=xmax_lower, num=1000)
            yp = pol_lower_boundary(xp)
            ax.plot(xp, yp, 'g-')
            xp = np.linspace(start=xmin_upper, stop=xmax_upper, num=1000)
            yp = pol_upper_boundary(xp)
            ax.plot(xp, yp, 'b-')
            # slitlet label
            yc_lower = pol_lower_boundary(EMIR_NAXIS1 / 2 + 0.5)
            yc_upper = pol_upper_boundary(EMIR_NAXIS1 / 2 + 0.5)
            tmpcolor = ['r', 'b'][slitlet_number % 2]
            xcsu = EMIR_NAXIS1 * csu_bar_slit_center / 341.5
            ax.text(xcsu, (yc_lower + yc_upper) / 2,
                    str(slitlet_number),
                    fontsize=10,
                    va='center',
                    ha='center',
                    bbox=dict(boxstyle="round,pad=0.1", fc="white", ec="grey"),
                    color=tmpcolor,
                    fontweight='bold',
                    backgroundcolor='white')

    # show plot
    pause_debugplot(12, pltshow=True)
Example #29
0
def useful_mos_xpixels(reduced_mos_data,
                       base_header,
                       vpix_region,
                       npix_removed_near_ohlines=0,
                       list_valid_wvregions=None,
                       debugplot=0):
    """Useful X-axis pixels removing +/- npixaround pixels around each OH line
    """

    # get wavelength calibration from image header
    naxis1 = base_header['naxis1']
    naxis2 = base_header['naxis2']
    crpix1 = base_header['crpix1']
    crval1 = base_header['crval1']
    cdelt1 = base_header['cdelt1']

    # check vertical region
    nsmin = int(vpix_region[0] + 0.5)
    nsmax = int(vpix_region[1] + 0.5)
    if nsmin > nsmax:
        raise ValueError('vpix_region values in wrong order')
    elif nsmin < 1 or nsmax > naxis2:
        raise ValueError('vpix_region outside valid range')

    # minimum and maximum pixels in the wavelength direction
    islitlet_min = get_islitlet(nsmin)
    ncmin = base_header['jmnslt{:02d}'.format(islitlet_min)]
    ncmax = base_header['jmxslt{:02d}'.format(islitlet_min)]
    islitlet_max = get_islitlet(nsmax)
    if islitlet_max > islitlet_min:
        for islitlet in range(islitlet_min + 1, islitlet_max + 1):
            ncmin_ = base_header['jmnslt{:02d}'.format(islitlet)]
            ncmax_ = base_header['jmnslt{:02d}'.format(islitlet)]
            ncmin = min(ncmin, ncmin_)
            ncmax = max(ncmax, ncmax_)

    # pixels within valid regions
    xisok_wvreg = np.zeros(naxis1, dtype='bool')
    if list_valid_wvregions is None:
        for ipix in range(ncmin, ncmax + 1):
            xisok_wvreg[ipix - 1] = True
    else:
        for wvregion in list_valid_wvregions:
            wvmin = float(wvregion[0])
            wvmax = float(wvregion[1])
            if wvmin > wvmax:
                raise ValueError('wvregion values in wrong order:'
                                 ' {}, {}'.format(wvmin, wvmax))
            minpix = int((wvmin - crval1) / cdelt1 + crpix1 + 0.5)
            maxpix = int((wvmax - crval1) / cdelt1 + crpix1 + 0.5)
            for ipix in range(minpix, maxpix + 1):
                if 1 <= ipix <= naxis1:
                    xisok_wvreg[ipix - 1] = True
        if np.sum(xisok_wvreg) < 1:
            raise ValueError('no valid wavelength ranges provided')

    # pixels affected by OH lines
    xisok_oh = np.ones(naxis1, dtype='bool')
    if int(npix_removed_near_ohlines) > 0:
        dumdata = pkgutil.get_data(
            'emirdrp.instrument.configs',
            'Oliva_etal_2013.dat'
        )
        oh_lines_tmpfile = StringIO(dumdata.decode('utf8'))
        catlines = np.genfromtxt(oh_lines_tmpfile)
        catlines_all_wave = np.concatenate(
            (catlines[:, 1], catlines[:, 0]))
        for waveline in catlines_all_wave:
            expected_pixel = int(
                (waveline - crval1) / cdelt1 + crpix1 + 0.5)
            minpix = expected_pixel - int(npix_removed_near_ohlines)
            maxpix = expected_pixel + int(npix_removed_near_ohlines)
            for ipix in range(minpix, maxpix + 1):
                if 1 <= ipix <= naxis1:
                    xisok_oh[ipix - 1] = False

    # pixels in valid regions not affected by OH lines
    xisok = np.logical_and(xisok_wvreg, xisok_oh)
    naxis1_effective = np.sum(xisok)
    if naxis1_effective < 1:
        raise ValueError('no valid wavelength range available after '
                         'removing OH lines')

    if abs(debugplot) in [21, 22]:
        slitlet2d = reduced_mos_data[(nsmin - 1):nsmax, :].copy()
        ximshow(slitlet2d,
                title='Rectified region',
                first_pixel=(1, nsmin),
                crval1=crval1, cdelt1=cdelt1, debugplot=debugplot)
        ax = ximshow(slitlet2d,
                     title='Rectified region\nafter blocking '
                           'removed wavelength ranges',
                     first_pixel=(1, nsmin),
                     crval1=crval1, cdelt1=cdelt1, show=False)
        for idum in range(1, naxis1 + 1):
            if not xisok_wvreg[idum - 1]:
                ax.plot([idum, idum], [nsmin, nsmax], 'g-')
        pause_debugplot(debugplot, pltshow=True)
        ax = ximshow(slitlet2d,
                     title='Rectified slitlet\nuseful regions after '
                           'removing OH lines',
                     first_pixel=(1, nsmin),
                     crval1=crval1, cdelt1=cdelt1, show=False)
        for idum in range(1, naxis1 + 1):
            if not xisok[idum - 1]:
                ax.plot([idum, idum], [nsmin, nsmax], 'm-')
        pause_debugplot(debugplot, pltshow=True)

    return xisok