def test_load_spectrum(fname):
    fname = os.path.join("tests", "testdata", "handy_spectra", fname)
    results = load_spectrum(fname)
    assert isinstance(results, Spectrum)
    assert results.header["OBJECT"].upper() == "HD30501"
    assert np.all(results.xaxis > 2110)  # nm
    assert np.all(results.xaxis < 2130)  # nm
    assert np.all(results.flux < 2)
    assert np.all(results.flux >= 0)
Exemple #2
0
def main(chip=None, verbose=False, error_off=False, disable_wav_scale=False, renormalize=False, norm_method="scalar",
         betasigma=False):
    """Main function."""
    wav_scale = not disable_wav_scale

    star = "HD211847"
    obsnum = 2
    star = star.upper()
    setup_tcm_dirs(star)
    if chip is None:
        chip = 4

    obs_name, params, output_prefix = tcm_helper_function(star, obsnum, chip)

    logging.debug(__("The observation used is {} \n", obs_name))

    host_params = [params["temp"], params["logg"], params["fe_h"]]
    comp_params = [params["comp_temp"], params["logg"], params["fe_h"]]

    closest_host_model = closest_model_params(*host_params)  # unpack temp, logg, fe_h with *
    closest_comp_model = closest_model_params(*comp_params)

    # Function to find the good models I need from parameters
    model1_pars = list(generate_close_params(closest_host_model, small=True))
    model2_pars = list(generate_close_params(closest_comp_model, small=True))

    # Load observation
    obs_spec = load_spectrum(obs_name)
    # Mask out bad portion of observed spectra
    obs_spec = spectrum_masking(obs_spec, star, obsnum, chip)
    # Barycentric correct spectrum
    _obs_spec = barycorr_crires_spectrum(obs_spec, extra_offset=None)
    # Determine Spectrum Errors
    try:
        if betasigma:
            errors = betasigma_error(obs_spec)
            logging.info(__("Beta-Sigma error value = {:6.5f}", errors))
        else:
            errors = spectrum_error(star, obsnum, chip, error_off=error_off)
            logging.info(__("File obtained error value = {}", errors))
    except KeyError as e:
        errors = None

    param_iter = len(alphas) * len(rvs) * len(gammas) * len(model2_pars) * len(model1_pars)
    logging.info(__("STARTING tcm_analysis\nWith {0} parameter iterations", param_iter))
    logging.debug(__("model1_pars = {}, model2_pars = {} ", len(model1_pars), len(model2_pars)))

    tcm_analysis(obs_spec, model1_pars, model2_pars, alphas, rvs, gammas, errors=errors,
                 verbose=verbose, norm=renormalize, prefix=output_prefix, wav_scale=wav_scale, norm_method=norm_method)

    return 0
def test_load_no_filename_fits():
    """Not a valid file."""
    with pytest.raises(ValueError):
        load_spectrum("")
