예제 #1
0
def test_arf_get_x_unit():
    session = Session()
    arf_x_lo, arf_x_hi = np.array([12.0, 12.1, 12.2]), np.array([12.1, 12.2, 12.3])
    arf = session.create_arf(arf_x_lo, arf_x_hi)
    expected_arf_x = (arf_x_hi + arf_x_lo)/2
    actual_arf_x = arf.get_x()
    np.testing.assert_array_almost_equal(expected_arf_x, actual_arf_x)
예제 #2
0
def test_source_component_arbitrary_grid_int():
    from sherpa.astro.ui.utils import Session
    from sherpa.models import Const1D
    from sherpa.data import Data1DInt

    ui = Session()

    x = numpy.array([1, 2, 3]), numpy.array([2, 3, 4])
    y = [1.5, 2.5, 3.5]
    re_x = numpy.array([10, 20, 30]), numpy.array([20, 30, 40])

    ui.load_arrays(1, x[0], x[1], y, Data1DInt)
    model = Const1D('c')
    model.c0 = 10

    regrid_model = model.regrid(*re_x)

    with pytest.warns(UserWarning):
        ui.plot_source_component(regrid_model)

    x_points = (x[0] + x[1])/2
    re_x_points = (re_x[0] + re_x[1])/2
    points = numpy.concatenate((x_points, re_x_points))

    numpy.testing.assert_array_equal(ui._compsrcplot.x, points)
    numpy.testing.assert_array_equal(ui._compsrcplot.y, [10, 10, 10, 100, 100, 100])
예제 #3
0
def test_rmf_get_x_unit():
    session = Session()
    rmf_x_lo, rmf_x_hi = np.array([21.0, 21.1, 21.2]), np.array([21.1, 21.2, 21.3])
    rmf = session.create_rmf(rmf_x_lo, rmf_x_hi)
    expected_rmf_x = (rmf_x_hi + rmf_x_lo)/2
    actual_rmf_x = rmf.get_x()
    np.testing.assert_array_almost_equal(expected_rmf_x, actual_rmf_x)
예제 #4
0
def test_rmf_get_x_unit():
    session = Session()
    rmf_x_lo, rmf_x_hi = np.array([21.0, 21.1,
                                   21.2]), np.array([21.1, 21.2, 21.3])
    rmf = session.create_rmf(rmf_x_lo, rmf_x_hi)
    expected_rmf_x = (rmf_x_hi + rmf_x_lo) / 2
    actual_rmf_x = rmf.get_x()
    np.testing.assert_array_almost_equal(expected_rmf_x, actual_rmf_x)
예제 #5
0
def test_arf_get_x_unit():
    session = Session()
    arf_x_lo, arf_x_hi = np.array([12.0, 12.1,
                                   12.2]), np.array([12.1, 12.2, 12.3])
    arf = session.create_arf(arf_x_lo, arf_x_hi)
    expected_arf_x = (arf_x_hi + arf_x_lo) / 2
    actual_arf_x = arf.get_x()
    np.testing.assert_array_almost_equal(expected_arf_x, actual_arf_x)
예제 #6
0
def test_bug_275(make_data_path):
    session = Session()
    session.load_data(make_data_path('3c273.pi'))
    str(session.get_data())
    str(session.get_rmf())
    str(session.get_arf())

    session.load_data(make_data_path('img.fits'))
    str(session.get_data())
예제 #7
0
def test_load_table_model(make_data_path):
    """What does load_table_model do?

    Even though this is not a FITS file, the code appears to
    need the I/O module to work.
    """

    s = Session()
    s.load_table_model('tbl', make_data_path('double.dat'))
    tbl = s.get_model_component('tbl')
    assert tbl.ndim is None
예제 #8
0
def test_create_rmf(make_data_path):
    from sherpa.astro.ui.utils import Session
    ui = Session()
    energ = np.arange(0.05, 1.1, 0.05)
    rmflo = energ[:-1]
    rmfhi = energ[1:]
    fname= make_data_path('test_rmfimg.fits')
    datarmf = ui.create_rmf(rmflo, rmfhi, fname=fname)
    assert len(datarmf._fch) == 1039
    assert len(datarmf._nch) == 1039
    assert len(datarmf.n_grp) == 900
    assert datarmf._rsp.shape[0] == 380384
