コード例 #1
0
    def test_bol_Uni(self):
        import matplotlib.pyplot as plt
        m1 = ps.Stella('cat_R500_M15_Ni006_E12', path='data/stella')
        fig, ax = plt.subplots()

        # Bol
        curves1 = m1.curves(bands=['bol'], wlrange=(1e0, 42.), is_nfrus=False)
        for lc in curves1:
            color = 'blue'
            ax.plot(lc.Time,
                    lc.Mag,
                    label=lc.Band.Name,
                    color=color,
                    linewidth=2,
                    ls='--')

        band03kEv = ps.BandUni(name='bol', wlrange=(1e0, 42.), length=300)
        wl_ab = np.min(band03kEv.wl2args), np.max(band03kEv.wl2args)
        curves2 = m1.curves(bands=[band03kEv], is_nfrus=False, wl_ab=wl_ab)
        for lc in curves2:
            color = 'red'
            ax.plot(lc.Time,
                    lc.Mag,
                    label=lc.Band.Name,
                    color=color,
                    linewidth=2,
                    ls=':')

        ax.invert_yaxis()
        #
        ax.legend()
        # ax.set_ylim(-14, -24)
        plt.show()
        import warnings
        warnings.warn("Should be check for shorck breakout")
コード例 #2
0
    def test_stella_curves_reddening_plot(self):
        from matplotlib import gridspec

        name = 'cat_R500_M15_Ni006_E12'
        path = join(dirname(abspath(__file__)), 'data', 'stella')
        bands = ('UVW1', 'UVW2', 'UVM2')
        # bands = ('UVW1', 'UVW2', 'UVM2', 'U', 'B', 'R', 'I')
        ebv = 1

        # mags reddening
        cs = ps.lcf.curves_compute(name, path, bands, t_diff=1.05)

        mdl = ps.Stella(name, path=path)
        is_SMC = False
        if is_SMC:
            curves_mags = ps.lcf.curves_reddening(cs, ebv=ebv, law='Rv2.1')
            curves = mdl.curves(bands,
                                ebv=ebv,
                                t_diff=1.05,
                                mode=ps.ReddeningLaw.SMC)  # best SMC MW
        else:
            curves_mags = ps.lcf.curves_reddening(
                cs, ebv=ebv, law=ps.extinction.law_default)
            curves = mdl.curves(bands,
                                ebv=ebv,
                                t_diff=1.05,
                                mode=ps.ReddeningLaw.MW)
        # curves = mdl.curves(bands, ebv=ebv, law=LawFitz, mode=ReddeningLaw.SMC)  # best SMC

        self.assertTrue((np.array(
            sorted(curves.BandNames) == sorted(curves_mags.BandNames))).all(),
                        msg="Error for the initial band names [%s] "
                        "VS secondary BandNames are %s." %
                        (' '.join(curves_mags.BandNames), ' '.join(
                            curves.BandNames)))

        # plot reddening with mags
        fig = plt.figure(figsize=(12, 12))
        gs1 = gridspec.GridSpec(4, 1)
        axUbv = fig.add_subplot(gs1[:-1, 0])
        axDM = fig.add_subplot(gs1[3, 0])
        lt = {lc.Band.Name: 'o' for lc in curves_mags}
        ax = ps.lcp.curves_plot(curves_mags,
                                ax=axUbv,
                                lt=lt,
                                markersize=2,
                                is_legend=False)
        xlim = ax.get_xlim()

        ps.lcp.curves_plot(curves, ax=axUbv)

        x = curves.TimeCommon
        for b in bands:
            y = curves.get(b).Mag - curves_mags.get(b).Mag
            axDM.plot(x, y, label="Delta {}".format(b))
        axDM.set_xlim(xlim)
        axDM.legend()

        plt.grid(linestyle=':', linewidth=1)
        plt.show()
コード例 #3
0
def fit_mfl(args, curves_o, bnames, fitter, name, path, t_diff, tlim,
            is_fit_sigmas):
    distance = args.distance  # pc
    z = args.redshift
    # Set distance and redshift
    if not args.distance and z > 0:
        distance = ps.cosmology_D_by_z(z) * 1e6

    # light curves
    mdl = ps.Stella(name, path=path)
    if args.is_curve_tt:  # tt
        print(
            "The curves [UBVRI+bol] was taken from tt-file {}. ".format(name) +
            "IMPORTANT: distance: 10 pc, z=0, E(B-V) = 0")
        curves_m = mdl.get_tt().read_curves()
        excluded = [bn for bn in curves_m.BandNames if bn not in bnames]
        for bn in excluded:
            curves_m.pop(bn)
    else:
        curves_m = mdl.curves(bnames,
                              z=z,
                              distance=distance,
                              ebv=args.color_excess,
                              t_beg=tlim[0],
                              t_end=tlim[1],
                              t_diff=t_diff)

    fit_result, res, dum = fitter.best_curves(curves_m,
                                              curves_o,
                                              dt0=0.,
                                              is_fit_sigmas=is_fit_sigmas)

    return curves_m, fit_result, res
コード例 #4
0
 def setUp(self):
     # name = 'ccsn2007bi1dNi6smE23bRlDC'
     name = 'cat_R500_M15_Ni006_E12'
     # name = 'cat_R1000_M15_Ni007_E15'
     path = join(dirname(abspath(__file__)), 'data', 'stella')
     self.stella = ps.Stella(name, path=path)
     self.tt = self.stella.get_tt()
