Beispiel #1
0
def test_constructor():

    # RA, Dec and L,B of the same point in the sky

    ra, dec = (125.6, -75.3)
    l, b = (288.44190139183564, -20.717313145391525)

    # This should throw an error as we are using Powerlaw instead of Powerlaw()
    with pytest.raises(RuntimeError):

        _ = ExtendedSource("my_source", Gaussian_on_sphere, Powerlaw)

    # This should throw an error because we should use a 2D function for the spatial shape
    with pytest.raises(RuntimeError):

        _ = ExtendedSource("my_source", Powerlaw(), Powerlaw())

    # Init with RA, Dec

    shape = Gaussian_on_sphere()
    source1 = ExtendedSource('my_source', shape, Powerlaw())
    shape.lon0 = ra * u.degree
    shape.lat0 = dec * u.degree

    assert source1.spatial_shape.lon0.value == ra
    assert source1.spatial_shape.lat0.value == dec

    # Verify that the position is free by default
    assert source1.spatial_shape.lon0.free
    assert source1.spatial_shape.lon0.free
Beispiel #2
0
def _get_point_source_composite(name):

    spectral_shape = Powerlaw() + Powerlaw()

    pts = PointSource(name, l=0, b=0, spectral_shape=spectral_shape)

    return pts
Beispiel #3
0
def test_input_output_with_complex_functions():

    my_particle_distribution = Powerlaw()

    my_particle_distribution.index = -1.52

    electrons = ParticleSource('electrons',
                               distribution_shape=my_particle_distribution)

    # Now set up the synch. spectrum for our source and the source itself

    synch_spectrum = _ComplexTestFunction()

    # Use the particle distribution we created as source for the electrons
    # producing synch. emission

    synch_spectrum.particle_distribution = my_particle_distribution

    synch_source = PointSource('synch_source',
                               ra=12.6,
                               dec=-13.5,
                               spectral_shape=synch_spectrum)

    my_model = Model(electrons, synch_source)

    my_model.display()

    my_model.save("__test.yml")

    new_model = load_model("__test.yml")

    assert len(new_model.sources) == len(my_model.sources)

    assert my_particle_distribution.index.value == new_model.electrons.spectrum.main.shape.index.value
def test_call_with_composite_function_with_units():
    def one_test(spectrum):

        print("Testing %s" % spectrum.expression)

        pts = PointSource("test", ra=0, dec=0, spectral_shape=spectrum)

        res = pts([100, 200] * u.keV)

        # This will fail if the units are wrong
        res.to(1 / (u.keV * u.cm**2 * u.s))

    # Test a simple composition

    spectrum = Powerlaw() * Exponential_cutoff()

    one_test(spectrum)

    spectrum = Band() + Blackbody()

    one_test(spectrum)

    # Test a more complicate composition

    spectrum = Powerlaw() * Exponential_cutoff() + Blackbody()

    one_test(spectrum)

    spectrum = Powerlaw() * Exponential_cutoff() * Exponential_cutoff(
    ) + Blackbody()

    one_test(spectrum)

    if has_xspec:

        spectrum = XS_phabs() * Powerlaw()

        one_test(spectrum)

        spectrum = XS_phabs() * XS_powerlaw()

        one_test(spectrum)

        spectrum = XS_phabs() * XS_powerlaw() * XS_phabs()

        one_test(spectrum)

        spectrum = XS_phabs() * XS_powerlaw() * XS_phabs() + Blackbody()

        one_test(spectrum)

        spectrum = XS_phabs() * XS_powerlaw() * XS_phabs() + XS_powerlaw()

        one_test(spectrum)
