示例#1
0
    def test_stella_curves_tbeg(self):
        name = 'cat_R500_M15_Ni006_E12'
        path = join(dirname(abspath(__file__)), 'data', 'stella')
        bands = ('U', 'B', 'V')

        mdl = ps.Stella(name, path=path)
        t_beg = 1.
        curves = mdl.curves(bands, t_beg=t_beg)

        print(ps.first(curves).Time[:3])
        self.assertTrue(
            np.any(ps.first(curves).Time >= t_beg),
            msg="There ara values Time less then t_beg = {}".format(t_beg))
示例#2
0
    def test_stella_curves(self):
        name = 'cat_R500_M15_Ni006_E12'
        path = join(dirname(abspath(__file__)), 'data', 'stella')
        bands = ('U', 'B', 'V')

        mdl = ps.Stella(name, path=path)
        curves = mdl.curves(bands)

        print(ps.first(curves).Time[:300])
        self.assertTrue(
            (np.array(sorted(curves.BandNames) == sorted(bands))).all(),
            msg="Error for the initial band names [%s] "
            "VS secondary BandNames are %s." %
            (' '.join(bands), ' '.join(curves.BandNames)))
示例#3
0
def plot_curves_vel(curves_o, vels_o, res_models, res_sorted, vels_m,
                    **kwargs):
    from pystella.rf import light_curve_plot as lcp
    from matplotlib import pyplot as plt

    font_size = kwargs.get('font_size', 10)
    linewidth = kwargs.get('linewidth', 2.0)
    markersize = kwargs.get('markersize', 5)
    xlim = kwargs.get('xlim', None)

    ylim = None
    num = len(res_sorted)
    nrow = int(num / 2.1) + 1
    ncol = 2
    fig = plt.figure(figsize=(12, nrow * 4))
    plt.matplotlib.rcParams.update({'font.size': font_size})
    tshift_lc = 0.
    if curves_o is not None:
        tshift_lc = ps.first(curves_o).tshift
    tshift_vel = 0.
    if vels_o is not None:
        tshift_vel = ps.first(vels_o).tshift

    i = 0
    for k, v in res_sorted.items():
        i += 1
        #  Plot UBV
        axUbv = fig.add_subplot(nrow, ncol, i)

        tshift_best = v.tshift
        axUbv.text(0.99,
                   0.94,
                   k,
                   horizontalalignment='right',
                   transform=axUbv.transAxes)
        axUbv.text(0.98,
                   0.85,
                   "dt={:.2f}".format(tshift_best),
                   horizontalalignment='right',
                   transform=axUbv.transAxes)
        axUbv.text(0.01,
                   0.05,
                   "$\chi^2: {:.2f}$".format(v.measure),
                   horizontalalignment='left',
                   transform=axUbv.transAxes,
                   bbox=dict(facecolor='green', alpha=0.3))

        curves = res_models[k]
        if curves is not None:
            lcp.curves_plot(curves,
                            ax=axUbv,
                            figsize=(12, 8),
                            linewidth=1,
                            is_legend=False)
            lt = {lc.Band.Name: 'o' for lc in curves_o}
            curves_o.set_tshift(tshift_lc + tshift_best)
            if xlim is None:
                xlim = axUbv.get_xlim()
            lcp.curves_plot(curves_o,
                            axUbv,
                            xlim=xlim,
                            lt=lt,
                            markersize=2,
                            is_legend=False,
                            is_line=False)
            # legend
            axUbv.legend(curves.BandNames,
                         loc='lower right',
                         frameon=False,
                         ncol=min(5, len(curves.BandNames)),
                         fontsize='small',
                         borderpad=1)

        if i % ncol == 0:
            axUbv.yaxis.tick_right()
            axUbv.set_ylabel('')
            axUbv.yaxis.set_ticks([])
        else:
            axUbv.yaxis.tick_left()
            axUbv.yaxis.set_label_position("left")
        axUbv.grid(linestyle=':')
        #  Plot Vel
        axVel = axUbv.twinx()
        axVel.set_ylim((0., 29))
        axVel.set_ylabel('Velocity [1000 km/s]')

        ts_vel = vels_m[k]
        x = ts_vel.Time
        y = ts_vel.V
        axVel.plot(x,
                   y,
                   label='Vel  %s' % k,
                   color='blue',
                   ls="-",
                   linewidth=linewidth)
        # Obs. vel

        markers_cycler = cycle(markers_style)
        for vel_o in vels_o:
            vel_o.tshift = tshift_vel + tshift_best
            marker = next(markers_cycler)
            if vel_o.IsErr:
                yyerr = abs(vel_o.Err)
                axVel.errorbar(vel_o.Time,
                               vel_o.V,
                               yerr=yyerr,
                               label='{0}'.format(vel_o.Name),
                               color="blue",
                               ls='',
                               markersize=markersize,
                               fmt=marker,
                               markeredgewidth=2,
                               markerfacecolor='none',
                               markeredgecolor='blue')
            else:
                axVel.plot(vel_o.Time,
                           vel_o.V,
                           label=vel_o.Name,
                           color='blue',
                           ls='',
                           marker=marker,
                           markersize=markersize)

    fig.subplots_adjust(wspace=0, hspace=0)
    return fig
