コード例 #1
0
ファイル: test_session.py プロジェクト: wmclaugh/sherpa
def test_issue_16():
    """Check one of the examples from #16"""

    s = Session()
    s._add_model_types(sherpa.models.basic)

    s.dataspace1d(0.01, 11, 0.01, id=1)
    s.dataspace1d(2, 5, 0.1, id="tst")
    s.set_source(1, 'powlaw1d.pl1')
    s.set_source('tst', 'powlaw1d.pltst')

    assert s.list_data_ids() == [1, 'tst']
    assert s.list_model_ids() == [1, 'tst']
    assert s.list_model_components() == ['pl1', 'pltst']

    s.delete_model(id='tst')
    s.delete_model_component("pltst")
    s.delete_data(id='tst')

    assert s.list_data_ids() == [1]
    assert s.list_model_ids() == [1]
    assert s.list_model_components() == ['pl1']

    s.delete_model(id=1)
    s.delete_model_component("pl1")
    s.delete_data(id=1)

    assert s.list_data_ids() == []
    assert s.list_model_ids() == []
    assert s.list_model_components() == []
コード例 #2
0
ファイル: test_session.py プロジェクト: wmclaugh/sherpa
def test_paramprompt_eof(caplog):
    """What happens when we end early?"""

    s = Session()
    s._add_model_types(sherpa.models.basic)

    cpt1 = s.create_model_component('const1d', 'bob')
    cpt2 = s.create_model_component('gauss1d', 'fred')

    s.paramprompt(True)
    assert len(caplog.records) == 0

    with pytest.raises(EOFError):
        with SherpaVerbosity('INFO'):
            with patch("sys.stdin", StringIO("\n5,1,5\n2\n")):
                s.set_source(bob + fred)

    assert len(caplog.records) == 0

    assert cpt1.c0.val == pytest.approx(1)
    assert cpt1.c0.min < -3e38
    assert cpt1.c0.max > 3e38

    assert cpt2.fwhm.val == pytest.approx(5)
    assert cpt2.fwhm.min == pytest.approx(1)
    assert cpt2.fwhm.max == pytest.approx(5)

    assert cpt2.pos.val == pytest.approx(2)
    assert cpt2.pos.min < -3e38
    assert cpt2.pos.max > 3e38

    assert cpt2.ampl.val == pytest.approx(1)
    assert cpt2.ampl.min < -3e38
    assert cpt2.ampl.max > 3e38
コード例 #3
0
ファイル: test_session.py プロジェクト: wmclaugh/sherpa
def test_paramprompt_single_parameter_check_invalid_max_out_of_bound(caplog):
    """Note this creates two warnings"""

    s = Session()
    s._add_model_types(sherpa.models.basic)

    s.paramprompt(True)
    assert len(caplog.records) == 0

    with SherpaVerbosity('INFO'):
        with patch("sys.stdin", StringIO(",,typo\n,,-200")):
            s.set_source("scale1d.bob")

    assert len(caplog.records) == 2
    lname, lvl, msg = caplog.record_tuples[0]
    assert lname == "sherpa.ui.utils"
    assert lvl == logging.INFO
    assert msg == "Please provide a float value; could not convert string to float: 'typo'"

    lname, lvl, msg = caplog.record_tuples[1]
    assert lname == "sherpa.models.parameter"
    assert lvl == logging.WARN
    assert msg == "parameter bob.c0 greater than new maximum; bob.c0 reset to -200"

    mdl = s.get_model_component('bob')
    assert mdl.c0.val == pytest.approx(-200)
    assert mdl.c0.min < -3e38
    assert mdl.c0.max == pytest.approx(-200)
