Exemplo n.º 1
0
    def test_FitMCMC_best_curves_dt_SN1999em(self):
        from pystella.rf import light_curve_func as lcf
        from pystella.rf import light_curve_plot as lcp
        # Get observations
        D = 11.5e6  # pc
        dm = -5. * np.log10(D) + 5
        # dm = -30.4  # D = 12.e6 pc
        curves_obs = sn1999em.read_curves()
        curves_obs.set_mshift(dm)

        # Get model
        name = 'cat_R500_M15_Ni006_E12'
        path = join(dirname(dirname(abspath(__file__))), 'data', 'stella')

        curves_mdl = lcf.curves_compute(name,
                                        path,
                                        distance=10,
                                        bands=curves_obs.BandNames)

        # fit
        is_debug = True  # True
        # fitter = FitLcMcmc()
        fitter = ps.FitMCMC()
        fitter.is_info = True
        if is_debug:
            fitter.is_debug = is_debug
            fitter.nwalkers = 100  # number of MCMC walkers
            fitter.nburn = 20  # "burn-in" period to let chains stabilize
            fitter.nsteps = 200  # number of MCMC steps to take
        threads = 1

        # fitter.is_debug = False
        fit_result, res, (th, e1,
                          e2), samples = fitter.best_curves(curves_mdl,
                                                            curves_obs,
                                                            dt0=0.,
                                                            threads=threads,
                                                            is_sampler=True)
        fig = fitter.plot_corner(samples,
                                 labels=('dt', ),
                                 bnames=curves_obs.BandNames)

        # print
        # '{:10s} {:.4f} ^{:.4f}_{:.4f}\n'.format('lnf:', res['lnf'], res['lnfsig2'], res['lnfsig1']) + \
        txt = '{:10s} {:.4f} ^{:.4f}_{:.4f} \n'.format('tshift:', th.dt, e1.dt, e2.dt) + \
              '{:10s} chi2= {:.1f}  BIC= {:.1f} AIC= {:.1f} dof= {} accept= {:.3f}\n'. \
                  format('stat:', res['chi2'], res['bic'], res['aic'], res['dof'], res['acceptance_fraction'])
        print(txt)
        # plot model
        curves_obs.set_tshift(-th.dt)
        ax = lcp.curves_plot(curves_mdl)

        lt = {lc.Band.Name: 'o' for lc in curves_obs}
        lcp.curves_plot(curves_obs, ax, lt=lt, xlim=(-10, 300), is_line=False)
        ax.text(0.1, 0.1, txt, transform=ax.transAxes)
        plt.show()
Exemplo n.º 2
0
    def test_fit_mpfit_curves_Stella_SN1999em(self):
        from pystella.rf import light_curve_func as lcf
        from pystella.rf import light_curve_plot as lcp
        # matplotlib.rcParams['backend'] = "TkAgg"
        # matplotlib.rcParams['backend'] = "Qt4Agg"
        # from matplotlib import pyplot as plt
        # Get observations
        D = 11.5e6  # pc
        dm = -5. * np.log10(D) + 5
        # dm = -30.4  # D = 12.e6 pc
        curves_obs = sn1999em.read_curves()
        curves_obs.set_mshift(dm)

        # Get model
        name = 'cat_R500_M15_Ni006_E12'
        path = join(dirname(dirname(abspath(__file__))), 'data', 'stella')

        curves_mdl = lcf.curves_compute(name, path, curves_obs.BandNames)

        # fit
        # fitter = FitLcMcmc()
        fitter = ps.FitMPFit()
        fitter.is_debug = True
        res = fitter.fit_curves(curves_obs, curves_mdl)

        # print
        # txt = '{0:10} {1:.4e} \n'.format('tshift:', res.tshift) + \
        #       '{0:10} {1:.4e} \n'.format('tsigma:', res.tsigma) + \
        #       '{0}\n'.format(res.comm)
        txt = '{:10s} {:.4f}+-{:.4f} \n'.format('tshift:', res.tshift, res.tsigma) + \
              '{:10s} {:.4f}+-{:.4f}\n'.format('msigma:', res.mshift, res.msigma) + \
              '{0}\n'.format(res.comm)

        print(txt)
        # plot model
        curves_obs.set_tshift(res.tshift)
        # curves_mdl.set_tshift(0.)
        ax = lcp.curves_plot(curves_mdl)

        lt = {lc.Band.Name: 'o' for lc in curves_obs}
        lcp.curves_plot(curves_obs, ax, lt=lt, xlim=(-10, 300), is_line=False)
        ax.text(0.1, 0.1, txt, transform=ax.transAxes)
        plt.show()
