def show_distortion_model_difference(self, ad, ad_ref):
        """
        Shows the difference between the distortion corrected output file and
        the corresponding reference file.

        Parameters
        ----------
        ad : AstroData
            Distortion Determined AstroData object
        ad_ref : AstroData
            Distortion Determined AstroData reference object
        """
        for num, (ext, ext_ref) in enumerate(zip(ad, ad_ref)):

            name, _ = os.path.splitext(ext.filename)
            shape = ext.shape
            data = generate_fake_data(shape, ext.dispersion_axis() - 1)

            model_out = ext.wcs.get_transform("pixels", "distortion_corrected")
            model_ref = ext_ref.wcs.get_transform("pixels",
                                                  "distortion_corrected")

            transform_out = transform.Transform(model_out)
            transform_ref = transform.Transform(model_ref)

            data_out = transform_out.apply(data, output_shape=ext.shape)
            data_ref = transform_ref.apply(data, output_shape=ext.shape)

            data_out = np.ma.masked_invalid(data_out)
            data_ref = np.ma.masked_invalid(data_ref)

            fig, ax = plt.subplots(
                dpi=150,
                num="Distortion Comparison: {:s} #{:d}".format(name, num))

            im = ax.imshow(data_ref - data_out)

            ax.set_xlabel("X [px]")
            ax.set_ylabel("Y [px]")
            ax.set_title(
                "Difference between output and reference: \n {:s} #{:d} ".
                format(name, num))

            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)

            cbar = fig.colorbar(im,
                                extend="max",
                                cax=cax,
                                orientation="vertical")
            cbar.set_label("Distortion [px]")

            fig_name = os.path.join(
                self.output_folder,
                "{:s}_{:d}_{:s}_{:.0f}_distDiff.svg".format(
                    name, num, self.grating, self.central_wavelength),
            )

            fig.savefig(fig_name)
