示例#1
0
def test_warning_dataplot_linecolor(caplog):
    """We get a warning when using linecolor: DataPlot"""

    data = Data1D('tst', np.asarray([1, 2, 3]), np.asarray([10, 12, 10.5]))
    plot = DataPlot()
    plot.prepare(data, stat=None)
    with caplog.at_level(logging.INFO, logger='sherpa'):
        plot.plot(linecolor='mousey')

    assert len(caplog.records) == 1
    lname, lvl, msg = caplog.record_tuples[0]
    assert lname == 'sherpa.plot.pylab_backend'
    assert lvl == logging.WARNING
    assert msg == 'The linecolor attribute, set to mousey, is unused.'
示例#2
0
文件: index.py 项目: wsf1990/sherpa
dump("sinfo1.numpoints")
dump("sinfo2.numpoints")

res = f.fit()
if res.succeeded: print("Fit succeeded")
if not res.succeeded: print("**** ERRRR, the fit failed folks")

report("res.format()")
report("res")

from sherpa.plot import DataPlot, ModelPlot
dplot = DataPlot()
dplot.prepare(f.data)
mplot = ModelPlot()
mplot.prepare(f.data, f.model)
dplot.plot()
mplot.overplot()

savefig("data_model_c0_c2.png")

dump("f.method.name")
original_method = f.method

from sherpa.optmethods import NelderMead
f.method = NelderMead()
resn = f.fit()
print("Change in statistic: {}".format(resn.dstatval))

fit2 = Fit(d, mdl, method=NelderMead())
fit2.fit()
示例#3
0
dump("pha.get_arf()")
dump("pha.get_rmf()")

dump("pha.header['INSTRUME']")
dump("pha.header['DETNAM']")
dump("pha.channel.size")

pha.set_analysis('energy')
pha.notice(0.3, 7)
tabs = ~pha.mask
pha.group_counts(20, tabStops=tabs)

from sherpa.plot import DataPlot
dplot = DataPlot()
dplot.prepare(pha)
dplot.plot(xlog=True, ylog=True)
savefig('pha_data.png')

chans, = pha.get_indep(filter=True)
counts = pha.get_dep(filter=True)
dump("chans.size, counts.size")

gchans = pha.apply_filter(chans, pha._middle)
dump("gchans.size")

plt.clf()
plt.plot(gchans, counts, 'o')
plt.xlabel('Channel')
plt.ylabel('Counts')
savefig('pha_data_manual.png')
示例#4
0
def fit(star_name, data, model, silent=False, breakdown=False):
    """A function that will fit a given multi-part model to a given spectrum.



    :param star_name: Name of the target star
    :type star_name: str
    :param data: Spectrum data in the form (wave, flux)
    :type data: tuple
    :param model: An unfit spectrum model
    :type model: object
    :param silent:  If true, no plots will generate, defaults to False
    :type silent: bool

    :return: model that is fit to the data
    :rtype: object


    """

    wave, flux = data

    # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

    d = Data1D(star_name, wave, flux)

    # ==========================================
    # Initial guesses

    # Dataset 1
    dplot = DataPlot()
    dplot.prepare(d)
    if silent is False:
        dplot.plot()

    mplot = ModelPlot()
    mplot.prepare(d, model)
    if silent is False:
        dplot.plot()
        mplot.overplot()
        plt.show()

    # =========================================
    # Fitting happens here - don't break please
    start = time.time()

    stat = LeastSq()

    opt = LevMar()

    opt.verbose = 0
    opt.ftol = 1e-15
    opt.xtol = 1e-15
    opt.gtol = 1e-15
    opt.epsfcn = 1e-15

    if silent is False:
        print(opt)

    vfit = Fit(d, model, stat=stat, method=opt)

    if silent is False:
        print(vfit)

    vres = vfit.fit()

    if silent is False:
        print()
        print()
        print(vres.format())

    # =========================================
    # Plotting after fit

    # Dataset 1
    if silent is False:
        fplot = FitPlot()
        mplot.prepare(d, model)
        fplot.prepare(dplot, mplot)
        fplot.plot()

        # residual
        plt.title(star_name)
        plt.plot(wave, flux - model(wave))

        # plt.xaxis(fontsize = )
        plt.xlabel("Wavelength (AA)", fontsize=12)
        plt.ylabel("Flux", fontsize=12)
        plt.tick_params(axis="both", labelsize=12)

    if silent is False:
        duration = time.time() - start
        print()
        print("Time taken: " + str(duration))
        print()

    plt.show()

    if breakdown is True:
        params = []

        cont = model[0]

        if silent is False:
            plt.scatter(wave, flux, marker=".", c="black")
            plt.plot(wave, model(wave), c="C1")

        for line in model:
            if line.name[0] != "(":
                if line.name == "Cont_flux":
                    if silent is False:
                        print(line)
                        plt.plot(wave, line(wave), linestyle="--")
                else:
                    params.append(line)
                    if silent is False:
                        print()
                        print(line)
                        plt.plot(wave, line(wave) * cont(wave), linestyle="--")

        plt.show()

        return model, params

    return model