def test_memoizer():

    po = Powerlaw()

    a = np.random.uniform(-3, -1, 2000)
    b = np.random.uniform(0.1, 10, 2000)

    for aa, bb in zip(a, b):

        po.index = aa
        po.K = bb

        po(1.0)
    def test_one(class_type):

        instance = class_type()

        if not instance.is_prior:

            # if we have fixed x_units then we will use those
            # in the test

            if instance.has_fixed_units():

                x_unit_to_use = instance.fixed_units[0]

            else:

                x_unit_to_use = u.keV

            # Use the function as a spectrum
            ps = PointSource("test", 0, 0, instance)

            if instance.name in ["Synchrotron", "_ComplexTestFunction"]:
                particleSource = ParticleSource("particles", Powerlaw())
                instance.set_particle_distribution(
                    particleSource.spectrum.main.shape)

            result = ps(1.0)

            assert isinstance(result, float)

            result = ps(1.0 * x_unit_to_use)

            assert isinstance(result, u.Quantity)

            result = ps(np.array([1, 2, 3]) * x_unit_to_use)

            assert isinstance(result, u.Quantity)

            if instance.name in ["Synchrotron", "_ComplexTestFunction"]:
                model = Model(particleSource, ps)
            else:
                model = Model(ps)

            new_model = clone_model(model)

            new_result = new_model["test"](np.array([1, 2, 3]) * x_unit_to_use)

            assert np.all(new_result == result)

            model.save("__test.yml", overwrite=True)

            new_model = load_model("__test.yml")

            new_result = new_model["test"](np.array([1, 2, 3]) * x_unit_to_use)

            assert np.all(new_result == result)

        else:

            print('Skipping prior function')
def test_input_output():

    tm = TemplateModel('__test')
    tm.alpha = -0.95
    tm.beta = -2.23

    fake_source = PointSource("test", ra=0.0, dec=0.0, spectral_shape=tm)

    fake_model = Model(fake_source)

    clone = clone_model(fake_model)

    assert clone.get_number_of_point_sources() == 1
    assert tm.data_file == clone.test.spectrum.main.shape.data_file

    assert clone.test.spectrum.main.shape.alpha.value == tm.alpha.value
    assert clone.test.spectrum.main.shape.beta.value == tm.beta.value

    xx = np.linspace(1, 10, 100)

    assert np.allclose(clone.test.spectrum.main.shape(xx), fake_model.test.spectrum.main.shape(xx))

    # Test pickling
    dump = pickle.dumps(clone)

    clone2 = pickle.loads(dump)

    assert clone2.get_number_of_point_sources() == 1
    assert tm.data_file == clone2.test.spectrum.main.shape.data_file
    assert np.allclose(clone2.test.spectrum.main.shape(xx), fake_model.test.spectrum.main.shape(xx))

    # Test pickling with other functions
    new_shape = tm * Powerlaw()

    new_shape.index_2 = -2.256

    dump2 = pickle.dumps(new_shape)

    clone3 = pickle.loads(dump2)

    assert clone3.index_2.value == new_shape.index_2.value

    # Now save to disk and reload
    fake_source2 = PointSource("test", ra=0.0, dec=0.0, spectral_shape=new_shape)

    fake_model2 = Model(fake_source2)

    fake_model2.save("__test.yml", overwrite=True)

    # Now try to reload
    reloaded_model = load_model("__test.yml")

    assert reloaded_model.get_number_of_point_sources() == 1
    assert np.allclose(fake_model2.test.spectrum.main.shape(xx), reloaded_model.test.spectrum.main.shape(xx))

    os.remove("__test.yml")
Beispiel #8
0
def test_call():

    # Multi-component

    po1 = Powerlaw()
    po2 = Powerlaw()

    c1 = SpectralComponent("component1", po1)
    c2 = SpectralComponent("component2", po2)

    point_source = PointSource("test_source", 125.4, -22.3, components=[c1, c2])

    assert np.all(point_source.spectrum.component1([1, 2, 3]) == po1([1, 2, 3]))
    assert np.all(point_source.spectrum.component2([1, 2, 3]) == po2([1, 2, 3]))

    one = point_source.spectrum.component1([1, 2, 3])
    two = point_source.spectrum.component2([1, 2, 3])

    assert np.all( np.abs(one + two - point_source([1,2,3])) == 0 )
Beispiel #9
0
def test_constructor_1source():

    # Test with one point source
    pts = _get_point_source()

    m = Model(pts)

    # Test with a point source with an invalid name
    with pytest.raises(AssertionError):

        _ = PointSource("name", 0, 0, Powerlaw())

    assert len(m.sources) == 1