Exemplo n.º 3
0
    def test_FitMPFit_best_curves_gp_SN1999em(self):
        from pystella.rf import light_curve_func as lcf
        from pystella.rf import light_curve_plot as lcp
        # matplotlib.rcParams['backend'] = "TkAgg"
        # matplotlib.rcParams['backend'] = "Qt4Agg"
        # from matplotlib import pyplot as plt
        # Get observations
        D = 11.5e6  # pc
        ebv_sn = 0.1
        dm = -5. * np.log10(D) + 5
        # dm = -30.4  # D = 12.e6 pc
        curves_obs = sn1999em.read_curves()
        curves_obs.set_mshift(dm)

        # Get model
        name = 'cat_R500_M15_Ni006_E12'
        path = join(dirname(dirname(abspath(__file__))), 'data', 'stella')

        curves_mdl = lcf.curves_compute(name, path, curves_obs.BandNames)

        # fit
        # fitter = FitLcMcmc()
        fitter = ps.FitMPFit()
        fitter.is_info = True
        fitter.is_debug = True
        res = fitter.best_curves_gp(curves_mdl, curves_obs, dt0=0., dm0=0.)
        # print
        txt = '{0:10} {1:.4e} \n'.format('tshift:', res['dt']) + \
              '{0:10} {1:.4e} \n'.format('tsigma:', res['dtsig'])
        print(txt)
        # plot model
        curves_obs.set_tshift(res['dt'])
        # curves_mdl.set_tshift(0.)
        ax = lcp.curves_plot(curves_mdl)

        lt = {lc.Band.Name: 'o' for lc in curves_obs}
        lcp.curves_plot(curves_obs, ax, lt=lt, xlim=(-10, 300), is_line=False)
        plt.show()