Exemple #4
0
def compare_spectra(table, params):
    """Plot the min chi2 result against the observations."""
    extractor = DBExtractor(table)
    gamma_df = extractor.simple_extraction(columns=["gamma"])
    extreme_gammas = [min(gamma_df.gamma.values), max(gamma_df.gamma.values)]
    for ii, chi2_val in enumerate(chi2_names[0:-2]):
        df = extractor.ordered_extraction(
            columns=["teff_1", "logg_1", "feh_1", "gamma", chi2_val],
            order_by=chi2_val,
            limit=1,
            asc=True)

        params1 = [
            df["teff_1"].values[0], df["logg_1"].values[0],
            df["feh_1"].values[0]
        ]

        params1 = [float(param1) for param1 in params1]

        gamma = df["gamma"].values

        obs_name, obs_params, output_prefix = bhm_helper_function(
            params["star"], params["obsnum"], ii + 1)
        print(obs_name)
        obs_spec = load_spectrum(obs_name)

        # Mask out bad portion of observed spectra
        # obs_spec = spectrum_masking(obs_spec, params["star"], params["obsnum"], ii + 1)

        # Barycentric correct spectrum
        _obs_spec = barycorr_crires_spectrum(obs_spec, extra_offset=None)
        normalization_limits = [obs_spec.xaxis[0] - 5, obs_spec.xaxis[-1] + 5]
        # models
        # print("params for models", params1)
        mod1 = load_starfish_spectrum(params1,
                                      limits=normalization_limits,
                                      hdr=True,
                                      normalize=False,
                                      area_scale=True,
                                      flux_rescale=True)

        bhm_grid_func = one_comp_model(mod1.xaxis, mod1.flux, gammas=gamma)
        bhm_upper_gamma = one_comp_model(mod1.xaxis,
                                         mod1.flux,
                                         gammas=extreme_gammas[1])
        bhm_lower_gamma = one_comp_model(mod1.xaxis,
                                         mod1.flux,
                                         gammas=extreme_gammas[0])

        bhm_grid_model = bhm_grid_func(obs_spec.xaxis).squeeze()
        bhm_grid_model_full = bhm_grid_func(mod1.xaxis).squeeze()
        bhm_upper_gamma = bhm_upper_gamma(obs_spec.xaxis).squeeze()
        bhm_lower_gamma = bhm_lower_gamma(obs_spec.xaxis).squeeze()

        model_spec_full = Spectrum(flux=bhm_grid_model_full, xaxis=mod1.xaxis)
        model_spec = Spectrum(flux=bhm_grid_model, xaxis=obs_spec.xaxis)
        bhm_upper_gamma = Spectrum(flux=bhm_upper_gamma, xaxis=obs_spec.xaxis)
        bhm_lower_gamma = Spectrum(flux=bhm_lower_gamma, xaxis=obs_spec.xaxis)

        model_spec = model_spec.remove_nans()
        model_spec = model_spec.normalize(method="exponential")
        model_spec_full = model_spec_full.remove_nans()
        model_spec_full = model_spec_full.normalize(method="exponential")
        bhm_lower_gamma = bhm_lower_gamma.remove_nans()
        bhm_lower_gamma = bhm_lower_gamma.normalize(method="exponential")
        bhm_upper_gamma = bhm_upper_gamma.remove_nans()
        bhm_upper_gamma = bhm_upper_gamma.normalize(method="exponential")

        from mingle.utilities.chisqr import chi_squared
        chisqr = chi_squared(obs_spec.flux, model_spec.flux)

        print("Recomputed chi^2 = {0}".format(chisqr))
        print("Database chi^2 = {0}".format(df[chi2_val]))
        fig, ax = plt.subplots(1, 1, figsize=(15, 8))
        plt.plot(obs_spec.xaxis,
                 obs_spec.flux + 0.01,
                 label="0.05 + Observation, {}".format(obs_name))
        plt.plot(model_spec.xaxis,
                 model_spec.flux,
                 label="Minimum \chi^2 model")
        plt.plot(model_spec_full.xaxis,
                 model_spec_full.flux,
                 "--",
                 label="Model_full_res")
        plt.plot(bhm_lower_gamma.xaxis,
                 bhm_lower_gamma.flux,
                 "-.",
                 label="gamma={}".format(extreme_gammas[0]))
        plt.plot(bhm_upper_gamma.xaxis,
                 bhm_upper_gamma.flux,
                 ":",
                 label="gamma={}".format(extreme_gammas[1]))
        plt.title("bhm spectrum")
        plt.legend()

        fig.tight_layout()
        name = "{0}-{1}_{2}_{3}_bhm_min_chi2_spectrum_comparison_{4}.png".format(
            params["star"], params["obsnum"], params["chip"], chi2_val,
            params["suffix"])
        plt.savefig(os.path.join(params["path"], "plots", name))
        plt.close()

        plt.plot(obs_spec.xaxis, obs_spec.flux, label="Observation")
        plt.show()