コード例 #5
0
def compute_tcolor(
        name,
        path,
        bands,
        d=ps.phys.pc2cm(10.),
        z=0.,
        t_cut=(1., np.inf),
        t_diff=1.1,
):
    model = ps.Stella(name, path=path)

    if not model.is_ph:
        print("No ph-data for: " + str(model))
        return None

    if not model.is_tt:
        print("No tt-data for: " + str(model))
        return None

    # serial_spec = model.read_series_spectrum(t_diff=1.)
    # curves = serial_spec.flux_to_curves(bands, d=distance)
    serial_spec = model.get_ph(t_diff=t_diff, t_beg=t_cut[0], t_end=t_cut[1])
    # # curves = serial_spec.
    mags = serial_spec.mags_bands(bands, z=z, d=d)

    # curves = model.curves(bands, z=z, distance=d)
    # read R_ph
    tt = model.get_tt().load()
    tt = tbl_rm_equal_el(tt, 'time')
    tt = tt[np.logical_and(t_cut[0] <= tt['time'],
                           tt['time'] <= t_cut[1])]  # time cut  days

    # compute Tnu, W
    Tnu, Teff, W = compute_Tnu_w(serial_spec, tt=tt)

    # fit mags by B(T_col) and get \zeta\theta & T_col
    Tcolors, zetaR, times = compute_Tcolor_zeta(mags,
                                                tt=tt,
                                                bands=bands,
                                                freq=serial_spec.Freq,
                                                d=d,
                                                z=z)

    # show results
    res = np.zeros(len(Tcolors),
                   dtype=np.dtype({
                       'names': ['time', 'Tcol', 'zeta', 'Tnu', 'Teff', 'W'],
                       'formats': [np.float64] * 6
                   }))
    res['time'] = times
    res['Tcol'] = Tcolors
    res['zeta'] = zetaR
    res['Tnu'] = Tnu
    res['Teff'] = Teff
    res['W'] = W

    return res
コード例 #6
0
    def test_stella_curves_tbeg(self):
        name = 'cat_R500_M15_Ni006_E12'
        path = join(dirname(abspath(__file__)), 'data', 'stella')
        bands = ('U', 'B', 'V')

        mdl = ps.Stella(name, path=path)
        t_beg = 1.
        curves = mdl.curves(bands, t_beg=t_beg)

        print(ps.first(curves).Time[:3])
        self.assertTrue(
            np.any(ps.first(curves).Time >= t_beg),
            msg="There ara values Time less then t_beg = {}".format(t_beg))
コード例 #7
0
    def test_stella_curves(self):
        name = 'cat_R500_M15_Ni006_E12'
        path = join(dirname(abspath(__file__)), 'data', 'stella')
        bands = ('U', 'B', 'V')

        mdl = ps.Stella(name, path=path)
        curves = mdl.curves(bands)

        print(ps.first(curves).Time[:300])
        self.assertTrue(
            (np.array(sorted(curves.BandNames) == sorted(bands))).all(),
            msg="Error for the initial band names [%s] "
            "VS secondary BandNames are %s." %
            (' '.join(bands), ' '.join(curves.BandNames)))
コード例 #8
0
    def test_lc_bol(self):
        import matplotlib.pyplot as plt
        from scipy.integrate import simps

        m1 = ps.Stella('cat_R500_M15_Ni006_E12', path='data/stella')
        curves = m1.curves(bands=['bol'], t_diff=1.0000001)
        # ax = ps.light_curve_plot.curves_plot(curves, xlim=(0.7, 1), ylim=(-14, -24), is_line=False)
        ax = ps.lcp.curves_plot(curves,
                                xlim=(-10, 155),
                                ylim=(-14, -24),
                                is_line=False)
        # tt
        tt1 = m1.get_tt().load()
        t = tt1['time']
        ax.plot(t,
                tt1['Mbol'],
                label='tt-bolometric LC ',
                color='red',
                lw=2,
                ls=':')
        # ph
        if True:
            ph = m1.get_ph()
            m_bol = []
            for t, spec in ph:
                lum = simps(spec.Flux[::-1], spec.Freq[::-1])
                bol = 4.75 - 2.5 * np.log10(np.abs(lum) / 3.86e33)
                m_bol.append(bol)
            ax.plot(ph.Time,
                    m_bol,
                    label='ph-bolometric LC ',
                    color='green',
                    lw=2,
                    ls='-.')
        ax.legend()
        plt.show()
        import warnings
        warnings.warn("Should be check for shorck breakout")
コード例 #9
0
    def test_stella_curves_VS_tt_plot(self):
        from pystella.rf.band import colors
        name = 'cat_R500_M15_Ni006_E12'
        path = join(dirname(abspath(__file__)), 'data', 'stella')
        bands = ('U', 'B', 'V', 'R', 'I')

        mdl = ps.Stella(name, path=path)
        curves = mdl.curves(bands)

        tt = mdl.get_tt().load()

        ax = ps.lcp.curves_plot(curves)
        for bname in bands:
            ax.plot(tt['time'],
                    tt['M' + bname],
                    label="tt " + bname,
                    color=colors(bname),
                    marker='*',
                    markersize=3,
                    ls='')
        ax.legend()

        plt.grid(linestyle=':', linewidth=1)
        plt.show()