コード例 #4
0
ファイル: test_session.py プロジェクト: wmclaugh/sherpa
def test_paramprompt_multi_parameter(caplog):
    """Check that paramprompt works with multiple parameters"""

    s = Session()
    s._add_model_types(sherpa.models.basic)

    cpt1 = s.create_model_component('const1d', 'bob')
    cpt2 = s.create_model_component('gauss1d', 'fred')

    s.paramprompt(True)
    assert len(caplog.records) == 0

    with SherpaVerbosity('INFO'):
        with patch("sys.stdin", StringIO("\n5,1,5\n\n-5, -5, 0")):
            s.set_source(bob + fred)

    assert len(caplog.records) == 0

    assert cpt1.c0.val == pytest.approx(1)
    assert cpt1.c0.min < -3e38
    assert cpt1.c0.max > 3e38

    assert cpt2.fwhm.val == pytest.approx(5)
    assert cpt2.fwhm.min == pytest.approx(1)
    assert cpt2.fwhm.max == pytest.approx(5)

    assert cpt2.pos.val == pytest.approx(0)
    assert cpt2.pos.min < -3e38
    assert cpt2.pos.max > 3e38

    assert cpt2.ampl.val == pytest.approx(-5)
    assert cpt2.ampl.min == pytest.approx(-5)
    assert cpt2.ampl.max == pytest.approx(0)
コード例 #5
0
def test_paramprompt_single_parameter_check_too_many_commas(caplog):
    """Check we tell users there was a problem"""

    s = Session()
    s._add_model_types(sherpa.models.basic)

    s.paramprompt(True)
    assert len(caplog.records) == 0

    with SherpaVerbosity('INFO'):
        with patch("sys.stdin", StringIO(",,,,\n12")):
            s.set_source("scale1d.bob")

    assert len(caplog.records) == 1
    lname, lvl, msg = caplog.record_tuples[0]
    assert lname == "sherpa.ui.utils"
    assert lvl == logging.INFO
    assert msg == "Error: Please provide a comma-separated list of floats; e.g. val,min,max"

    mdl = s.get_model_component('bob')
    assert mdl.c0.val == pytest.approx(12)
    assert mdl.c0.min < -3e38
    assert mdl.c0.max > 3e38

    # remove the bob symbol from the global table
    s.clean()
コード例 #6
0
def test_309(make_data_path):

    idval = 'bug309'

    # have values near unity for the data
    ynorm = 1e9

    session = Session()

    dname = make_data_path('load_template_with_interpolation-bb_data.dat')

    session.load_data(idval, dname)
    session.get_data(idval).y *= ynorm

    indexname = 'bb_index.dat'
    datadir = make_data_path('')

    # Need to load the data from the same directory as the index
    basedir = os.getcwd()
    os.chdir(datadir)
    try:
        session.load_template_model('bbtemp', indexname)
    finally:
        os.chdir(basedir)

    bbtemp = session.get_model_component('bbtemp')
    session.set_source(idval, bbtemp * ynorm)

    session.set_method('gridsearch')
    session.set_method_opt('sequence', None)
    session.fit(idval)
コード例 #7
0
def test_paramprompt_single_parameter_check_invalid_max(caplog):

    s = Session()
    s._add_model_types(sherpa.models.basic)

    s.paramprompt(True)
    assert len(caplog.records) == 0

    with SherpaVerbosity('INFO'):
        with patch("sys.stdin", StringIO("-2,,typo\n-200,,-2")):
            s.set_source("scale1d.bob")

    assert len(caplog.records) == 1
    lname, lvl, msg = caplog.record_tuples[0]
    assert lname == "sherpa.ui.utils"
    assert lvl == logging.INFO
    assert msg == "Please provide a float value; could not convert string to float: 'typo'"

    mdl = s.get_model_component('bob')
    assert mdl.c0.val == pytest.approx(-200)
    assert mdl.c0.min < -3e38
    assert mdl.c0.max == pytest.approx(-2)

    # remove the bob symbol from the global table
    s.clean()
コード例 #8
0
def test_set_source_invalid():

    s = Session()
    with pytest.raises(ArgumentErr) as ae:
        s.set_source('2 * made_up.foo')

    assert str(ae.value) == "invalid model expression: name 'made_up' is not defined"
コード例 #9
0
def test_309(make_data_path):

    idval = 'bug309'

    # have values near unity for the data
    ynorm = 1e9

    session = Session()

    dname = make_data_path('load_template_with_interpolation-bb_data.dat')

    session.load_data(idval, dname)
    session.get_data(idval).y *= ynorm

    indexname = 'bb_index.dat'
    datadir = make_data_path('')

    # Need to load the data from the same directory as the index
    basedir = os.getcwd()
    os.chdir(datadir)
    try:
        session.load_template_model('bbtemp', indexname)
    finally:
        os.chdir(basedir)

    bbtemp = session.get_model_component('bbtemp')
    session.set_source(idval, bbtemp * ynorm)

    session.set_method('gridsearch')
    session.set_method_opt('sequence', None)
    session.fit(idval)
