def engines(nm=None): switcher = { 'mpfit': ps.FitMPFit(), # 'mcmc': ps.FitMCMC(nwalkers=200, nburn=100, nsteps=500), 'mcmc': ps.FitMCMC(), # 'mcmc': ps.FitMCMC(nwalkers=100, nburn=50, nsteps=200), } if nm is not None: return switcher.get(nm) return list(switcher.keys())
def test_fit_mpfit_curves_Stella_SN1999em(self): from pystella.rf import light_curve_func as lcf from pystella.rf import light_curve_plot as lcp # matplotlib.rcParams['backend'] = "TkAgg" # matplotlib.rcParams['backend'] = "Qt4Agg" # from matplotlib import pyplot as plt # Get observations D = 11.5e6 # pc dm = -5. * np.log10(D) + 5 # dm = -30.4 # D = 12.e6 pc curves_obs = sn1999em.read_curves() curves_obs.set_mshift(dm) # Get model name = 'cat_R500_M15_Ni006_E12' path = join(dirname(dirname(abspath(__file__))), 'data', 'stella') curves_mdl = lcf.curves_compute(name, path, curves_obs.BandNames) # fit # fitter = FitLcMcmc() fitter = ps.FitMPFit() fitter.is_debug = True res = fitter.fit_curves(curves_obs, curves_mdl) # print # txt = '{0:10} {1:.4e} \n'.format('tshift:', res.tshift) + \ # '{0:10} {1:.4e} \n'.format('tsigma:', res.tsigma) + \ # '{0}\n'.format(res.comm) txt = '{:10s} {:.4f}+-{:.4f} \n'.format('tshift:', res.tshift, res.tsigma) + \ '{:10s} {:.4f}+-{:.4f}\n'.format('msigma:', res.mshift, res.msigma) + \ '{0}\n'.format(res.comm) print(txt) # plot model curves_obs.set_tshift(res.tshift) # curves_mdl.set_tshift(0.) ax = lcp.curves_plot(curves_mdl) lt = {lc.Band.Name: 'o' for lc in curves_obs} lcp.curves_plot(curves_obs, ax, lt=lt, xlim=(-10, 300), is_line=False) ax.text(0.1, 0.1, txt, transform=ax.transAxes) plt.show()
def test_FitMPFit_best_curves_gp_SN1999em(self): from pystella.rf import light_curve_func as lcf from pystella.rf import light_curve_plot as lcp # matplotlib.rcParams['backend'] = "TkAgg" # matplotlib.rcParams['backend'] = "Qt4Agg" # from matplotlib import pyplot as plt # Get observations D = 11.5e6 # pc ebv_sn = 0.1 dm = -5. * np.log10(D) + 5 # dm = -30.4 # D = 12.e6 pc curves_obs = sn1999em.read_curves() curves_obs.set_mshift(dm) # Get model name = 'cat_R500_M15_Ni006_E12' path = join(dirname(dirname(abspath(__file__))), 'data', 'stella') curves_mdl = lcf.curves_compute(name, path, curves_obs.BandNames) # fit # fitter = FitLcMcmc() fitter = ps.FitMPFit() fitter.is_info = True fitter.is_debug = True res = fitter.best_curves_gp(curves_mdl, curves_obs, dt0=0., dm0=0.) # print txt = '{0:10} {1:.4e} \n'.format('tshift:', res['dt']) + \ '{0:10} {1:.4e} \n'.format('tsigma:', res['dtsig']) print(txt) # plot model curves_obs.set_tshift(res['dt']) # curves_mdl.set_tshift(0.) ax = lcp.curves_plot(curves_mdl) lt = {lc.Band.Name: 'o' for lc in curves_obs} lcp.curves_plot(curves_obs, ax, lt=lt, xlim=(-10, 300), is_line=False) plt.show()
def test_FitMPF_best_curves_SN1999em(self): from pystella.rf import light_curve_func as lcf from pystella.rf import light_curve_plot as lcp # Get observations D = 11.5e6 # pc ebv_sn = 0. # ebv_sn = 0.1 # dm = -5. * np.log10(D) + 5 # dm = -30.4 # D = 12.e6 pc curves_obs = sn1999em.read_curves() # Get model name = 'cat_R500_M15_Ni006_E12' path = join(dirname(dirname(abspath(__file__))), 'data', 'stella') curves_mdl = ps.Stella(name, path=path).curves(curves_obs.BandNames, distance=D, ebv=ebv_sn) # fit fitter = ps.FitMPFit() fitter.is_info = True fitter.is_debug = True fitter.is_quiet = True fit_result, res, dum = fitter.best_curves( curves_mdl, curves_obs, dt0=0., ) # fit_result, res, dum = fitter.best_curves(curves_mdl, curves_obs, dt0=0., dm0=0.) dt, dtsig = res['dt'], res['dtsig'] dm, dmsig = res['dm'], res['dmsig'] # plot model curves_mdl.set_tshift(dt) ax = lcp.curves_plot(curves_mdl) lt = {lc.Band.Name: 'o' for lc in curves_obs} lcp.curves_plot(curves_obs, ax, lt=lt, xlim=(-10, 300), is_line=False) # print txt = '{:10} {:.2f} +- {:.4f}'.format('dt:', res['dt'], res['dtsig']) txt += '\n{:10} {:.2f} +- {:.4f}'.format('dm:', res['dm'], res['dmsig']) print(txt) title_font = { 'size': '12', 'color': 'black', 'weight': 'normal', 'verticalalignment': 'bottom' } ax.text(0.03, 0.95, txt, transform=ax.transAxes, bbox={ 'facecolor': 'none', 'alpha': 0.2, 'edgecolor': 'none' }, **title_font) plt.show()