Exemplo n.º 4
0
def plot_curves_vel(curves_o, vels_o, res_models, res_sorted, vels_m,
                    **kwargs):
    from pystella.rf import light_curve_plot as lcp
    from matplotlib import pyplot as plt

    font_size = kwargs.get('font_size', 10)
    linewidth = kwargs.get('linewidth', 2.0)
    markersize = kwargs.get('markersize', 5)
    xlim = kwargs.get('xlim', None)

    ylim = None
    num = len(res_sorted)
    nrow = int(num / 2.1) + 1
    ncol = 2
    fig = plt.figure(figsize=(12, nrow * 4))
    plt.matplotlib.rcParams.update({'font.size': font_size})
    tshift_lc = 0.
    if curves_o is not None:
        tshift_lc = ps.first(curves_o).tshift
    tshift_vel = 0.
    if vels_o is not None:
        tshift_vel = ps.first(vels_o).tshift

    i = 0
    for k, v in res_sorted.items():
        i += 1
        #  Plot UBV
        axUbv = fig.add_subplot(nrow, ncol, i)

        tshift_best = v.tshift
        axUbv.text(0.99,
                   0.94,
                   k,
                   horizontalalignment='right',
                   transform=axUbv.transAxes)
        axUbv.text(0.98,
                   0.85,
                   "dt={:.2f}".format(tshift_best),
                   horizontalalignment='right',
                   transform=axUbv.transAxes)
        axUbv.text(0.01,
                   0.05,
                   "$\chi^2: {:.2f}$".format(v.measure),
                   horizontalalignment='left',
                   transform=axUbv.transAxes,
                   bbox=dict(facecolor='green', alpha=0.3))

        curves = res_models[k]
        if curves is not None:
            lcp.curves_plot(curves,
                            ax=axUbv,
                            figsize=(12, 8),
                            linewidth=1,
                            is_legend=False)
            lt = {lc.Band.Name: 'o' for lc in curves_o}
            curves_o.set_tshift(tshift_lc + tshift_best)
            if xlim is None:
                xlim = axUbv.get_xlim()
            lcp.curves_plot(curves_o,
                            axUbv,
                            xlim=xlim,
                            lt=lt,
                            markersize=2,
                            is_legend=False,
                            is_line=False)
            # legend
            axUbv.legend(curves.BandNames,
                         loc='lower right',
                         frameon=False,
                         ncol=min(5, len(curves.BandNames)),
                         fontsize='small',
                         borderpad=1)

        if i % ncol == 0:
            axUbv.yaxis.tick_right()
            axUbv.set_ylabel('')
            axUbv.yaxis.set_ticks([])
        else:
            axUbv.yaxis.tick_left()
            axUbv.yaxis.set_label_position("left")
        axUbv.grid(linestyle=':')
        #  Plot Vel
        axVel = axUbv.twinx()
        axVel.set_ylim((0., 29))
        axVel.set_ylabel('Velocity [1000 km/s]')

        ts_vel = vels_m[k]
        x = ts_vel.Time
        y = ts_vel.V
        axVel.plot(x,
                   y,
                   label='Vel  %s' % k,
                   color='blue',
                   ls="-",
                   linewidth=linewidth)
        # Obs. vel

        markers_cycler = cycle(markers_style)
        for vel_o in vels_o:
            vel_o.tshift = tshift_vel + tshift_best
            marker = next(markers_cycler)
            if vel_o.IsErr:
                yyerr = abs(vel_o.Err)
                axVel.errorbar(vel_o.Time,
                               vel_o.V,
                               yerr=yyerr,
                               label='{0}'.format(vel_o.Name),
                               color="blue",
                               ls='',
                               markersize=markersize,
                               fmt=marker,
                               markeredgewidth=2,
                               markerfacecolor='none',
                               markeredgecolor='blue')
            else:
                axVel.plot(vel_o.Time,
                           vel_o.V,
                           label=vel_o.Name,
                           color='blue',
                           ls='',
                           marker=marker,
                           markersize=markersize)

    fig.subplots_adjust(wspace=0, hspace=0)
    return fig