Beispiel #10
0
def test_pickling_unpickling():

    # 1d function
    po = Powerlaw()

    po.K = 5.35

    new_po = pickle.loads(pickle.dumps(po))

    assert new_po.K.value == po.K.value

    # 2d function
    gs = Gaussian_on_sphere()

    _ = pickle.loads(pickle.dumps(gs))

    # 3d function
    c = Continuous_injection_diffusion()

    _ = pickle.loads(pickle.dumps(c))

    # composite function
    po2 = Powerlaw()
    li = Line()
    composite = po2 * li + po2 - li + 2 * po2 / li  # type: Function1D

    # Change some parameter
    composite.K_1 = 3.2
    composite.a_2 = 1.56

    dump = pickle.dumps(composite)

    new_composite = pickle.loads(dump)

    assert new_composite.K_1.value == composite.K_1.value
    assert new_composite.a_2.value == composite.a_2.value
Beispiel #11
0
def test_add_and_remove_independent_variable():

    mg = ModelGetter()
    m = mg.model

    # Create an independent variable
    independent_variable = IndependentVariable("time", 1.0, u.s)

    # Try to add it
    m.add_independent_variable(independent_variable)

    # Try to remove it
    m.remove_independent_variable("time")

    with pytest.raises(AssertionError):

        m.add_independent_variable(Parameter("time", 1.0))

    # Try to add it twice, which shouldn't fail
    m.add_independent_variable(independent_variable)
    m.add_independent_variable(independent_variable)

    # Try to display it just to make sure it works

    m.display()

    # Now try to use it
    link_law = Powerlaw()

    link_law.K.value = 1.0
    link_law.index.value = -1.0

    n_free_before_link = len(m.free_parameters)

    m.link(m.one.spectrum.main.Powerlaw.K, independent_variable, link_law)

    # The power law adds two parameters, but the link removes one, so
    assert len(m.free_parameters) - 1 == n_free_before_link

    # Now see if it works

    for t in np.linspace(0, 10, 100):

        independent_variable.value = t

        assert m.one.spectrum.main.Powerlaw.K.value == link_law(t)
Beispiel #12
0
def test_constructor_1source():

    # Test with one point source
    pts = _get_point_source()

    m = Model(pts)

    # Test with a point source with an invalid name
    pts = PointSource("name", 0, 0, Powerlaw())

    with pytest.raises(InvalidInput):

        _ = Model(pts)

    # Test with two identical point sources, which should raise, as sources must have unique names
    many_sources = [pts] * 2

    with pytest.raises(InvalidInput):

        _ = Model(*many_sources)
Beispiel #13
0
def test_display():

    mg = ModelGetter()
    m = mg.model

    m.display()

    m.display(complete=True)

    # Now display a model without free parameters
    m = Model(PointSource("test", 0.0, 0.0, Powerlaw()))

    for parameter in m.parameters.values():

        parameter.fix = True

    m.display()

    # Now display a model without fixed parameters (very unlikely)
    for parameter in m.parameters.values():

        parameter.free = True

    m.display()
Beispiel #14
0
def test_duplicate():

    instance = Powerlaw()
    instance.index = -2.25
    instance.K = 0.5

    # Duplicate it

    duplicate = instance.duplicate()

    # Check that we have the same results

    assert duplicate(2.25) == instance(2.25)

    # Check that the parameters are not linked anymore
    instance.index = -1.12

    assert instance.index.value != duplicate.index.value

    print(instance)
    print(duplicate)
Beispiel #15
0
def _get_point_source(name="test"):

    pts = PointSource(name, ra=0, dec=0, spectral_shape=Powerlaw())

    return pts
Beispiel #16
0
def _get_particle_source(name="test_part"):

    part = ParticleSource(name, Powerlaw())

    return part
Beispiel #17
0
def _get_extended_source(name="test_ext"):

    ext = ExtendedSource(name, Gaussian_on_sphere(), Powerlaw())

    return ext