Exemple #5
0
def compare_spectra(table, params):
    """Plot the min chi2 result against the observations."""
    extractor = DBExtractor(table)
    for ii, chi2_val in enumerate(chi2_names[0:-2]):
        df = extractor.minimum_value_of(chi2_val)
        df = df[["teff_1", "logg_1", "feh_1", "gamma",
                 "teff_2", "logg_2", "feh_2", "rv", chi2_val]]

        params1 = [df["teff_1"].values[0], df["logg_1"].values[0], df["feh_1"].values[0]]
        params2 = [df["teff_2"].values[0], df["logg_2"].values[0], df["feh_2"].values[0]]

        params1 = [float(param1) for param1 in params1]
        params2 = [float(param2) for param2 in params2]

        gamma = df["gamma"].values
        rv = df["rv"].values

        from simulators.iam_module import iam_helper_function
        obs_name, obs_params, output_prefix = iam_helper_function(params["star"], params["obsnum"], ii + 1)

        obs_spec = load_spectrum(obs_name)

        # Mask out bad portion of observed spectra
        # obs_spec = spectrum_masking(obs_spec, params["star"], params["obsnum"], ii + 1)

        # Barycentric correct spectrum
        _obs_spec = barycorr_crires_spectrum(obs_spec, extra_offset=None)
        normalization_limits = [obs_spec.xaxis[0] - 5, obs_spec.xaxis[-1] + 5]
        # models
        # print("params for models", params1, params2)
        mod1 = load_starfish_spectrum(params1, limits=normalization_limits,
                                      hdr=True, normalize=False, area_scale=True,
                                      flux_rescale=True)

        mod2 = load_starfish_spectrum(params2, limits=normalization_limits,
                                      hdr=True, normalize=False, area_scale=True,
                                      flux_rescale=True)

        iam_grid_func = inherent_alpha_model(mod1.xaxis, mod1.flux, mod2.flux,
                                             rvs=rv, gammas=gamma)

        iam_grid_model = iam_grid_func(obs_spec.xaxis).squeeze()
        iam_grid_model_full = iam_grid_func(mod1.xaxis).squeeze()

        model_spec_full = Spectrum(flux=iam_grid_model_full, xaxis=mod1.xaxis)
        model_spec = Spectrum(flux=iam_grid_model, xaxis=obs_spec.xaxis)

        model_spec = model_spec.remove_nans()
        model_spec = model_spec.normalize(method="exponential")
        model_spec_full = model_spec_full.remove_nans()
        model_spec_full = model_spec_full.normalize(method="exponential")

        fig, ax = plt.subplots(1, 1)
        plt.plot(obs_spec.xaxis, obs_spec.flux, label="Observation")
        plt.plot(model_spec.xaxis, model_spec.flux, label="Minimum \chi^2 model")
        plt.plot(model_spec_full.xaxis, model_spec_full.flux, "--", label="Model_full_res")

        plt.legend()

        fig.tight_layout()
        name = "{0}-{1}_{2}_{3}_min_chi2_spectrum_comparison_{4}.png".format(
            params["star"], params["obsnum"], params["chip"], chi2_val, params["suffix"])
        plt.savefig(os.path.join(params["path"], "plots", name))
        plt.close()

        plt.plot(obs_spec.xaxis, obs_spec.flux, label="Observation")
        plt.show()