コード例 #10
0
def fit_mfl_vel(args,
                curves_o,
                vels_o,
                bnames,
                fitter,
                name,
                path,
                t_diff,
                tlim,
                Vnorm=1e8,
                A=0.):
    distance = args.distance  # pc
    z = args.redshift
    # Set distance and redshift
    if not args.distance and z > 0:
        distance = ps.cosmology_D_by_z(z) * 1e6

    tss_m = ps.SetTimeSeries("Models")
    tss_o = ps.SetTimeSeries("Obs")
    curves_m = None

    # light curves
    if curves_o is not None:
        mdl = ps.Stella(name, path=path)
        if args.is_curve_tt:  # tt
            print("The curves [UBVRI+bol] was taken from tt-file {}. ".format(
                name) + "IMPORTANT: distance: 10 pc, z=0, E(B-V) = 0")
            curves_m = mdl.get_tt().read_curves()
            excluded = [bn for bn in curves_m.BandNames if bn not in bnames]
            for bn in excluded:
                curves_m.pop(bn)
        else:
            curves_m = mdl.curves(bnames,
                                  z=z,
                                  distance=distance,
                                  ebv=args.color_excess,
                                  t_beg=tlim[0],
                                  t_end=tlim[1],
                                  t_diff=t_diff)

        for lc in curves_m:
            tss_m.add(lc)
            # tss_m[lc.Band.Name] = lc

        for lc in curves_o:
            l, tshift, mshift = lc.clone()
            tss_o.add(l)
            # tss_o[lc.Band.Name], tshift, mshift = lc.shifted()
            # tshifts[lc.Band.Name] = tshift

    vel_m = None
    # compute model velocities
    try:
        tbl = ps.vel.compute_vel_swd(name, path)
        # tbl = velocity.compute_vel_res_tt(name, path)
        vel_m = ps.vel.VelocityCurve('Vel', tbl['time'], tbl['vel'] / Vnorm)
        if vel_m is None:
            raise ValueError('Problem with vel_m.')
    except ValueError as ext:
        print(ext)

    # velocity
    if curves_m is not None:
        # To increase the weight of Velocities with fitting
        for i in range(curves_m.Length):
            for vel_o in vels_o:
                key = 'Vel{:d}{}'.format(i, vel_o.Name)
                tss_m.add(vel_m.copy(name=key))
                tss_o.add(vel_o.copy(name=key))
                # tss_o[key] = vel_o
                # tss_m[key] = vel_m
    else:
        # tss_m.add(vel_m.copy(name='Vel'))
        for i, vel_o in enumerate(vels_o):
            key = 'Vel{:d}{}'.format(i, vel_o.Name)
            tss_m.add(vel_m.copy(name=key))
            tss_o.add(vel_o.copy(name=key))
        # tss_o.add(vels_o.copy(name='Vel'))
        # tss_m['Vel'] = vel_m
        # tss_o['Vel'] = vels_o

    # fit
    res = fitter.fit_tss(tss_m, tss_o, A=A)

    return curves_m, res, vel_m
コード例 #11
0
def main():
    import sys
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        plt = None
    # try:
    #     import seaborn as sns
    #     # sns.set()
    #     sns.set_style("ticks")
    #
    # except ImportError:
    #     sns = None
    fsave = None
    dic_axes = None
    ylim_par = None
    is_legend = True
    ls_cycle = cycle(ps.linestyles_extend)
    marker_cycle = cycle(ps.lcp.markers)

    parser = get_parser()
    args, unknownargs = parser.parse_known_args()

    if args.ylim_par:
        ylim_par = ps.str2interval(args.ylim_par, llim=0, rlim=9.9, sep=':')

    if args.path:
        pathDef = os.path.expanduser(args.path)
    else:
        pathDef = os.getcwd()
    # Set model names
    names = []
    if args.input:
        for nm in args.input:
            names.append(nm[0])  # remove extension
    else:
        if len(unknownargs) > 0:
            names.append(unknownargs[0])

    if len(names) == 0:
        # logger.error(" No data. Use key '-i' ")
        parser.print_help()
        sys.exit(2)

    times = list(map(float, args.times.split(':')))

    for i, nm in enumerate(names):
        path, name = os.path.split(nm)
        if len(path) == 0:
            path = pathDef
        name = name.replace('.swd', '')  # remove extension

        print("Run swd-model %s %s for %s moments" % (path, name, args.times))
        stella = ps.Stella(name, path=path)
        swd = stella.get_swd().load()

        if args.is_uph:
            logger.info(' Compute and print uph')
            duph = swd.params_ph()
            # print uph
            print(duph.keys())
            for row in zip(*duph.values()):
                print(['{:12f}'.format(x) for x in row])
            # save uph
            if args.is_write:
                fsave = os.path.join(path, "{0}.uph".format(name))
                print("Save uph to {0}".format(fsave))
                uph_write(duph, fsave)
            else:
                fig = plot_uph(duph, vnorm=args.vnorm, label=name)
                if args.is_save:
                    fsave = os.path.expanduser("~/uph_{0}.pdf".format(name))
        elif args.is_mult:
            make_cartoon(swd,
                         times,
                         vnorm=args.vnorm,
                         rnorm=args.rnorm,
                         lumnorm=args.lumnorm,
                         is_legend=is_legend)
        else:
            # ls = next(ls_cycle) # skip solid
            fig, dic_axes = ps.lcp.plot_shock_details(swd,
                                                      times=times,
                                                      vnorm=args.vnorm,
                                                      rnorm=args.rnorm,
                                                      tnorm=args.tnorm,
                                                      lumnorm=args.lumnorm,
                                                      is_legend=is_legend,
                                                      is_axes=True,
                                                      ylim_par=ylim_par,
                                                      dic_axes=dic_axes,
                                                      ls=next(ls_cycle))
            if args.frho is not None:
                ps.lcp.plot_swd_chem(dic_axes, args.frho, stella.Path)

            if args.tau:
                ps.Band.load_settings()
                # Set band names
                bnames = ('B', )
                if args.bnames:
                    ps.Band.load_settings()
                    bnames = []
                    for bname in args.bnames.split(':'):
                        if not ps.band.is_exist(bname):
                            print('No such band: ' + bname)
                            parser.print_help()
                            sys.exit(2)
                        bnames.append(bname)

                ps.lcp.plot_swd_tau(dic_axes,
                                    stella,
                                    times=times,
                                    bnames=bnames,
                                    tau_ph=args.tau,
                                    is_obs_time=False,
                                    marker=next(marker_cycle),
                                    vnorm=args.vnorm,
                                    tnorm=args.tnorm)

            if args.is_save:
                fsave = os.path.expanduser("~/swd_{0}_t{1}.pdf".format(
                    name, str.replace(args.times, ':', '-')))
        #  Save the figure or show
    if fsave is not None:
        print("Save plot to {0}".format(fsave))
        fig.savefig(fsave, bbox_inches='tight')
    else:
        # plt.ion()
        plt.show()