def test_create_rmf(make_data_path):
    from sherpa.astro.ui.utils import Session
    ui = Session()
    energ = np.arange(0.05, 1.1, 0.05)
    rmflo = energ[:-1]
    rmfhi = energ[1:]
    fname = make_data_path('test_rmfimg.fits')
    datarmf = ui.create_rmf(rmflo, rmfhi, fname=fname)
    assert len(datarmf._fch) == 1039
    assert len(datarmf._nch) == 1039
    assert len(datarmf.n_grp) == 900
    assert datarmf._rsp.shape[0] == 380384
예제 #10
0
def test_list_ids():
    session = Session()
    session.load_arrays(1, TEST, TEST)
    session.load_arrays("1", TEST, TEST2)

    # order of 1 and "1" is not determined
    assert {1, "1"} == set(session.list_data_ids())
    assert_array_equal(TEST2, session.get_data('1').get_dep())
    assert_array_equal(TEST, session.get_data(1).get_dep())
예제 #11
0
def test_bug_275(make_data_path):
    session = Session()
    session.load_data(make_data_path('3c273.pi'))
    str(session.get_data())
    str(session.get_rmf())
    str(session.get_arf())

    session.load_data(make_data_path('img.fits'))
    str(session.get_data())
예제 #12
0
파일: test_data.py 프로젝트: wsf1990/sherpa
def test_load_arrays_no_errors(data_no_errors):
    from sherpa.astro.ui.utils import Session
    session = Session()
    data = data_no_errors
    data_class = data.__class__
    data_args = DATA_NO_ERRORS_ARGS
    args = data_args + (data_class, )
    session.load_arrays(*args)
    new_data = session.get_data(data.name)
    assert new_data is not data
    # DATA-NOTE: Do we need an equality operator for data classes? These tests are very partial
    # Note that when they are created with load_arrays they seem to lose the name, which becomes the ID
    numpy.testing.assert_array_equal(new_data.get_indep(), data.get_indep())
    numpy.testing.assert_array_equal(new_data.get_dep(), data.get_dep())
예제 #13
0
def test_arf_rmf_get_x(make_data_path):
    arf_name = make_data_path('3c120_heg_-1.arf')
    rmf_name = make_data_path('3c120_heg_-1.rmf')

    session = Session()
    arf = session.unpack_arf(arf_name)
    rmf = session.unpack_rmf(rmf_name)

    expected_array_10 = [0.57724115, 0.57730836, 0.57737556, 0.5774428, 0.57751006,
                         0.57757729, 0.57764456, 0.57771185, 0.57777914, 0.57784647]
    actual_arf_10 = arf.get_x()[0:10]
    actual_rmf_10 = rmf.get_x()[0:10]

    np.testing.assert_array_almost_equal(expected_array_10, actual_arf_10)
    np.testing.assert_array_almost_equal(expected_array_10, actual_rmf_10)
예제 #14
0
def test_zero_division_calc_stat():
    ui = AstroSession()
    x = numpy.arange(100)
    y = numpy.zeros(100)
    ui.load_arrays(1, x, y, DataPHA)
    ui.group_counts(1, 100)
    ui.set_full_model(1, Const1D("const"))

    # in principle I wouldn't need to call calc_stat_info(), I could just
    # use _get_stat_info to reproduce the issue, However, _get_stat_info is not a public
    # method, so I want to double check that calc_stat_info does not throw an exception.
    # So, first we try to run calc_stat_info and make sure there are no exceptions.
    # Then, since calc_stat_info only logs something and doesn't return anything, we use
    # a white box approach to get the result from _get_stat_info.
    ui.calc_stat_info()
    assert ui._get_stat_info()[0].rstat is numpy.nan