示例#4
0
def main():
    import os
    import sys

    model_ext = '.ph'
    n_best = 15
    ps.Band.load_settings()
    is_set_model = False
    parser = get_parser()
    args, unknownargs = parser.parse_known_args()

    path = os.getcwd()
    if args.path:
        path = os.path.expanduser(args.path)
        # print(f'-p: {path}')

    # Set model names
    names = []
    if args.input:
        for arg in args.input:
            names.extend(arg2names(arg=arg, path=path, ext=model_ext))
            is_set_model = True
        names = list(set(names))  # with unique values

    if len(names) == 0 and not is_set_model:  # run for all files in the path
        names = ps.path.get_model_names(path, model_ext)

    if len(names) == 0:
        print(f'PATH: {path}')
        print(f'-i: {args.input}')
        print(
            'I do not know the models for fitting. Please use the key -i MODEL or -i *R500* '
        )
        parser.print_help()
        sys.exit(2)

    # Set band names
    bnames = []
    if args.bnames:
        for bname in args.bnames.split('-'):
            if not ps.band.is_exist(bname):
                print('No such band: ' + bname)
                parser.print_help()
                sys.exit(2)
            bnames.append(bname)

    # Get observations
    observations = ps.cb.observations(args)
    if observations is None:
        print('No obs data. Use key: -c: ')
        parser.print_help()
        sys.exit(2)

    # curves
    curves_o = merge_obs(observations, ps.SetLightCurve)
    vels_o = merge_obs(observations, ps.vel.SetVelocityCurve)

    # set bands  as observed
    if curves_o is not None and len(bnames) == 0:
        bnames = [bn for bn in curves_o.BandNames if ps.band.is_exist(bn)]

    # Set distance and redshift
    is_sigma = args.is_fit_sigmas
    z = args.redshift
    t_diff = args.t_diff
    if args.distance is not None:
        print("Fit magnitudes on z={0:F} at distance={1:E}".format(
            z, args.distance))
        if z > 0:
            print("  Cosmology D(z)={0:E} Mpc".format(ps.cosmology_D_by_z(z)))
    else:
        distance = 10  # pc
        if z > 0:
            distance = ps.cosmology_D_by_z(z) * 1e6
            print("Fit magnitudes on z={0:F} with cosmology D(z)={1:E} pc".
                  format(z, distance))
        args.distance = distance
    if is_sigma:
        print("Fit magnitudes with model uncertainties.")

    print("Color excess E(B-V) ={:f}".format(args.color_excess))

    # Time limits for models
    tlim = (0, float('inf'))

    if args.tlim:
        tlim = list(map(float, args.tlim.replace('\\', '').split(':')))
    print('Time limits for models: {}'.format(':'.join(map(str, tlim))))

    # The fit engine
    fitter = engines(args.engine)
    fitter.is_info = args.is_not_quiet  # fitter = FitMPFit(is_debug=args.is_not_quiet)
    # fitter.is_debug = args.is_not_quiet

    if args.is_not_quiet:
        fitter.print_parameters()

    # The filter results by tshift
    if args.dtshift:
        dtshift = ps.str2interval(args.dtshift,
                                  llim=float("-inf"),
                                  rlim=float('inf'))
    else:
        dtshift = (float("-inf"), float("inf"))

    # tshift = 0.
    res_models = {}
    vels_m = {}
    res_chi = {}
    if len(names) == 1:
        name = names[0]
        if args.is_not_quiet:
            if tlim is not None:
                print("Fitting for model %s %s for %s moments" %
                      (path, name, tlim))
            else:
                print("Fitting for model %s %s, tweight= %f" %
                      (path, name, args.tweight))
        # curves_m = lcf.curves_compute(name, path, bnames, z=args.redshift, distance=args.distance,
        #                               t_beg=tlim[0], t_end=tlim[1], t_diff=t_diff)
        # res = fitter.fit_curves(curves_o, curves_m)
        if vels_o is None:
            curves_m, res, res_full = fit_mfl(args, curves_o, bnames, fitter,
                                              name, path, t_diff, tlim,
                                              is_sigma)
        else:
            curves_m, res, vel_m = fit_mfl_vel(args,
                                               curves_o,
                                               vels_o,
                                               bnames,
                                               fitter,
                                               name,
                                               path,
                                               t_diff,
                                               tlim,
                                               is_sigma,
                                               A=args.tweight)
            vels_m[name] = vel_m

        print("{}: time shift  = {:.2f}+/-{:.4f} Measure: {:.4f} {}".format(
            name, res.tshift, res.tsigma, res.measure, res.comm))
        # best_tshift = res.tshift
        res_models[name] = curves_m
        res_chi[name] = res
        res_sorted = res_chi
    elif len(names) > 1:
        if args.nodes > 1:
            print("Run parallel fitting: nodes={}, models  {}".format(
                args.nodes, len(names)))
            with futures.ProcessPoolExecutor(
                    max_workers=args.nodes) as executor:
                if vels_o is None:
                    future_to_name = {
                        executor.submit(fit_mfl, args, curves_o, bnames,
                                        fitter, n, path, t_diff, tlim,
                                        is_sigma): n
                        for n in names
                    }
                else:
                    future_to_name = {
                        executor.submit(fit_mfl_vel, args, curves_o, vels_o,
                                        bnames, fitter, n, path, t_diff, tlim,
                                        is_sigma): n
                        for n in names
                    }

                i = 0
                for future in futures.as_completed(future_to_name):
                    i += 1
                    name = future_to_name[future]
                    try:
                        data = future.result()
                    except Exception as exc:
                        print('%r generated an exception: %s' % (name, exc))
                    else:
                        res_models[name], res, vels_m[name] = data
                        res_chi[name] = res
                        print("[{}/{}] {:30s} -> {}".format(
                            i, len(names), name, res.comm))
        else:
            i = 0
            for name in names:
                i += 1
                txt = "Fitting [{}] for model {:30s}  [{}/{}]".format(
                    fitter.Name, name, i, len(names))
                if args.is_not_quiet:
                    print(txt)
                else:
                    sys.stdout.write(u"\u001b[1000D" + txt)
                    sys.stdout.flush()

                if vels_o is None:
                    curves_m, res, res_full = fit_mfl(args, curves_o, bnames,
                                                      fitter, name, path,
                                                      t_diff, tlim, is_sigma)
                else:
                    curves_m, res, vel_m = fit_mfl_vel(args,
                                                       curves_o,
                                                       vels_o,
                                                       bnames,
                                                       fitter,
                                                       name,
                                                       path,
                                                       t_diff,
                                                       tlim,
                                                       is_sigma,
                                                       A=args.tweight)
                    vels_m[name] = vel_m
                res_models[name] = curves_m
                res_chi[name] = res

        # select with dtshift
        res_chi_sel = {}
        for k, v in res_chi.items():
            if dtshift[0] < v.tshift < dtshift[1]:
                res_chi_sel[k] = v
        res_chi = res_chi_sel
        # sort with measure
        res_sorted = OrderedDict(
            sorted(res_chi.items(), key=lambda kv: kv[1].measure))
    else:
        print("No any data about models. Path: {}".format(path))
        parser.print_help()
        sys.exit(2)

    # print results
    print("\n Results (tshift in range:{:.2f} -- {:.2f}".format(
        dtshift[0], dtshift[1]))
    print("{:40s} ||{:18s}|| {:10}".format('Model', 'dt+-t_err', 'Measure'))
    for k, v in res_sorted.items():
        print("{:40s} || {:7.2f}+/-{:7.4f} || {:.4f} || {}".format(
            k, v.tshift, v.tsigma, v.measure, v.comm))

    if len(res_sorted) >= n_best:
        # plot chi squared
        plot_squared_grid(res_sorted, path, is_not_quiet=args.is_not_quiet)
        plot_chi_par(res_sorted, path)

        # plot only Nbest modeles
        # while len(res_sorted) > Nbest:
        #     res_sorted.popitem()

        # plot_squared_3d(None, res_sorted, path, p=('R', 'M', 'E'), is_polar=True)

    best_mdl, res = ps.first(res_sorted.items())
    # res = first(res_sorted.values())[0]
    print("Best fit model:")
    print("{}: time shift  = {:.2f}+/-{:.4f} Measure: {:.4f}".format(
        best_mdl, res.tshift, res.tsigma, res.measure))
    print("{}: ".format(best_mdl), res.comm)

    # shift observational data
    # curves_o.set_tshift(best_tshift)

    # plot only NbestPlot modeles
    while len(res_sorted) > args.plotnbest:
        res_sorted.popitem()

    if vels_o is not None and vels_o.Length > 0:
        # vel_o.tshift = best_tshift
        fig = plot_curves_vel(curves_o, vels_o, res_models, res_sorted, vels_m)
    else:
        fig = plot_curves(curves_o, res_models, res_sorted, xlim=tlim)

    if args.save_file is not None:
        fsave = args.save_file
        if len(os.path.dirname(fsave)) == 0:
            fsave = os.path.expanduser("~/{}".format(fsave))
        if not fsave.endswith('.pdf'):
            fsave += '.pdf'
        print("Save plot to {0}".format(fsave))
        fig.savefig(fsave, bbox_inches='tight')
    else:
        from matplotlib import pyplot as plt
        # plt.subplots_adjust(left=0.07, right=0.96, top=0.97, bottom=0.06)
        plt.ion()
        plt.show()
        plt.pause(0.0001)
        print('')
        input("===> Hit <return> to quit")