def show_distortion_model_difference(self, ad, ad_ref): """ Shows the difference between the distortion corrected output file and the corresponding reference file. Parameters ---------- ad : AstroData Distortion Determined AstroData object ad_ref : AstroData Distortion Determined AstroData reference object """ for num, (ext, ext_ref) in enumerate(zip(ad, ad_ref)): name, _ = os.path.splitext(ext.filename) shape = ext.shape data = generate_fake_data(shape, ext.dispersion_axis() - 1) model_out = ext.wcs.get_transform("pixels", "distortion_corrected") model_ref = ext_ref.wcs.get_transform("pixels", "distortion_corrected") transform_out = transform.Transform(model_out) transform_ref = transform.Transform(model_ref) data_out = transform_out.apply(data, output_shape=ext.shape) data_ref = transform_ref.apply(data, output_shape=ext.shape) data_out = np.ma.masked_invalid(data_out) data_ref = np.ma.masked_invalid(data_ref) fig, ax = plt.subplots( dpi=150, num="Distortion Comparison: {:s} #{:d}".format(name, num)) im = ax.imshow(data_ref - data_out) ax.set_xlabel("X [px]") ax.set_ylabel("Y [px]") ax.set_title( "Difference between output and reference: \n {:s} #{:d} ". format(name, num)) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) cbar = fig.colorbar(im, extend="max", cax=cax, orientation="vertical") cbar.set_label("Distortion [px]") fig_name = os.path.join( self.output_folder, "{:s}_{:d}_{:s}_{:.0f}_distDiff.svg".format( name, num, self.grating, self.central_wavelength), ) fig.savefig(fig_name)
def test_2d_nonaffine_transform(): """Test a more complex 2D transform with and without flux conservation""" triangle = models.Polynomial1D(degree=2, c0=0, c1=0.5, c2=0.5) triangle.inverse = InverseQuadratic1D(c0=0, c1=0.5, c2=0.5) size = 10 x = np.arange(size * size, dtype=float).reshape(size, size) m = models.Shift(0.5) | triangle | models.Shift(-0.5) dg = transform.DataGroup([x], [transform.Transform(m & m)]) output = dg.transform(subsample=5, conserve=True) y = [output['data'][i * (i + 1) // 2: (i + 1) * (i + 2) // 2, j * (j+1) // 2: (j + 1) * (j + 2) //2].sum() for i in range(1, size-1) for j in range(1, size-1)] assert np.allclose(x[1:-1, 1:-1], np.array(y).reshape(size-2, size-2), rtol=0.005) # As before, the mean won't give very good results. It's worse here # because the gradient is very high in one direction output = dg.transform() y = [output['data'][i * (i + 1) // 2: (i + 1) * (i + 2) // 2, j * (j+1) // 2: (j + 1) * (j + 2) //2].mean() for i in range(1, size-1) for j in range(1, size-1)] assert np.allclose(x[1:-1, 1:-1], np.array(y).reshape(size-2, size-2), rtol=0.03, atol=0.2)
def test_2d_affine_transform(): """Test a simple 2D transform with and without flux conservation""" size = 10 x = np.arange(size * size, dtype=float).reshape(size, size) # We still lose the pixels at either end, resulting in a 18x18-pixel array m = models.Shift(0.5) | models.Scale(2) | models.Shift(-0.5) dg = transform.DataGroup([x], [transform.Transform(m & m)]) output = dg.transform() y = output['data'][1:-1, 1:-1].reshape(8,2,8,2).mean(axis=3).mean(axis=1) assert np.array_equal(x[1:-1, 1:-1], y) output = dg.transform(conserve=True) y = output['data'][1:-1, 1:-1].reshape(8,2,8,2).sum(axis=3).sum(axis=1) assert np.array_equal(x[1:-1, 1:-1], y)
def test_1d_affine_transform(): """Test a simple 1D transform with and without flux conservation""" size = 100 x = np.arange(size, dtype=float) # The 0.5-pixel shifts ensure that each input pixel has a footprint # that precisely covers output pixels # We still lose the pixels at either end, resulting in a 198-pixel array m = models.Shift(0.5) | models.Scale(2) | models.Shift(-0.5) dg = transform.DataGroup([x], [transform.Transform(m)]) output = dg.transform() y = output['data'][1:-1].reshape(size-2, 2).mean(axis=1) assert np.array_equal(x[1:-1], y) output = dg.transform(conserve=True) y = output['data'][1:-1].reshape(size-2, 2).sum(axis=1) assert np.array_equal(x[1:-1], y)
def test_1d_nonaffine_transform(): """Test a more complex 1D transform with and without flux conservation""" triangle = models.Polynomial1D(degree=2, c0=0, c1=0.5, c2=0.5) triangle.inverse = InverseQuadratic1D(c0=0, c1=0.5, c2=0.5) size = 100 x = np.arange(size, dtype=float) m = models.Shift(0.5) | triangle | models.Shift(-0.5) dg = transform.DataGroup([x], [transform.Transform(m)]) output = dg.transform(conserve=True) y = [output['data'][i * (i + 1) // 2: (i + 1) * (i + 2) // 2].sum() for i in range(5, size-1)] assert np.allclose(x[5:-1], y, rtol=0.001) # The mean isn't really the right thing because the output pixels aren't # evenly spread within each input pixel, so larger differences are expected # but this test confirms that things aren't going completely nuts! output = dg.transform() # We start at 5 to avoid numerical issues of ~1% when limited resampling y = [output['data'][i * (i + 1) // 2: (i + 1) * (i + 2) // 2].mean() for i in range(5, size-1)] assert np.allclose(x[5:-1], y, rtol=0.005)
def do_plots(ad, ad_ref): """ Generate diagnostic plots. Parameters ---------- ad : AstroData ad_ref : AstroData """ n_hlines = 25 n_vlines = 25 output_dir = "./plots/geminidr/gmos/test_gmos_spect_ls_distortion_determine" os.makedirs(output_dir, exist_ok=True) name, _ = os.path.splitext(ad.filename) grating = ad.disperser(pretty=True) bin_x = ad.detector_x_bin() bin_y = ad.detector_y_bin() central_wavelength = ad.central_wavelength() * 1e9 # in nanometers # -- Show distortion map --- for ext_num, ext in enumerate(ad): fname, _ = os.path.splitext(os.path.basename(ext.filename)) n_rows, n_cols = ext.shape x = np.linspace(0, n_cols, n_vlines, dtype=int) y = np.linspace(0, n_rows, n_hlines, dtype=int) X, Y = np.meshgrid(x, y) model = rebuild_distortion_model(ext) U = X - model(X, Y) V = np.zeros_like(U) fig, ax = plt.subplots( num="Distortion Map {:s} #{:d}".format(fname, ext_num)) vmin = U.min() if U.min() < 0 else -0.1 * U.ptp() vmax = U.max() if U.max() > 0 else +0.1 * U.ptp() vcen = 0 Q = ax.quiver( X, Y, U, V, U, cmap="coolwarm", norm=colors.DivergingNorm(vcenter=vcen, vmin=vmin, vmax=vmax)) ax.set_xlabel("X [px]") ax.set_ylabel("Y [px]") ax.set_title( "Distortion Map\n{:s} #{:d}- Bin {:d}x{:d}".format( fname, ext_num, bin_x, bin_y)) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) cbar = fig.colorbar(Q, extend="max", cax=cax, orientation="vertical") cbar.set_label("Distortion [px]") fig.tight_layout() fig_name = os.path.join( output_dir, "{:s}_{:d}_{:s}_{:.0f}_distMap.png".format( fname, ext_num, grating, central_wavelength)) fig.savefig(fig_name) del fig, ax # -- Show distortion model difference --- for num, (ext, ext_ref) in enumerate(zip(ad, ad_ref)): name, _ = os.path.splitext(ext.filename) shape = ext.shape data = generate_fake_data(shape, ext.dispersion_axis() - 1) model_out = remap_distortion_model( rebuild_distortion_model(ext), ext.dispersion_axis() - 1) model_ref = remap_distortion_model( rebuild_distortion_model(ext_ref), ext_ref.dispersion_axis() - 1) transform_out = transform.Transform(model_out) transform_ref = transform.Transform(model_ref) data_out = transform_out.apply(data, output_shape=ext.shape) data_ref = transform_ref.apply(data, output_shape=ext.shape) data_out = np.ma.masked_invalid(data_out) data_ref = np.ma.masked_invalid(data_ref) fig, ax = plt.subplots( dpi=150, num="Distortion Comparison: {:s} #{:d}".format(name, num)) im = ax.imshow(data_ref - data_out) ax.set_xlabel("X [px]") ax.set_ylabel("Y [px]") ax.set_title( "Difference between output and reference: \n {:s} #{:d} ".format( name, num)) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) cbar = fig.colorbar(im, extend="max", cax=cax, orientation="vertical") cbar.set_label("Distortion [px]") fig_name = os.path.join( output_dir, "{:s}_{:d}_{:s}_{:.0f}_distDiff.png".format( name, num, grating, central_wavelength)) fig.savefig(fig_name)
def applyStackedObjectMask(self, adinputs=None, **params): """ This primitive takes an image with an OBJMASK and transforms that OBJMASK onto the pixel planes of the input images, using their WCS information. If the first image is a stack, this allows us to mask fainter objects than can be detected in the individual input images. Parameters ---------- suffix: str suffix to be added to output files source: str name of stream containing single stacked image order: int (0-5) order of interpolation threshold: float threshold above which an interpolated pixel should be flagged """ log = self.log log.debug(gt.log_message("primitive", self.myself(), "starting")) source = params["source"] order = params["order"] threshold = params["threshold"] sfx = params["suffix"] force_affine = True try: source_stream = self.streams[source] except KeyError: try: ad_source = astrodata.open(source) except: log.warning(f"Cannot find stream or file named {source}. Continuing.") return adinputs else: if len(source_stream) != 1: log.warning(f"Stream {source} does not contain single " "AstroData object. Continuing.") return adinputs ad_source = source_stream[0] # There's no reason why we can't handle multiple extensions if any(len(ad) != len(ad_source) for ad in adinputs): log.warning("At least one AstroData input has a different number " "of extensions to the reference. Continuing.") return adinputs for ad in adinputs: for ext, source_ext in zip(ad, ad_source): if getattr(ext, 'OBJMASK') is not None: t_align = source_ext.wcs.forward_transform | ext.wcs.backward_transform if force_affine: affine = adwcs.calculate_affine_matrices(t_align.inverse, ad[0].shape) objmask = affine_transform(source_ext.OBJMASK.astype(np.float32), affine.matrix, affine.offset, output_shape=ext.shape, order=order, cval=0) else: objmask = transform.Transform(t_align).apply(source_ext.OBJMASK.astype(np.float32), output_shape=ext.shape, order=order, cval=0) ext.OBJMASK = np.where(abs(objmask) > threshold, 1, 0).astype(np.uint8) # We will deliberately keep the input image's OBJCAT (if it # exists) since this will be required for aligning the inputs. ad.update_filename(suffix=sfx, strip=True) return adinputs
def tileArrays(self, adinputs=None, **params): """ This primitive combines extensions by tiling (no interpolation). The array_section() and detector_section() descriptors are used to derive the geometry of the tiling, so outside help (from the instrument's geometry_conf module) is only required if there are multiple arrays being tiled together, as the gaps need to be specified. Parameters ---------- suffix: str suffix to be added to output files tile_all: bool tile to a single extension, rather than one per array? (array=physical detector) sci_only: bool tile only the data plane? """ log = self.log log.debug(gt.log_message("primitive", self.myself(), "starting")) timestamp_key = self.timestamp_keys[self.myself()] suffix = params['suffix'] tile_all = params['tile_all'] attributes = ['data'] if params["sci_only"] else None adoutputs = [] for ad in adinputs: if len(ad) == 1: log.warning("{} has only one extension, so there's nothing " "to tile".format(ad.filename)) adoutputs.append(ad) continue # Get information to calculate the output geometry # TODO: Think about arbitrary ROIs array_info = gt.array_information(ad) detshape = array_info.detector_shape if not tile_all and set(array_info.array_shapes) == {(1, 1)}: log.warning("{} has nothing to tile, as tile_all=False but " "each array has only one amplifier.") adoutputs.append(ad) continue blocks = [transform.Block(ad[arrays], shape=shape) for arrays, shape in zip(array_info.extensions, array_info.array_shapes)] offsets = [ad[exts[0]].array_section() for exts in array_info.extensions] if tile_all and detshape != (1, 1): # We need gaps! geotable = import_module('.geometry_conf', self.inst_lookups) chip_gaps = geotable.tile_gaps[ad.detector_name()] try: xgap, ygap = chip_gaps except TypeError: # single number, applies to both xgap = ygap = chip_gaps transforms = [] for i, (origin, offset) in enumerate(zip(array_info.origins, offsets)): xshift = (origin[1] + offset.x1 + xgap * (i % detshape[1])) // ad.detector_x_bin() yshift = (origin[0] + offset.y1 + ygap * (i // detshape[1])) // ad.detector_y_bin() transforms.append(transform.Transform(models.Shift(xshift) & models.Shift(yshift))) adg = transform.AstroDataGroup(blocks, transforms) adg.set_reference() ad_out = adg.transform(attributes=attributes, process_objcat=True) else: # ADG.transform() produces full AD objects so we start with # the first one, and then append the single extensions created # by later calls to it. for i, block in enumerate(blocks): # Simply create a single tiled array adg = transform.AstroDataGroup([block]) adg.set_reference() if i == 0: ad_out = adg.transform(attributes=attributes, process_objcat=True) else: ad_out.append(adg.transform(attributes=attributes, process_objcat=True)[0]) gt.mark_history(ad_out, primname=self.myself(), keyword=timestamp_key) ad_out.orig_filename = ad.filename ad_out.update_filename(suffix=suffix, strip=True) adoutputs.append(ad_out) return adoutputs
def applyQECorrection(self, adinputs=None, **params): """ This primitive applies a wavelength-dependent QE correction to a 2D spectral image, based on the wavelength solution of an associated processed_arc. It is only designed to work on FLATs, and therefore unmosaicked data. Parameters ---------- suffix: str suffix to be added to output files """ log = self.log log.debug(gt.log_message("primitive", self.myself(), "starting")) timestamp_key = self.timestamp_keys[self.myself()] sfx = params["suffix"] arc = params["arc"] # Get a suitable arc frame (with distortion map) for every science AD if arc is None: self.getProcessedArc(adinputs, refresh=False) arc_list = self._get_cal(adinputs, 'processed_arc') else: arc_list = arc distort_model = models.Identity(2) for ad, arc in zip(*gt.make_lists(adinputs, arc_list, force_ad=True)): if ad.phu.get(timestamp_key): log.warning( "No changes will be made to {}, since it has " "already been processed by applyQECorrection".format( ad.filename)) continue if 'e2v' in ad.detector_name(pretty=True): log.warning("{} has the e2v CCDs, so no QE correction " "is necessary".format(ad.filename)) continue # Determines whether to multiply or divide by QE correction is_flat = 'FLAT' in ad.tags # If the arc's binning doesn't match, we may still be able to # fall back to the approximate solution xbin, ybin = ad.detector_x_bin(), ad.detector_y_bin() if arc is not None and (arc.detector_x_bin() != xbin or arc.detector_y_bin() != ybin): log.warning( "Science frame {} and arc {} have different binnings," "so cannot use arc".format(ad.filename, arc.filename)) arc = None # OK, we definitely want to try to do this, get a wavelength solution try: wavecal = arc[0].WAVECAL except (TypeError, AttributeError): wave_model = None else: model_dict = dict(zip(wavecal['name'], wavecal['coefficients'])) wave_model = astromodels.dict_to_chebyshev(model_dict) if not isinstance(wave_model, models.Chebyshev1D): log.warning("Problem reading wavelength solution from arc " "{}".format(arc.filename)) if wave_model is None: if 'sq' in self.mode: raise OSError("No wavelength solution for {}".format( ad.filename)) else: log.warning("Using approximate wavelength solution for " "{}".format(ad.filename)) try: fitcoord = arc[0].FITCOORD except (TypeError, AttributeError): # distort_model already has Identity inverse so nothing required pass else: # TODO: This is copied from determineDistortion() and will need # to be refactored out. Or we might be able to simply replace it # with a gWCS.pixel_to_world() call model_dict = dict( zip(fitcoord['inv_name'], fitcoord['inv_coefficients'])) m_inverse = astromodels.dict_to_chebyshev(model_dict) if not isinstance(m_inverse, models.Chebyshev2D): log.warning("Problem reading distortion model from arc " "{}".format(arc.filename)) else: distort_model.inverse = models.Mapping( (0, 1, 1)) | (m_inverse & models.Identity(1)) if distort_model.inverse == distort_model: # Identity(2) if 'sq' in self.mode: raise OSError("No distortion model for {}".format( ad.filename)) else: log.warning( "Proceeding without a disortion correction for " "{}".format(ad.filename)) ad_detsec = ad.detector_section() adg = transform.create_mosaic_transform(ad, geotable) if arc is not None: arc_detsec = arc.detector_section()[0] shifts = [ c1 - c2 for c1, c2 in zip( np.array(ad_detsec).min(axis=0), arc_detsec) ] xshift, yshift = shifts[0] / xbin, shifts[2] / ybin # x1, y1 if xshift or yshift: log.stdinfo("Found a shift of ({},{}) pixels between " "{} and the calibration.".format( xshift, yshift, ad.filename)) add_shapes, add_transforms = [], [] for (arr, trans) in adg: # Try to work out shape of this Block in the unmosaicked # arc, and then apply a shift to align it with the # science Block before applying the same transform. if xshift == 0: add_shapes.append( ((arc_detsec.y2 - arc_detsec.y1) // ybin, arr.shape[1])) else: add_shapes.append( (arr.shape[0], (arc_detsec.x2 - arc_detsec.x1) // xbin)) t = transform.Transform( models.Shift(-xshift) & models.Shift(-yshift)) t.append(trans) add_transforms.append(t) adg.calculate_output_shape( additional_array_shapes=add_shapes, additional_transforms=add_transforms) origin_shift = models.Shift(-adg.origin[1]) & models.Shift( -adg.origin[0]) for t in adg.transforms: t.append(origin_shift) # Irrespective of arc or not, apply the distortion model (it may # be Identity), recalculate output_shape and reset the origin for t in adg.transforms: t.append(distort_model.copy()) adg.calculate_output_shape() adg.reset_origin() # Now we know the shape of the output, we can construct the # approximate wavelength solution; ad.dispersion() returns a list! if wave_model is None: wave_model = ( models.Shift(-0.5 * adg.output_shape[1]) | models.Scale(ad.dispersion(asNanometers=True)[0]) | models.Shift(ad.central_wavelength(asNanometers=True))) for ccd, (block, trans) in enumerate(adg, start=1): if ccd == 2: continue for ext, corner in zip(block, block.corners): ygrid, xgrid = np.indices(ext.shape) xgrid += corner[1] # No need for ygrid xnew = trans(xgrid, ygrid)[0] # Some unit-based stuff here to prepare for gWCS waves = wave_model(xnew) * u.nm try: qe_correction = qeModel(ext)( (waves / u.nm).to(u.dimensionless_unscaled).value) except TypeError: # qeModel() returns None msg = "No QE correction found for {}:{}".format( ad.filename, ext.hdr['EXTVER']) if 'sq' in self.mode: raise ValueError(msg) else: log.warning(msg) log.fullinfo( "Mean relative QE of EXTVER {} is {:.5f}".format( ext.hdr['EXTVER'], qe_correction.mean())) if not is_flat: qe_correction = 1. / qe_correction qe_correction[qe_correction < 0] = 0 qe_correction[qe_correction > 10] = 0 ext.multiply(qe_correction) # Timestamp and update the filename gt.mark_history(ad, primname=self.myself(), keyword=timestamp_key) ad.update_filename(suffix=sfx, strip=True) return adinputs