예제 #15
0
def test_source_component_arbitrary_grid():
    ui = Session()
    model = Const1D('c')
    model.c0 = 10

    def tst(x, y, re_x, yy):
        ui.load_arrays(1, x, y)
        regrid_model = model.regrid(re_x)
        ui.plot_source_component(regrid_model)
        numpy.testing.assert_array_equal(ui._compsrcplot.x, x)
        numpy.testing.assert_array_almost_equal(ui._compsrcplot.y, yy)

    x = [1, 2, 3]
    y = [1, 2, 3]
    re_x = [10, 20, 30]
    tst(x, y, re_x, [0, 0, 0])

    x = numpy.linspace(1, 10, 10)
    y = x
    re_x = numpy.linspace(5, 15, 15)
    tst(x, y, re_x, [0, 0, 0, 0, 10, 10, 10, 10, 10, 10])

    re_x = numpy.linspace(1, 5, 15)
    tst(x, y, re_x, [10, 10, 10, 10, 10, 0, 0, 0, 0, 0])

    re_x = numpy.linspace(3, 5, 15)
    tst(x, y, re_x, [0, 0, 10, 10, 10, 0, 0, 0, 0, 0])
예제 #16
0
def test_zero_division_calc_stat():
    ui = AstroSession()
    x = numpy.arange(100)
    y = numpy.zeros(100)
    ui.load_arrays(1, x, y, DataPHA)
    ui.group_counts(1, 100)
    ui.set_full_model(1, Const1D("const"))

    # in principle I wouldn't need to call calc_stat_info(), I could just
    # use _get_stat_info to reproduce the issue, However, _get_stat_info is not a public
    # method, so I want to double check that calc_stat_info does not throw an exception.
    # So, first we try to run calc_stat_info and make sure there are no exceptions.
    # Then, since calc_stat_info only logs something and doesn't return anything, we use
    # a white box approach to get the result from _get_stat_info.
    ui.calc_stat_info()
    assert ui._get_stat_info()[0].rstat is numpy.nan
예제 #17
0
def test_list_ids():
    session = Session()
    session.load_arrays(1, TEST, TEST)
    session.load_arrays("1", TEST, TEST2)

    # order of 1 and "1" is not determined
    assert {1, "1"} == set(session.list_data_ids())
    assert_array_equal(TEST2, session.get_data('1').get_dep())
    assert_array_equal(TEST, session.get_data(1).get_dep())
예제 #18
0
def test_arf_rmf_get_x(make_data_path):
    arf_name = make_data_path('3c120_heg_-1.arf')
    rmf_name = make_data_path('3c120_heg_-1.rmf')

    session = Session()
    arf = session.unpack_arf(arf_name)
    rmf = session.unpack_rmf(rmf_name)

    expected_array_10 = [
        0.57724115, 0.57730836, 0.57737556, 0.5774428, 0.57751006, 0.57757729,
        0.57764456, 0.57771185, 0.57777914, 0.57784647
    ]
    actual_arf_10 = arf.get_x()[0:10]
    actual_rmf_10 = rmf.get_x()[0:10]

    np.testing.assert_array_almost_equal(expected_array_10, actual_arf_10)
    np.testing.assert_array_almost_equal(expected_array_10, actual_rmf_10)
예제 #19
0
파일: test_data.py 프로젝트: wsf1990/sherpa
def data_for_load_arrays(request):
    data_class = request.param
    from sherpa.astro.ui.utils import Session
    session = Session()
    data_args = INSTANCE_ARGS[data_class]
    args = data_args + (data_class, )
    data = data_class(*data_args)
    return session, args, data
예제 #20
0
def setup():
    const = Const1D("const")
    const.c0 = 0
    const.c0.freeze()

    my_model = MyModel("my_model")
    my_model.integrate = False

    return Session(), my_model, const
예제 #21
0
def test_source_component_arbitrary_grid():
    from sherpa.astro.ui.utils import  Session
    from sherpa.models import Const1D

    ui = Session()

    x = [1, 2, 3]
    y = [1, 2, 3]
    re_x = [10, 20, 30]

    ui.load_arrays(1, x, y)
    model = Const1D('c')
    model.c0 = 10

    regrid_model = model.regrid(re_x)

    with pytest.warns(UserWarning):
        ui.plot_source_component(regrid_model)

    numpy.testing.assert_array_equal(ui._compsrcplot.x, x + re_x)
    numpy.testing.assert_array_equal(ui._compsrcplot.y, [10, ]*6)