コード例 #10
0
ファイル: test_session.py プロジェクト: wmclaugh/sherpa
def test_paramprompt_single_parameter_combo_works(caplog):

    s = Session()
    s._add_model_types(sherpa.models.basic)

    s.paramprompt(True)
    assert len(caplog.records) == 0

    with SherpaVerbosity('INFO'):
        with patch("sys.stdin", StringIO("-2,-10,10")):
            s.set_source("scale1d.bob")

    assert len(caplog.records) == 0
    mdl = s.get_model_component('bob')
    assert mdl.c0.val == pytest.approx(-2)
    assert mdl.c0.min == pytest.approx(-10)
    assert mdl.c0.max == pytest.approx(10)
コード例 #11
0
ファイル: test_session.py プロジェクト: wmclaugh/sherpa
def test_delete_model_component_warning(caplog):
    """Check we get a warning (which ends up being issue #16)"""

    s = Session()
    s._add_model_types(sherpa.models.basic)

    s.set_source('const1d.mdl + gauss1d.mdl2')
    assert s.list_model_components() == ['mdl', 'mdl2']

    assert len(caplog.records) == 0
    s.delete_model_component('mdl2')

    assert len(caplog.records) == 1
    lname, lvl, msg = caplog.record_tuples[0]
    assert lname == "sherpa.ui.utils"
    assert lvl == logging.WARNING
    assert msg == "the model component 'gauss1d.mdl2' is found in model 1 and cannot be deleted"

    assert s.list_model_components() == ['mdl', 'mdl2']
コード例 #12
0
def test_paramprompt_single_parameter_max_works(caplog):

    s = Session()
    s._add_model_types(sherpa.models.basic)

    s.paramprompt(True)
    assert len(caplog.records) == 0

    with SherpaVerbosity('INFO'):
        with patch("sys.stdin", StringIO(",,100")):
            s.set_source("scale1d.bob")

    assert len(caplog.records) == 0
    mdl = s.get_model_component('bob')
    assert mdl.c0.val == pytest.approx(1)
    assert mdl.c0.min < -3e38
    assert mdl.c0.max == pytest.approx(100)

    # remove the bob symbol from the global table
    s.clean()
コード例 #13
0

def savefig(name):
    plt.savefig(name)
    print("# Created: {}".format(name))


s = Session()
xlo = [2, 3, 5, 7, 8]
xhi = [3, 5, 6, 8, 9]
y = [10, 27, 14, 10, 14]
s.load_arrays(1, xlo, xhi, y, Data1DInt)
mdl = Polynom1D('mdl')
mdl.c0 = 6
mdl.c1 = 1
s.set_source(mdl)
s.plot_fit()

savefig('ui_plot_fit_basic.png')

s.plot_data(color='black')
p = s.get_model_plot_prefs()
p['marker'] = '*'
p['markerfacecolor'] = 'green'
p['markersize'] = 12
s.plot_model(linestyle=':', alpha=0.7, overplot=True)

savefig('ui_plot_fit_manual.png')

print("**** s.get_model_plot()")
plot = s.get_model_plot(recalc=False)
コード例 #14
0
ファイル: basic_session.py プロジェクト: wsf1990/sherpa
print(s.list_data_ids())

print("# get_data()")
print(repr(s.get_data()))

print("# get_data()")
print(s.get_data())

print("# get_stat_name/get_method_name")
print(s.get_stat_name())
print(s.get_method_name())

s.set_stat('cash')
s.set_method('simplex')

s.set_source('const1d.mdl')

print("# mdl")
print(mdl)

print("# get_source")
print(s.get_source())

print("# fit")
s.fit()

print("# get_fit_results")
r = s.get_fit_results()
print(r)

s.get_data_plot_prefs()['yerrorbars'] = False