示例#5
0
def multifit(star_name, data_list, model_list, silent=False):
    """A function that will fit 2 models to 2 spectra simultaneously.
        This was created to fit the NaI doublets at ~3300 and ~5890 Angstroms.

    :param star_name: Name of the target star
    :type star_name: str
    :param data_list: List of spectrum data in the form [(wave, flux), (wave, flux),...]
    :type data_list: tuple
    :param model_list:  A list of unfit spectrum models
    :type model_list: list
    :param silent:  If true, no plots will generate, defaults to False
    :type silent: bool

    :return: models that are fit to the data
    :rtype: list

    """

    wave1, flux1 = data_list[0]
    wave2, flux2 = data_list[1]

    model1 = model_list[0]
    model2 = model_list[1]

    name_1 = star_name + " 1"
    name_2 = star_name + " 2"

    d1 = Data1D(name_1, wave1, flux1)
    d2 = Data1D(name_2, wave2, flux2)

    dall = DataSimulFit("combined", (d1, d2))
    mall = SimulFitModel("combined", (model1, model2))

    # # ==========================================
    # # Initial guesses

    # Dataset 1
    dplot1 = DataPlot()
    dplot1.prepare(d1)
    if silent is False:
        dplot1.plot()

    mplot1 = ModelPlot()
    mplot1.prepare(d1, model1)
    if silent is False:
        dplot1.plot()
        mplot1.overplot()
        plt.show()

        # Dataset 2
    dplot2 = DataPlot()
    dplot2.prepare(d2)
    if silent is False:
        dplot2.plot()

    mplot2 = ModelPlot()
    mplot2.prepare(d2, model2)
    if silent is False:
        dplot2.plot()
        mplot2.overplot()
        plt.show()

    # # =========================================
    # # Fitting happens here - don't break please
    stat = LeastSq()

    opt = LevMar()
    opt.verbose = 0
    opt.ftol = 1e-15
    opt.xtol = 1e-15
    opt.gtol = 1e-15
    opt.epsfcn = 1e-15
    print(opt)

    vfit = Fit(dall, mall, stat=stat, method=opt)
    print(vfit)
    vres = vfit.fit()

    print()
    print()
    print("Did the fit succeed? [bool]")
    print(vres.succeeded)
    print()
    print()
    print(vres.format())

    # # =========================================
    # # Plotting after fit
    if silent is False:
        # Dataset 1
        fplot1 = FitPlot()
        mplot1.prepare(d1, model1)
        fplot1.prepare(dplot1, mplot1)
        fplot1.plot()

        # residual
        title = "Data 1"
        plt.title(title)
        plt.plot(wave1, flux1 - model1(wave1))
        plt.show()

        # Dataset 2
        fplot2 = FitPlot()
        mplot2.prepare(d2, model2)
        fplot2.prepare(dplot2, mplot2)
        fplot2.plot()

        # residual
        title = "Data 2"
        plt.title(title)
        plt.plot(wave2, flux2 - model2(wave2))
        plt.show()

        # both datasets - no residuals
        splot = SplitPlot()
        splot.addplot(fplot1)
        splot.addplot(fplot2)

        plt.tight_layout()
        plt.show()

    return model_list