예제 #22
0
def test_source_component_arbitrary_grid_int():
    from sherpa.astro.ui.utils import Session
    from sherpa.models import Const1D
    from sherpa.data import Data1DInt

    ui = Session()

    x = numpy.array([1, 2, 3]), numpy.array([2, 3, 4])
    y = [1.5, 2.5, 3.5]
    re_x = numpy.array([10, 20, 30]), numpy.array([20, 30, 40])

    ui.load_arrays(1, x[0], x[1], y, Data1DInt)
    model = Const1D('c')
    model.c0 = 10

    regrid_model = model.regrid(*re_x)

    with pytest.warns(UserWarning):
        ui.plot_source_component(regrid_model)

    x_points = (x[0] + x[1]) / 2
    re_x_points = (re_x[0] + re_x[1]) / 2
    points = numpy.concatenate((x_points, re_x_points))

    numpy.testing.assert_array_equal(ui._compsrcplot.x, points)
    numpy.testing.assert_array_equal(ui._compsrcplot.y,
                                     [10, 10, 10, 100, 100, 100])
예제 #23
0
def test_instrument_model(make_data_path):
    """Check the full response model"""

    from sherpa.astro import xspec

    s = Session()
    s.load_pha(make_data_path('3c273.pi'))

    a1 = xspec.XSpowerlaw()
    m1 = xspec.XSwabs()

    s.set_source(m1 * a1)

    src = s.get_source()
    mdl = s.get_model()
    assert src.ndim == 1
    assert mdl.ndim == 1
예제 #24
0
def setup2d():
    const = Const2D("const")
    const.c0 = 0
    const.c0.freeze()

    x = [2, 3, 2, 3]
    y = [2, 2, 3, 3]
    xhi = [2.1, 3.5, 2.1, 3.5]
    yhi = [2.1, 2.1, 3.5, 3.5]

    # This is the result when rebinning [100, ] * 4
    z = [225, ] * 4

    my_model = MyModel2D("my_model")

    return Session(), my_model, const, (x, y, xhi, yhi, z)
예제 #25
0
파일: test_plot.py 프로젝트: wsf1990/sherpa
def test_plot_model_arbitrary_grid_integrated():
    ui = Session()

    x = [1, 2, 3], [2, 3, 4]
    y = [1, 2, 3]
    re_x = [10, 20, 30], [20, 30, 40]

    ui.load_arrays(1, x[0], x[1], y, Data1DInt)
    model = Const1D('c')
    model.c0 = 10

    regrid_model = model.regrid(*re_x)
    ui.set_model(regrid_model)

    with pytest.warns(UserWarning):
        ui.plot_model()

    numpy.testing.assert_array_equal(ui._modelplot.x, [1.5, 2.5, 3.5])
    numpy.testing.assert_array_equal(ui._modelplot.y, [10, 10, 10])
예제 #26
0
def psf_fixture(request):
    configuration = request.param

    fixture_data = make_images(configuration)

    ui = Session()

    ui.set_data(1, fixture_data.image)

    exact_expected_sigma = configuration.source_sigma
    approx_expected_sigma = approx(exact_expected_sigma, rel=3e-2)

    # Set the source model as a 2D Gaussian, and set the PSF in Sherpa
    source_position = configuration.source_position
    sherpa_source = SigmaGauss2D('source')
    sherpa_source.ampl = configuration.source_amplitude
    sherpa_source.sigma_a = exact_expected_sigma
    sherpa_source.sigma_b = exact_expected_sigma
    sherpa_source.xpos = source_position
    sherpa_source.ypos = source_position
    ui.set_source(sherpa_source)
    ui.set_psf(fixture_data.psf_model)

    return ui, sherpa_source, approx_expected_sigma
예제 #27
0
def test_source_component_arbitrary_grid_int():
    ui = Session()

    x = numpy.array([1, 2, 3]), numpy.array([2, 3, 4])
    y = [1.5, 2.5, 3.5]
    re_x = numpy.array([10, 20, 30]), numpy.array([20, 30, 40])

    ui.load_arrays(1, x[0], x[1], y, Data1DInt)
    model = Const1D('c')
    model.c0 = 10

    regrid_model = model.regrid(*re_x)
    ui.plot_source_component(regrid_model)

    x_points = (x[0] + x[1]) / 2.0
    numpy.testing.assert_array_equal(ui._compsrcplot.x, x_points)
    numpy.testing.assert_array_equal(ui._compsrcplot.y, [0., 0., 0.])
