Exemple #1
def test_invalid_operands():
    Test that certain operators do not work with models whose inputs/outputs do
    not match up correctly.

    with pytest.raises(ModelDefinitionError):
        Rotation2D(90) | Gaussian1D(1, 0, 0.1)

    with pytest.raises(ModelDefinitionError):
        Rotation2D(90) + Gaussian1D(1, 0, 0.1)
def test_replace_submodel():
    Replace a model in a Compound model
    S1 = Shift(2, name='shift2') | Scale(
        3, name='scale3')  # First shift then scale
    S2 = Scale(2, name='scale2') | Shift(
        3, name='shift3')  # First scale then shift

    m = S1 & S2
    assert m(1, 2) == (9, 7)

    m2 = m.replace_submodel('scale3', Scale(4, name='scale4'))
    assert m2(1, 2) == (12, 7)
    assert m(1, 2) == (9, 7)
    # Check the inverse has been updated
    assert m2.inverse(12, 7) == (1, 2)

    # Produce the same result by replacing a single model with a compound
    m3 = m.replace_submodel('shift2', Shift(2) | Scale(2))
    assert m(1, 2) == (9, 7)
    assert m3(1, 2) == (18, 7)
    # Check the inverse has been updated
    assert m3.inverse(18, 7) == (1, 2)

    # Test with arithmetic model compunding operator
    m = S1 + S2
    assert m(1) == 14
    m2 = m.replace_submodel('scale2', Scale(4, name='scale4'))
    assert m2(1) == 16

    # Test with fix_inputs()
    R = fix_inputs(Rotation2D(angle=90, name='rotate'), {0: 1})
    m4 = S1 | R
    assert_allclose(m4(0), (-6, 1))

    m5 = m4.replace_submodel('rotate', Rotation2D(180))
    assert_allclose(m5(0), (-1, -6))

    # Check we get a value error when model name doesn't exist
    with pytest.raises(ValueError):
        m2 = m.replace_submodel('not_there', Scale(2))

    # And now a model set
    P = Polynomial1D(degree=1, n_models=2, name='poly')
    S = Shift([1, 2], n_models=2)
    m = P | S
    assert_array_equal(m([0, 1]), (1, 2))
    with pytest.raises(ValueError):
        m2 = m.replace_submodel('poly', Polynomial1D(degree=1, c0=1))
    m2 = m.replace_submodel('poly',
                            Polynomial1D(degree=1, c0=[1, 2], n_models=2))
    assert_array_equal(m2([0, 1]), (2, 4))
Exemple #3
def test_mapping_inverse():
    """Tests inverting a compound model that includes a `Mapping`."""

    rs1 = Rotation2D(12.1) & Scale(13.2)
    rs2 = Rotation2D(14.3) & Scale(15.4)

    # Rotates 2 of the coordinates and scales the third--then rotates on a
    # different axis and scales on the axis of rotation.  No physical meaning
    # here just a simple test
    m = rs1 | Mapping([2, 0, 1]) | rs2

    assert_allclose((0, 1, 2), m.inverse(*m(0, 1, 2)), atol=1e-08)
Exemple #4
def test_parameters_compound_models():
    tan = Pix2Sky_TAN()
    sky_coords = coord.SkyCoord(ra=5.6, dec=-72, unit=u.deg)
    lon_pole = 180 * u.deg
    n2c = RotateNative2Celestial(sky_coords.ra, sky_coords.dec, lon_pole)
    rot = Rotation2D(23)
    m = rot | n2c
Exemple #5
def compute_spec_transform(fiducial, refwcs):
    Compute a simple transform given a fidicial point in a spatial-spectral wcs.
    cdelt1 = refwcs.wcsinfo.cdelt1 / 3600.
    cdelt2 = refwcs.wcsinfo.cdelt2 / 3600.
    cdelt3 = refwcs.wcsinfo.cdelt3
    roll_ref = refwcs.wcsinfo.roll_ref

    y, x = grid_from_spec_domain(refwcs)
    ra, dec, lam = refwcs(x, y)

    min_lam = np.nanmin(lam)

    offset = Shift(0.) & Shift(0.)
    rot = Rotation2D(roll_ref)
    scale = Scale(cdelt1) & Scale(cdelt2)
    tan = Pix2Sky_TAN()
    skyrot = RotateNative2Celestial(fiducial[0][0], fiducial[0][1], 180.0)
    spatial = offset | rot | scale | tan | skyrot
    spectral = Scale(cdelt3) | Shift(min_lam)
    mapping = Mapping((1, 1, 0), )
    mapping.inverse = Mapping((2, 1))
    transform = mapping | spatial & spectral
    transform.outputs = ('ra', 'dec', 'lamda')
    return transform