Exemple #6
0
def main():
    """Main function."""
    star = "HD211847"
    param_file = os.path.join(simulators.paths["parameters"], "{}_params.dat".format(star))
    params = parse_paramfile(param_file, path=None)
    host_params = [params["temp"], params["logg"], params["fe_h"]]
    # comp_params = [params["comp_temp"], params["logg"], params["fe_h"]]

    obsnum = 2
    chip = 4

    obs_name = os.path.join(
        simulators.paths["spectra"], "{}-{}-mixavg-tellcorr_{}.fits".format(star, obsnum, chip))
    logging.info("The observation used is ", obs_name, "\n")

    closest_host_model = closest_model_params(*host_params)  # unpack temp, logg, fe_h with *
    # closest_comp_model = closest_model_params(*comp_params)

    # original_model = "Z-0.0/lte05700-4.50-0.0.PHOENIX-ACES-AGSS-COND-2011-HiRes.fits"
    # logging.debug(pv("closest_host_model"))
    # logging.debug(pv("closest_comp_model"))
    # logging.debug(pv("original_model"))

    # Function to find the good models I need
    # models = find_phoenix_model_names(model_base_dir, original_model)
    # Function to find the good models I need from parameters
    # model_par_gen = generate_close_params(closest_host_model)
    model_pars = list(generate_close_params(closest_host_model))  # Turn to list

    print("Model parameters", model_pars)

    # Load observation
    obs_spec = load_spectrum(obs_name)
    _obs_spec = barycorr_crires_spectrum(obs_spec, extra_offset=None)
    obs_spec.flux /= 1.02
    # Mask out bad portion of observed spectra ## HACK
    if chip == 4:
        # Ignore first 40 pixels
        obs_spec.wav_select(obs_spec.xaxis[40], obs_spec.xaxis[-1])

    import simulators
    gammas = np.arange(*simulators.sim_grid["gammas"])

    ####
    chi2_grids = bhm_analysis(obs_spec, model_pars, gammas, verbose=True)
    ####
    (model_chisqr_vals, model_xcorr_vals, model_xcorr_rv_vals,
     broadcast_chisqr_vals, broadcast_gamma, broadcast_chisquare) = chi2_grids

    TEFF = [par[0] for par in model_pars]
    LOGG = [par[1] for par in model_pars]
    FEH = [par[2] for par in model_pars]

    plt.plot(TEFF, broadcast_chisqr_vals, "+", label="broadcast")
    plt.plot(TEFF, model_chisqr_vals, ".", label="org")
    plt.title("TEFF vs Broadcast chisqr_vals")
    plt.legend()
    plt.show()
    plt.plot(TEFF, broadcast_gamma, "o")
    plt.title("TEFF vs Broadcast gamma grid")
    plt.show()

    plt.plot(LOGG, broadcast_chisqr_vals, "+", label="broadcast")
    plt.plot(LOGG, model_chisqr_vals, ".", label="org")
    plt.title("LOGG verse Broadcast chisqr_vals")
    plt.legend()
    plt.show()
    plt.plot(LOGG, broadcast_gamma, "o")
    plt.title("LOGG verse Broadcast gamma grid")
    plt.show()

    plt.plot(FEH, broadcast_chisqr_vals, "+", label="broadcast")
    plt.plot(FEH, model_chisqr_vals, ".", label="org")
    plt.title("FEH vs Broadcast chisqr_vals")
    plt.legend()
    plt.show()
    plt.plot(FEH, broadcast_gamma, "o")
    plt.title("FEH vs Broadcast gamma grid")
    plt.show()

    TEFFS_unique = np.array(set(TEFF))
    LOGG_unique = np.array(set(LOGG))
    FEH_unique = np.array(set(FEH))
    X, Y, Z = np.meshgrid(TEFFS_unique, LOGG_unique, FEH_unique)  # set sparse=True for memory efficency
    print("Teff grid", X)
    print("Logg grid", Y)
    print("FEH grid", Z)
    assert len(TEFF) == sum(len(x) for x in (TEFFS_unique, LOGG_unique, FEH_unique))

    chi_ND = np.empty_like(X.shape)
    print("chi_ND.shape", chi_ND.shape)
    print("len(TEFFS_unique)", len(TEFFS_unique))
    print("len(LOGG_unique)", len(LOGG_unique))
    print("len(FEH_unique)", len(FEH_unique))

    for i, tf in enumerate(TEFFS_unique):
        for j, lg in enumerate(LOGG_unique):
            for k, fh in enumerate(FEH_unique):
                print("i,j,k", (i, j, k))
                print("num = t", np.sum(TEFF == tf))
                print("num = lg", np.sum(LOGG == lg))
                print("num = fh", np.sum(FEH == fh))
                mask = (TEFF == tf) * (LOGG == lg) * (FEH == fh)
                print("num = tf, lg, fh", np.sum(mask))
                chi_ND[i, j, k] = broadcast_chisqr_vals[mask]
                print("broadcast val", broadcast_chisqr_vals[mask],
                      "\norg val", model_chisqr_vals[mask])

    # logging.debug(pv("model_chisqr_vals"))
    # logging.debug(pv("model_xcorr_vals"))
    chisqr_argmin_indx = np.argmin(model_chisqr_vals)
    xcorr_argmax_indx = np.argmax(model_xcorr_vals)

    print("Minimum  Chisqr value =", model_chisqr_vals[chisqr_argmin_indx])  # , min(model_chisqr_vals)
    print("Chisqr at max correlation value", model_chisqr_vals[chisqr_argmin_indx])

    print("model_xcorr_vals = {}".format(model_xcorr_vals))
    print("Maximum Xcorr value =", model_xcorr_vals[xcorr_argmax_indx])  # , max(model_xcorr_vals)
    print("Xcorr at min Chiqsr", model_xcorr_vals[chisqr_argmin_indx])

    # logging.debug(pv("model_xcorr_rv_vals"))
    print("RV at max xcorr =", model_xcorr_rv_vals[xcorr_argmax_indx])
    # print("Meadian RV val =", np.median(model_xcorr_rv_vals))
    print(pv("model_xcorr_rv_vals[chisqr_argmin_indx]"))
    print(pv("sp.stats.mode(np.around(model_xcorr_rv_vals))"))

    # print("Max Correlation model = ", models[xcorr_argmax_indx].split("/")[-2:])
    # print("Min Chisqr model = ", models[chisqr_argmin_indx].split("/")[-2:])
    print("Max Correlation model = ", model_pars[xcorr_argmax_indx])
    print("Min Chisqr model = ", model_pars[chisqr_argmin_indx])

    limits = [2110, 2160]

    best_model_params = model_pars[chisqr_argmin_indx]
    best_model_spec = load_starfish_spectrum(best_model_params, limits=limits, normalize=True)

    best_xcorr_model_params = model_pars[xcorr_argmax_indx]
    best_xcorr_model_spec = load_starfish_spectrum(best_xcorr_model_params, limits=limits, normalize=True)

    close_model_spec = load_starfish_spectrum(closest_model_params, limits=limits, normalize=True)

    plt.plot(obs_spec.xaxis, obs_spec.flux, label="Observations")
    plt.plot(best_model_spec.xaxis, best_model_spec.flux, label="Best Model")
    plt.plot(best_xcorr_model_spec.xaxis, best_xcorr_model_spec.flux, label="Best xcorr Model")
    plt.plot(close_model_spec.xaxis, close_model_spec.flux, label="Close Model")
    plt.legend()
    plt.xlim(*limits)
    plt.show()

    logging.debug("After plot")