예제 #28
0
def test_plot_model_arbitrary_grid_integrated():
    ui = Session()
    model = Const1D('c')
    model.c0 = 10

    def tst(x, y, re_x, yy):
        ui.load_arrays(1, x[0], x[1], y, Data1DInt)
        regrid_model = model.regrid(*re_x)
        ui.set_model(regrid_model)
        ui.plot_model()
        avg_x = 0.5 * (x[0] + x[1])
        numpy.testing.assert_array_equal(ui._modelplot.x, avg_x)
        numpy.testing.assert_array_almost_equal(ui._modelplot.y, yy)

    tmp = numpy.arange(1, 5, 1)
    x = tmp[:-1], tmp[1:]
    y = x[0]
    tmp = numpy.arange(10, 50, 10)
    re_x = tmp[:-1], tmp[1:]
    tst(x, y, re_x, [0, 0, 0])

    tmp = numpy.arange(1, 20, 1)
    x = tmp[:-1], tmp[1:]
    y = x[0]
    tmp = numpy.arange(1, 20, 0.5)
    re_x = tmp[:-1], tmp[1:]
    tst(x, y, re_x, len(y) * [10])

    tmp = numpy.arange(1, 20, 1)
    x = tmp[:-1], tmp[1:]
    y = x[0]
    tmp = numpy.arange(10, 20, 0.5)
    re_x = tmp[:-1], tmp[1:]
    n = int(len(y) / 2)
    yy = numpy.append(n * [0.], n * [10.])
    tst(x, y, re_x, yy)
예제 #29
0
def test_plot_model_arbitrary_grid_integrated():
    from sherpa.astro.ui.utils import Session
    from sherpa.models import Const1D
    from sherpa.data import Data1DInt

    ui = Session()

    x = [1, 2, 3], [2, 3, 4]
    y = [1, 2, 3]
    re_x = [10, 20, 30], [20, 30, 40]

    ui.load_arrays(1, x[0], x[1], y, Data1DInt)
    model = Const1D('c')
    model.c0 = 10

    regrid_model = model.regrid(*re_x)
    ui.set_model(regrid_model)

    with pytest.warns(UserWarning):
        ui.plot_model()

    numpy.testing.assert_array_equal(ui._modelplot.x, [1.5, 2.5, 3.5])
    numpy.testing.assert_array_equal(ui._modelplot.y, [10, 10, 10])
예제 #30
0
파일: test_plot.py 프로젝트: wsf1990/sherpa
def test_source_component_arbitrary_grid():
    ui = Session()

    x = [1, 2, 3]
    y = [1, 2, 3]
    re_x = [10, 20, 30]

    ui.load_arrays(1, x, y)
    model = Const1D('c')
    model.c0 = 10

    regrid_model = model.regrid(re_x)

    with pytest.warns(UserWarning):
        ui.plot_source_component(regrid_model)

    numpy.testing.assert_array_equal(ui._compsrcplot.x, x + re_x)
    numpy.testing.assert_array_equal(ui._compsrcplot.y, [
        10,
    ] * 6)
예제 #31
0
def test_show_bkg_model_with_bkg(make_data_path):
    session = Session()
    session.load_data('foo', make_data_path('3c273.pi'))
    session.show_bkg_model()
    session.show_bkg_model('foo')
예제 #32
0
def test_show_bkg_model():
    session = Session()
    session.load_arrays(1, [1, 2], [1, 2])
    session.show_bkg_model()
    session.show_bkg_model('xx')
    session.show_bkg_source()
    session.show_bkg_source('xx')
예제 #33
0
def test_set_log():
    session = Session()
    assert not session.get_data_plot_prefs()['xlog']
    assert not session.get_data_plot_prefs()['ylog']
    session.set_xlog()
    assert session.get_data_plot_prefs()['xlog']
    session.set_ylog()
    assert session.get_data_plot_prefs()['ylog']
    session.set_xlinear()
    assert not session.get_data_plot_prefs()['xlog']
    session.set_ylinear()
    assert not session.get_data_plot_prefs()['ylog']