コード例 #12
0
def main(name=None, model_ext='.ph'):
    import sys
    import os
    import getopt
    import fnmatch

    is_quiet = False
    is_save_mags = False
    is_save_plot = False
    is_plot_time_points = False
    is_extinction = False
    is_curve_old = False
    is_curve_tt = False
    is_axes_right = False
    is_grid = False

    vel_mode = None
    view_opts = ('single', 'grid', 'gridl', 'gridm')
    opt_grid = view_opts[0]
    t_diff = 1.01
    linestyles = ['-']

    label = None
    fsave = None
    fname = None
    # path = ''
    path = os.getcwd()
    z = 0.
    e = 0.
    magnification = 1.
    distance = None  # 10.  # pc
    callback = None
    xlim = None
    ylim = None
    xtype = 'lin'
    # bshift = None
    bshift = None

    ps.band.Band.load_settings()

    try:
        opts, args = getopt.getopt(sys.argv[1:],
                                   "hqtb:c:d:e:g:i:l:o:m:p:v:s:w:x:y:z:",
                                   ['dt=', 'curve-old', 'curve-tt'])
    except getopt.GetoptError as err:
        print(str(err))  # will print something like "option -a not recognized"
        usage()
        sys.exit(2)

    if len(args) > 0:
        path, name = os.path.split(str(args[0]))
        path = os.path.expanduser(path)
        name = name.replace('.ph', '')
    elif len(opts) == 0:
        usage()
        sys.exit(2)

    bnames = ['U', 'B', 'V', 'R']
    # bands = ['U', 'B', 'V', 'R', "I", 'UVM2', "UVW1", "UVW2", 'g', "r", "i"]

    for opt, arg in opts:
        if opt == '-e':
            e = float(arg)
            is_extinction = True
            continue
        if opt == '-b':
            bnames = []
            bshift = {}
            for b in str(arg).split('-'):
                # extract band shift
                if ':' in b:
                    bname, shift = b.split(':')
                    if '_' in shift:
                        bshift[bname] = -float(shift.replace('_', ''))
                    else:
                        bshift[bname] = float(shift)
                else:
                    bname = b
                if not ps.band.is_exist(bname):
                    print('No such band: ' + bname)
                    sys.exit(2)
                bnames.append(bname)
            continue
        if opt == '-c':
            c = ps.cb.lc_wrapper(str(arg))
            if callback is not None:
                c = ps.cb.CallBackArray((callback, c))
            callback = c
            continue
        if opt == '--dt':
            t_diff = float(arg)
            continue
        if opt == '-q':
            is_quiet = True
            continue
        if opt == '-g':
            opt_grid = str.strip(arg).lower()
            if opt_grid not in view_opts:
                print('No such view option: {0}. Can be '.format(
                    opt_grid, '|'.join(view_opts)))
                sys.exit(2)
            continue
        if opt == '-s':
            is_save_plot = True
            fsave = str.strip(arg)
            continue
        if opt == '--curve-old':
            is_curve_old = True
            continue
        if opt == '--curve-tt':
            is_curve_tt = True
            continue
        if opt == '-w':
            is_save_mags = True
            if arg != '1':
                fname = arg.strip()
            continue
        if opt == '-t':
            is_plot_time_points = True
            continue
        if opt == '-v':
            vel_mode = arg.strip()
            continue
        if opt == '-m':
            magnification = float(arg)
            continue
        if opt == '-z':
            z = float(arg)
            continue
        if opt == '-d':
            distance = float(arg)
            continue
        if opt == '-l':
            label = str.strip(arg)
            continue
        if opt == '-x':
            if 'log' in arg:
                xtype = 'log'
                s12 = arg.replace(xtype, '')
                s12 = s12.rstrip(':')
            else:
                s12 = arg
            # print(f's12= {s12}')
            xlim = ps.str2interval(s12, llim=0, rlim=float('inf'))
            continue
        if opt == '-y':
            ylim = ps.str2interval(arg, llim=-10, rlim=-22)
            continue
        if opt == '-p':
            path = os.path.expanduser(str(arg))
            if not (os.path.isdir(path) and os.path.exists(path)):
                print("No such directory: " + path)
                sys.exit(2)
            continue
        elif opt == '-h':
            usage()
            sys.exit(2)

    # Set model names
    names = []
    is_set_model = False
    if name is None:
        for opt, arg in opts:
            if opt == '-i':
                nm = os.path.splitext(os.path.basename(str(arg)))[0]
                if os.path.exists(os.path.join(path, nm + model_ext)):
                    names.append(nm)
                else:
                    files = [
                        f for f in os.listdir(path)
                        if os.path.isfile(os.path.join(path, f))
                        and fnmatch.fnmatch(f, arg)
                    ]
                    for f in files:
                        nm = os.path.splitext(os.path.basename(f))[0]
                        names.append(nm)
                    names = list(set(names))
                    print('Input {} models: {}'.format(len(names),
                                                       ' '.join(names)))
                is_set_model = True
    else:
        print(name)
        names.append(name)
        is_set_model = True

    if len(names) == 0 and not is_set_model:  # run for all files in the path
        names = ps.path.get_model_names(path, model_ext)

    # Set distance and redshift
    if distance is None:
        if z > 0:
            distance = ps.cosmology_D_by_z(z) * 1e6
            print(
                "Plot magnitudes on z={0:F} with D(z)={1:E} pc (cosmological)".
                format(z, distance))
        else:
            distance = 10  # pc
    else:
        print("Plot magnitudes on z={0:F} at distance={1:E}".format(
            z, distance))
        if z > 0:
            print("  Cosmology D(z)={0:E} Mpc".format(ps.cosmology_D_by_z(z)))

    # Run models
    if len(names) > 0:
        models_mags = {}  # dict((k, None) for k in names)
        models_vels = {}  # dict((k, None) for k in names)
        for i, name in enumerate(names):
            mdl = ps.Stella(name, path=path)

            if is_curve_tt:  # tt
                print(
                    "The curves [UBVRI+bol] was taken from tt-file. IMPORTANT: distance: 10 pc, z=0, E(B-V) = 0"
                )
                curves = mdl.get_tt().read_curves()
            elif is_curve_old:  # old
                print("Use old proc for Stella magnitudes")
                curves = ps.lcf.curves_compute(name,
                                               path,
                                               bnames,
                                               z=z,
                                               distance=distance,
                                               magnification=magnification,
                                               t_diff=t_diff)
                if is_extinction:
                    curves = ps.lcf.curves_reddening(curves, ebv=e, z=z)
            else:
                curves = mdl.curves(bnames,
                                    z=z,
                                    distance=distance,
                                    ebv=e,
                                    magnification=magnification,
                                    t_diff=t_diff)

            models_mags[name] = curves

            if vel_mode is not None:
                if vel_mode == 'swd':
                    vels = ps.vel.compute_vel_swd(name, path, z=z)
                elif vel_mode.startswith('ttres'):
                    vels = ps.vel.compute_vel_res_tt(name,
                                                     path,
                                                     z=z,
                                                     is_info=False,
                                                     is_new_std='old'
                                                     not in vel_mode.lower())
                else:
                    raise ValueError(
                        'This mode [{}] for velocity is not supported'.format(
                            vel_mode))

                if vels is None:
                    sys.exit("Error: no data for: %s in %s" % (name, path))
                models_vels[name] = vels
                print("[%d/%d] Done mags & velocity for %s" %
                      (i + 1, len(names), name))
            else:
                models_vels = None
                print("[%d/%d] Done mags for %s" % (i + 1, len(names), name))

        if label is None:
            if callback is not None:
                label = "ts=%s z=%4.2f D=%6.2e mu=%3.1f ebv=%4.2f" % (
                    callback.arg_totext(0), z, distance, magnification, e)
            else:
                label = "z=%4.2f D=%6.2e mu=%3.1f ebv=%4.2f" % (
                    z, distance, magnification, e)

        # save curves to files
        if is_save_mags:
            for curves in models_mags.values():
                if fname is None:
                    fname = os.path.join(path, curves.Name)
                    if z > 0.:
                        fname = '{}_Z{:.2g}'.format(fname, z)
                    if distance > 10.:
                        fname = '{}_D{:.2e}'.format(fname, distance)
                    if e > 0:
                        fname = '{}_E{:0.2g}'.format(fname, e)
                    fname = '{}{}'.format(fname, '.ubv')
                if ps.lcf.curves_save(curves, fname):
                    print("Magnitudes of {} have been saved to {}".format(
                        curves.Name, fname))
                else:
                    print("Error with Magnitudes saved to {}".format(
                        curves.Name, fname))
        # plot
        elif not is_quiet:
            if opt_grid in view_opts[1:]:
                sep = opt_grid[:-1]
                if sep == 'd':
                    sep = 'l'  # line separator
                fig = plot_grid(models_mags,
                                bnames,
                                call=callback,
                                xlim=xlim,
                                xtype=xtype,
                                ylim=ylim,
                                title=label,
                                sep=sep,
                                is_grid=False)
            else:
                # linestyles = ['--', '-.', '-', ':']
                fig = plot_all(models_vels,
                               models_mags,
                               bnames,
                               d=distance,
                               call=callback,
                               xlim=xlim,
                               xtype=xtype,
                               ylim=ylim,
                               is_time_points=is_plot_time_points,
                               title=label,
                               bshift=bshift,
                               is_axes_right=is_axes_right,
                               is_grid=is_grid,
                               legloc=1,
                               fontsize=14,
                               lines=linestyles)
                # lcp.setFigMarkersBW(fig)
                # lcp.setFigLinesBW(fig)

            if is_save_plot:
                if len(fsave) == 0:
                    if vel_mode is not None:
                        fsave = "ubv_vel_%s" % name
                    else:
                        fsave = "ubv_%s" % name

                if is_extinction and e > 0:
                    fsave = "%s_e0%2d" % (fsave, int(e * 100)
                                          )  # bad formula for name

                fsave = os.path.expanduser(fsave)
                fsave = os.path.splitext(fsave)[0] + '.pdf'

                print("Save plot to %s " % os.path.abspath(fsave))
                fig.savefig(fsave, bbox_inches='tight', format='pdf')
            else:
                import matplotlib.pyplot as plt
                # plt.ion()
                plt.show()
                # plt.pause(0.0001)
                # print('')
                # input("===> Hit <return> to quit")

    else:
        print(
            "There are no such models in the directory: %s with extension: %s "
            % (path, model_ext))