Beispiel #18
0
def test_call():

    # Multi-component

    po1 = Powerlaw()
    po2 = Powerlaw()

    c1 = SpectralComponent("component1", po1)
    c2 = SpectralComponent("component2", po2)

    ra, dec = (125.6, -75.3)

    def test_one(class_type, name):

        print("testing %s ..." % name)

        shape = class_type()
        source = ExtendedSource('test_source_%s' % name,
                                shape,
                                components=[c1, c2])

        if name != "SpatialTemplate_2D":
            shape.lon0 = ra * u.degree
            shape.lat0 = dec * u.degree

        else:
            make_test_template(ra, dec, "__test.fits")
            shape.load_file("__test.fits")
            shape.K = 1.0

        assert np.all(source.spectrum.component1([1, 2, 3]) == po1([1, 2, 3]))
        assert np.all(source.spectrum.component2([1, 2, 3]) == po2([1, 2, 3]))

        one = source.spectrum.component1([1, 2, 3])
        two = source.spectrum.component2([1, 2, 3])

        #check spectral components
        assert np.all(
            np.abs(one + two -
                   source.get_spatially_integrated_flux([1, 2, 3])) == 0)

        #check spectral and spatial components
        total = source([ra, ra, ra], [dec, dec, dec], [1, 2, 3])
        spectrum = one + two
        spatial = source.spatial_shape([ra, ra, ra], [dec, dec, dec])
        assert np.all(np.abs(total - spectrum * spatial) == 0)

        total = source([ra * 1.01] * 3, [dec * 1.01] * 3, [1, 2, 3])
        spectrum = one + two
        spatial = source.spatial_shape([ra * 1.01] * 3, [dec * 1.01] * 3)
        assert np.all(np.abs(total - spectrum * spatial) == 0)

    for key in _known_functions:

        if key in ["Latitude_galactic_diffuse"]:
            #not testing latitude galactic diffuse for now.
            continue

        this_function = _known_functions[key]

        if this_function._n_dim == 2 and not this_function().is_prior:

            test_one(this_function, key)

    with pytest.raises(AssertionError):
        #this will fail because the Latitude_galactic_diffuse function isn't normalized.
        test_one(_known_functions["Latitude_galactic_diffuse"],
                 "Latitude_galactic_diffuse")
Beispiel #19
0
def test_links():

    mg = ModelGetter()

    m = mg.model

    n_free_before_link = len(m.free_parameters)

    # Link as equal (default)
    m.link(m.one.spectrum.main.Powerlaw.K, m.two.spectrum.main.Powerlaw.K)

    assert len(m.free_parameters) == n_free_before_link - 1

    # Try to display it just to make sure it works

    m.display()

    # Now test the link

    # This should print a warning, as trying to change the value of a linked parameters does not have any effect
    with pytest.warns(RuntimeWarning):

        m.one.spectrum.main.Powerlaw.K = 1.23456

    # This instead should work
    new_value = 1.23456
    m.two.spectrum.main.Powerlaw.K.value = new_value

    assert m.one.spectrum.main.Powerlaw.K.value == new_value

    # Now try to remove the link

    # First we remove it from the wrong parameters, which should issue a warning
    with pytest.warns(RuntimeWarning):

        m.unlink(m.two.spectrum.main.Powerlaw.K)

    # Remove it from the right parameter

    m.unlink(m.one.spectrum.main.Powerlaw.K)

    assert len(m.free_parameters) == n_free_before_link

    # Redo the same, but with a powerlaw law
    link_law = Powerlaw()

    link_law.K.value = 1.0
    link_law.index.value = -1.0

    n_free_before_link = len(m.free_parameters)

    m.link(m.one.spectrum.main.Powerlaw.K, m.two.spectrum.main.Powerlaw.K,
           link_law)

    # The power law adds two parameters, but the link removes one, so
    assert len(m.free_parameters) - 1 == n_free_before_link

    # Check that the link works
    new_value = 1.23456
    m.two.spectrum.main.Powerlaw.K.value = new_value

    predicted_value = link_law(new_value)

    assert m.one.spectrum.main.Powerlaw.K.value == predicted_value

    # Remove the link
    m.unlink(m.one.spectrum.main.Powerlaw.K)
Beispiel #20
0
def _get_point_source_gal(name):

    pts = PointSource(name, l=0, b=0, spectral_shape=Powerlaw())

    return pts
