コード例 #1
0
    def test_popov_SN1999em_emcee(self):
        n = 100
        start, end = 0.1, 200.
        jd_shift = 20.
        dm = -29.38  # D = 7.5e6 pc
        # dm = -30.4  # D = 12.e6 pc
        time = np.exp(np.linspace(np.log(start), np.log(end), n))
        popov = Popov('test', R=450., M=15., Mni=0.04, E=0.7)
        lc_m = popov.LCBol(time)

        # ax = popov.plot_Lbol(time)
        # sn1999em.plot_ubv(ax, path=sn1999em.sn_path, jd_shift=jd_shift, mshift=dm)
        # plt.show()
        curves_o = sn1999em.read_curves()
        lc_o = curves_o.get('V')
        lc_o.mshift = dm
        print("Run: find tshift with bayesian: obs band %s with %s ..." %
              (lc_o.Band.Name, popov))
        fitter = FitMCMC()
        fitter.is_debug = True
        res = fitter.fit_lc(lc_o, lc_m)
        tshift, tsigma = res.tshift, res.tsigma
        print("Result: tshift= %s tsigma= %s ..." % (tshift, tsigma))

        ax = popov.plot_Lbol(time)
        lcp.lc_plot(lc_m, ax)
        lcp.lc_plot(lc_o, ax, is_line=False)
        # sn1999em.plot_ubv(ax, path=sn1999em.sn_path, jd_shift=-tshift, mshift=dm)
        plt.show()
コード例 #2
0
    def test_fit_popov_SN1999em(self):
        D = 11.5e6  # pc
        dm = -5. * np.log10(D) + 5
        # dm = -30.4  # D = 12.e6 pc
        curves = sn1999em.read_curves()
        lc = curves.get('R')
        lc.mshift = dm
        # lc.tshift = -lc.tmin

        # fit
        ppv, tshift = popov_fit(lc, R0=1200., M0=10., Mni0=0.01, E0=1., dt0=0.)
        # print
        txt = '{0:10} {1:.4e} R_sun\n'.format('R:', ppv.R0 / ps.phys.R_sun) + \
              '{0:10} {1:.4e} M_sun\n'.format('Mtot:', ppv.Mtot / ps.phys.M_sun) + \
              '{0:10} {1:.4e} M_sun\n'.format('Mni:', ppv.Mni / ps.phys.M_sun) + \
              '{0:10} {1} ergs\n'.format('Etot:', ppv.Etot) + \
              '{0:10} {1} d'.format('tshift:', tshift)
        print(txt)
        # plot model
        # time = lc.Time - tshift
        ax = ppv.plot_Lbol(lc.Time)
        # plot obs
        lc.tshift = lc.tshift + tshift
        x = lc.Time
        # x = lc.Time + jd_shift + res
        y = lc.Mag
        ax.plot(x,
                y,
                label='%s SN 1999em' % lc.Band.Name,
                ls="",
                color='red',
                markersize=8,
                marker="o")
        plt.show()
コード例 #3
0
    def test_fit_time_popov_SN1999em(self):
        jd_shift = 20.
        dm = -29.38  # D = 7.5e6 pc
        # dm = -30.4  # D = 12.e6 pc
        curves = sn1999em.read_curves()
        lc = curves.get('V')
        lc.mshift = dm

        time = lc.Time - lc.Tmin
        # time = np.exp(np.linspace(np.log(start), np.log(end), n))
        popov = Popov('test', R=450., M=15., Mni=0.04, E=0.7)
        mags = popov.MagBol(time)

        # fit
        tshift = myfit(mags, lc)

        # plot
        ax = popov.plot_Lbol(time)
        x = lc.Time + tshift
        # x = lc.Time + jd_shift + res
        y = lc.Mag
        ax.plot(x,
                y,
                label='%s SN 1999em' % lc.Band.Name,
                ls=":",
                color='red',
                markersize=8,
                marker="o")
        plt.show()
