def test_compound_model_input_units_equivalencies_defaults(model): m = model['class'](**model['parameters']) assert m.input_units_equivalencies is None compound_model = m + m assert compound_model.inputs_map()['x'][0].input_units_equivalencies is None fixed_input_model = fix_inputs(compound_model, {'x': 1}) assert fixed_input_model.input_units_equivalencies is None compound_model = m - m assert compound_model.inputs_map()['x'][0].input_units_equivalencies is None fixed_input_model = fix_inputs(compound_model, {'x': 1}) assert fixed_input_model.input_units_equivalencies is None compound_model = m & m assert compound_model.inputs_map()['x1'][0].input_units_equivalencies is None fixed_input_model = fix_inputs(compound_model, {'x0': 1}) assert fixed_input_model.inputs_map()['x1'][0].input_units_equivalencies is None assert fixed_input_model.input_units_equivalencies is None if m.n_outputs == m.n_inputs: compound_model = m | m assert compound_model.inputs_map()['x'][0].input_units_equivalencies is None fixed_input_model = fix_inputs(compound_model, {'x': 1}) assert fixed_input_model.input_units_equivalencies is None
def _models_with_units(): m1 = _ExampleModel() & _ExampleModel() m2 = _ExampleModel() + _ExampleModel() p = Polynomial1D(1) p._input_units = {'x': u.m / u.s} p._return_units = {'y': u.m / u.s} m3 = _ExampleModel() | p m4 = fix_inputs(m1, {'x0': 1}) m5 = fix_inputs(m1, {0: 1}) models = [m1, m2, m3, m4, m5] input_units = [{'x0': u.Unit("m"), 'x1': u.Unit("m")}, {'x': u.Unit("m")}, {'x': u.Unit("m")}, {'x1': u.Unit("m")}, {'x1': u.Unit("m")} ] return_units = [{'y0': u.Unit("m / s"), 'y1': u.Unit("m / s")}, {'y': u.Unit("m / s")}, {'y': u.Unit("m / s")}, {'y0': u.Unit("m / s"), 'y1': u.Unit("m / s")}, {'y0': u.Unit("m / s"), 'y1': u.Unit("m / s")} ] return np.array([models, input_units, return_units], dtype=object).T
def test_fix_inputs(tmpdir): model = astmodels.Pix2Sky_TAN() | astmodels.Rotation2D() tree = { 'compound': fix_inputs(model, {'x': 45}), 'compound1': fix_inputs(model, {0: 45}) } helpers.assert_roundtrip_tree(tree, tmpdir)
def test_fix_inputs(tmpdir): with warnings.catch_warnings(): # Some schema files are missing from asdf<=2.4.2 which causes warnings if LooseVersion(asdf.__version__) <= '2.4.2': warnings.filterwarnings('ignore', 'Unable to locate schema file') model = astmodels.Pix2Sky_TAN() | astmodels.Rotation2D() tree = { 'compound': fix_inputs(model, {'x': 45}), 'compound1': fix_inputs(model, {0: 45}) } helpers.assert_roundtrip_tree(tree, tmpdir)
def test_fix_inputs_compound_bounding_box(): base_model = models.Gaussian2D(1, 2, 3, 4, 5) bbox = {2.5: (-1, 1), 3.14: (-7, 3)} model = fix_inputs(base_model, {'y': 2.5}, bounding_boxes=bbox) assert model.bounding_box == (-1, 1) model = fix_inputs(base_model, {'x': 2.5}, bounding_boxes=bbox) assert model.bounding_box == (-1, 1) model = fix_inputs(base_model, {'y': 2.5}, bounding_boxes=bbox, selector_args=(('y', True),)) assert model.bounding_box == (-1, 1) model = fix_inputs(base_model, {'x': 2.5}, bounding_boxes=bbox, selector_args=(('x', True),)) assert model.bounding_box == (-1, 1) model = fix_inputs(base_model, {'x': 2.5}, bounding_boxes=bbox, selector_args=((0, True),)) assert model.bounding_box == (-1, 1) base_model = models.Identity(4) bbox = {(2.5, 1.3): ((-1, 1), (-3, 3)), (2.5, 2.71): ((-3, 3), (-1, 1))} model = fix_inputs(base_model, {'x0': 2.5, 'x1': 1.3}, bounding_boxes=bbox) assert model.bounding_box == ((-1, 1), (-3, 3)) model = fix_inputs(base_model, {'x0': 2.5, 'x1': 1.3}, bounding_boxes=bbox, selector_args=(('x0', True), ('x1', True))) assert model.bounding_box == ((-1, 1), (-3, 3)) model = fix_inputs(base_model, {'x0': 2.5, 'x1': 1.3}, bounding_boxes=bbox, selector_args=((0, True), (1, True))) assert model.bounding_box == ((-1, 1), (-3, 3))
def fix_inputs(self, fixed): """ Return a new unique WCS by fixing inputs to constant values. Parameters ---------- fixed : dict Keyword arguments with fixed values corresponding to `self.selector`. Returns ------- new_wcs : `WCS` A new unique WCS corresponding to the values in `fixed`. Examples -------- >>> w = WCS(pipeline, selector={"spectral_order": [1, 2]}) # doctest: +SKIP >>> new_wcs = w.set_inputs(spectral_order=2) # doctest: +SKIP >>> new_wcs.inputs # doctest: +SKIP ("x", "y") """ if not HAS_FIX_INPUTS: raise ImportError('"fix_inputs" needs astropy version >= 4.0.') new_pipeline = [] step0 = self.pipeline[0] new_transform = fix_inputs(step0[1], fixed) new_pipeline.append((step0[0], new_transform)) new_pipeline.extend(self.pipeline[1:]) return self.__class__(new_pipeline)
def test_fix_inputs_type(): with pytest.raises(TypeError): tree = {'compound': fix_inputs(3, {'x': 45})} helpers.assert_roundtrip_tree(tree, tmpdir) with pytest.raises(AttributeError): tree = {'compound': astmodels.Pix2Sky_TAN() & {'x': 45}} helpers.assert_roundtrip_tree(tree, tmpdir)
def test_fix_inputs(tmpdir): with warnings.catch_warnings(): # Some schema files are missing from asdf<=2.4.2 which causes warnings if Version(asdf.__version__) <= Version('2.5.1'): warnings.filterwarnings('ignore', 'Unable to locate schema file') model0 = astmodels.Pix2Sky_TAN() model0.input_units_equivalencies = {'x': u.dimensionless_angles(), 'y': u.dimensionless_angles()} model1 = astmodels.Rotation2D() model = model0 | model1 tree = { 'compound': fix_inputs(model, {'x': 45}), 'compound1': fix_inputs(model, {0: 45}) } helpers.assert_roundtrip_tree(tree, tmpdir)