示例#6
0
        leftrange = len(mplot2.y) / (2)
        left = int(round(leftrange))
        rightrange = len(mplot2.y) / (2)
        right = int(round(rightrange))

        ### test to see if eq calc knows only to do filled in area

        y_cont = np.linspace(max(mplot2.y[:left]),
                             max(mplot2.y[right:]),
                             num=len(mplot2.x))
        y_model = mplot2.y
        y_cn = y_model / y_cont
        dcontplot = DataPlot()
        dcont = Data1D("normalized", mplot.x, y_cn)
        dcontplot.prepare(dcont)
        dcontplot.plot()
        os.chdir(r"/home/dtyler/Desktop/DocumentsDT/outputs/standard_" +
                 fits_image_filename)
        plt.savefig(fits_image_filename[:-4] + "norm_continuum.png")
        os.chdir(r"/home/dtyler/Desktop/DocumentsDT/Programs")

        ###################################
        ###  calculate equivalent width
        spectrum = Spectrum1D(flux=y_cn * u.dimensionless_unscaled,
                              spectral_axis=mplot.x * u.AA)
        eqw = equivalent_width(spectrum)
        """print("EW", eqw.value) 
		print('norm factor', norm_factor)
		for par in bmdl.pars:
			print(par.fullname, par.val)"""
示例#7
0
    print("# dump")
    print("{}".format(name))
    print(repr(eval(name)))
    print("----------------------------------------")


edges = np.asarray([-10, -5, 5, 12, 17, 20, 30, 56, 60])
y = np.asarray([28, 62, 17, 4, 2, 4, 125, 55])

from sherpa.data import Data1DInt
d = Data1DInt('example histogram', edges[:-1], edges[1:], y)

from sherpa.plot import DataPlot
dplot = DataPlot()
dplot.prepare(d)
dplot.plot()

savefig('dataplot_histogram.png')

from sherpa.plot import Histogram
hplot = Histogram()
hplot.overplot(d.xlo, d.xhi, d.y)

savefig('dataplot_histogram_overplot.png')

from sherpa.models.basic import Const1D, Gauss1D
mdl = Const1D('base') - Gauss1D('line')
mdl.pars[0].val = 10
mdl.pars[1].val = 25
mdl.pars[2].val = 22
mdl.pars[3].val = 10
示例#8
0
    print("# dump")
    print(name)
    print(repr(eval(name)))


from sherpa.data import Data1D
x = [1, 1.5, 2, 4, 8, 17]
y = [1, 1.5, 1.75, 3.25, 6, 16]
d = Data1D('interpolation', x, y)

report("print(d)")

from sherpa.plot import DataPlot
dplot = DataPlot()
dplot.prepare(d)
dplot.plot()
savefig('data.png')

# Note: can not print(dplot) as there is a problem with the fact
#       the input to the data object is a list, not ndarray
#       Sherpa 4.10.0

from sherpa.models.basic import Polynom1D
mdl = Polynom1D()
report("print(mdl)")

mdl.c2.thaw()

from sherpa.plot import ModelPlot
mplot = ModelPlot()
mplot.prepare(d, mdl)
示例#9
0
文件: example.py 项目: wsf1990/sherpa
from openpyxl import load_workbook
wb = load_workbook('pone.0171996.s001.xlsx')
fig4 = wb['Fig4data']
t = []; y = []; dy = []
for r in list(fig4.values)[2:]:
    t.append(r[0])
    y.append(r[3])
    dy.append(r[4])

from sherpa.data import Data1D
d = Data1D('NaNO_3', t, y, dy)

from sherpa.plot import DataPlot
dplot = DataPlot()
dplot.prepare(d)
dplot.plot()
savefig("data.png")

report("d")
dump("d.get_filter()")

d.ignore(None, 1)
dump("d.get_filter()")

dump("d.get_filter(format='%d')")

dplot.prepare(d)

from sherpa.models.basic import Const1D, Exp
plateau = Const1D('plateau')
rise = Exp('rise')