コード例 #4
0
    def test_lc_gp_Ntime(self):
        from plugin import sn1999em
        from pystella.fit.fit_gp import FitGP

        Ntime = 50
        colors = band.colors()
        fig, ax = plt.subplots(figsize=(12, 10))
        ax.invert_yaxis()

        curves_o = sn1999em.read_curves()
        for bname in curves_o.BandNames:
            lc_o = curves_o.get(bname)
            lc_o_gp, gp = FitGP.fit_lc(lc_o, Ntime=Ntime)
            # plot
            ax.errorbar(lc_o.Time, lc_o.Mag, label='{0} {1}'.format(bname, 'obs'), yerr=lc_o.MagErr,
                        fmt='s', color=colors[bname], ls='', markersize=2.5)
            x = lc_o_gp.Time
            y_pred = lc_o_gp.Mag
            sigma = lc_o_gp.MagErr
            ax.fill(np.concatenate([x, x[::-1]]),
                    np.concatenate([y_pred - sigma,
                                    (y_pred + sigma)[::-1]]),
                    #              np.concatenate([y_pred - 1.9600 * sigma,
                    #                             (y_pred + 1.9600 * sigma)[::-1]]),
                    alpha=.7, fc='grey', ec='None', label='95% confidence interval')
            ax.plot(x, y_pred, ls='', marker='o', markersize=2, color=colors[bname])
        plt.show()
コード例 #5
0
ファイル: test_gp.py プロジェクト: baklanovp/pystella
    def test_fit_GP_SN1999em(self):
        # add data
        dm = -29.38  # D = 7.5e6 pc
        # dm = -30.4  # D = 12.e6 pc
        curves = sn1999em.read_curves()
        lc = curves.get('V')
        lc.mshift = dm
        t = lc.Time
        y = lc.Mag
        yerr = lc.MagErr

        # k = gptools.SquaredExponentialKernel()
        # gp = gptools.GaussianProcess(k)
        # k = gptools.SquaredExponentialKernel(param_bounds=[(0, 1e3), (0, 100)])
        k = gptools.SquaredExponentialKernel(param_bounds=[(0., max(np.abs(y))),
                                                           (0, np.std(t))])
        gp = gptools.GaussianProcess(k, mu=gptools.LinearMeanFunction())

        gp.add_data(t, y, err_y=yerr)

        is_mcmc = True
        is_mcmc = False
        if is_mcmc:
            out = gp.predict(t, use_MCMC=True, full_MCMC=True,
                             return_std=False,
                             num_proc=0,
                             nsamp=200,
                             plot_posterior=True,
                             plot_chains=False,
                             burn=100,
                             thin=1)

        else:
            gp.optimize_hyperparameters(verbose=True)
            out = gp.predict(t, use_MCMC=False)

        y_star, err_y_star = out

        # gp.optimize_hyperparameters()
        # y_star, err_y_star = gp.predict(t)

        fig = plt.figure()
        ax = fig.add_axes((0.1, 0.3, 0.8, 0.65))
        ax.invert_yaxis()

        ax.plot(t, y, color='blue', label='L bol', lw=2.5)
        ax.errorbar(t, y, yerr=yerr, fmt='o', color='blue', label='%s obs.')

        #
        # ax.plot(t, y_star, color='red', ls='--', lw=1.5, label='GP')
        # third: plot a constrained function with errors
        ax.plot(t, y_star, '-', color='gray')
        ax.fill_between(t, y_star - 2 * err_y_star, y_star + 2 * err_y_star, color='gray', alpha=0.3)
        # ax.errorbar(t, y_star, err_y_star, fmt='.k', ms=6)

        plt.legend()
        plt.show()