Exemple #7
0
def main(star,
         obsnum,
         chip=None,
         suffix=None,
         error_off=False,
         disable_wav_scale=False,
         renormalize=False,
         norm_method="scalar",
         betasigma=False):
    """Best Host modelling main function."""
    wav_scale = not disable_wav_scale
    star = star.upper()
    setup_bhm_dirs(star)
    # Define the broadcasted gamma grid
    gammas = np.arange(*simulators.sim_grid["gammas"])
    # print("bhm gammas", gammas)

    obs_name, params, output_prefix = bhm_helper_function(star, obsnum, chip)

    if suffix is not None:
        output_prefix = output_prefix + str(suffix)
    print("The observation used is ", obs_name, "\n")

    # Host Model parameters to iterate over
    # model_pars = get_bhm_model_pars(params, method="close")
    model_pars = get_bhm_model_pars(params, method="config")  # Use config file

    # Load observation
    obs_spec = load_spectrum(obs_name)
    from spectrum_overload import Spectrum
    assert isinstance(obs_spec, Spectrum)
    # Mask out bad portion of observed spectra
    obs_spec = spectrum_masking(obs_spec, star, obsnum, chip)

    # Barycentric correct spectrum
    _obs_spec = barycorr_crires_spectrum(obs_spec, extra_offset=None)

    # Determine Spectrum Errors
    try:
        if betasigma:
            print("DOING BETASIGMA ERRORS")
            N = simulators.betasigma.get("N", 5)
            j = simulators.betasigma.get("j", 2)
            errors, derrors = betasigma_error(obs_spec, N=N, j=j)
            print("Beta-Sigma error value = {:6.5f}+/-{:6.5f}".format(
                errors, derrors))
            logging.info(
                __("Beta-Sigma error value = {:6.5f}+/-{:6.5f}", errors,
                   derrors))
        else:
            print("NOT DOING BETASIGMA ERRORS")
            errors = spectrum_error(star, obsnum, chip, error_off=error_off)
            logging.info(__("File obtained error value = {0}", errors))
    except KeyError as e:
        print("ERRORS Failed so set to None")
        errors = None

    bhm_analysis(obs_spec,
                 model_pars,
                 gammas,
                 errors=errors,
                 verbose=False,
                 norm=renormalize,
                 wav_scale=wav_scale,
                 prefix=output_prefix,
                 norm_method=norm_method)
    print("after bhm_analysis")

    print("\nNow use bin/coadd_bhm_db.py")
    return 0
Exemple #8
0
def main(star,
         obsnum,
         teff_1,
         logg_1,
         feh_1,
         teff_2,
         logg_2,
         feh_2,
         gamma,
         rv,
         plot_name=None):
    fig, axis = plt.subplots(2, 2, figsize=(15, 8), squeeze=False)

    for chip, ax in zip(range(1, 5), axis.flatten()):
        # Get observation data
        obs_name, params, output_prefix = iam_helper_function(
            star, obsnum, chip)

        # Load observed spectrum
        obs_spec = load_spectrum(obs_name)

        # Mask out bad portion of observed spectra
        obs_spec = spectrum_masking(obs_spec, star, obsnum, chip)

        # Barycentric correct spectrum
        _obs_spec = barycorr_crires_spectrum(obs_spec, extra_offset=None)

        # Determine Spectrum Errors
        # error_off = False
        # errors = spectrum_error(star, obsnum, chip, error_off=error_off)

        # Create model with given parameters
        host = load_starfish_spectrum([teff_1, logg_1, feh_1],
                                      limits=[2110, 2165],
                                      area_scale=True,
                                      hdr=True)
        if teff_2 is None:
            assert (logg_2 is None) and (feh_2 is None) and (
                rv == 0), "All must be None for bhm case."
            companion = Spectrum(xaxis=host.xaxis,
                                 flux=np.zeros_like(host.flux))
        else:
            companion = load_starfish_spectrum([teff_2, logg_2, feh_2],
                                               limits=[2110, 2165],
                                               area_scale=True,
                                               hdr=True)

        joint_model = inherent_alpha_model(host.xaxis,
                                           host.flux,
                                           companion.flux,
                                           gammas=gamma,
                                           rvs=rv)

        model_spec = Spectrum(xaxis=host.xaxis,
                              flux=joint_model(host.xaxis).squeeze())
        model_spec = model_spec.remove_nans()
        model_spec = model_spec.normalize("exponential")

        # plot
        obs_spec.plot(axis=ax, label="{}-{}".format(star, obsnum))
        model_spec.plot(axis=ax, linestyle="--", label="Chi-squared model")
        ax.set_xlim([obs_spec.xmin() - 0.5, obs_spec.xmax() + 0.5])

        ax.set_title("{} obs {} chip {}".format(star, obsnum, chip))
        ax.legend()
    plt.tight_layout()

    if plot_name is None:
        fig.show()
    else:
        if not (plot_name.endswith(".png") or plot_name.endswith(".pdf")):
            raise ValueError("plot_name does not end with .pdf or .png")
        fig.savefig(plot_name)

    return 0