コード例 #13
0
 def setUp(self):
     name = 'levJ_R450_M15_Ni004_E10'
     path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data',
                         'stella')
     stella = ps.Stella(name, path=path)
     self.tau = stella.get_tau()
コード例 #14
0
    def test_FitMCMC_best_curves_dtdm_SN1999em(self):
        from pystella.rf import light_curve_func as lcf
        from pystella.rf import light_curve_plot as lcp
        # Get observations
        D = 11.5e6  # pc
        # MD0 = -5. * np.log10(D) + 5
        ebv_sn = 0.1
        # dm = -30.4  # D = 12.e6 pc
        curves_obs = sn1999em.read_curves()
        # curves_obs.set_mshift(MD0)

        # Get model
        name = 'cat_R500_M15_Ni006_E12'
        path = join(dirname(dirname(abspath(__file__))), 'data', 'stella')

        curves_mdl = ps.Stella(name, path=path).curves(curves_obs.BandNames,
                                                       distance=D,
                                                       ebv=ebv_sn)

        # fit
        is_debug = True  # True
        # fitter = FitLcMcmc()
        fitter = ps.FitMCMC()
        fitter.is_info = True
        if is_debug:
            fitter.is_debug = is_debug
            fitter.nwalkers = 100  # number of MCMC walkers
            fitter.nburn = 20  # "burn-in" period to let chains stabilize
            fitter.nsteps = 200  # number of MCMC steps to take
        threads = 3

        # fitter.is_debug = False
        # fit_result, res, samples \
        fit_result, res, (th, ep,
                          em), sampler = fitter.best_curves(curves_mdl,
                                                            curves_obs,
                                                            dt0=0.,
                                                            dm0=0.,
                                                            threads=threads,
                                                            is_sampler=True)
        nb = curves_obs.Length
        dt, dm, sigs = fitter.theta2arr(th, nb)
        ep_dt, ep_dm, ep_sigs = fitter.theta2arr(ep, nb)
        em_dt, em_dm, em_sigs = fitter.theta2arr(em, nb)

        samples = sampler.flatchain[fitter.nburn:, :]
        fig = fitter.plot_corner(samples,
                                 bnames=curves_obs.BandNames,
                                 bins=55,
                                 alpha=0.5,
                                 verbose=fitter.is_info)

        # print
        txt = 'chi2= {:.1f}  BIC= {:.1f} AIC= {:.1f} measure= {:.3f}\n'.\
            format(res['chi2'], res['bic'], res['aic'], res['measure'])
        txt += '\n dt= {:.1f} ^{{{:.1f}}}_{{{:.1f}}} '.format(dt, ep_dt, em_dt)
        txt += '\n dm= {:.2f} ^{{{:.2f}}}_{{{:.2f}}} '.format(dm, ep_dm, em_dm)
        for i, bname in enumerate(curves_obs.BandNames):
            txt += '\n sigma_{{{:s}}} = {:.2f}^{{{:.2f}}}_{{{:.2f}}} '. \
                format(bname, sigs[i], ep_sigs[i], em_sigs[i])
        print(txt)

        curves_mdl.set_tshift(dt)
        curves_mdl.set_mshift(dm)

        ax = None
        errs = {}
        for i, bn in enumerate(curves_mdl.BandNames):
            errs[bn] = [
                sigs[i]
            ] * curves_mdl[bn].Length  # The error are the same for all points
        curves_clone = curves_mdl.clone(err=errs)
        ax = ps.curves_plot(curves_clone,
                            ax=ax,
                            is_legend=False,
                            is_fill=True,
                            alpha=0.2)
        # Obs
        lt = {lc.Band.Name: 'o' for lc in curves_obs}
        lcp.curves_plot(curves_obs, ax, lt=lt, xlim=(-10, 300), is_line=False)

        ax.text(0.05, 0.05, txt, transform=ax.transAxes)
        plt.show()