Beispiel #21
0
def test_call_with_composite_function_with_units():

    def one_test(spectrum):

        print(("Testing %s" % spectrum.expression))

        # # if we have fixed x_units then we will use those
        # # in the test
        #
        # if spectrum.expression.has_fixed_units():
        #
        #     x_unit_to_use, y_unit_to_use = spectrum.expression.fixed_units[0]
        #
        # else:

        x_unit_to_use = u.keV

        pts = PointSource("test", ra=0, dec=0, spectral_shape=spectrum)

        res = pts([100, 200] * x_unit_to_use)

        # This will fail if the units are wrong
        res.to(1 / (u.keV * u.cm**2 * u.s))

    # Test a simple composition

    spectrum = Powerlaw() * Exponential_cutoff()

    one_test(spectrum)

    spectrum = Band() + Blackbody()

    one_test(spectrum)

    # Test a more complicate composition

    spectrum = Powerlaw() * Exponential_cutoff() + Blackbody()

    one_test(spectrum)

    spectrum = Powerlaw() * Exponential_cutoff() * Exponential_cutoff() + Blackbody()

    one_test(spectrum)

    if has_xspec:

        spectrum = XS_phabs() * Powerlaw()

        one_test(spectrum)

        spectrum = XS_phabs() * XS_powerlaw()

        one_test(spectrum)

        spectrum = XS_phabs() * XS_powerlaw() * XS_phabs()

        one_test(spectrum)

        spectrum = XS_phabs() * XS_powerlaw() * XS_phabs() + Blackbody()

        one_test(spectrum)

        spectrum = XS_phabs() * XS_powerlaw() * XS_phabs() + XS_powerlaw()

        one_test(spectrum)
Beispiel #22
0
def test_time_domain_integration():

    po = Powerlaw()

    default_powerlaw = Powerlaw()

    src = PointSource("test", ra=0.0, dec=0.0, spectral_shape=po)

    m = Model(src)  # type: model.Model

    # Add time independent variable
    time = IndependentVariable("time", 0.0, u.s)

    m.add_independent_variable(time)

    # Now link one of the parameters with a simple line law
    line = Line()

    line.a = 0.0

    m.link(po.index, time, line)

    # Test the display just to make sure it doesn't crash
    m.display()

    # Now test the average with the integral

    energies = np.linspace(1, 10, 10)

    results = m.get_point_source_fluxes(0, energies,
                                        tag=(time, 0, 10))  # type: np.ndarray

    assert np.all(results == 1.0)

    # Now test the linking of the normalization, first with a constant then with a line with a certain
    # angular coefficient

    m.unlink(po.index)

    po.index.value = default_powerlaw.index.value

    line2 = Line()

    line2.a = 0.0
    line2.b = 1.0

    m.link(po.K, time, line2)

    time.value = 1.0

    results = m.get_point_source_fluxes(0, energies, tag=(time, 0, 10))

    assert np.allclose(results, default_powerlaw(energies))

    # Now make something actually happen
    line2.a = 1.0
    line2.b = 1.0

    results = m.get_point_source_fluxes(0, energies,
                                        tag=(time, 0, 10))  # type: np.ndarray

    # Compare with analytical result
    def F(x):
        return line2.a.value / 2.0 * x**2 + line2.b.value * x

    effective_norm = (F(10) - F(0)) / 10.0

    expected_results = default_powerlaw(
        energies) * effective_norm  # type: np.ndarray

    assert np.allclose(expected_results, results)
def test_constructor():

    # RA, Dec and L,B of the same point in the sky

    ra, dec = (125.6, -75.3)
    l, b = (288.44190139183564, -20.717313145391525)

    # This should throw as we are using Powerlaw instead of Powerlaw()
    with pytest.raises(TypeError):

        _ = PointSource("my_source", ra, dec, Powerlaw)

    # Init with RA, Dec

    point_source1 = PointSource('my_source', ra, dec, Powerlaw())

    assert point_source1.position.get_ra() == ra
    assert point_source1.position.get_dec() == dec

    assert abs(point_source1.position.get_l() - l) < 1e-7
    assert abs(point_source1.position.get_b() - b) < 1e-7

    assert point_source1.position.ra.value == ra
    assert point_source1.position.dec.value == dec

    # Init with l,b

    point_source2 = PointSource('my_source',
                                l=l,
                                b=b,
                                spectral_shape=Powerlaw())

    assert point_source2.position.get_l() == l
    assert point_source2.position.get_b() == b

    assert abs(point_source2.position.get_ra() - ra) < 1e-7
    assert abs(point_source2.position.get_dec() - dec) < 1e-7

    assert point_source2.position.l.value == l
    assert point_source2.position.b.value == b

    # Multi-component

    po1 = Powerlaw()
    po2 = Powerlaw()

    c1 = SpectralComponent("component1", po1)
    c2 = SpectralComponent("component2", po2)

    point_source3 = PointSource("test_source", ra, dec, components=[c1, c2])

    assert np.all(
        point_source3.spectrum.component1([1, 2, 3]) == po1([1, 2, 3]))
    assert np.all(
        point_source3.spectrum.component2([1, 2, 3]) == po2([1, 2, 3]))

    with pytest.raises(AssertionError):

        # Illegal RA

        _ = PointSource("test", 720.0, -15.0, components=[c1, c2])

    with pytest.raises(AssertionError):
        # Illegal Dec

        _ = PointSource("test", 120.0, 180.0, components=[c1, c2])

    with pytest.raises(AssertionError):
        # Illegal l

        _ = PointSource("test", l=-195, b=-15.0, components=[c1, c2])

    with pytest.raises(AssertionError):
        # Illegal b

        _ = PointSource("test", l=120.0, b=-180.0, components=[c1, c2])