Beispiel #2
0
def test_2d_nonaffine_transform():
    """Test a more complex 2D transform with and without flux conservation"""
    triangle = models.Polynomial1D(degree=2, c0=0, c1=0.5, c2=0.5)
    triangle.inverse = InverseQuadratic1D(c0=0, c1=0.5, c2=0.5)

    size = 10
    x = np.arange(size * size, dtype=float).reshape(size, size)
    m = models.Shift(0.5) | triangle | models.Shift(-0.5)

    dg = transform.DataGroup([x], [transform.Transform(m & m)])

    output = dg.transform(subsample=5, conserve=True)
    y = [output['data'][i * (i + 1) // 2: (i + 1) * (i + 2) // 2,
                        j * (j+1) // 2: (j + 1) * (j + 2) //2].sum()
         for i in range(1, size-1) for j in range(1, size-1)]
    assert np.allclose(x[1:-1, 1:-1], np.array(y).reshape(size-2, size-2), rtol=0.005)

    # As before, the mean won't give very good results. It's worse here
    # because the gradient is very high in one direction
    output = dg.transform()
    y = [output['data'][i * (i + 1) // 2: (i + 1) * (i + 2) // 2,
                        j * (j+1) // 2: (j + 1) * (j + 2) //2].mean()
         for i in range(1, size-1) for j in range(1, size-1)]
    assert np.allclose(x[1:-1, 1:-1], np.array(y).reshape(size-2, size-2),
                       rtol=0.03, atol=0.2)
Beispiel #3
0
def test_2d_affine_transform():
    """Test a simple 2D transform with and without flux conservation"""
    size = 10
    x = np.arange(size * size, dtype=float).reshape(size, size)
    # We still lose the pixels at either end, resulting in a 18x18-pixel array
    m = models.Shift(0.5) | models.Scale(2) | models.Shift(-0.5)

    dg = transform.DataGroup([x], [transform.Transform(m & m)])

    output = dg.transform()
    y = output['data'][1:-1, 1:-1].reshape(8,2,8,2).mean(axis=3).mean(axis=1)
    assert np.array_equal(x[1:-1, 1:-1], y)

    output = dg.transform(conserve=True)
    y = output['data'][1:-1, 1:-1].reshape(8,2,8,2).sum(axis=3).sum(axis=1)
    assert np.array_equal(x[1:-1, 1:-1], y)
Beispiel #4
0
def test_1d_affine_transform():
    """Test a simple 1D transform with and without flux conservation"""
    size = 100
    x = np.arange(size, dtype=float)
    # The 0.5-pixel shifts ensure that each input pixel has a footprint
    # that precisely covers output pixels
    # We still lose the pixels at either end, resulting in a 198-pixel array
    m = models.Shift(0.5) | models.Scale(2) | models.Shift(-0.5)

    dg = transform.DataGroup([x], [transform.Transform(m)])

    output = dg.transform()
    y = output['data'][1:-1].reshape(size-2, 2).mean(axis=1)
    assert np.array_equal(x[1:-1], y)

    output = dg.transform(conserve=True)
    y = output['data'][1:-1].reshape(size-2, 2).sum(axis=1)
    assert np.array_equal(x[1:-1], y)
Beispiel #5
0
def test_1d_nonaffine_transform():
    """Test a more complex 1D transform with and without flux conservation"""
    triangle = models.Polynomial1D(degree=2, c0=0, c1=0.5, c2=0.5)
    triangle.inverse = InverseQuadratic1D(c0=0, c1=0.5, c2=0.5)

    size = 100
    x = np.arange(size, dtype=float)
    m = models.Shift(0.5) | triangle | models.Shift(-0.5)

    dg = transform.DataGroup([x], [transform.Transform(m)])

    output = dg.transform(conserve=True)
    y = [output['data'][i * (i + 1) // 2: (i + 1) * (i + 2) // 2].sum()
         for i in range(5, size-1)]
    assert np.allclose(x[5:-1], y, rtol=0.001)

    # The mean isn't really the right thing because the output pixels aren't
    # evenly spread within each input pixel, so larger differences are expected
    # but this test confirms that things aren't going completely nuts!
    output = dg.transform()
    # We start at 5 to avoid numerical issues of ~1% when limited resampling
    y = [output['data'][i * (i + 1) // 2: (i + 1) * (i + 2) // 2].mean()
         for i in range(5, size-1)]
    assert np.allclose(x[5:-1], y, rtol=0.005)
Beispiel #6
0
def do_plots(ad, ad_ref):
    """
    Generate diagnostic plots.

    Parameters
    ----------
    ad : AstroData

    ad_ref : AstroData
    """
    n_hlines = 25
    n_vlines = 25

    output_dir = "./plots/geminidr/gmos/test_gmos_spect_ls_distortion_determine"
    os.makedirs(output_dir, exist_ok=True)

    name, _ = os.path.splitext(ad.filename)
    grating = ad.disperser(pretty=True)
    bin_x = ad.detector_x_bin()
    bin_y = ad.detector_y_bin()
    central_wavelength = ad.central_wavelength() * 1e9  # in nanometers

    # -- Show distortion map ---
    for ext_num, ext in enumerate(ad):
        fname, _ = os.path.splitext(os.path.basename(ext.filename))
        n_rows, n_cols = ext.shape

        x = np.linspace(0, n_cols, n_vlines, dtype=int)
        y = np.linspace(0, n_rows, n_hlines, dtype=int)

        X, Y = np.meshgrid(x, y)

        model = rebuild_distortion_model(ext)
        U = X - model(X, Y)
        V = np.zeros_like(U)

        fig, ax = plt.subplots(
            num="Distortion Map {:s} #{:d}".format(fname, ext_num))

        vmin = U.min() if U.min() < 0 else -0.1 * U.ptp()
        vmax = U.max() if U.max() > 0 else +0.1 * U.ptp()
        vcen = 0

        Q = ax.quiver(
            X, Y, U, V, U, cmap="coolwarm",
            norm=colors.DivergingNorm(vcenter=vcen, vmin=vmin, vmax=vmax))

        ax.set_xlabel("X [px]")
        ax.set_ylabel("Y [px]")
        ax.set_title(
            "Distortion Map\n{:s} #{:d}- Bin {:d}x{:d}".format(
                fname, ext_num, bin_x, bin_y))

        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)

        cbar = fig.colorbar(Q, extend="max", cax=cax, orientation="vertical")
        cbar.set_label("Distortion [px]")

        fig.tight_layout()
        fig_name = os.path.join(
            output_dir, "{:s}_{:d}_{:s}_{:.0f}_distMap.png".format(
                fname, ext_num, grating, central_wavelength))

        fig.savefig(fig_name)
        del fig, ax

    # -- Show distortion model difference ---
    for num, (ext, ext_ref) in enumerate(zip(ad, ad_ref)):
        name, _ = os.path.splitext(ext.filename)
        shape = ext.shape
        data = generate_fake_data(shape, ext.dispersion_axis() - 1)

        model_out = remap_distortion_model(
            rebuild_distortion_model(ext), ext.dispersion_axis() - 1)

        model_ref = remap_distortion_model(
            rebuild_distortion_model(ext_ref), ext_ref.dispersion_axis() - 1)

        transform_out = transform.Transform(model_out)
        transform_ref = transform.Transform(model_ref)

        data_out = transform_out.apply(data, output_shape=ext.shape)
        data_ref = transform_ref.apply(data, output_shape=ext.shape)

        data_out = np.ma.masked_invalid(data_out)
        data_ref = np.ma.masked_invalid(data_ref)

        fig, ax = plt.subplots(
            dpi=150, num="Distortion Comparison: {:s} #{:d}".format(name, num))

        im = ax.imshow(data_ref - data_out)

        ax.set_xlabel("X [px]")
        ax.set_ylabel("Y [px]")
        ax.set_title(
            "Difference between output and reference: \n {:s} #{:d} ".format(
                name, num))

        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)

        cbar = fig.colorbar(im, extend="max", cax=cax, orientation="vertical")
        cbar.set_label("Distortion [px]")

        fig_name = os.path.join(
            output_dir, "{:s}_{:d}_{:s}_{:.0f}_distDiff.png".format(
                name, num, grating, central_wavelength))

        fig.savefig(fig_name)
    def applyStackedObjectMask(self, adinputs=None, **params):
        """
        This primitive takes an image with an OBJMASK and transforms that
        OBJMASK onto the pixel planes of the input images, using their WCS
        information. If the first image is a stack, this allows us to mask
        fainter objects than can be detected in the individual input images.

        Parameters
        ----------
        suffix: str
            suffix to be added to output files
        source: str
            name of stream containing single stacked image
        order: int (0-5)
            order of interpolation
        threshold: float
            threshold above which an interpolated pixel should be flagged
        """
        log = self.log
        log.debug(gt.log_message("primitive", self.myself(), "starting"))
        source = params["source"]
        order = params["order"]
        threshold = params["threshold"]
        sfx = params["suffix"]
        force_affine = True

        try:
            source_stream = self.streams[source]
        except KeyError:
            try:
                ad_source = astrodata.open(source)
            except:
                log.warning(f"Cannot find stream or file named {source}. Continuing.")
                return adinputs
        else:
            if len(source_stream) != 1:
                log.warning(f"Stream {source} does not contain single "
                            "AstroData object. Continuing.")
                return adinputs
            ad_source = source_stream[0]

        # There's no reason why we can't handle multiple extensions
        if any(len(ad) != len(ad_source) for ad in adinputs):
            log.warning("At least one AstroData input has a different number "
                        "of extensions to the reference. Continuing.")
            return adinputs

        for ad in adinputs:
            for ext, source_ext in zip(ad, ad_source):
                if getattr(ext, 'OBJMASK') is not None:
                    t_align = source_ext.wcs.forward_transform | ext.wcs.backward_transform
                    if force_affine:
                        affine = adwcs.calculate_affine_matrices(t_align.inverse, ad[0].shape)
                        objmask = affine_transform(source_ext.OBJMASK.astype(np.float32),
                                                   affine.matrix, affine.offset,
                                                   output_shape=ext.shape, order=order,
                                                   cval=0)
                    else:
                        objmask = transform.Transform(t_align).apply(source_ext.OBJMASK.astype(np.float32),
                                                                     output_shape=ext.shape, order=order,
                                                                     cval=0)
                    ext.OBJMASK = np.where(abs(objmask) > threshold, 1, 0).astype(np.uint8)
                # We will deliberately keep the input image's OBJCAT (if it
                # exists) since this will be required for aligning the inputs.
            ad.update_filename(suffix=sfx, strip=True)

        return adinputs
Beispiel #8
0
    def tileArrays(self, adinputs=None, **params):
        """
        This primitive combines extensions by tiling (no interpolation).
        The array_section() and detector_section() descriptors are used
        to derive the geometry of the tiling, so outside help (from the
        instrument's geometry_conf module) is only required if there are
        multiple arrays being tiled together, as the gaps need to be
        specified.

        Parameters
        ----------
        suffix: str
            suffix to be added to output files
        tile_all: bool
            tile to a single extension, rather than one per array?
            (array=physical detector)
        sci_only: bool
            tile only the data plane?
        """
        log = self.log
        log.debug(gt.log_message("primitive", self.myself(), "starting"))
        timestamp_key = self.timestamp_keys[self.myself()]

        suffix = params['suffix']
        tile_all = params['tile_all']
        attributes = ['data'] if params["sci_only"] else None

        adoutputs = []
        for ad in adinputs:
            if len(ad) == 1:
                log.warning("{} has only one extension, so there's nothing "
                            "to tile".format(ad.filename))
                adoutputs.append(ad)
                continue

            # Get information to calculate the output geometry
            # TODO: Think about arbitrary ROIs
            array_info = gt.array_information(ad)
            detshape = array_info.detector_shape
            if not tile_all and set(array_info.array_shapes) == {(1, 1)}:
                log.warning("{} has nothing to tile, as tile_all=False but "
                            "each array has only one amplifier.")
                adoutputs.append(ad)
                continue

            blocks = [transform.Block(ad[arrays], shape=shape) for arrays, shape in
                      zip(array_info.extensions, array_info.array_shapes)]
            offsets = [ad[exts[0]].array_section()
                       for exts in array_info.extensions]

            if tile_all and detshape != (1, 1):  # We need gaps!
                geotable = import_module('.geometry_conf', self.inst_lookups)
                chip_gaps = geotable.tile_gaps[ad.detector_name()]
                try:
                    xgap, ygap = chip_gaps
                except TypeError:  # single number, applies to both
                    xgap = ygap = chip_gaps
                transforms = []
                for i, (origin, offset) in enumerate(zip(array_info.origins, offsets)):
                    xshift = (origin[1] + offset.x1 + xgap * (i % detshape[1])) // ad.detector_x_bin()
                    yshift = (origin[0] + offset.y1 + ygap * (i // detshape[1])) // ad.detector_y_bin()
                    transforms.append(transform.Transform(models.Shift(xshift) & models.Shift(yshift)))
                adg = transform.AstroDataGroup(blocks, transforms)
                adg.set_reference()
                ad_out = adg.transform(attributes=attributes, process_objcat=True)
            else:
                # ADG.transform() produces full AD objects so we start with
                # the first one, and then append the single extensions created
                # by later calls to it.
                for i, block in enumerate(blocks):
                    # Simply create a single tiled array
                    adg = transform.AstroDataGroup([block])
                    adg.set_reference()
                    if i == 0:
                        ad_out = adg.transform(attributes=attributes,
                                               process_objcat=True)
                    else:
                        ad_out.append(adg.transform(attributes=attributes,
                                                    process_objcat=True)[0])

            gt.mark_history(ad_out, primname=self.myself(), keyword=timestamp_key)
            ad_out.orig_filename = ad.filename
            ad_out.update_filename(suffix=suffix, strip=True)
            adoutputs.append(ad_out)
        return adoutputs
    def applyQECorrection(self, adinputs=None, **params):
        """
        This primitive applies a wavelength-dependent QE correction to
        a 2D spectral image, based on the wavelength solution of an
        associated processed_arc.

        It is only designed to work on FLATs, and therefore unmosaicked data.

        Parameters
        ----------
        suffix: str
            suffix to be added to output files
        """
        log = self.log
        log.debug(gt.log_message("primitive", self.myself(), "starting"))
        timestamp_key = self.timestamp_keys[self.myself()]

        sfx = params["suffix"]
        arc = params["arc"]

        # Get a suitable arc frame (with distortion map) for every science AD
        if arc is None:
            self.getProcessedArc(adinputs, refresh=False)
            arc_list = self._get_cal(adinputs, 'processed_arc')
        else:
            arc_list = arc

        distort_model = models.Identity(2)

        for ad, arc in zip(*gt.make_lists(adinputs, arc_list, force_ad=True)):
            if ad.phu.get(timestamp_key):
                log.warning(
                    "No changes will be made to {}, since it has "
                    "already been processed by applyQECorrection".format(
                        ad.filename))
                continue

            if 'e2v' in ad.detector_name(pretty=True):
                log.warning("{} has the e2v CCDs, so no QE correction "
                            "is necessary".format(ad.filename))
                continue

            # Determines whether to multiply or divide by QE correction
            is_flat = 'FLAT' in ad.tags

            # If the arc's binning doesn't match, we may still be able to
            # fall back to the approximate solution
            xbin, ybin = ad.detector_x_bin(), ad.detector_y_bin()
            if arc is not None and (arc.detector_x_bin() != xbin
                                    or arc.detector_y_bin() != ybin):
                log.warning(
                    "Science frame {} and arc {} have different binnings,"
                    "so cannot use arc".format(ad.filename, arc.filename))
                arc = None

            # OK, we definitely want to try to do this, get a wavelength solution
            try:
                wavecal = arc[0].WAVECAL
            except (TypeError, AttributeError):
                wave_model = None
            else:
                model_dict = dict(zip(wavecal['name'],
                                      wavecal['coefficients']))
                wave_model = astromodels.dict_to_chebyshev(model_dict)
                if not isinstance(wave_model, models.Chebyshev1D):
                    log.warning("Problem reading wavelength solution from arc "
                                "{}".format(arc.filename))

            if wave_model is None:
                if 'sq' in self.mode:
                    raise OSError("No wavelength solution for {}".format(
                        ad.filename))
                else:
                    log.warning("Using approximate wavelength solution for "
                                "{}".format(ad.filename))

            try:
                fitcoord = arc[0].FITCOORD
            except (TypeError, AttributeError):
                # distort_model already has Identity inverse so nothing required
                pass
            else:
                # TODO: This is copied from determineDistortion() and will need
                # to be refactored out. Or we might be able to simply replace it
                # with a gWCS.pixel_to_world() call
                model_dict = dict(
                    zip(fitcoord['inv_name'], fitcoord['inv_coefficients']))
                m_inverse = astromodels.dict_to_chebyshev(model_dict)
                if not isinstance(m_inverse, models.Chebyshev2D):
                    log.warning("Problem reading distortion model from arc "
                                "{}".format(arc.filename))
                else:
                    distort_model.inverse = models.Mapping(
                        (0, 1, 1)) | (m_inverse & models.Identity(1))

            if distort_model.inverse == distort_model:  # Identity(2)
                if 'sq' in self.mode:
                    raise OSError("No distortion model for {}".format(
                        ad.filename))
                else:
                    log.warning(
                        "Proceeding without a disortion correction for "
                        "{}".format(ad.filename))

            ad_detsec = ad.detector_section()
            adg = transform.create_mosaic_transform(ad, geotable)
            if arc is not None:
                arc_detsec = arc.detector_section()[0]
                shifts = [
                    c1 - c2 for c1, c2 in zip(
                        np.array(ad_detsec).min(axis=0), arc_detsec)
                ]
                xshift, yshift = shifts[0] / xbin, shifts[2] / ybin  # x1, y1
                if xshift or yshift:
                    log.stdinfo("Found a shift of ({},{}) pixels between "
                                "{} and the calibration.".format(
                                    xshift, yshift, ad.filename))
                add_shapes, add_transforms = [], []
                for (arr, trans) in adg:
                    # Try to work out shape of this Block in the unmosaicked
                    # arc, and then apply a shift to align it with the
                    # science Block before applying the same transform.
                    if xshift == 0:
                        add_shapes.append(
                            ((arc_detsec.y2 - arc_detsec.y1) // ybin,
                             arr.shape[1]))
                    else:
                        add_shapes.append(
                            (arr.shape[0],
                             (arc_detsec.x2 - arc_detsec.x1) // xbin))
                    t = transform.Transform(
                        models.Shift(-xshift) & models.Shift(-yshift))
                    t.append(trans)
                    add_transforms.append(t)
                adg.calculate_output_shape(
                    additional_array_shapes=add_shapes,
                    additional_transforms=add_transforms)
                origin_shift = models.Shift(-adg.origin[1]) & models.Shift(
                    -adg.origin[0])
                for t in adg.transforms:
                    t.append(origin_shift)

            # Irrespective of arc or not, apply the distortion model (it may
            # be Identity), recalculate output_shape and reset the origin
            for t in adg.transforms:
                t.append(distort_model.copy())
            adg.calculate_output_shape()
            adg.reset_origin()

            # Now we know the shape of the output, we can construct the
            # approximate wavelength solution; ad.dispersion() returns a list!
            if wave_model is None:
                wave_model = (
                    models.Shift(-0.5 * adg.output_shape[1])
                    | models.Scale(ad.dispersion(asNanometers=True)[0])
                    | models.Shift(ad.central_wavelength(asNanometers=True)))

            for ccd, (block, trans) in enumerate(adg, start=1):
                if ccd == 2:
                    continue
                for ext, corner in zip(block, block.corners):
                    ygrid, xgrid = np.indices(ext.shape)
                    xgrid += corner[1]  # No need for ygrid
                    xnew = trans(xgrid, ygrid)[0]
                    # Some unit-based stuff here to prepare for gWCS
                    waves = wave_model(xnew) * u.nm
                    try:
                        qe_correction = qeModel(ext)(
                            (waves / u.nm).to(u.dimensionless_unscaled).value)
                    except TypeError:  # qeModel() returns None
                        msg = "No QE correction found for {}:{}".format(
                            ad.filename, ext.hdr['EXTVER'])
                        if 'sq' in self.mode:
                            raise ValueError(msg)
                        else:
                            log.warning(msg)
                    log.fullinfo(
                        "Mean relative QE of EXTVER {} is {:.5f}".format(
                            ext.hdr['EXTVER'], qe_correction.mean()))
                    if not is_flat:
                        qe_correction = 1. / qe_correction
                    qe_correction[qe_correction < 0] = 0
                    qe_correction[qe_correction > 10] = 0
                    ext.multiply(qe_correction)

            # Timestamp and update the filename
            gt.mark_history(ad, primname=self.myself(), keyword=timestamp_key)
            ad.update_filename(suffix=sfx, strip=True)

        return adinputs