コード例 #15
0
    def test_FitMPF_best_curves_SN1999em(self):
        from pystella.rf import light_curve_func as lcf
        from pystella.rf import light_curve_plot as lcp
        # Get observations
        D = 11.5e6  # pc
        ebv_sn = 0.
        # ebv_sn = 0.1
        # dm = -5. * np.log10(D) + 5
        # dm = -30.4  # D = 12.e6 pc
        curves_obs = sn1999em.read_curves()

        # Get model
        name = 'cat_R500_M15_Ni006_E12'
        path = join(dirname(dirname(abspath(__file__))), 'data', 'stella')

        curves_mdl = ps.Stella(name, path=path).curves(curves_obs.BandNames,
                                                       distance=D,
                                                       ebv=ebv_sn)

        # fit
        fitter = ps.FitMPFit()
        fitter.is_info = True
        fitter.is_debug = True
        fitter.is_quiet = True
        fit_result, res, dum = fitter.best_curves(
            curves_mdl,
            curves_obs,
            dt0=0.,
        )
        # fit_result, res, dum = fitter.best_curves(curves_mdl, curves_obs, dt0=0., dm0=0.)
        dt, dtsig = res['dt'], res['dtsig']
        dm, dmsig = res['dm'], res['dmsig']
        # plot model

        curves_mdl.set_tshift(dt)
        ax = lcp.curves_plot(curves_mdl)

        lt = {lc.Band.Name: 'o' for lc in curves_obs}
        lcp.curves_plot(curves_obs, ax, lt=lt, xlim=(-10, 300), is_line=False)
        # print
        txt = '{:10} {:.2f} +- {:.4f}'.format('dt:', res['dt'], res['dtsig'])
        txt += '\n{:10} {:.2f} +- {:.4f}'.format('dm:', res['dm'],
                                                 res['dmsig'])
        print(txt)
        title_font = {
            'size': '12',
            'color': 'black',
            'weight': 'normal',
            'verticalalignment': 'bottom'
        }
        ax.text(0.03,
                0.95,
                txt,
                transform=ax.transAxes,
                bbox={
                    'facecolor': 'none',
                    'alpha': 0.2,
                    'edgecolor': 'none'
                },
                **title_font)
        plt.show()
