Esempio n. 1
0
def test_compound_model_input_units_equivalencies_defaults(model):
    m = model['class'](**model['parameters'])

    assert m.input_units_equivalencies is None

    compound_model = m + m
    assert compound_model.inputs_map()['x'][0].input_units_equivalencies is None
    fixed_input_model = fix_inputs(compound_model, {'x': 1})

    assert fixed_input_model.input_units_equivalencies is None

    compound_model = m - m
    assert compound_model.inputs_map()['x'][0].input_units_equivalencies is None
    fixed_input_model = fix_inputs(compound_model, {'x': 1})

    assert fixed_input_model.input_units_equivalencies is None

    compound_model = m & m
    assert compound_model.inputs_map()['x1'][0].input_units_equivalencies is None
    fixed_input_model = fix_inputs(compound_model, {'x0': 1})
    assert fixed_input_model.inputs_map()['x1'][0].input_units_equivalencies is None

    assert fixed_input_model.input_units_equivalencies is None

    if m.n_outputs == m.n_inputs:
        compound_model = m | m
        assert compound_model.inputs_map()['x'][0].input_units_equivalencies is None
        fixed_input_model = fix_inputs(compound_model, {'x': 1})

        assert fixed_input_model.input_units_equivalencies is None
Esempio n. 2
0
def _models_with_units():
    m1 = _ExampleModel() & _ExampleModel()
    m2 = _ExampleModel() + _ExampleModel()
    p = Polynomial1D(1)
    p._input_units = {'x': u.m / u.s}
    p._return_units = {'y': u.m / u.s}
    m3 = _ExampleModel() | p
    m4 = fix_inputs(m1, {'x0': 1})
    m5 = fix_inputs(m1, {0: 1})

    models = [m1, m2, m3, m4, m5]
    input_units = [{'x0': u.Unit("m"), 'x1': u.Unit("m")},
                   {'x': u.Unit("m")},
                   {'x': u.Unit("m")},
                   {'x1': u.Unit("m")},
                   {'x1': u.Unit("m")}
                   ]

    return_units = [{'y0': u.Unit("m / s"), 'y1': u.Unit("m / s")},
                    {'y': u.Unit("m / s")},
                    {'y': u.Unit("m / s")},
                    {'y0': u.Unit("m / s"), 'y1': u.Unit("m / s")},
                    {'y0': u.Unit("m / s"), 'y1': u.Unit("m / s")}
                    ]
    return np.array([models, input_units, return_units], dtype=object).T
Esempio n. 3
0
def test_fix_inputs(tmpdir):
    model = astmodels.Pix2Sky_TAN() | astmodels.Rotation2D()
    tree = {
        'compound': fix_inputs(model, {'x': 45}),
        'compound1': fix_inputs(model, {0: 45})
    }

    helpers.assert_roundtrip_tree(tree, tmpdir)
Esempio n. 4
0
def test_fix_inputs(tmpdir):

    with warnings.catch_warnings():
        # Some schema files are missing from asdf<=2.4.2 which causes warnings
        if LooseVersion(asdf.__version__) <= '2.4.2':
            warnings.filterwarnings('ignore', 'Unable to locate schema file')

        model = astmodels.Pix2Sky_TAN() | astmodels.Rotation2D()
        tree = {
            'compound': fix_inputs(model, {'x': 45}),
            'compound1': fix_inputs(model, {0: 45})
        }

        helpers.assert_roundtrip_tree(tree, tmpdir)
Esempio n. 5
0
def test_fix_inputs_compound_bounding_box():
    base_model = models.Gaussian2D(1, 2, 3, 4, 5)
    bbox = {2.5: (-1, 1), 3.14: (-7, 3)}

    model = fix_inputs(base_model, {'y': 2.5}, bounding_boxes=bbox)
    assert model.bounding_box == (-1, 1)
    model = fix_inputs(base_model, {'x': 2.5}, bounding_boxes=bbox)
    assert model.bounding_box == (-1, 1)

    model = fix_inputs(base_model, {'y': 2.5}, bounding_boxes=bbox, selector_args=(('y', True),))
    assert model.bounding_box == (-1, 1)
    model = fix_inputs(base_model, {'x': 2.5}, bounding_boxes=bbox, selector_args=(('x', True),))
    assert model.bounding_box == (-1, 1)
    model = fix_inputs(base_model, {'x': 2.5}, bounding_boxes=bbox, selector_args=((0, True),))
    assert model.bounding_box == (-1, 1)

    base_model = models.Identity(4)
    bbox = {(2.5, 1.3): ((-1, 1), (-3, 3)), (2.5, 2.71): ((-3, 3), (-1, 1))}

    model = fix_inputs(base_model, {'x0': 2.5, 'x1': 1.3}, bounding_boxes=bbox)
    assert model.bounding_box == ((-1, 1), (-3, 3))

    model = fix_inputs(base_model, {'x0': 2.5, 'x1': 1.3}, bounding_boxes=bbox,
                       selector_args=(('x0', True), ('x1', True)))
    assert model.bounding_box == ((-1, 1), (-3, 3))
    model = fix_inputs(base_model, {'x0': 2.5, 'x1': 1.3}, bounding_boxes=bbox,
                       selector_args=((0, True), (1, True)))
    assert model.bounding_box == ((-1, 1), (-3, 3))
Esempio n. 6
0
    def fix_inputs(self, fixed):
        """
        Return a new unique WCS by fixing inputs to constant values.

        Parameters
        ----------
        fixed : dict
            Keyword arguments with fixed values corresponding to `self.selector`.

        Returns
        -------
        new_wcs : `WCS`
            A new unique WCS corresponding to the values in `fixed`.

        Examples
        --------
        >>> w = WCS(pipeline, selector={"spectral_order": [1, 2]}) # doctest: +SKIP
        >>> new_wcs = w.set_inputs(spectral_order=2) # doctest: +SKIP
        >>> new_wcs.inputs # doctest: +SKIP
            ("x", "y")

        """
        if not HAS_FIX_INPUTS:
            raise ImportError('"fix_inputs" needs astropy version >= 4.0.')

        new_pipeline = []
        step0 = self.pipeline[0]
        new_transform = fix_inputs(step0[1], fixed)
        new_pipeline.append((step0[0], new_transform))
        new_pipeline.extend(self.pipeline[1:])
        return self.__class__(new_pipeline)
Esempio n. 7
0
def test_fix_inputs_type():
    with pytest.raises(TypeError):
        tree = {'compound': fix_inputs(3, {'x': 45})}
        helpers.assert_roundtrip_tree(tree, tmpdir)

    with pytest.raises(AttributeError):
        tree = {'compound': astmodels.Pix2Sky_TAN() & {'x': 45}}
        helpers.assert_roundtrip_tree(tree, tmpdir)
Esempio n. 8
0
def test_fix_inputs(tmpdir):

    with warnings.catch_warnings():
        # Some schema files are missing from asdf<=2.4.2 which causes warnings
        if Version(asdf.__version__) <= Version('2.5.1'):
            warnings.filterwarnings('ignore', 'Unable to locate schema file')

        model0 = astmodels.Pix2Sky_TAN()
        model0.input_units_equivalencies = {'x': u.dimensionless_angles(),
                                            'y': u.dimensionless_angles()}
        model1 = astmodels.Rotation2D()
        model = model0 | model1

        tree = {
            'compound': fix_inputs(model, {'x': 45}),
            'compound1': fix_inputs(model, {0: 45})
        }

        helpers.assert_roundtrip_tree(tree, tmpdir)