예제 #34
0
def test_save_restore(tmpdir):
    outfile = tmpdir.join("sherpa.save")
    session = Session()
    session.load_arrays(1, TEST, TEST2)
    session.save(str(outfile), clobber=True)
    session.clean()
    assert set() == set(session.list_data_ids())

    session.restore(str(outfile))
    assert {
        1,
    } == set(session.list_data_ids())
    assert_array_equal(TEST, session.get_data(1).get_indep()[0])
    assert_array_equal(TEST2, session.get_data(1).get_dep())
예제 #35
0
def test_show_bkg_model():
    session = Session()
    session.load_arrays(1, [1, 2], [1, 2])
    session.show_bkg_model()
    session.show_bkg_model('xx')
    session.show_bkg_source()
    session.show_bkg_source('xx')
예제 #36
0
def test_show_bkg_model_with_bkg(make_data_path):
    session = Session()
    session.load_data('foo', make_data_path('3c273.pi'))
    session.show_bkg_model()
    session.show_bkg_model('foo')
예제 #37
0
def convert_model(expr, postfix, groups, names):
    """Extract the model components.

    Model names go from m1 to mn (when groups is empty) or
    m1g1 to mng1 and then m1g2 to mng<ngrops>.

    Parameters
    ----------
    expr : str
        The XSPEC model expression.
    postfix : str
        Add to the model + number string.
    groups : list of int
        The groups to create. It must not be empty. We special case a
        single group, as there's no need to add an identifier.
    names : set of str
        The names we have created (will be updated). This is just
        for testing.

    Returns
    -------
    exprs : list of lists
        The model expression for each group (if groups was None then
        for a single group). Each list contains a pair of
        (str, None) or ((str, str), Model), where for the Model
        case the two names are the model type and the instance name.

    Notes
    -----
    We require the Sherpa XSPEC module here.

    """

    # Create our own session object to make it easy to
    # find out XSPEC models and the correct naming scheme.
    #
    session = Session()
    session._add_model_types(xspec,
                             (xspec.XSAdditiveModel,
                              xspec.XSMultiplicativeModel,
                              xspec.XSConvolutionKernel))


    # Let's remove the spaces
    dbg(f"Processing model expression: {expr}")
    expr = expr.translate({32: None})

    if len(groups) == 1:
        groups = [None]
        def mkname(ctr, grp):
            n = f"m{ctr}{postfix}"
            if n in names:
                raise RuntimeError("Unable to handle model names with this input script")

            names.add(n)
            return n

    else:
        def mkname(ctr, grp):
            n = f"m{ctr}{postfix}g{grp}"
            if n in names:
                raise RuntimeError("Unable to handle model names with this input script")

            names.add(n)
            return n

    out = [[] for _ in groups]

    def add_term(start, end, ctr, storage):
        """storage tracks the convolution state."""

        # It's not ideal we need this
        if start == end:
            return ctr

        name = f"xs{expr[start:end]}"
        dbg(f"Identified model expression '{name}'")

        for i, grp in enumerate(groups, 1):
            try:
                mdl = session.create_model_component(name, mkname(ctr, grp))
            except ArgumentErr:
                raise ValueError(f"Unrecognized XSPEC model '{name[2:]}' in {expr}") from None

            out[i - 1].append((mdl.name.split('.'), mdl))

            # This only needs to be checked for the first group.
            #
            if i == 1:
                if isinstance(mdl, xspec.XSConvolutionKernel):
                    if storage['convolution'] is not None:
                        print("Convolution found within a convolution expession. This is not handled correctly.")
                    else:
                        storage['convolution'] = storage['depth']
                        storage['lastterm'] = 'convolution'
                        out[i - 1].append(("(", None))

                elif isinstance(mdl, xspec.XSMultiplicativeModel):
                    storage['lastterm'] = 'multiplicative'

                elif isinstance(mdl, xspec.XSAdditiveModel):
                    storage['lastterm'] = 'additive'

                else:
                    raise RuntimeError(f"Unrecognized XSPEC model: {mdl.__class__}")

        return ctr + 1

    def add_sep(sep):
        for i, grp in enumerate(groups, 1):
            out[i - 1].append((sep, None))

    def check_end_convolution(storage):
        if storage['convolution'] is None:
            return

        if storage['convolution'] > storage['depth']:
            return

        if storage['convolution'] == storage['depth']:
            add_sep(")")
        elif storage['convolution'] > storage['depth']:
            print("WARNING: convolution model may not be handled correctly.")

        storage['convolution'] = None

    def in_convolution(storage):
        return storage['convolution'] is not None \
            and storage['convolution'] == storage['depth']

    maxchar = len(expr) - 1
    start = 0
    end = 0
    mnum = 1
    storage = {'convolution': None, 'depth': 0, 'lastterm': None}
    for end, char in enumerate(expr):

        # Not completely sure of the supported language.
        #
        if char not in "*+-/()":
            continue

        if char == "(":
            if end == 0:
                add_sep('(')
            else:
                mnum = add_term(start, end, mnum, storage)
                # We want to leave ( alone for convolution models.
                # This is a bit messy.
                #
                if expr[end - 1] in "+*-/":
                    add_sep(" (")
                elif storage['lastterm'] == 'convolution':
                    add_sep("(")
                else:
                    add_sep(" * (")

            start = end + 1
            storage['depth'] += 1
            continue

        if char == ")":
            mnum = add_term(start, end, mnum, storage)

            # Check to see whether we can close out the convolution.
            #
            check_end_convolution(storage)

            if end == maxchar:
                add_sep(")")
            elif expr[end + 1] in "+*-/":
                add_sep(") ")
            else:
                add_sep(") * ")

            start = end + 1
            storage['depth'] -= 1
            continue

        mnum = add_term(start, end, mnum, storage)

        # conv*a1+a2 is conv(s1) + a2
        if char == '+':
            check_end_convolution(storage)

        # If we had conv*mdl then we want to drop the *
        # but conv*m1*a1 is conv(m1*a1)
        #
        if not(char == '*' and in_convolution(storage) and out[0][-1][1] is None and out[0][-1][0] == "("):
            # Add space characters to separate out the expression,
            # even if start==end.
            #
            add_sep(f" {char} ")

        start = end + 1

    # Last name, which is not always present.
    #
    if start < end:
        add_term(start, None, mnum, storage)

    # I'm not sure whether this should only be done if start < end.
    # I am concerned we may have a case where this is needed.
    #
    check_end_convolution(storage)

    if storage['convolution'] is not None:
        print("WARNING: convolution model may not be handled correctly.")

    if storage['depth'] != 0:
        print("WARNING: unexpected issues handling brackets in the model")

    return out