コード例 #16
0
ファイル: tau.py プロジェクト: baklanovp/pystella
def main():
    import os
    import sys
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        plt = None

    ps.Band.load_settings()

    model_ext = '.tau'

    parser = get_parser()
    args, unknownargs = parser.parse_known_args()

    path = os.getcwd()
    if args.path:
        path = os.path.expanduser(args.path)

    # Set model names
    fname = None
    if args.input:
        fname = args.input.strip()
        fname = fname.replace(model_ext, '')

    if fname is None:
        parser.print_help()
        sys.exit(2)

    model = ps.Stella(fname, path=path)

    if not model.is_tau:
        print("No tau-data for: " + str(model))
        return None

    fig = None
    xlim = None
    fplot = None
    print('\n Arguments')
    times = str2float(args.times)
    print(' The time moments: ', args.times)
    print(' The optical depth ', args.tau_ph)
    if args.phot:
        print(' The photospheric parameters ', args.phot)
    if args.xlim is not None:
        xlim = str2float(args.xlim)
        print(" xlim: ", xlim)
    # Set band names
    bnames = ('B', )
    ps.Band.load_settings()
    if args.bnames:
        bnames = []
        for bname in args.bnames.split('-'):
            if not ps.band.is_exist(bname):
                print('No such band: ' + bname)
                parser.print_help()
                sys.exit(2)
            bnames.append(bname)

    tau = model.get_tau().load(is_info=False)
    print('\n Loaded data from {}'.format(tau.FName))
    print('Model has Nzone= {} Ntimes= {}'.format(tau.Nzon, tau.Ntimes))
    print("The model time interval: {:.3e} - {:3e} days".format(
        min(tau.Times), max(tau.Times)))
    print("The bnames are  {}".format(', '.join(bnames)))
    # print(tau.Wl2angs)
    # tau = b.Tau
    # print(tau.shape)

    ###
    # Plot
    if args.phot:
        pars = args.phot.split(':')
        if isinstance(pars, str):
            pars = [pars]
        pars_data = [p.replace('log', '') for p in pars]
        tau_data = tau.params_ph(pars=pars_data,
                                 moments=times,
                                 tau_ph=args.tau_ph)

        if args.write_prefix:
            fwrite = os.path.expanduser(args.write_prefix)
            tau.data_save(fwrite, tau_data, pars_data)
        else:
            # Print parameters
            print('\nPhotospheric parameters:')
            for ii, p in enumerate(pars_data):
                print('{:9s} {}'.format(
                    't_real', ' '.join([f'{p}_{b:10s}' for b in bnames])))
                for i, (t, freq, y) in enumerate(tau_data[p]):
                    s = '{:9.4f} '.format(t)
                    for bname in bnames:
                        b = ps.band.band_by_name(bname)
                        fr_eff = b.freq_eff
                        idx = (np.abs(freq - fr_eff)).argmin()
                        s += ' {:10e}'.format(y[idx])
                    print(s)
            # Plot
            fig = plot_tau_phot(tau_data,
                                pars,
                                tau_ph=args.tau_ph,
                                xlim=xlim,
                                title=tau.Name,
                                bnames=bnames)
            fplot = os.path.expanduser("~/tau_{}_{}.pdf".format(
                fname, str.replace(args.phot, ':', '-')))
    else:
        fig = plot_tau_moments(tau, moments=times, xlim=xlim)

    if args.is_save:
        if fplot is None:
            fplot = os.path.expanduser("~/tau_{0}_t{1}.pdf".format(
                fname, str.replace(args.times, ':', '-')))
        print("Save plot to {0}".format(fplot))
        fig.savefig(fplot, bbox_inches='tight')
    else:
        plt.show()