Exemplo n.º 5
0
def plot_curves(curves_o, res_models, res_sorted, **kwargs):
    from pystella.rf import light_curve_plot as lcp
    from matplotlib import pyplot as plt
    import math

    font_size = kwargs.get('font_size', 8)
    xlim = kwargs.get('xlim', None)
    ylim = kwargs.get('ylim', None)
    num = len(res_sorted)
    # nrow = int(num / 2.1) + 1
    # ncol = 2 if num > 1 else 1
    ncol = min(4, int(np.sqrt(num)))  # 2 if num > 1 else 1
    nrow = math.ceil(num / ncol)
    # fig = plt.figure(figsize=(12, nrow * 4))
    fig = plt.figure(figsize=(min(ncol, 2) * 5, max(nrow, 2) * 4))
    plt.matplotlib.rcParams.update({'font.size': font_size})

    # tshift0 = ps.first(curves_o).tshift
    i = 0
    for k, v in res_sorted.items():
        i += 1
        ax = fig.add_subplot(nrow, ncol, i)

        tshift_best = v.tshift
        curves = res_models[k]
        curves.set_tshift(tshift_best)
        lcp.curves_plot(curves,
                        ax=ax,
                        figsize=(12, 8),
                        linewidth=1,
                        is_legend=False)
        if xlim is None or xlim[1] == float('inf'):
            xlim = ax.get_xlim()
        else:
            ax.set_xlim(xlim)
        lt = {lc.Band.Name: 'o' for lc in curves_o}
        # curves_o.set_tshift(tshift0)
        lcp.curves_plot(curves_o,
                        ax,
                        xlim=xlim,
                        lt=lt,
                        markersize=4,
                        is_legend=False,
                        is_line=False)

        ax.text(0.99,
                0.94,
                k,
                horizontalalignment='right',
                transform=ax.transAxes)
        ax.text(0.98,
                0.85,
                "{} dt={:.2f}".format(i, tshift_best),
                horizontalalignment='right',
                transform=ax.transAxes)
        ax.text(0.01,
                0.05,
                r"$\chi^2: {:.2f}$".format(v.measure),
                horizontalalignment='left',
                transform=ax.transAxes,
                bbox=dict(facecolor='green', alpha=0.3))
        # ax.text(0.9, 0.9, "{:.2f}".format(v.measure), horizontalalignment='right', transform=ax.transAxes,
        # bbox=dict(facecolor='green', alpha=0.3))

        # fix axes
        if ncol == 1:
            ax.yaxis.tick_left()
            ax.yaxis.set_label_position("left")
            ax.set_ylabel('Magnitude')
        elif i % ncol == 0:
            ax.yaxis.tick_right()
            ax.yaxis.set_label_position("right")
            # ax.set_ylabel('')
            # ax.set_yticklabels([])
            # ax2.set_ylabel('Magnitude')
        elif i % ncol == 1:
            ax.yaxis.tick_left()
            ax.yaxis.set_label_position("left")
            ax.set_ylabel('Magnitude')
        else:
            ax.set_ylabel('')
            ax.set_yticklabels([])
            ax.yaxis.set_ticks_position('both')
        # legend
        ax.legend(curves.BandNames,
                  loc='lower right',
                  frameon=False,
                  ncol=min(5, len(curves.BandNames)),
                  fontsize='small',
                  borderpad=1)
        # lc_colors = band.bands_colors()

    fig.subplots_adjust(wspace=0, hspace=0)
    return fig
Exemplo n.º 6
0
    def test_FitMCMC_best_curves_dtdm_SN1999em(self):
        from pystella.rf import light_curve_func as lcf
        from pystella.rf import light_curve_plot as lcp
        # Get observations
        D = 11.5e6  # pc
        # MD0 = -5. * np.log10(D) + 5
        ebv_sn = 0.1
        # dm = -30.4  # D = 12.e6 pc
        curves_obs = sn1999em.read_curves()
        # curves_obs.set_mshift(MD0)

        # Get model
        name = 'cat_R500_M15_Ni006_E12'
        path = join(dirname(dirname(abspath(__file__))), 'data', 'stella')

        curves_mdl = ps.Stella(name, path=path).curves(curves_obs.BandNames,
                                                       distance=D,
                                                       ebv=ebv_sn)

        # fit
        is_debug = True  # True
        # fitter = FitLcMcmc()
        fitter = ps.FitMCMC()
        fitter.is_info = True
        if is_debug:
            fitter.is_debug = is_debug
            fitter.nwalkers = 100  # number of MCMC walkers
            fitter.nburn = 20  # "burn-in" period to let chains stabilize
            fitter.nsteps = 200  # number of MCMC steps to take
        threads = 3

        # fitter.is_debug = False
        # fit_result, res, samples \
        fit_result, res, (th, ep,
                          em), sampler = fitter.best_curves(curves_mdl,
                                                            curves_obs,
                                                            dt0=0.,
                                                            dm0=0.,
                                                            threads=threads,
                                                            is_sampler=True)
        nb = curves_obs.Length
        dt, dm, sigs = fitter.theta2arr(th, nb)
        ep_dt, ep_dm, ep_sigs = fitter.theta2arr(ep, nb)
        em_dt, em_dm, em_sigs = fitter.theta2arr(em, nb)

        samples = sampler.flatchain[fitter.nburn:, :]
        fig = fitter.plot_corner(samples,
                                 bnames=curves_obs.BandNames,
                                 bins=55,
                                 alpha=0.5,
                                 verbose=fitter.is_info)

        # print
        txt = 'chi2= {:.1f}  BIC= {:.1f} AIC= {:.1f} measure= {:.3f}\n'.\
            format(res['chi2'], res['bic'], res['aic'], res['measure'])
        txt += '\n dt= {:.1f} ^{{{:.1f}}}_{{{:.1f}}} '.format(dt, ep_dt, em_dt)
        txt += '\n dm= {:.2f} ^{{{:.2f}}}_{{{:.2f}}} '.format(dm, ep_dm, em_dm)
        for i, bname in enumerate(curves_obs.BandNames):
            txt += '\n sigma_{{{:s}}} = {:.2f}^{{{:.2f}}}_{{{:.2f}}} '. \
                format(bname, sigs[i], ep_sigs[i], em_sigs[i])
        print(txt)

        curves_mdl.set_tshift(dt)
        curves_mdl.set_mshift(dm)

        ax = None
        errs = {}
        for i, bn in enumerate(curves_mdl.BandNames):
            errs[bn] = [
                sigs[i]
            ] * curves_mdl[bn].Length  # The error are the same for all points
        curves_clone = curves_mdl.clone(err=errs)
        ax = ps.curves_plot(curves_clone,
                            ax=ax,
                            is_legend=False,
                            is_fill=True,
                            alpha=0.2)
        # Obs
        lt = {lc.Band.Name: 'o' for lc in curves_obs}
        lcp.curves_plot(curves_obs, ax, lt=lt, xlim=(-10, 300), is_line=False)

        ax.text(0.05, 0.05, txt, transform=ax.transAxes)
        plt.show()