コード例 #6
0
    def test_fit_GP_SN1999em(self):
        # add data
        dm = -29.38  # D = 7.5e6 pc
        # dm = -30.4  # D = 12.e6 pc
        curves = sn1999em.read_curves()
        lc = curves.get('V')
        lc.mshift = dm
        t = lc.Time
        y = lc.Mag
        yerr = lc.Err

        # k = gptools.SquaredExponentialKernel()
        # gp = gptools.GaussianProcess(k)
        # k = gptools.SquaredExponentialKernel(param_bounds=[(0, 1e3), (0, 100)])
        k = gptools.SquaredExponentialKernel(param_bounds=[(0., max(np.abs(y))),
                                                           (0, np.std(t))])
        gp = gptools.GaussianProcess(k, mu=gptools.LinearMeanFunction())

        gp.add_data(t, y, err_y=yerr)

        is_mcmc = True
        is_mcmc = False
        if is_mcmc:
            out = gp.predict(t, use_MCMC=True, full_MCMC=True,
                             return_std=False,
                             num_proc=0,
                             nsamp=200,
                             plot_posterior=True,
                             plot_chains=False,
                             burn=100,
                             thin=1)

        else:
            gp.optimize_hyperparameters(verbose=True)
            out = gp.predict(t, use_MCMC=False)

        y_star, err_y_star = out

        # gp.optimize_hyperparameters()
        # y_star, err_y_star = gp.predict(t)

        fig = plt.figure()
        ax = fig.add_axes((0.1, 0.3, 0.8, 0.65))
        ax.invert_yaxis()

        ax.plot(t, y, color='blue', label='L bol', lw=2.5)
        ax.errorbar(t, y, yerr=yerr, fmt='o', color='blue', label='%s obs.')

        #
        # ax.plot(t, y_star, color='red', ls='--', lw=1.5, label='GP')
        # third: plot a constrained function with errors
        ax.plot(t, y_star, '-', color='gray')
        ax.fill_between(t, y_star - 2 * err_y_star, y_star + 2 * err_y_star, color='gray', alpha=0.3)
        # ax.errorbar(t, y_star, err_y_star, fmt='.k', ms=6)

        plt.legend()
        plt.show()
コード例 #7
0
    def test_FitMCMC_best_curves_dt_SN1999em(self):
        from pystella.rf import light_curve_func as lcf
        from pystella.rf import light_curve_plot as lcp
        # Get observations
        D = 11.5e6  # pc
        dm = -5. * np.log10(D) + 5
        # dm = -30.4  # D = 12.e6 pc
        curves_obs = sn1999em.read_curves()
        curves_obs.set_mshift(dm)

        # Get model
        name = 'cat_R500_M15_Ni006_E12'
        path = join(dirname(dirname(abspath(__file__))), 'data', 'stella')

        curves_mdl = lcf.curves_compute(name,
                                        path,
                                        distance=10,
                                        bands=curves_obs.BandNames)

        # fit
        is_debug = True  # True
        # fitter = FitLcMcmc()
        fitter = ps.FitMCMC()
        fitter.is_info = True
        if is_debug:
            fitter.is_debug = is_debug
            fitter.nwalkers = 100  # number of MCMC walkers
            fitter.nburn = 20  # "burn-in" period to let chains stabilize
            fitter.nsteps = 200  # number of MCMC steps to take
        threads = 1

        # fitter.is_debug = False
        fit_result, res, (th, e1,
                          e2), samples = fitter.best_curves(curves_mdl,
                                                            curves_obs,
                                                            dt0=0.,
                                                            threads=threads,
                                                            is_sampler=True)
        fig = fitter.plot_corner(samples,
                                 labels=('dt', ),
                                 bnames=curves_obs.BandNames)

        # print
        # '{:10s} {:.4f} ^{:.4f}_{:.4f}\n'.format('lnf:', res['lnf'], res['lnfsig2'], res['lnfsig1']) + \
        txt = '{:10s} {:.4f} ^{:.4f}_{:.4f} \n'.format('tshift:', th.dt, e1.dt, e2.dt) + \
              '{:10s} chi2= {:.1f}  BIC= {:.1f} AIC= {:.1f} dof= {} accept= {:.3f}\n'. \
                  format('stat:', res['chi2'], res['bic'], res['aic'], res['dof'], res['acceptance_fraction'])
        print(txt)
        # plot model
        curves_obs.set_tshift(-th.dt)
        ax = lcp.curves_plot(curves_mdl)

        lt = {lc.Band.Name: 'o' for lc in curves_obs}
        lcp.curves_plot(curves_obs, ax, lt=lt, xlim=(-10, 300), is_line=False)
        ax.text(0.1, 0.1, txt, transform=ax.transAxes)
        plt.show()