コード例 #17
0
def main():
    is_save_plot = False
    is_kcorr = False
    is_fit = False
    is_fit_wl = False
    is_write = False

    fsave = None
    fplot = None
    z_sn = 0.
    bn_rest = None
    bn_obs = None

    try:
        opts, args = getopt.getopt(sys.argv[1:], "b:fhsup:i:k:o:t:w:x:")
    except getopt.GetoptError as err:
        print(str(err))  # will print something like "option -a not recognized"
        usage()
        sys.exit(2)

    name = ''
    path = os.getcwd()
    ps.Band.load_settings()
    # name = 'cat_R500_M15_Ni006_E12'

    if not name:
        if len(opts) == 0:
            usage()
            sys.exit(2)
        for opt, arg in opts:
            if opt == '-i':
                name = str(arg)
                break

    # set_bands = ['B-V']
    # set_bands = ['B-V', 'B-V-I']
    # set_bands = ['U-B', 'U-B-V', 'B-V']
    t_ab = None
    wl_ab = None
    set_bands = ['B-V', 'B-V-I', 'V-I']
    # set_bands = ['B-V', 'B-V-I', 'V-I', 'J-H-K']
    times = [5., 15., 30., 60., 90., 120.]

    for opt, arg in opts:
        if opt == '-b':
            set_bands = str(arg).split('_')
            for bset in set_bands:
                for b in bset.split('-'):
                    if not ps.band.is_exist(b):
                        print('No such band: ' + b)
                        sys.exit(2)
            continue
        if opt == '-w':
            is_write = True
            fsave = arg
            continue
        if opt == '-s':
            is_save_plot = True
            if len(arg) > 0:
                fplot = str(arg).strip()
            continue
        if opt == '-x':
            wl_ab = interval2float(arg)
            # wl_ab = [np.float(s) for s in (str(arg).split(':'))]
            continue
        if opt == '-t':
            t_ab = list(map(float, arg.split(':')))  # interval2float(arg)
            if len(t_ab) > 1:
                times = t_ab
            continue
        if opt == '-o':
            ops = str(arg).split(':')
            is_fit = "fit" in ops
            is_fit_wl = "wl" in ops
            continue
        if opt == '-k':
            ops = str(arg).split(':')
            if len(ops) == 3:
                z_sn = float(ops[0])
                bn_rest = ops[1].strip()
                bn_obs = ops[2].strip()
                is_kcorr = True
            else:
                raise ValueError(
                    'Args: {} should be string as "z:Srest:Sobs"'.format(arg))
            continue
        if opt == '-p':
            path = os.path.expanduser(str(arg))
            if not (os.path.isdir(path) and os.path.exists(path)):
                print("No such directory: " + path)
                sys.exit(2)
            continue
        elif opt == '-h':
            usage()
            sys.exit(2)

    if not name:
        print("No model. Use key -i.")
        sys.exit(2)

    model = ps.Stella(name, path=path)
    series = model.get_ph(t_diff=1.05)

    if not model.is_ph:
        print("No ph-data for: " + str(model))
        return None

    if is_fit:
        if is_write:
            if fsave is None or len(fsave) == 0:
                fsave = "spec_%s" % name
            print("Save series to %s " % fsave)
            series_cut = series.copy(t_ab=t_ab, wl_ab=wl_ab)
            write_magAB(series_cut)
            sys.exit(2)

        if not model.is_tt:
            print("Error in fit-band: no tt-data for: " + str(model))
            sys.exit(2)
        series = model.get_ph(t_diff=1.05)
        series_cut = series.copy(t_ab=t_ab, wl_ab=wl_ab)
        fig = plot_fit_bands(model, series_cut, set_bands, times)
    elif is_kcorr:
        times, kcorr = [], []
        for t, k in ps.rf.rad_func.kcorrection(series, z_sn, bn_rest, bn_obs):
            times.append(t)
            kcorr.append(k)
        if is_write:
            if fsave is None or len(fsave) == 0 or fsave == '1':
                fsave = os.path.join(os.path.expanduser('~/'),
                                     "kcorr_%s" % name) + '.txt'
            kcorr_save(fsave, times, kcorr)
            sys.exit(3)
        else:
            fig = plot_kcorr(times, kcorr)
    elif is_fit_wl:
        if not model.is_tt:
            print("Error in fit-wave: no tt-data for: " + str(model))
            sys.exit(2)
        if is_write:
            if fsave is None or len(fsave) == 0 or fsave == '1':
                fsave = os.path.join(os.path.expanduser('~/'),
                                     "temp_%s" % name) + '.txt'
            series = series.copy(t_ab=t_ab)
            plot_fit_wl(model, series, wl_ab, times,
                        fsave=fsave)  # just save data
            sys.exit(3)

        fig = plot_fit_wl(model, series, wl_ab, times)
    else:
        series = model.get_ph(t_diff=1.05)
        series_cut = series.copy(t_ab=t_ab, wl_ab=wl_ab)
        fig = plot_spec_poly(series_cut)
        print("Plot spectral F(t,nu): " + str(model))

    if fig is not None:
        if is_save_plot:
            if fplot is None or len(fplot) == 0:
                fplot = "spec_%s" % name
            d = os.path.expanduser('~/')
            fplot = os.path.join(d, os.path.splitext(fplot)[0]) + '.pdf'

            print("Save plot to %s " % fplot)
            fig.savefig(fplot, bbox_inches='tight')
        else:
            # plt.grid()
            plt.show()