def test_arbitrary_rescale_arb_norm_is_np_arange(start, stop, step):
    """Checking it is an array."""
    model = np.random.random((4, 3, 5))
    new_model, arb_norm = arbitrary_rescale(model, start, stop, step)
    assert np.all(arb_norm == np.arange(start, stop, step))
    # check shape of new_model
    assert new_model.shape == (*model.shape, len(arb_norm))
def test_arbitrary_rescale():
    model = np.arange(0, 12).reshape(3, 4)
    print(model)
    new_model, arb_norm = arbitrary_rescale(model, 1, 2.1, 0.5)

    assert new_model.shape == (3, 4, 3)
    assert np.allclose(arb_norm, np.asarray([1.0, 1.5, 2.0]))
    assert np.allclose(new_model[:, :, 0], model)
    assert np.allclose(new_model[:, :, 1], model * 1.5)
    assert np.allclose(new_model[:, :, 2], model * 2)
def iam_wrapper(num, params1, model2_pars, rvs, gammas, obs_spec, norm=False,
                verbose=False, save_only=True, chip=None, prefix=None, errors=None,
                area_scale=True, wav_scale=True, grid_slices=False, norm_method="scalar",
                fudge=None):
    """Wrapper for iteration loop of iam. params1 fixed, model2_pars are many.

    fudge is multiplicative on companion spectrum.
    """
    if prefix is None:
        sf = os.path.join(
            simulators.paths["output_dir"], obs_spec.header["OBJECT"].upper(),
            "iam_{0}_{1}-{2}_part{6}_host_pars_[{3}_{4}_{5}].csv".format(
                obs_spec.header["OBJECT"].upper(), int(obs_spec.header["MJD-OBS"]), chip,
                params1[0], params1[1], params1[2], num))
        prefix = os.path.join(
            simulators.paths["output_dir"], obs_spec.header["OBJECT"].upper())  # for fudge

    else:
        sf = "{0}_part{4}_host_pars_[{1}_{2}_{3}].csv".format(
            prefix, params1[0], params1[1], params1[2], num)
    save_filename = sf

    if os.path.exists(save_filename) and save_only:
        print("'{0}' exists, so not repeating calculation.".format(save_filename))
        return None
    else:
        if not save_only:
            iam_grid_chisqr_vals = np.empty(len(model2_pars))
        for jj, params2 in enumerate(model2_pars):
            if verbose:
                print(("Starting iteration with parameters: "
                       "{0}={1},{2}={3}").format(num, params1, jj, params2))

            # Main Part
            rv_limits = observation_rv_limits(obs_spec, rvs, gammas)

            obs_spec = obs_spec.remove_nans()
            assert ~np.any(np.isnan(obs_spec.flux)), "Observation has nan"

            # Load phoenix models and scale by area and wavelength limit
            mod1_spec, mod2_spec = \
                prepare_iam_model_spectra(params1, params2, limits=rv_limits,
                                          area_scale=area_scale, wav_scale=wav_scale)
            # Estimated flux ratio from models
            inherent_alpha = continuum_alpha(mod1_spec, mod2_spec, chip)

            # Combine model spectra with iam model
            mod1_spec.plot(label=params1)
            mod2_spec.plot(label=params2)
            plt.close()

            if fudge or (fudge is not None):
                fudge_factor = float(fudge)
                mod2_spec.flux *= fudge_factor  # fudge factor multiplication
                mod2_spec.plot(label="fudged {0}".format(params2))
                plt.title("fudges models")
                plt.legend()

                fudge_prefix = os.path.basename(os.path.normpath(prefix))
                fname = os.path.join(simulators.paths["output_dir"],
                                     obs_spec.header["OBJECT"].upper(), "iam", "fudgeplots",
                                     "{1}_fudged_model_spectra_factor={0}_num={2}_iter_{3}.png".format(fudge_factor,
                                                                                                       fudge_prefix,
                                                                                                       num, jj))
                plt.savefig(fname)
                plt.close()
                warnings.warn("Using a fudge factor = {0}".format(fudge_factor))

            iam_grid_func = inherent_alpha_model(mod1_spec.xaxis, mod1_spec.flux, mod2_spec.flux,
                                                 rvs=rvs, gammas=gammas)
            iam_grid_models = iam_grid_func(obs_spec.xaxis)

            # Continuum normalize all iam_gird_models
            def axis_continuum(flux):
                """Continuum to apply along axis with predefined variables parameters."""
                return continuum(obs_spec.xaxis, flux, splits=20, method="exponential", top=20)

            iam_grid_continuum = np.apply_along_axis(axis_continuum, 0, iam_grid_models)

            iam_grid_models = iam_grid_models / iam_grid_continuum

            # RE-NORMALIZATION
            if chip == 4:
                # Quadratically renormalize anyway
                obs_spec = renormalization(obs_spec, iam_grid_models, normalize=True, method="quadratic")
            obs_flux = renormalization(obs_spec, iam_grid_models, normalize=norm, method=norm_method)

            if grid_slices:
                # Long execution plotting.
                plot_iam_grid_slices(obs_spec.xaxis, rvs, gammas, iam_grid_models,
                                     star=obs_spec.header["OBJECT"].upper(),
                                     xlabel="wavelength", ylabel="rv", zlabel="gamma",
                                     suffix="iam_grid_models", chip=chip)

            old_shape = iam_grid_models.shape
            # Arbitrary_normalization of observation
            iam_grid_models, arb_norm = arbitrary_rescale(iam_grid_models,
                                                          *simulators.sim_grid["arb_norm"])
            # print("Arbitrary Normalized iam_grid_model shape.", iam_grid_models.shape)
            assert iam_grid_models.shape == (*old_shape, len(arb_norm))

            # Calculate Chi-squared
            obs_flux = np.expand_dims(obs_flux, -1)  # expand on last axis to match rescale
            iam_norm_grid_chisquare = chi_squared(obs_flux, iam_grid_models, error=errors)

            # Take minimum chi-squared value along Arbitrary normalization axis
            iam_grid_chisquare, arbitrary_norms = arbitrary_minimums(iam_norm_grid_chisquare, arb_norm)

            npix = obs_flux.shape[0]  # Number of pixels used

            if grid_slices:
                # Long execution plotting.
                plot_iam_grid_slices(rvs, gammas, arb_norm, iam_norm_grid_chisquare,
                                     star=obs_spec.header["OBJECT"].upper(),
                                     xlabel="rv", ylabel="gamma", zlabel="Arbitrary Normalization",
                                     suffix="iam_grid_chisquare", chip=chip)

            if not save_only:
                iam_grid_chisqr_vals[jj] = iam_grid_chisquare.ravel()[np.argmin(iam_grid_chisquare)]

            save_full_iam_chisqr(save_filename, params1, params2,
                                 inherent_alpha, rvs, gammas,
                                 iam_grid_chisquare, arbitrary_norms, npix, verbose=verbose)
        if save_only:
            return None
        else:
            return iam_grid_chisqr_vals
def test_arbitrary_rescale_on_different_size_models(shape):
    """Testing with specific examples."""
    model = np.random.random(shape)
    new_model, arb_norm = arbitrary_rescale(model, 0.8, 1.2, 0.05)
    assert new_model.shape == (*model.shape, len(arb_norm))
def test_arbitrary_rescale_arb_norm_examples(start, stop, step, expected):
    """testing with specfic examples."""
    model = np.random.random((3, 2))
    __, arb_norm = arbitrary_rescale(model, start, stop, step)
    assert np.allclose(arb_norm, expected)