Exemple #9
0
def main():
    """Main."""
    star = "HD30501"
    obsnum = "1"
    chip = 1
    obs_name = select_observation(star, obsnum, chip)

    # Load observation
    observed_spectra = load_spectrum(obs_name)
    if chip == 4:
        # Ignore first 40 pixels
        observed_spectra.wav_select(observed_spectra.xaxis[40], observed_spectra.xaxis[-1])

    # Load models
    host_spectrum_model, companion_spectrum_model = load_PHOENIX_hd30501(limits=[2100, 2200], normalize=True)

    obs_resolution = crires_resolution(observed_spectra.header)

    # Convolve models to resolution of instrument
    host_spectrum_model, companion_spectrum_model = convolve_models((host_spectrum_model, companion_spectrum_model),
                                                                    obs_resolution, chip_limits=None)

    plot_obs_with_model(observed_spectra, host_spectrum_model, companion_spectrum_model,
                        show=False, title="Before BERV Correction")

    # Berv Correct
    # Calculate the star RV relative to synthetic spectum
    #                        [mean_val, K1, omega,   e,     Tau,       Period, starmass (Msun), msini(Mjup), i]
    parameters = {"HD30501": [23.710, 1703.1, 70.4, 0.741, 53851.5, 2073.6, 0.81, 90],
                  "HD211847": [6.689, 291.4, 159.2, 0.685, 62030.1, 7929.4, 0.94, 19.2, 7]}
    try:
        host_params = parameters[star]
    except ValueError:
        print("Parameters for {} are not in parameters list. Improve this.".format(star))
        raise
    host_params[1] = host_params[1] / 1000  # Convert K! to km/s
    host_params[2] = np.deg2rad(host_params[2])  # Omega needs to be in radians for ajplanet

    obs_time = observed_spectra.header["DATE-OBS"]
    print(obs_time, isinstance(obs_time, str))
    print(obs_time.replace("T", " ").split("."))
    jd = ephem.julian_date(obs_time.replace("T", " ").split(".")[0])
    host_rv = pl_rv_array(jd, *host_params[0:6])[0]
    print("host_rv", host_rv, "km/s")

    offset = -host_rv
    # offset = 0
    _berv_corrected_observed_spectra = barycorr_crires_spectrum(observed_spectra, offset)  # Issue with air/vacuum
    # This introduces nans into the observed spectrum
    berv_corrected_observed_spectra.wav_select(*berv_corrected_observed_spectra.xaxis[
        np.isfinite(berv_corrected_observed_spectra.flux)][[0, -1]])
    # Shift to star RV

    plot_obs_with_model(berv_corrected_observed_spectra, host_spectrum_model,
                        companion_spectrum_model, title="After BERV Correction")

    # print("\nWarning!!!\n BERV is not good have added a offset to get rest working\n")

    # Chisquared fitting
    alphas = 10 ** np.linspace(-4, 0.1, 100)
    rvs = np.arange(-50, 50, 0.05)

    # chisqr_store = np.empty((len(alphas), len(rvs)))
    observed_limits = [np.floor(berv_corrected_observed_spectra.xaxis[0]),
                       np.ceil(berv_corrected_observed_spectra.xaxis[-1])]
    print("Observed_limits ", observed_limits)

    n_jobs = 1
    if n_jobs is None:
        n_jobs = mprocess.cpu_count() - 1
    start_time = dt.now()

    if np.all(np.isnan(host_spectrum_model.flux)):
        print("Host spectrum is all Nans")
    if np.all(np.isnan(companion_spectrum_model.flux)):
        print("Companion spectrum is all Nans")

    print("Now performing the Chisqr grid analaysis")
    obs_chisqr_parallel = parallel_chisqr(alphas, rvs, berv_corrected_observed_spectra, alpha_model2,
                                          (host_spectrum_model, companion_spectrum_model,
                                           observed_limits, berv_corrected_observed_spectra), n_jobs=n_jobs)
    # chisqr_parallel = parallel_chisqr(alphas, rvs, simlulated_obs, alpha_model, (org_star_spec,
    #                                   org_bd_spec, new_limits), n_jobs=4)

    end_time = dt.now()
    print("Time to run parallel chisquared = {}".format(end_time - start_time))
    # Plot memmap
    # plt.subplot(2, 1, 1)
    x, y = np.meshgrid(rvs, alphas)
    fig = plt.figure(figsize=(7, 7))
    cf = plt.contourf(x, y, np.log10(obs_chisqr_parallel.reshape(len(alphas), len(rvs))), 100)
    fig.colorbar(cf)
    plt.title("Sigma chisquared")
    plt.ylabel("Flux ratio")
    plt.xlabel("RV (km/s)")
    plt.show()

    # Locate minimum and plot resulting model next to observation
    def find_min_chisquared(x, y, z):
        """Find minimum vlaue in chisqr grid."""
        min_loc = np.argmin(z)
        print("min location", min_loc)

        x_sol = x.ravel()[min_loc]
        y_sol = y.ravel()[min_loc]
        z_sol = z.ravel()[min_loc]
        return x_sol, y_sol, z_sol, min_loc

    rv_solution, alpha_solution, min_chisqr, min_loc = find_min_chisquared(x, y, obs_chisqr_parallel)
    print("Minium Chisqr value {2}\n RV sol = {0}\nAlpha Sol = {1}".format(rv_solution, alpha_solution, min_chisqr))

    solution_model = alpha_model2(alpha_solution, rv_solution, host_spectrum_model, companion_spectrum_model,
                                  observed_limits)
    # alpha_model2(alpha, rv, host, companion, limits, new_x=None):

    plt.plot(solution_model.xaxis, solution_model.flux, label="Min chisqr solution")
    plt.plot(berv_corrected_observed_spectra.xaxis, berv_corrected_observed_spectra.flux, label="Observation")
    plt.legend(loc=0)
    plt.show()

    # Dump the results into a pickle file
    pickle_path = "/home/jneal/.chisqrpickles/"
    pickle_name = "Chisqr_results_{0}_{1}_chip_{2}.pickle".format(star, obsnum, chip)
    with open(os.path.join(pickle_path, pickle_name), "wb") as f:
        # Pickle all the necessary parameters to store.
        pickle.dump((rvs, alphas, berv_corrected_observed_spectra, host_spectrum_model, companion_spectrum_model,
                     rv_solution, alpha_solution, min_chisqr, min_loc, solution_model), f)