コード例 #8
0
    def test_fit_mpfit_curves_Stella_SN1999em(self):
        from pystella.rf import light_curve_func as lcf
        from pystella.rf import light_curve_plot as lcp
        # matplotlib.rcParams['backend'] = "TkAgg"
        # matplotlib.rcParams['backend'] = "Qt4Agg"
        # from matplotlib import pyplot as plt
        # Get observations
        D = 11.5e6  # pc
        dm = -5. * np.log10(D) + 5
        # dm = -30.4  # D = 12.e6 pc
        curves_obs = sn1999em.read_curves()
        curves_obs.set_mshift(dm)

        # Get model
        name = 'cat_R500_M15_Ni006_E12'
        path = join(dirname(dirname(abspath(__file__))), 'data', 'stella')

        curves_mdl = lcf.curves_compute(name, path, curves_obs.BandNames)

        # fit
        # fitter = FitLcMcmc()
        fitter = ps.FitMPFit()
        fitter.is_debug = True
        res = fitter.fit_curves(curves_obs, curves_mdl)

        # print
        # txt = '{0:10} {1:.4e} \n'.format('tshift:', res.tshift) + \
        #       '{0:10} {1:.4e} \n'.format('tsigma:', res.tsigma) + \
        #       '{0}\n'.format(res.comm)
        txt = '{:10s} {:.4f}+-{:.4f} \n'.format('tshift:', res.tshift, res.tsigma) + \
              '{:10s} {:.4f}+-{:.4f}\n'.format('msigma:', res.mshift, res.msigma) + \
              '{0}\n'.format(res.comm)

        print(txt)
        # plot model
        curves_obs.set_tshift(res.tshift)
        # curves_mdl.set_tshift(0.)
        ax = lcp.curves_plot(curves_mdl)

        lt = {lc.Band.Name: 'o' for lc in curves_obs}
        lcp.curves_plot(curves_obs, ax, lt=lt, xlim=(-10, 300), is_line=False)
        ax.text(0.1, 0.1, txt, transform=ax.transAxes)
        plt.show()
コード例 #9
0
ファイル: test_fit.py プロジェクト: baklanovp/pystella
    def test_fit_popov_SN1999em(self):
        D = 11.5e6  # pc
        dm = -5. * np.log10(D) + 5
        # dm = -30.4  # D = 12.e6 pc
        curves = sn1999em.read_curves()
        lc = curves.get('R')
        lc.mshift = dm
        # lc.tshift = -lc.tmin

        # fit
        ppv, tshift = popov_fit(lc, R0=1000., M0=10., Mni0=0.01, E0=1., dt0=0.)
        # plot model
        # time = lc.Time - tshift
        ax = ppv.plot_Lbol(lc.Time)
        # plot obs
        lc.tshift = lc.tshift + tshift
        x = lc.Time
        # x = lc.Time + jd_shift + res
        y = lc.Mag
        ax.plot(x, y, label='%s SN 1999em' % lc.Band.Name,
                ls=".", color='red', markersize=8, marker="o")
        plt.show()
コード例 #10
0
    def test_FitMPFit_best_curves_gp_SN1999em(self):
        from pystella.rf import light_curve_func as lcf
        from pystella.rf import light_curve_plot as lcp
        # matplotlib.rcParams['backend'] = "TkAgg"
        # matplotlib.rcParams['backend'] = "Qt4Agg"
        # from matplotlib import pyplot as plt
        # Get observations
        D = 11.5e6  # pc
        ebv_sn = 0.1
        dm = -5. * np.log10(D) + 5
        # dm = -30.4  # D = 12.e6 pc
        curves_obs = sn1999em.read_curves()
        curves_obs.set_mshift(dm)

        # Get model
        name = 'cat_R500_M15_Ni006_E12'
        path = join(dirname(dirname(abspath(__file__))), 'data', 'stella')

        curves_mdl = lcf.curves_compute(name, path, curves_obs.BandNames)

        # fit
        # fitter = FitLcMcmc()
        fitter = ps.FitMPFit()
        fitter.is_info = True
        fitter.is_debug = True
        res = fitter.best_curves_gp(curves_mdl, curves_obs, dt0=0., dm0=0.)
        # print
        txt = '{0:10} {1:.4e} \n'.format('tshift:', res['dt']) + \
              '{0:10} {1:.4e} \n'.format('tsigma:', res['dtsig'])
        print(txt)
        # plot model
        curves_obs.set_tshift(res['dt'])
        # curves_mdl.set_tshift(0.)
        ax = lcp.curves_plot(curves_mdl)

        lt = {lc.Band.Name: 'o' for lc in curves_obs}
        lcp.curves_plot(curves_obs, ax, lt=lt, xlim=(-10, 300), is_line=False)
        plt.show()