Exemplo n.º 7
0
    def test_FitMPF_best_curves_SN1999em(self):
        from pystella.rf import light_curve_func as lcf
        from pystella.rf import light_curve_plot as lcp
        # Get observations
        D = 11.5e6  # pc
        ebv_sn = 0.
        # ebv_sn = 0.1
        # dm = -5. * np.log10(D) + 5
        # dm = -30.4  # D = 12.e6 pc
        curves_obs = sn1999em.read_curves()

        # Get model
        name = 'cat_R500_M15_Ni006_E12'
        path = join(dirname(dirname(abspath(__file__))), 'data', 'stella')

        curves_mdl = ps.Stella(name, path=path).curves(curves_obs.BandNames,
                                                       distance=D,
                                                       ebv=ebv_sn)

        # fit
        fitter = ps.FitMPFit()
        fitter.is_info = True
        fitter.is_debug = True
        fitter.is_quiet = True
        fit_result, res, dum = fitter.best_curves(
            curves_mdl,
            curves_obs,
            dt0=0.,
        )
        # fit_result, res, dum = fitter.best_curves(curves_mdl, curves_obs, dt0=0., dm0=0.)
        dt, dtsig = res['dt'], res['dtsig']
        dm, dmsig = res['dm'], res['dmsig']
        # plot model

        curves_mdl.set_tshift(dt)
        ax = lcp.curves_plot(curves_mdl)

        lt = {lc.Band.Name: 'o' for lc in curves_obs}
        lcp.curves_plot(curves_obs, ax, lt=lt, xlim=(-10, 300), is_line=False)
        # print
        txt = '{:10} {:.2f} +- {:.4f}'.format('dt:', res['dt'], res['dtsig'])
        txt += '\n{:10} {:.2f} +- {:.4f}'.format('dm:', res['dm'],
                                                 res['dmsig'])
        print(txt)
        title_font = {
            'size': '12',
            'color': 'black',
            'weight': 'normal',
            'verticalalignment': 'bottom'
        }
        ax.text(0.03,
                0.95,
                txt,
                transform=ax.transAxes,
                bbox={
                    'facecolor': 'none',
                    'alpha': 0.2,
                    'edgecolor': 'none'
                },
                **title_font)
        plt.show()