Exemple #6
def test_basic_compound_inverse():
    Test basic inversion of compound models in the limited sense supported for
    models made from compositions and joins only.

    t = (Shift(2) & Shift(3)) | (Scale(2) & Scale(3)) | Rotation2D(90)
    assert_allclose(t.inverse(*t(0, 1)), (0, 1))
Exemple #7
def test_mapping_basic_permutations():
    Tests a couple basic examples of the Mapping model--specifically examples
    that merely permute the outputs.

    x, y = Rotation2D(90)(1, 2)

    rs = Rotation2D(90) | Mapping((1, 0))
    x_prime, y_prime = rs(1, 2)
    assert_allclose((x, y), (y_prime, x_prime))

    # A more complicated permutation
    m = Rotation2D(90) & Scale(2)
    x, y, z = m(1, 2, 3)

    ms = m | Mapping((2, 0, 1))
    x_prime, y_prime, z_prime = ms(1, 2, 3)
    assert_allclose((x, y, z), (y_prime, z_prime, x_prime))
Exemple #8
def test_simple_two_model_compose_2d():
    A simple example consisting of two rotations.

    r1 = Rotation2D(45) | Rotation2D(45)

    assert isinstance(r1, CompoundModel)
    assert r1.n_inputs == 2
    assert r1.n_outputs == 2
    assert_allclose(r1(0, 1), (-1, 0), atol=1e-10)

    r2 = Rotation2D(90) | Rotation2D(90)  # Rotate twice by 90 degrees
    assert_allclose(r2(0, 1), (0, -1), atol=1e-10)

    # Compose R with itself to produce 4 rotations
    r3 = r1 | r1

    assert_allclose(r3(0, 1), (0, -1), atol=1e-10)
    def evaluate(self, x, y, wavelength, order):
        """Return the dispersed pixel(s) given center x, y, lam and order
        x :  int,float
            Input x location on the direct image
        y :  int,float
            Input y location on the direct image
        wavelength : float
            Wavelength to disperse
        order : list
            The order to use

        x, y in the grism image for the pixel at x0, y0 that was
        specified as input using the wavelength and order specified

        I kept the potential for rotation from NIRISS, unsure if it's actually
        needed/useful for WFC3. Original note:

        There's spatial dependence for NIRISS so the forward transform
        dependes on x,y as well as the filter wheel rotation. Theta is
        usu. taken to be the different between fwcpos_ref in the specwcs
        reference file and fwcpos from the input image.
        if wavelength < 0:
            raise ValueError("Wavelength should be greater than zero")

            iorder = self._order_mapping[int(order.flatten()[0])]
        except AttributeError:
            iorder = self._order_mapping[order]
        except KeyError:
            raise ValueError("Specified order is not available")

        t = self.lmodels[iorder](wavelength)
        xmodel = self.xmodels[iorder]
        ymodel = self.ymodels[iorder]

        dx = xmodel.evaluate(x, y, t)
        dy = ymodel.evaluate(x, y, t)

        ## rotate by theta
        if self.theta != 0.0:
            rotate = Rotation2D(self.theta)
            dx, dy = rotate(dx, dy)

        return (x + dx, y + dy, x, y, order)
Exemple #10
def test_identity_input():
    Test a case where an Identity (or Mapping) model is the first in a chain
    of composite models and thus is responsible for handling input broadcasting

    Regression test for https://github.com/astropy/astropy/pull/3362

    ident1 = Identity(1)
    shift = Shift(1)
    rotation = Rotation2D(angle=90)
    model = ident1 & shift | rotation
    assert_allclose(model(1, 2), [-3.0, 1.0])
Exemple #11
def test_identity():
    x = np.zeros((2, 3))
    y = np.ones((2, 3))

    ident1 = Identity(1)
    shift = Shift(1)
    rotation = Rotation2D(angle=60)
    model = ident1 & shift | rotation
    assert_allclose(model(1, 2), (-2.098076211353316, 2.3660254037844393))
    res_x, res_y = model(x, y)
    assert_allclose((res_x, res_y),
                    (np.array([[-1.73205081, -1.73205081, -1.73205081],
                               [-1.73205081, -1.73205081, -1.73205081]
                               ]), np.array([[1., 1., 1.], [1., 1., 1.]])))
    assert_allclose(model.inverse(res_x, res_y), (x, y), atol=1.e-10)
Exemple #12
def test_levmar2x2_multivariate():
    inputs = [np.array([10., 10., 20., 20.]), np.array([10., 20., 20., 10.])]
    outputs = [
        np.array([8.06101731, 0.98994949, 8.06101731, 15.13208512]),
        np.array([12.16223664, 19.23330445, 26.30437226, 19.23330445])
    rot = Rotation2D()
    rot.fittable = True
    model = (Shift() & Shift()) | rot
    fitter = linearfit._LevMarLSQFitter2x2()
    finfo = fitter(model, inputs, outputs)
    assert np.allclose(finfo.parameters,
                       np.array([4.3, -7.1, 45.]),
Exemple #13
def test_tabular_in_compound():
    Issue #7411 - evaluate should not change the shape of the output.
    t = Tabular1D(points=([1, 5, 7],), lookup_table=[12, 15, 19],
    rot = Rotation2D(2)
    p = Polynomial1D(1)
    x = np.arange(12).reshape((3, 4))
    # Create a compound model which does ot execute Tabular.__call__,
    # but model.evaluate and is followed by a Rotation2D which
    # checks the exact shapes.
    model = p & t | rot
    x1, y1 = model(x, x)
    assert x1.ndim == 2
    assert y1.ndim == 2
def test_slicing_on_instances_2():
    More slicing tests.

    Regression test for https://github.com/embray/astropy/pull/10

    model_a = Shift(1, name='a')
    model_b = Shift(2, name='b')
    model_c = Rotation2D(3, name='c')
    model_d = Scale(2, name='d')
    model_e = Scale(3, name='e')

    m = (model_a & model_b) | model_c | (model_d & model_e)

    with pytest.raises(ModelDefinitionError):
        # The slice can't actually be taken since the resulting model cannot be
        # evaluated
        assert m[1:].submodel_names == ('b', 'c', 'd', 'e')

    assert m[:].submodel_names == ('a', 'b', 'c', 'd', 'e')
    assert m['a':].submodel_names == ('a', 'b', 'c', 'd', 'e')

    with pytest.raises(ModelDefinitionError):
        assert m['c':'d'].submodel_names == ('c', 'd')

    assert m[1:2].name == 'b'
    assert m[2:7].submodel_names == ('c', 'd', 'e')
    with pytest.raises(IndexError):
    with pytest.raises(IndexError):

    with pytest.raises(ModelDefinitionError):
        assert m[-4:4].submodel_names == ('b', 'c', 'd')

    with pytest.raises(ModelDefinitionError):
        assert m[-4:-2].submodel_names == ('b', 'c')
Exemple #15
    def build_miri_output_wcs(self, refwcs=None):
        Create a simple output wcs covering footprint of the input datamodels
        # TODO: generalize this for more than one input datamodel
        # TODO: generalize this for imaging modes with distorted wcs
        input_model = self.input_models[0]
        if refwcs == None:
            refwcs = input_model.meta.wcs

        x, y = wcstools.grid_from_bounding_box(refwcs.bounding_box,
                                               step=(1, 1),
        ra, dec, lam = refwcs(x.flatten(), y.flatten())
        # TODO: once astropy.modeling._Tabular is fixed, take out the
        # flatten() and reshape() code above and below
        ra = ra.reshape(x.shape)
        dec = dec.reshape(x.shape)
        lam = lam.reshape(x.shape)

        # Find rotation of the slit from y axis from the wcs forward transform
        # TODO: figure out if angle is necessary for MIRI.  See for discussion
        # https://github.com/STScI-JWST/jwst/pull/347
        rotation = [m for m in refwcs.forward_transform if \
            isinstance(m, Rotation2D)]
        if rotation:
            rot_slit = functools.reduce(lambda x, y: x | y, rotation)
            rot_angle = rot_slit.inverse.angle.value
            unrotate = rot_slit.inverse
            refwcs_minus_rot = refwcs.forward_transform | \
                unrotate & Identity(1)
            # Correct for this rotation in the wcs
            ra, dec, lam = refwcs_minus_rot(x.flatten(), y.flatten())
            ra = ra.reshape(x.shape)
            dec = dec.reshape(x.shape)
            lam = lam.reshape(x.shape)

        # Get the slit size at the center of the dispersion
        sky_coords = SkyCoord(ra=ra, dec=dec, unit=u.deg)
        slit_coords = sky_coords[int(sky_coords.shape[0] / 2)]
        slit_angular_size = slit_coords[0].separation(slit_coords[-1])
        log.debug('Slit angular size: {0}'.format(slit_angular_size.arcsec))

        # Compute slit center from bounding_box
        dx0 = refwcs.bounding_box[0][0]
        dx1 = refwcs.bounding_box[0][1]
        dy0 = refwcs.bounding_box[1][0]
        dy1 = refwcs.bounding_box[1][1]
        slit_center_pix = (dx1 - dx0) / 2
        dispersion_center_pix = (dy1 - dy0) / 2
        slit_center = refwcs_minus_rot(dx0 + slit_center_pix,
                                       dy0 + dispersion_center_pix)
        slit_center_sky = SkyCoord(ra=slit_center[0],
        log.debug('slit center: {0}'.format(slit_center))

        # Compute spatial and spectral scales
        spatial_scale = slit_angular_size / slit_coords.shape[0]
        log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec))
        tcenter = int((dx1 - dx0) / 2)
        trace = lam[:, tcenter]
        trace = trace[~np.isnan(trace)]
        spectral_scale = np.abs((trace[-1] - trace[0]) / trace.shape[0])
        log.debug('spectral scale: {0}'.format(spectral_scale))

        # Compute transform for output frame
        log.debug('Slit center %s' % slit_center_pix)
        offset = Shift(-slit_center_pix) & Shift(-slit_center_pix)
        # TODO: double-check the signs on the following rotation angles
        roll_ref = input_model.meta.wcsinfo.roll_ref * u.deg
        rot = Rotation2D(roll_ref)
        tan = Pix2Sky_TAN()
        lon_pole = _compute_lon_pole(slit_center_sky, tan)
        skyrot = RotateNative2Celestial(slit_center_sky.ra,
                                        slit_center_sky.dec, lon_pole)
        min_lam = np.nanmin(lam)
        mapping = Mapping((0, 0, 1))

        transform = Shift(-slit_center_pix) & Identity(1) | \
            Scale(spatial_scale) & Scale(spectral_scale) | \
            Identity(1) & Shift(min_lam) | mapping | \
            (rot | tan | skyrot) & Identity(1)

        transform.inputs = (x, y)
        transform.outputs = ('ra', 'dec', 'lamda')

        # Build the output wcs
        input_frame = refwcs.input_frame
        output_frame = refwcs.output_frame
        wnew = WCS(output_frame=output_frame, forward_transform=transform)

        # Build the bounding_box in the output frame wcs object
        bounding_box_grid = wnew.backward_transform(ra, dec, lam)

        bounding_box = []
        for axis in input_frame.axes_order:
            axis_min = np.nanmin(bounding_box_grid[axis])
            axis_max = np.nanmax(bounding_box_grid[axis])
            bounding_box.append((axis_min, axis_max))
        wnew.bounding_box = tuple(bounding_box)

        # Update class properties
        self.output_spatial_scale = spatial_scale
        self.output_spectral_scale = spectral_scale
        self.output_wcs = wnew
Exemple #16
    def build_nirspec_output_wcs(self, refwcs=None):
        Create a simple output wcs covering footprint of the input datamodels
        # TODO: generalize this for more than one input datamodel
        # TODO: generalize this for imaging modes with distorted wcs
        input_model = self.input_models[0]
        if refwcs == None:
            refwcs = input_model.meta.wcs

        # Generate grid of sky coordinates for area within bounding box
        bb = refwcs.bounding_box
        det = x, y = wcstools.grid_from_bounding_box(bb,
                                                     step=(1, 1),
        sky = ra, dec, lam = refwcs(*det)
        x_center = int((bb[0][1] - bb[0][0]) / 2)
        y_center = int((bb[1][1] - bb[1][0]) / 2)
        log.debug("Center of bounding box: {}  {}".format(x_center, y_center))

        # Compute slit angular size, slit center sky coords
        xpos = []
        sz = 3
        for row in lam:
            if np.isnan(row[x_center]):
                f = interpolate.interp1d(row[x_center - sz + 1:x_center + sz],
                                           x_center - sz + 1:x_center + sz],
                xpos.append(f(lam[y_center, x_center]))
        x_arg = np.array(xpos)[~np.isnan(lam[:, x_center])]
        y_arg = y[~np.isnan(lam[:, x_center]), x_center]
        # slit_coords, spect0 = refwcs(x_arg, y_arg, output='numericals_plus')
        slit_ra, slit_dec, slit_spec_ref = refwcs(x_arg, y_arg)
        slit_coords = SkyCoord(ra=slit_ra, dec=slit_dec, unit=u.deg)
        pix_num = np.flipud(np.arange(len(slit_ra)))
        # pix_num = np.arange(len(slit_ra))
        interpol_ra = interpolate.interp1d(pix_num, slit_ra)
        interpol_dec = interpolate.interp1d(pix_num, slit_dec)
        slit_center_pix = len(slit_spec_ref) / 2. - 1
        log.debug('Slit center pix: {0}'.format(slit_center_pix))
        slit_center_sky = SkyCoord(ra=interpol_ra(slit_center_pix),
        log.debug('Slit center: {0}'.format(slit_center_sky))
        log.debug('Fiducial: {0}'.format(
        angular_slit_size = np.abs(slit_coords[0].separation(slit_coords[-1]))
        log.debug('Slit angular size: {0}'.format(angular_slit_size.arcsec))
        dra, ddec = slit_coords[0].spherical_offsets_to(slit_coords[-1])
        offset_up_slit = (dra.to(u.arcsec), ddec.to(u.arcsec))
        log.debug('Offset up the slit: {0}'.format(offset_up_slit))

        # Compute spatial and spectral scales
        xposn = np.array(xpos)[~np.isnan(xpos)]
        dx = xposn[-1] - xposn[0]
        slit_npix = np.sqrt(dx**2 + np.array(len(xposn) - 1)**2)
        spatial_scale = angular_slit_size / slit_npix
        log.debug('Spatial scale: {0}'.format(spatial_scale.arcsec))
        spectral_scale = lam[y_center, x_center] - lam[y_center, x_center - 1]

        # Compute slit angle relative (clockwise) to y axis
        slit_rot_angle = (np.arcsin(dx / slit_npix) * u.radian).to(u.degree)
        slit_rot_angle = slit_rot_angle.value
        log.debug('Slit rotation angle: {0}'.format(slit_rot_angle))

        # Compute transform for output frame
        roll_ref = input_model.meta.wcsinfo.roll_ref
        min_lam = np.nanmin(lam)
        offset = Shift(-slit_center_pix) & Shift(-slit_center_pix)

        # TODO: double-check the signs on the following rotation angles
        rot = Rotation2D(roll_ref + slit_rot_angle)
        scale = Scale(spatial_scale.value) & Scale(spatial_scale.value)
        tan = Pix2Sky_TAN()
        lon_pole = _compute_lon_pole(slit_center_sky, tan)
        skyrot = RotateNative2Celestial(slit_center_sky.ra.value,

        spatial_trans = offset | rot | scale | tan | skyrot
        spectral_trans = Scale(spectral_scale) | Shift(min_lam)
        mapping = Mapping((1, 1, 0))
        mapping.inverse = Mapping((2, 1))
        transform = mapping | spatial_trans & spectral_trans
        transform.outputs = ('ra', 'dec', 'lamda')

        # Build the output wcs
        input_frame = refwcs.input_frame
        output_frame = refwcs.output_frame
        wnew = WCS(output_frame=output_frame, forward_transform=transform)

        # Build the bounding_box in the output frame wcs object
        bounding_box_grid = wnew.backward_transform(ra, dec, lam)
        bounding_box = []
        for axis in input_frame.axes_order:
            axis_min = np.nanmin(bounding_box_grid[axis])
            axis_max = np.nanmax(bounding_box_grid[axis])
            bounding_box.append((axis_min, axis_max))
        wnew.bounding_box = tuple(bounding_box)

        # Update class properties
        self.output_spatial_scale = spatial_scale
        self.output_spectral_scale = spectral_scale
        self.output_wcs = wnew
Exemple #17
def test_drop_axes_3():
    mapping = Mapping((1, ), n_inputs=2)
    assert mapping.n_inputs == 2
    rotation = Rotation2D(60)
    model = rotation | mapping
    assert_allclose(model(1, 2), 1.86602540378)
Exemple #18
# Licensed under a 3-clause BSD style license - see LICENSE.rst
# -*- coding: utf-8 -*-
import numpy as np
from astropy.modeling.models import Shift, Rotation2D
from asdf.tests import helpers
from ...import jwextension
from ...models import (AngleFromGratingEquation, WavelengthFromGratingEquation,
                       Unitless2DirCos, DirCos2Unitless, Rotation3DToGWA, Gwa2Slit,
                       Snell, Logical, V23ToSky, Slit)
import pytest

m1 = Shift(1) & Shift(2) | Rotation2D(3.1)
m2 = Shift(2) & Shift(2) | Rotation2D(23.1)

test_models = [DirCos2Unitless(), Unitless2DirCos(),
               Rotation3DToGWA(angles=[12.1, 1.3, 0.5, 3.4], axes_order='xyzx'),
               AngleFromGratingEquation(20000, -1), WavelengthFromGratingEquation(25000, 2),
               Logical('GT', 5, 10), Logical('LT', np.ones((10,))* 5, np.arange(10)),
               V23ToSky(angles=[0.1259, -0.1037, -146.03468, 69.503032, 80.46448], axes_order="zyxyz"),
               Snell(angle=-16.5, kcoef=[0.583, 0.462, 3.891], lcoef=[0.002526, 0.01, 1200.556],
                     tcoef=[-2.66e-05, 0.0, 0.0, 0.0, 0.0, 0.0], tref=35, pref=0,
                     temperature=35, pressure=0),

@pytest.mark.parametrize(('model'), test_models)
def test_model(tmpdir, model):
    tree = {'model': model}
    helpers.assert_roundtrip_tree(tree, tmpdir, extensions=jwextension.JWSTExtension())
    def evaluate(self, x, y, x0, y0, order):
        """Return the valid pixel(s) and wavelengths given x0, y0, x, y, order
        x0: int,float,list
            Source object x-center
        y0: int,float,list
            Source object y-center
        x :  int,float,list
            Input x location
        y :  int,float,list
            Input y location
        order : int
            Spectral order to use

        x0, y0, lambda, order in the direct image for the pixel that was
        specified as input using the wavelength l and spectral order

        I kept the possibility of having a rotation like NIRISS, although I
        don't know if there is a use case for it for WFC3.

        The two `flatten` lines may need to be uncommented if we want to use
        this for array input.
            iorder = self._order_mapping[int(order.flatten()[0])]
        except AttributeError:
            iorder = self._order_mapping[order]
        except KeyError:
            raise ValueError("Specified order is not available")

        # The next two lines are to get around the fact that
        # modeling.standard_broadcasting=False does not work.
        #x00 = x0.flatten()[0]
        #y00 = y0.flatten()[0]

        t = np.linspace(0, 1, 10)  #sample t
        xmodel = self.xmodels[iorder]
        ymodel = self.ymodels[iorder]
        lmodel = self.lmodels[iorder]

        dx = xmodel.evaluate(x0, y0, t)
        dy = ymodel.evaluate(x0, y0, t)

        if self.theta != 0.0:
            rotate = Rotation2D(self.theta)
            dx, dy = rotate(dx, dy)

        so = np.argsort(dx)
        tab = Tabular1D(dx[so], t[so], bounds_error=False, fill_value=None)

        dxr = astmath.SubtractUfunc()
        wavelength = dxr | tab | lmodel
        model = Mapping(
            (2, 3, 0, 2,
             4)) | Const1D(x0) & Const1D(y0) & wavelength & Const1D(order)
        return model(x, y, x0, y0, order)