コード例 #11
0
ファイル: test_fit.py プロジェクト: baklanovp/pystella
    def test_fit_time_popov_SN1999em(self):
        jd_shift = 20.
        dm = -29.38  # D = 7.5e6 pc
        # dm = -30.4  # D = 12.e6 pc
        curves = sn1999em.read_curves()
        lc = curves.get('V')
        lc.mshift = dm

        time = lc.Time - lc.tmin
        # time = np.exp(np.linspace(np.log(start), np.log(end), n))
        popov = Popov('test', R=450., M=15., Mni=0.04, E=0.7)
        mags = popov.MagBol(time)

        # fit
        tshift = myfit(mags, lc)

        # plot
        ax = popov.plot_Lbol(time)
        x = lc.Time + tshift
        # x = lc.Time + jd_shift + res
        y = lc.Mag
        ax.plot(x, y, label='%s SN 1999em' % lc.Band.Name,
                ls=".", color='red', markersize=8, marker="o")
        plt.show()
コード例 #12
0
    def test_FitMCMC_fake_curves(self):
        # Get observations
        # D = 11.5e6  # pc
        D = 10.5e6  # pc
        dm = -5. * np.log10(D) + 5
        curves_o = sn1999em.read_curves()
        curves_o.set_mshift(dm)

        # Model
        fname = 'cat_R500_M15_Ni006_E12.tt.ubv'
        path = join(dirname(dirname(abspath(__file__))), 'data', 'stella')
        # curves_m = pd.read_csv(join(path, fname), header=0, delim_whitespace=True)
        curves_m = ps.curves_read(join(path, fname))

        # Filter
        bnames = curves_m.BandNames
        for bname in bnames:
            if bname not in curves_o.BandNames:
                curves_m.pop(bname)

        # curves_m = curves_m[curves_m.TimeCommon > 0]

        # fit
        is_debug = True
        fitter = ps.FitMCMC()
        fitter.is_info = True  # True False
        fitter.is_debug = is_debug
        if is_debug:
            fitter.nwalkers = 100  # number of MCMC walkers
            fitter.nburn = 20  # "burn-in" period to let chains stabilize
            fitter.nsteps = 200  # number of MCMC steps to take
        else:
            fitter.nwalkers = 200  # number of MCMC walkers
            fitter.nburn = 100  # "burn-in" period to let chains stabilize
            fitter.nsteps = 500  # number of MCMC steps to take

        print('is_debug ', is_debug)

        res, samples = fitter.best_curves(curves_m,
                                          curves_o,
                                          dt0=0.,
                                          dm0=0.,
                                          is_sampler=True)

        curves_o.set_tshift(0.)
        curves_o.set_mshift(dm)

        curves_m.set_tshift(-res['dt'])
        curves_m.set_mshift(-res['dm'])

        # Plot
        fig, (ax, axhist) = plt.subplots(2, 1, figsize=(9, 9))
        fig = fitter.plot_corner(samples, axhist=axhist)

        ps.curves_plot(curves_m, ax=ax)
        ps.curves_plot(curves_o, ax=ax, is_line=False, markersize=2)

        # print
        txt = '{:10s} {:.4f} ^{:.4f}_{:.4f} \n'.format('tshift:', res['dt'], res['dtsig2'], res['dtsig1']) + \
              '{:10s} {:.4f} ^{:.4f}_{:.4f}\n'.format('msigma:', res['dm'], res['dmsig2'], res['dmsig1']) + \
              '{:10s} {:.4f} ^{:.4f}_{:.4f}\n'.format('lnf:', res['lnf'], res['lnfsig2'], res['lnfsig1']) + \
              '{:10s} chi2= {:.4f} dof= {} accept= {:.3f}\n'. \
                  format('stat:', res['chi2'], res['dof'], res['acceptance_fraction'])
        print(txt)
        ax.text(0.1, 0.02, txt, transform=ax.transAxes)
        xlim = ax.get_xlim()
        xlim = -30, xlim[1]
        ax.set_xlim(xlim)
        fig.tight_layout()
        plt.show()