Beispiel #24
0
def test_function_composition():

    Test_function = get_a_function_class()

    line = Test_function()
    powerlaw = Powerlaw()

    composite = powerlaw + line

    composite.set_units(u.keV, 1.0 / (u.keV * u.cm**2 * u.s))

    for x in ([1, 2, 3,
               4], [1, 2, 3, 4] * u.keV, 1.0, np.array([1.0, 2.0, 3.0, 4.0])):

        assert np.all(composite(x) == line(x) + powerlaw(x))

    # Test -
    po = Powerlaw()
    li = Line()
    composite = po - li

    assert composite(1.0) == (po(1.0) - li(1.0))

    # test *
    composite = po * li

    assert composite(2.25) == po(2.25) * li(2.25)

    # test /
    composite = po / li

    assert composite(2.25) == po(2.25) / li(2.25)

    # test .of
    composite = po.of(li)

    assert composite(2.25) == po(li(2.25))

    # test power
    composite = po**li

    assert composite(2.25) == po(2.25)**li(2.25)

    # test negation
    neg_po = -po

    assert neg_po(2.25) == -po(2.25)

    # test abs
    new_li = Line()
    new_li.b = -10.0

    abs_new_li = abs(new_li)

    assert new_li(1.0) < 0
    assert abs_new_li(1.0) == abs(new_li(1.0))

    # test rpower
    composite = 2.0**new_li

    assert composite(2.25) == 2.0**(new_li(2.25))

    # test multiplication by a number
    composite = 2.0 * po

    assert composite(2.25) == 2.0 * po(2.25)

    # Number divided by
    composite = 1.0 / li

    assert composite(2.25) == 1.0 / li(2.25)

    # Composite of composite
    composite = po * li + po - li + 2 * po / li

    assert composite(2.25) == po(2.25) * li(2.25) + po(2.25) - li(
        2.25) + 2 * po(2.25) / li(2.25)

    print(composite)
def test_call_with_units():

    po = Powerlaw()

    result = po(1.0)

    assert result.ndim == 0

    with pytest.raises(AssertionError):

        # This raises because the units of the function have not been set up

        _ = po(1.0 * u.keV)

    # Use the function as a spectrum
    ps = PointSource("test", 0, 0, po)

    result = po(1.0 * u.keV)

    assert isinstance(result, u.Quantity)

    result = po(np.array([1, 2, 3]) * u.keV)

    assert isinstance(result, u.Quantity)

    # Now test all the functions
    def test_one(class_type):

        instance = class_type()

        # Use the function as a spectrum
        ps = PointSource("test", 0, 0, instance)

        result = ps(1.0 * u.keV)

        assert isinstance(result, u.Quantity)

        result = ps(np.array([1, 2, 3]) * u.keV)

        assert isinstance(result, u.Quantity)

        result = ps(1.0)

        assert isinstance(result, float)

    for key in _known_functions:

        this_function = _known_functions[key]

        # Test only the power law of XSpec, which is the only one we know we can test at 1 keV

        if key.find("XS") == 0 and key != "XS_powerlaw":

            # An XSpec model. Test it only if it's a power law (the others might need other parameters during
            # initialization)

            continue

        if key.find("TemplateModel") == 0:

            # The TemplateModel function has its own test

            continue

        if this_function._n_dim == 1:

            print("testing %s ..." % key)

            test_one(_known_functions[key])