예제 #38
0
def test_save_restore(tmpdir):
    outfile = tmpdir.join("sherpa.save")
    session = Session()
    session.load_arrays(1, TEST, TEST2)
    session.save(str(outfile), clobber=True)
    session.clean()
    assert set() == set(session.list_data_ids())

    session.restore(str(outfile))
    assert {1, } == set(session.list_data_ids())
    assert_array_equal(TEST, session.get_data(1).get_indep()[0])
    assert_array_equal(TEST2, session.get_data(1).get_dep())
예제 #39
0
def test_set_log():
    session = Session()
    assert not session.get_data_plot_prefs()['xlog']
    assert not session.get_data_plot_prefs()['ylog']
    session.set_xlog()
    assert session.get_data_plot_prefs()['xlog']
    session.set_ylog()
    assert session.get_data_plot_prefs()['ylog']
    session.set_xlinear()
    assert not session.get_data_plot_prefs()['xlog']
    session.set_ylinear()
    assert not session.get_data_plot_prefs()['ylog']
예제 #40
0
def setup(make_data_path):
    image = make_data_path('sim_0.5bin_2g+c.fits')
    psf_bin1 = make_data_path('psf_0.0_00_bin1.img')
    psf_bin05 = make_data_path('psf_0.0_00_bin0.5.img')

    return Session(), image, psf_bin05, psf_bin1