コード例 #13
0
    def test_FitMCMC_best_curves_dtdm_SN1999em(self):
        from pystella.rf import light_curve_func as lcf
        from pystella.rf import light_curve_plot as lcp
        # Get observations
        D = 11.5e6  # pc
        # MD0 = -5. * np.log10(D) + 5
        ebv_sn = 0.1
        # dm = -30.4  # D = 12.e6 pc
        curves_obs = sn1999em.read_curves()
        # curves_obs.set_mshift(MD0)

        # Get model
        name = 'cat_R500_M15_Ni006_E12'
        path = join(dirname(dirname(abspath(__file__))), 'data', 'stella')

        curves_mdl = ps.Stella(name, path=path).curves(curves_obs.BandNames,
                                                       distance=D,
                                                       ebv=ebv_sn)

        # fit
        is_debug = True  # True
        # fitter = FitLcMcmc()
        fitter = ps.FitMCMC()
        fitter.is_info = True
        if is_debug:
            fitter.is_debug = is_debug
            fitter.nwalkers = 100  # number of MCMC walkers
            fitter.nburn = 20  # "burn-in" period to let chains stabilize
            fitter.nsteps = 200  # number of MCMC steps to take
        threads = 3

        # fitter.is_debug = False
        # fit_result, res, samples \
        fit_result, res, (th, ep,
                          em), sampler = fitter.best_curves(curves_mdl,
                                                            curves_obs,
                                                            dt0=0.,
                                                            dm0=0.,
                                                            threads=threads,
                                                            is_sampler=True)
        nb = curves_obs.Length
        dt, dm, sigs = fitter.theta2arr(th, nb)
        ep_dt, ep_dm, ep_sigs = fitter.theta2arr(ep, nb)
        em_dt, em_dm, em_sigs = fitter.theta2arr(em, nb)

        samples = sampler.flatchain[fitter.nburn:, :]
        fig = fitter.plot_corner(samples,
                                 bnames=curves_obs.BandNames,
                                 bins=55,
                                 alpha=0.5,
                                 verbose=fitter.is_info)

        # print
        txt = 'chi2= {:.1f}  BIC= {:.1f} AIC= {:.1f} measure= {:.3f}\n'.\
            format(res['chi2'], res['bic'], res['aic'], res['measure'])
        txt += '\n dt= {:.1f} ^{{{:.1f}}}_{{{:.1f}}} '.format(dt, ep_dt, em_dt)
        txt += '\n dm= {:.2f} ^{{{:.2f}}}_{{{:.2f}}} '.format(dm, ep_dm, em_dm)
        for i, bname in enumerate(curves_obs.BandNames):
            txt += '\n sigma_{{{:s}}} = {:.2f}^{{{:.2f}}}_{{{:.2f}}} '. \
                format(bname, sigs[i], ep_sigs[i], em_sigs[i])
        print(txt)

        curves_mdl.set_tshift(dt)
        curves_mdl.set_mshift(dm)

        ax = None
        errs = {}
        for i, bn in enumerate(curves_mdl.BandNames):
            errs[bn] = [
                sigs[i]
            ] * curves_mdl[bn].Length  # The error are the same for all points
        curves_clone = curves_mdl.clone(err=errs)
        ax = ps.curves_plot(curves_clone,
                            ax=ax,
                            is_legend=False,
                            is_fill=True,
                            alpha=0.2)
        # Obs
        lt = {lc.Band.Name: 'o' for lc in curves_obs}
        lcp.curves_plot(curves_obs, ax, lt=lt, xlim=(-10, 300), is_line=False)

        ax.text(0.05, 0.05, txt, transform=ax.transAxes)
        plt.show()