Beispiel #26
0
def test_auto_unlink():

    mg = ModelGetter()

    m = mg.model

    n_free_before_link = len(m.free_parameters)

    # Link as equal (default)
    m.link(m.one.spectrum.main.Powerlaw.K, m.two.spectrum.main.Powerlaw.K)
    m.link(m.one.spectrum.main.Powerlaw.index, m.two.spectrum.main.Powerlaw.index)

    n_free_source2 = len(m.two.free_parameters)

    #This should give a warning about automatically unlinking 2 parameters
    with pytest.warns(RuntimeWarning):
        
        m.remove_source(m.two.name)
    
    assert len(m.free_parameters) == n_free_before_link - n_free_source2
    

    # Redo the same, but with a powerlaw law

    mg = ModelGetter()
    
    m = mg.model
    
    link_law = Powerlaw()

    link_law.K.value = 1.0
    link_law.index.value = -1.0

    m.link(m.one.spectrum.main.Powerlaw.K, m.two.spectrum.main.Powerlaw.K, link_law)

    with pytest.warns(RuntimeWarning):
        
        m.remove_source(m.two.name)
    
    assert len(m.free_parameters) == n_free_before_link - n_free_source2


    # Redo the same, but with two linked parameters

    mg = ModelGetter()
    
    m = mg.model
    
    m.link([m.one.spectrum.main.Powerlaw.K,m.ext_one.spectrum.main.Powerlaw.K],m.two.spectrum.main.Powerlaw.K)

    with pytest.warns(RuntimeWarning):
        
        m.remove_source(m.two.name)
    
    assert len(m.free_parameters) == n_free_before_link - n_free_source2

    
    
    


    m.unlink([m.one.spectrum.main.Powerlaw.K, m.ext_one.spectrum.main.Powerlaw.K])
Beispiel #27
0
def test_call_with_units():

    po = Powerlaw()

    result = po(1.0)

    assert result.ndim == 0

    with pytest.raises(AssertionError):

        # This raises because the units of the function have not been set up

        _ = po(1.0 * u.keV)

    # Use the function as a spectrum
    ps = PointSource("test", 0, 0, po)

    result = po(1.0 * u.keV)

    assert isinstance(result, u.Quantity)

    result = po(np.array([1, 2, 3]) * u.keV)

    assert isinstance(result, u.Quantity)

    # Now test all the functions
    def test_one(class_type):

        instance = class_type()

        if not instance.is_prior:

            # if we have fixed x_units then we will use those
            # in the test

            if instance.has_fixed_units():

                x_unit_to_use = instance.fixed_units[0]

            else:

                x_unit_to_use = u.keV

            # Use the function as a spectrum
            ps = PointSource("test", 0, 0, instance)

            if instance.name in ["Synchrotron", "_ComplexTestFunction"]:
                particleSource = ParticleSource("particles", Powerlaw())
                instance.set_particle_distribution(
                    particleSource.spectrum.main.shape)

            # elif instance.name in ["PhAbs", "TbAbs"]:

            #     instance

            result = ps(1.0)

            assert isinstance(result, float)

            result = ps(1.0 * x_unit_to_use)

            assert isinstance(result, u.Quantity)

            result = ps(np.array([1, 2, 3]) * x_unit_to_use)

            assert isinstance(result, u.Quantity)

            if instance.name in ["Synchrotron", "_ComplexTestFunction"]:
                model = Model(particleSource, ps)
            else:
                model = Model(ps)

            new_model = clone_model(model)

            new_result = new_model["test"](np.array([1, 2, 3]) * x_unit_to_use)

            assert np.all(new_result == result)

            model.save("__test.yml", overwrite=True)

            new_model = load_model("__test.yml")

            new_result = new_model["test"](np.array([1, 2, 3]) * x_unit_to_use)

            assert np.all(new_result == result)

        else:

            print('Skipping prior function')

    for key in _known_functions:

        this_function = _known_functions[key]

        # Test only the power law of XSpec, which is the only one we know we can test at 1 keV

        if key.find("XS") == 0 and key != "XS_powerlaw" or (
                key in _multiplicative_models):

            # An XSpec model. Test it only if it's a power law (the others might need other parameters during
            # initialization)

            continue

        if key.find("TemplateModel") == 0:

            # The TemplateModel function has its own test

            continue


#        if key.find("Synchrotron")==0:

# Naima Synchtron function should have its own test

#            continue

        if this_function._n_dim == 1:

            print("testing %s ..." % key)

            test_one(_known_functions[key])