def test_FitMCMC_best_curves_dt_SN1999em(self): from pystella.rf import light_curve_func as lcf from pystella.rf import light_curve_plot as lcp # Get observations D = 11.5e6 # pc dm = -5. * np.log10(D) + 5 # dm = -30.4 # D = 12.e6 pc curves_obs = sn1999em.read_curves() curves_obs.set_mshift(dm) # Get model name = 'cat_R500_M15_Ni006_E12' path = join(dirname(dirname(abspath(__file__))), 'data', 'stella') curves_mdl = lcf.curves_compute(name, path, distance=10, bands=curves_obs.BandNames) # fit is_debug = True # True # fitter = FitLcMcmc() fitter = ps.FitMCMC() fitter.is_info = True if is_debug: fitter.is_debug = is_debug fitter.nwalkers = 100 # number of MCMC walkers fitter.nburn = 20 # "burn-in" period to let chains stabilize fitter.nsteps = 200 # number of MCMC steps to take threads = 1 # fitter.is_debug = False fit_result, res, (th, e1, e2), samples = fitter.best_curves(curves_mdl, curves_obs, dt0=0., threads=threads, is_sampler=True) fig = fitter.plot_corner(samples, labels=('dt', ), bnames=curves_obs.BandNames) # print # '{:10s} {:.4f} ^{:.4f}_{:.4f}\n'.format('lnf:', res['lnf'], res['lnfsig2'], res['lnfsig1']) + \ txt = '{:10s} {:.4f} ^{:.4f}_{:.4f} \n'.format('tshift:', th.dt, e1.dt, e2.dt) + \ '{:10s} chi2= {:.1f} BIC= {:.1f} AIC= {:.1f} dof= {} accept= {:.3f}\n'. \ format('stat:', res['chi2'], res['bic'], res['aic'], res['dof'], res['acceptance_fraction']) print(txt) # plot model curves_obs.set_tshift(-th.dt) ax = lcp.curves_plot(curves_mdl) lt = {lc.Band.Name: 'o' for lc in curves_obs} lcp.curves_plot(curves_obs, ax, lt=lt, xlim=(-10, 300), is_line=False) ax.text(0.1, 0.1, txt, transform=ax.transAxes) plt.show()
def test_fit_mpfit_curves_Stella_SN1999em(self): from pystella.rf import light_curve_func as lcf from pystella.rf import light_curve_plot as lcp # matplotlib.rcParams['backend'] = "TkAgg" # matplotlib.rcParams['backend'] = "Qt4Agg" # from matplotlib import pyplot as plt # Get observations D = 11.5e6 # pc dm = -5. * np.log10(D) + 5 # dm = -30.4 # D = 12.e6 pc curves_obs = sn1999em.read_curves() curves_obs.set_mshift(dm) # Get model name = 'cat_R500_M15_Ni006_E12' path = join(dirname(dirname(abspath(__file__))), 'data', 'stella') curves_mdl = lcf.curves_compute(name, path, curves_obs.BandNames) # fit # fitter = FitLcMcmc() fitter = ps.FitMPFit() fitter.is_debug = True res = fitter.fit_curves(curves_obs, curves_mdl) # print # txt = '{0:10} {1:.4e} \n'.format('tshift:', res.tshift) + \ # '{0:10} {1:.4e} \n'.format('tsigma:', res.tsigma) + \ # '{0}\n'.format(res.comm) txt = '{:10s} {:.4f}+-{:.4f} \n'.format('tshift:', res.tshift, res.tsigma) + \ '{:10s} {:.4f}+-{:.4f}\n'.format('msigma:', res.mshift, res.msigma) + \ '{0}\n'.format(res.comm) print(txt) # plot model curves_obs.set_tshift(res.tshift) # curves_mdl.set_tshift(0.) ax = lcp.curves_plot(curves_mdl) lt = {lc.Band.Name: 'o' for lc in curves_obs} lcp.curves_plot(curves_obs, ax, lt=lt, xlim=(-10, 300), is_line=False) ax.text(0.1, 0.1, txt, transform=ax.transAxes) plt.show()
def test_FitMPFit_best_curves_gp_SN1999em(self): from pystella.rf import light_curve_func as lcf from pystella.rf import light_curve_plot as lcp # matplotlib.rcParams['backend'] = "TkAgg" # matplotlib.rcParams['backend'] = "Qt4Agg" # from matplotlib import pyplot as plt # Get observations D = 11.5e6 # pc ebv_sn = 0.1 dm = -5. * np.log10(D) + 5 # dm = -30.4 # D = 12.e6 pc curves_obs = sn1999em.read_curves() curves_obs.set_mshift(dm) # Get model name = 'cat_R500_M15_Ni006_E12' path = join(dirname(dirname(abspath(__file__))), 'data', 'stella') curves_mdl = lcf.curves_compute(name, path, curves_obs.BandNames) # fit # fitter = FitLcMcmc() fitter = ps.FitMPFit() fitter.is_info = True fitter.is_debug = True res = fitter.best_curves_gp(curves_mdl, curves_obs, dt0=0., dm0=0.) # print txt = '{0:10} {1:.4e} \n'.format('tshift:', res['dt']) + \ '{0:10} {1:.4e} \n'.format('tsigma:', res['dtsig']) print(txt) # plot model curves_obs.set_tshift(res['dt']) # curves_mdl.set_tshift(0.) ax = lcp.curves_plot(curves_mdl) lt = {lc.Band.Name: 'o' for lc in curves_obs} lcp.curves_plot(curves_obs, ax, lt=lt, xlim=(-10, 300), is_line=False) plt.show()
def plot_curves_vel(curves_o, vels_o, res_models, res_sorted, vels_m, **kwargs): from pystella.rf import light_curve_plot as lcp from matplotlib import pyplot as plt font_size = kwargs.get('font_size', 10) linewidth = kwargs.get('linewidth', 2.0) markersize = kwargs.get('markersize', 5) xlim = kwargs.get('xlim', None) ylim = None num = len(res_sorted) nrow = int(num / 2.1) + 1 ncol = 2 fig = plt.figure(figsize=(12, nrow * 4)) plt.matplotlib.rcParams.update({'font.size': font_size}) tshift_lc = 0. if curves_o is not None: tshift_lc = ps.first(curves_o).tshift tshift_vel = 0. if vels_o is not None: tshift_vel = ps.first(vels_o).tshift i = 0 for k, v in res_sorted.items(): i += 1 # Plot UBV axUbv = fig.add_subplot(nrow, ncol, i) tshift_best = v.tshift axUbv.text(0.99, 0.94, k, horizontalalignment='right', transform=axUbv.transAxes) axUbv.text(0.98, 0.85, "dt={:.2f}".format(tshift_best), horizontalalignment='right', transform=axUbv.transAxes) axUbv.text(0.01, 0.05, "$\chi^2: {:.2f}$".format(v.measure), horizontalalignment='left', transform=axUbv.transAxes, bbox=dict(facecolor='green', alpha=0.3)) curves = res_models[k] if curves is not None: lcp.curves_plot(curves, ax=axUbv, figsize=(12, 8), linewidth=1, is_legend=False) lt = {lc.Band.Name: 'o' for lc in curves_o} curves_o.set_tshift(tshift_lc + tshift_best) if xlim is None: xlim = axUbv.get_xlim() lcp.curves_plot(curves_o, axUbv, xlim=xlim, lt=lt, markersize=2, is_legend=False, is_line=False) # legend axUbv.legend(curves.BandNames, loc='lower right', frameon=False, ncol=min(5, len(curves.BandNames)), fontsize='small', borderpad=1) if i % ncol == 0: axUbv.yaxis.tick_right() axUbv.set_ylabel('') axUbv.yaxis.set_ticks([]) else: axUbv.yaxis.tick_left() axUbv.yaxis.set_label_position("left") axUbv.grid(linestyle=':') # Plot Vel axVel = axUbv.twinx() axVel.set_ylim((0., 29)) axVel.set_ylabel('Velocity [1000 km/s]') ts_vel = vels_m[k] x = ts_vel.Time y = ts_vel.V axVel.plot(x, y, label='Vel %s' % k, color='blue', ls="-", linewidth=linewidth) # Obs. vel markers_cycler = cycle(markers_style) for vel_o in vels_o: vel_o.tshift = tshift_vel + tshift_best marker = next(markers_cycler) if vel_o.IsErr: yyerr = abs(vel_o.Err) axVel.errorbar(vel_o.Time, vel_o.V, yerr=yyerr, label='{0}'.format(vel_o.Name), color="blue", ls='', markersize=markersize, fmt=marker, markeredgewidth=2, markerfacecolor='none', markeredgecolor='blue') else: axVel.plot(vel_o.Time, vel_o.V, label=vel_o.Name, color='blue', ls='', marker=marker, markersize=markersize) fig.subplots_adjust(wspace=0, hspace=0) return fig
def plot_curves(curves_o, res_models, res_sorted, **kwargs): from pystella.rf import light_curve_plot as lcp from matplotlib import pyplot as plt import math font_size = kwargs.get('font_size', 8) xlim = kwargs.get('xlim', None) ylim = kwargs.get('ylim', None) num = len(res_sorted) # nrow = int(num / 2.1) + 1 # ncol = 2 if num > 1 else 1 ncol = min(4, int(np.sqrt(num))) # 2 if num > 1 else 1 nrow = math.ceil(num / ncol) # fig = plt.figure(figsize=(12, nrow * 4)) fig = plt.figure(figsize=(min(ncol, 2) * 5, max(nrow, 2) * 4)) plt.matplotlib.rcParams.update({'font.size': font_size}) # tshift0 = ps.first(curves_o).tshift i = 0 for k, v in res_sorted.items(): i += 1 ax = fig.add_subplot(nrow, ncol, i) tshift_best = v.tshift curves = res_models[k] curves.set_tshift(tshift_best) lcp.curves_plot(curves, ax=ax, figsize=(12, 8), linewidth=1, is_legend=False) if xlim is None or xlim[1] == float('inf'): xlim = ax.get_xlim() else: ax.set_xlim(xlim) lt = {lc.Band.Name: 'o' for lc in curves_o} # curves_o.set_tshift(tshift0) lcp.curves_plot(curves_o, ax, xlim=xlim, lt=lt, markersize=4, is_legend=False, is_line=False) ax.text(0.99, 0.94, k, horizontalalignment='right', transform=ax.transAxes) ax.text(0.98, 0.85, "{} dt={:.2f}".format(i, tshift_best), horizontalalignment='right', transform=ax.transAxes) ax.text(0.01, 0.05, r"$\chi^2: {:.2f}$".format(v.measure), horizontalalignment='left', transform=ax.transAxes, bbox=dict(facecolor='green', alpha=0.3)) # ax.text(0.9, 0.9, "{:.2f}".format(v.measure), horizontalalignment='right', transform=ax.transAxes, # bbox=dict(facecolor='green', alpha=0.3)) # fix axes if ncol == 1: ax.yaxis.tick_left() ax.yaxis.set_label_position("left") ax.set_ylabel('Magnitude') elif i % ncol == 0: ax.yaxis.tick_right() ax.yaxis.set_label_position("right") # ax.set_ylabel('') # ax.set_yticklabels([]) # ax2.set_ylabel('Magnitude') elif i % ncol == 1: ax.yaxis.tick_left() ax.yaxis.set_label_position("left") ax.set_ylabel('Magnitude') else: ax.set_ylabel('') ax.set_yticklabels([]) ax.yaxis.set_ticks_position('both') # legend ax.legend(curves.BandNames, loc='lower right', frameon=False, ncol=min(5, len(curves.BandNames)), fontsize='small', borderpad=1) # lc_colors = band.bands_colors() fig.subplots_adjust(wspace=0, hspace=0) return fig
def test_FitMCMC_best_curves_dtdm_SN1999em(self): from pystella.rf import light_curve_func as lcf from pystella.rf import light_curve_plot as lcp # Get observations D = 11.5e6 # pc # MD0 = -5. * np.log10(D) + 5 ebv_sn = 0.1 # dm = -30.4 # D = 12.e6 pc curves_obs = sn1999em.read_curves() # curves_obs.set_mshift(MD0) # Get model name = 'cat_R500_M15_Ni006_E12' path = join(dirname(dirname(abspath(__file__))), 'data', 'stella') curves_mdl = ps.Stella(name, path=path).curves(curves_obs.BandNames, distance=D, ebv=ebv_sn) # fit is_debug = True # True # fitter = FitLcMcmc() fitter = ps.FitMCMC() fitter.is_info = True if is_debug: fitter.is_debug = is_debug fitter.nwalkers = 100 # number of MCMC walkers fitter.nburn = 20 # "burn-in" period to let chains stabilize fitter.nsteps = 200 # number of MCMC steps to take threads = 3 # fitter.is_debug = False # fit_result, res, samples \ fit_result, res, (th, ep, em), sampler = fitter.best_curves(curves_mdl, curves_obs, dt0=0., dm0=0., threads=threads, is_sampler=True) nb = curves_obs.Length dt, dm, sigs = fitter.theta2arr(th, nb) ep_dt, ep_dm, ep_sigs = fitter.theta2arr(ep, nb) em_dt, em_dm, em_sigs = fitter.theta2arr(em, nb) samples = sampler.flatchain[fitter.nburn:, :] fig = fitter.plot_corner(samples, bnames=curves_obs.BandNames, bins=55, alpha=0.5, verbose=fitter.is_info) # print txt = 'chi2= {:.1f} BIC= {:.1f} AIC= {:.1f} measure= {:.3f}\n'.\ format(res['chi2'], res['bic'], res['aic'], res['measure']) txt += '\n dt= {:.1f} ^{{{:.1f}}}_{{{:.1f}}} '.format(dt, ep_dt, em_dt) txt += '\n dm= {:.2f} ^{{{:.2f}}}_{{{:.2f}}} '.format(dm, ep_dm, em_dm) for i, bname in enumerate(curves_obs.BandNames): txt += '\n sigma_{{{:s}}} = {:.2f}^{{{:.2f}}}_{{{:.2f}}} '. \ format(bname, sigs[i], ep_sigs[i], em_sigs[i]) print(txt) curves_mdl.set_tshift(dt) curves_mdl.set_mshift(dm) ax = None errs = {} for i, bn in enumerate(curves_mdl.BandNames): errs[bn] = [ sigs[i] ] * curves_mdl[bn].Length # The error are the same for all points curves_clone = curves_mdl.clone(err=errs) ax = ps.curves_plot(curves_clone, ax=ax, is_legend=False, is_fill=True, alpha=0.2) # Obs lt = {lc.Band.Name: 'o' for lc in curves_obs} lcp.curves_plot(curves_obs, ax, lt=lt, xlim=(-10, 300), is_line=False) ax.text(0.05, 0.05, txt, transform=ax.transAxes) plt.show()
def test_FitMPF_best_curves_SN1999em(self): from pystella.rf import light_curve_func as lcf from pystella.rf import light_curve_plot as lcp # Get observations D = 11.5e6 # pc ebv_sn = 0. # ebv_sn = 0.1 # dm = -5. * np.log10(D) + 5 # dm = -30.4 # D = 12.e6 pc curves_obs = sn1999em.read_curves() # Get model name = 'cat_R500_M15_Ni006_E12' path = join(dirname(dirname(abspath(__file__))), 'data', 'stella') curves_mdl = ps.Stella(name, path=path).curves(curves_obs.BandNames, distance=D, ebv=ebv_sn) # fit fitter = ps.FitMPFit() fitter.is_info = True fitter.is_debug = True fitter.is_quiet = True fit_result, res, dum = fitter.best_curves( curves_mdl, curves_obs, dt0=0., ) # fit_result, res, dum = fitter.best_curves(curves_mdl, curves_obs, dt0=0., dm0=0.) dt, dtsig = res['dt'], res['dtsig'] dm, dmsig = res['dm'], res['dmsig'] # plot model curves_mdl.set_tshift(dt) ax = lcp.curves_plot(curves_mdl) lt = {lc.Band.Name: 'o' for lc in curves_obs} lcp.curves_plot(curves_obs, ax, lt=lt, xlim=(-10, 300), is_line=False) # print txt = '{:10} {:.2f} +- {:.4f}'.format('dt:', res['dt'], res['dtsig']) txt += '\n{:10} {:.2f} +- {:.4f}'.format('dm:', res['dm'], res['dmsig']) print(txt) title_font = { 'size': '12', 'color': 'black', 'weight': 'normal', 'verticalalignment': 'bottom' } ax.text(0.03, 0.95, txt, transform=ax.transAxes, bbox={ 'facecolor': 'none', 'alpha': 0.2, 'edgecolor': 'none' }, **title_font) plt.show()