コード例 #14
0
    def test_FitMPF_best_curves_SN1999em(self):
        from pystella.rf import light_curve_func as lcf
        from pystella.rf import light_curve_plot as lcp
        # Get observations
        D = 11.5e6  # pc
        ebv_sn = 0.
        # ebv_sn = 0.1
        # dm = -5. * np.log10(D) + 5
        # dm = -30.4  # D = 12.e6 pc
        curves_obs = sn1999em.read_curves()

        # Get model
        name = 'cat_R500_M15_Ni006_E12'
        path = join(dirname(dirname(abspath(__file__))), 'data', 'stella')

        curves_mdl = ps.Stella(name, path=path).curves(curves_obs.BandNames,
                                                       distance=D,
                                                       ebv=ebv_sn)

        # fit
        fitter = ps.FitMPFit()
        fitter.is_info = True
        fitter.is_debug = True
        fitter.is_quiet = True
        fit_result, res, dum = fitter.best_curves(
            curves_mdl,
            curves_obs,
            dt0=0.,
        )
        # fit_result, res, dum = fitter.best_curves(curves_mdl, curves_obs, dt0=0., dm0=0.)
        dt, dtsig = res['dt'], res['dtsig']
        dm, dmsig = res['dm'], res['dmsig']
        # plot model

        curves_mdl.set_tshift(dt)
        ax = lcp.curves_plot(curves_mdl)

        lt = {lc.Band.Name: 'o' for lc in curves_obs}
        lcp.curves_plot(curves_obs, ax, lt=lt, xlim=(-10, 300), is_line=False)
        # print
        txt = '{:10} {:.2f} +- {:.4f}'.format('dt:', res['dt'], res['dtsig'])
        txt += '\n{:10} {:.2f} +- {:.4f}'.format('dm:', res['dm'],
                                                 res['dmsig'])
        print(txt)
        title_font = {
            'size': '12',
            'color': 'black',
            'weight': 'normal',
            'verticalalignment': 'bottom'
        }
        ax.text(0.03,
                0.95,
                txt,
                transform=ax.transAxes,
                bbox={
                    'facecolor': 'none',
                    'alpha': 0.2,
                    'edgecolor': 'none'
                },
                **title_font)
        plt.show()
コード例 #15
0
    def test_fit_stella_model(self):
        """Example: Fitting a Stella model to Sn1999em"""

        # obs data
        dm = -29.38
        curves = sn1999em.read_curves()
        lc_obs = curves.get('V')
        lc_obs.mshift = -dm
        lc_obs.tshift = -lc_obs.Tmin

        # model light curves
        name = 'cat_R1000_M15_Ni007_E15'
        path = join(dirname(abspath(__file__)), 'data', 'stella')
        curves_model = lcf.curves_compute(name, path, bands='V')
        lc_mdl = curves_model.get('V')

        todo
        # Choose the "true" parameters.
        dt_init = 0.  # initial time shift
        dm_init = 0.  # initial time shift
        f_true = 0.534

        # Generate some synthetic data from the model.
        # N = 50
        t = lc_obs.Time
        merr = lc_obs.Err
        m = lc_obs.Mag

        def lnlike(theta, t, m, yerr):
            dt, dm, lnf = theta
            model = lc_mdl.Mag
            inv_sigma2 = 1.0 / (yerr**2 + model**2 * np.exp(2 * lnf))
            return -0.5 * (np.sum((lc_obs - model)**2 * inv_sigma2 -
                                  np.log(inv_sigma2)))

        nll = lambda *args: -lnlike(*args)
        result = op.minimize(
            nll, [dt_init, dm_init, np.log(f_true)], args=(t, m, merr))
        m_ml, b_ml, lnf_ml = result["x"]

        def lnprior(theta):
            m, b, lnf = theta
            if -5.0 < m < 0.5 and 0.0 < b < 10.0 and -10.0 < lnf < 1.0:
                return 0.0
            return -np.inf

        def lnprob(theta, x, y, yerr):
            lp = lnprior(theta)
            if not np.isfinite(lp):
                return -np.inf
            return lp + lnlike(theta, x, y, yerr)

        ndim, nwalkers = 3, 100
        pos = [
            result["x"] + 1e-4 * np.random.randn(ndim) for i in range(nwalkers)
        ]

        sampler = emcee.EnsembleSampler(nwalkers,
                                        ndim,
                                        lnprob,
                                        args=(t, m, merr))
        sampler.run_mcmc(pos, 500)

        samples = sampler.chain[:, 50:, :].reshape((-1, ndim))

        xl = np.array([0, 10])
        for m, b, lnf in samples[np.random.randint(len(samples), size=100)]:
            plt.plot(xl, m * xl + b, color="k", alpha=0.1)
        plt.plot(xl, dt_init * xl + dm_init, color="r", lw=2, alpha=0.8)
        plt.errorbar(t, m, yerr=merr, fmt=".k")
        plt.show()

        samples[:, 2] = np.exp(samples[:, 2])
        m_mcmc, b_mcmc, f_mcmc = map(
            lambda v: (v[1], v[2] - v[1], v[1] - v[0]),
            zip(*np.percentile(samples, [16, 50, 84], axis=0)))

        def print_v3(v3):
            print("v = %f + %f - %f" % v3)

        map(print_v3, (m_mcmc, b_mcmc, f_mcmc))