def main():
    """Main function."""
    star = "HD30501"
    host_parameters = load_param_file(star)
    obsnum = 1
    chip = 1
    obs_name = select_observation(star, obsnum, chip)

    # Load observation
    # uncorrected_spectra = load_spectrum(obs_name)
    observed_spectra = load_spectrum(obs_name)
    _observed_spectra = barycorr_crires_spectrum(observed_spectra,
                                                 extra_offset=None)
    observed_spectra.flux /= 1.02

    obs_resolution = crires_resolution(observed_spectra.header)

    wav_model = fits.getdata(
        os.path.join(wav_dir, "WAVE_PHOENIX-ACES-AGSS-COND-2011.fits"))
    wav_model /= 10  # turn into nm
    logging.debug(__("Phoenix wav_model = {0}", wav_model))

    closest_model = phoenix_name_from_params(model_base_dir, host_parameters)
    original_model = "Z-0.0/lte05200-4.50-0.0.PHOENIX-ACES-AGSS-COND-2011-HiRes.fits"
    logging.debug(__("closest_model {0}", closest_model))
    logging.debug(__("original_model {0}", original_model))

    # Function to find the good models I need
    models = find_phoenix_model_names(model_base_dir, original_model)
    if isinstance(models, list):
        logging.debug(__("Number of close models returned {0}", len(models)))

    model_chisqr_vals = np.empty_like(models)
    model_xcorr_vals = np.empty_like(models)
    model_xcorr_rv_vals = np.empty_like(models)

    for ii, model_name in enumerate(models):
        mod_flux = fits.getdata(model_name)
        mod_header = fits.getheader(model_name)
        mod_spectrum = Spectrum(xaxis=wav_model,
                                flux=mod_flux,
                                header=mod_header,
                                calibrated=True)

        # Normalize Phoenix Spectrum
        # mod_spectrum.wav_select(2080, 2200)  # limits for simple normalization
        mod_spectrum.wav_select(2105, 2165)  # limits for simple normalization
        # norm_mod_spectrum = simple_normalization(mod_spectrum)
        norm_mod_spectrum = spec_local_norm(mod_spectrum, plot=False)

        # Wav select
        norm_mod_spectrum.wav_select(
            np.min(observed_spectra.xaxis) - 5,
            np.max(observed_spectra.xaxis) +
            5)  # +- 5nm of obs for convolution

        # Convolve to resolution of instrument
        conv_mod_spectrum = convolve_models([norm_mod_spectrum],
                                            obs_resolution,
                                            chip_limits=None)[0]

        # Find crosscorrelation RV
        # # Should run though all models and find best rv to apply uniformly
        rvoffset, cc_max = xcorr_peak(observed_spectra,
                                      conv_mod_spectrum,
                                      plot=False)

        # Interpolate to obs
        conv_mod_spectrum.spline_interpolate_to(observed_spectra)
        # conv_mod_spectrum.interpolate1d_to(observed_spectra)
        model_chi_val = chi_squared(observed_spectra.flux,
                                    conv_mod_spectrum.flux)

        # argmax = np.argmax(cc_max)
        model_chisqr_vals[ii] = model_chi_val
        model_xcorr_vals[ii] = cc_max
        model_xcorr_rv_vals[ii] = rvoffset

    logging.debug(pv("model_chisqr_vals"))
    logging.debug(pv("model_xcorr_vals"))
    chisqr_argmin_indx = np.argmin(model_chisqr_vals)
    xcorr_argmax_indx = np.argmax(model_xcorr_vals)

    logging.debug(pv("chisqr_argmin_indx"))
    logging.debug(pv("xcorr_argmax_indx"))

    logging.debug(pv("model_chisqr_vals"))
    print("Minimum  Chisqr value =",
          model_chisqr_vals[chisqr_argmin_indx])  # , min(model_chisqr_vals)
    print("Chisqr at max correlation value",
          model_chisqr_vals[chisqr_argmin_indx])

    print("model_xcorr_vals = {}".format(model_xcorr_vals))
    print("Maximum Xcorr value =",
          model_xcorr_vals[xcorr_argmax_indx])  # , max(model_xcorr_vals)
    print("Xcorr at min Chiqsr", model_xcorr_vals[chisqr_argmin_indx])

    logging.debug(pv("model_xcorr_rv_vals"))
    print("RV at max xcorr =", model_xcorr_rv_vals[xcorr_argmax_indx])
    # print("Median RV val =", np.median(model_xcorr_rv_vals))
    print(pv("model_xcorr_rv_vals[chisqr_argmin_indx]"))
    # print(pv("sp.stats.mode(np.around(model_xcorr_rv_vals))"))

    print("Max Correlation model = ",
          models[xcorr_argmax_indx].split("/")[-2:])
    print("Min Chisqr model = ", models[chisqr_argmin_indx].split("/")[-2:])

    limits = [2110, 2160]
    best_model = models[chisqr_argmin_indx]
    best_model_spec = load_phoenix_spectrum(best_model,
                                            limits=limits,
                                            normalize=True)
    best_model_spec = convolve_models([best_model_spec],
                                      obs_resolution,
                                      chip_limits=None)[0]

    best_xcorr_model = models[xcorr_argmax_indx]
    best_xcorr_model_spec = load_phoenix_spectrum(best_xcorr_model,
                                                  limits=limits,
                                                  normalize=True)
    best_xcorr_model_spec = convolve_models([best_xcorr_model_spec],
                                            obs_resolution,
                                            chip_limits=None)[0]

    close_model_spec = load_phoenix_spectrum(closest_model[0],
                                             limits=limits,
                                             normalize=True)
    close_model_spec = convolve_models([close_model_spec],
                                       obs_resolution,
                                       chip_limits=None)[0]

    plt.plot(observed_spectra.xaxis,
             observed_spectra.flux,
             label="Observations")
    plt.plot(best_model_spec.xaxis, best_model_spec.flux, label="Best Model")
    plt.plot(best_xcorr_model_spec.xaxis,
             best_xcorr_model_spec.flux,
             label="Best xcorr Model")
    plt.plot(close_model_spec.xaxis,
             close_model_spec.flux,
             label="Close Model")
    plt.legend()
    plt.xlim(*limits)
    plt.show()
    print("After plot")
from astropy.io import fits
import matplotlib.pyplot as plt
from mingle.utilities.spectrum_utils import load_spectrum
from simulators.bhm_module import bhm_helper_function

from mingle.utilities.crires_utilities import barycorr_crires_spectrum

# In[8]:

for star in ["hdtest2", "hdtest3"]:

    for chip in range(1, 5):
        obs_name, obs_params, output_prefix = bhm_helper_function(
            star.upper(), "001", chip)
        obs_spec = load_spectrum(obs_name)
        data = fits.getdata(
            "/home/jneal/.handy_spectra/{}-001-mixavg-tellcorr_{}.fits".format(
                star.upper(), chip))

        plt.plot(data["wavelength"], data["flux"] + 0.01, label="fits")
        obs_spec.plot(label="load spec")
        plt.legend()
        plt.show()

# In[9]:

# Test with bayrrcorrection
for star in ["hdtest2", "hdtest3"]:

    for chip in range(1, 5):