def measure_spf_errors(yaml_file_str,
                       Number_rand_mcmc,
                       Norm_90_inplot=1.,
                       save=False):
    """
    take a set of scatt angles and a set of HG parameter and return
    the log of a 2g HG SPF (usefull to fit from a set of points)

    Args:
        yaml_file_str: name of the yaml_ parameter file
        Number_rand_mcmc: number of randomnly selected psf we use to
                        plot the error bars
        Norm_90_inplot: the value at which you want to normalize the spf
                        in the plot ay 90 degree (to re-measure the error
                        bars properly)

    Returns:
        a dic that contains the 'best_spf', 'errorbar_sup',
                                'errorbar_sup', 'errorbar'
    """

    with open(os.path.join('initialization_files', yaml_file_str + '.yaml'),
              'r') as yaml_file:
        params_mcmc_yaml = yaml.load(yaml_file, Loader=yaml.FullLoader)

    dico_return = dict()

    nwalkers = params_mcmc_yaml['NWALKERS']
    n_dim_mcmc = params_mcmc_yaml['N_DIM_MCMC']

    datadir = os.path.join(basedir, params_mcmc_yaml['BAND_DIR'])

    mcmcresultdir = os.path.join(datadir, 'results_MCMC')

    file_prefix = params_mcmc_yaml['FILE_PREFIX']

    SPF_MODEL = params_mcmc_yaml['SPF_MODEL']  #Type of description for the SPF

    name_h5 = file_prefix + "_backend_file_mcmc"

    chain_name = os.path.join(mcmcresultdir, name_h5 + ".h5")

    reader = backends.HDFBackend(chain_name)

    #we only exctract the last itearations, assuming it converged
    chain_flat = reader.get_chain(discard=0, flat=True)
    burnin = np.clip(reader.iteration - 10 * Number_rand_mcmc // nwalkers, 0,
                     None)
    chain_flat = reader.get_chain(discard=burnin, flat=True)

    if (SPF_MODEL == 'hg_1g'):
        norm_chain = np.exp(chain_flat[:, 7])
        g1_chain = chain_flat[:, 8]

        bestmodel_Norm = np.percentile(norm_chain, 50)
        bestmodel_g1 = np.percentile(g1_chain, 50)

        Normalization = Norm_90_inplot

        if save == True:
            Normalization = bestmodel_Norm

        best_hg_mcmc = hg_1g(scattered_angles, bestmodel_g1, Normalization)

    elif SPF_MODEL == 'hg_2g':
        norm_chain = np.exp(chain_flat[:, 7])
        g1_chain = chain_flat[:, 8]
        g2_chain = chain_flat[:, 9]
        alph1_chain = chain_flat[:, 10]

        bestmodel_Norm = np.percentile(norm_chain, 50)
        bestmodel_g1 = np.percentile(g1_chain, 50)
        bestmodel_g2 = np.percentile(g2_chain, 50)
        bestmodel_alpha1 = np.percentile(alph1_chain, 50)
        Normalization = Norm_90_inplot

        if save == True:
            Normalization = bestmodel_Norm

        best_hg_mcmc = hg_2g(scattered_angles, bestmodel_g1, bestmodel_g2,
                             bestmodel_alpha1, Normalization)

    elif SPF_MODEL == 'hg_3g':
        # temporary, the 3g is not finish so we remove some of the
        # chains that are obvisouly bad. When 3g is finally converged,
        # we removed that
        # incl_chain = np.degrees(np.arccos(chain_flat[:, 3]))
        # where_incl_is_ok = np.where(incl_chain > 76)
        # norm_chain = np.exp(chain_flat[where_incl_is_ok, 7]).flatten()
        # g1_chain =  chain_flat[where_incl_is_ok, 8].flatten()
        # g2_chain = chain_flat[where_incl_is_ok, 9].flatten()
        # alph1_chain = chain_flat[where_incl_is_ok, 10].flatten()
        # g3_chain = chain_flat[where_incl_is_ok, 11].flatten()
        # alph2_chain = chain_flat[where_incl_is_ok, 12].flatten()

        # log_prob_samples_flat = reader.get_log_prob(discard=burnin,
        #                                             flat=True)
        # log_prob_samples_flat = log_prob_samples_flat[where_incl_is_ok]
        # wheremin = np.where(
        #         log_prob_samples_flat == np.max(log_prob_samples_flat))
        # wheremin0 = np.array(wheremin).flatten()[0]

        # bestmodel_g1 = g1_chain[wheremin0]
        # bestmodel_g2 = g2_chain[wheremin0]
        # bestmodel_g3 = g3_chain[wheremin0]
        # bestmodel_alpha1 = alph1_chain[wheremin0]
        # bestmodel_alpha2 = alph2_chain[wheremin0]
        # bestmodel_Norm = norm_chain[wheremin0]

        norm_chain = np.exp(chain_flat[:, 7])
        g1_chain = chain_flat[:, 8]
        g2_chain = chain_flat[:, 9]
        alph1_chain = chain_flat[:, 10]
        g3_chain = chain_flat[:, 11]
        alph2_chain = chain_flat[:, 12]

        bestmodel_Norm = np.percentile(norm_chain, 50)
        bestmodel_g1 = np.percentile(g1_chain, 50)
        bestmodel_g2 = np.percentile(g2_chain, 50)
        bestmodel_alpha1 = np.percentile(alph1_chain, 50)
        bestmodel_g3 = np.percentile(g3_chain, 50)
        bestmodel_alpha2 = np.percentile(alph2_chain, 50)
        Normalization = Norm_90_inplot

        # we normalize the best model at 90 either by the value found
        # by the MCMC if we want to save or by the value in the
        # Norm_90_inplot if we want to plot

        Normalization = Norm_90_inplot
        if save == True:
            Normalization = bestmodel_Norm

        best_hg_mcmc = hg_3g(scattered_angles, bestmodel_g1, bestmodel_g2,
                             bestmodel_g3, bestmodel_alpha1, bestmodel_alpha2,
                             Normalization)

    dico_return['best_spf'] = best_hg_mcmc

    random_param_number = np.random.randint(1,
                                            len(g1_chain) - 1,
                                            Number_rand_mcmc)

    if (SPF_MODEL == 'hg_1g') or (SPF_MODEL == 'hg_2g') or (SPF_MODEL
                                                            == 'hg_3g'):
        g1_rand = g1_chain[random_param_number]
        norm_rand = norm_chain[random_param_number]

        if (SPF_MODEL == 'hg_2g') or (SPF_MODEL == 'hg_3g'):
            g2_rand = g2_chain[random_param_number]
            alph1_rand = alph1_chain[random_param_number]

            if SPF_MODEL == 'hg_3g':
                g3_rand = g3_chain[random_param_number]
                alph2_rand = alph2_chain[random_param_number]

    hg_mcmc_rand = np.zeros((len(best_hg_mcmc), len(random_param_number)))

    errorbar_sup = scattered_angles * 0.
    errorbar_inf = scattered_angles * 0.
    errorbar = scattered_angles * 0.

    for num_model in range(Number_rand_mcmc):

        norm_here = norm_rand[num_model]

        # we normalize the random SPF at 90 either by the value of
        # the SPF by the MCMC if we want to save or around the
        # Norm_90_inplot if we want to plot

        Normalization = norm_here * Norm_90_inplot / bestmodel_Norm
        if save == True:
            Normalization = norm_here

        if (SPF_MODEL == 'hg_1g'):
            g1_here = g1_rand[num_model]

        if (SPF_MODEL == 'hg_2g'):
            g1_here = g1_rand[num_model]
            g2_here = g2_rand[num_model]
            alph1_here = alph1_rand[num_model]
            hg_mcmc_rand[:,
                         num_model] = hg_2g(scattered_angles, g1_here, g2_here,
                                            alph1_here, Normalization)

        if SPF_MODEL == 'hg_3g':
            g3_here = g3_rand[num_model]
            alph2_here = alph2_rand[num_model]
            hg_mcmc_rand[:, num_model] = hg_3g(scattered_angles, g1_here,
                                               g2_here, g3_here, alph1_here,
                                               alph2_here, Normalization)

    for anglei in range(len(scattered_angles)):
        errorbar_sup[anglei] = np.max(hg_mcmc_rand[anglei, :])
        errorbar_inf[anglei] = np.min(hg_mcmc_rand[anglei, :])
        errorbar[anglei] = (np.max(hg_mcmc_rand[anglei, :]) -
                            np.min(hg_mcmc_rand[anglei, :])) / 2.

    dico_return['errorbar_sup'] = errorbar_sup
    dico_return['errorbar_inf'] = errorbar_inf
    dico_return['errorbar'] = errorbar

    if save == True:
        savefortext = np.transpose(
            [scattered_angles, best_hg_mcmc, errorbar_sup, errorbar_inf])
        path_and_name_txt = os.path.join(folder_save_pdf,
                                         file_prefix + '_spf.txt')

        np.savetxt(path_and_name_txt, savefortext, delimiter=',',
                   fmt='%10.2f')  # save the array in a txt
    return dico_return
    min_scat = 13.3
    max_scat = 166.7
    nb_random_models = 1000

    folder_save_pdf = os.path.join(basedir, 'Spf_plots_produced')

    scattered_angles = np.arange(np.round(max_scat - min_scat)) + np.round(
        np.min(min_scat))

    # ###########################################################################
    # ### SPF injected real value
    # ###########################################################################
    g1_injected = 0.825
    g2_injected = -0.201
    alph1_injected = 0.298
    injected_hg = hg_2g(scattered_angles, g1_injected, g2_injected,
                        alph1_injected, 1.0)

    # ###########################################################################
    # ### SPHERE H2 exctracted by Milli et al. 2017
    # ###########################################################################
    # sphere spf extracted by Milli et al. 2017
    angles_sphere_extractJulien = np.zeros(49)
    spf_shpere_extractJulien = np.zeros(49)
    errors_sphere_extractJulien = np.zeros(49)

    i = 0
    with open(
            os.path.join(basedir, 'SPHERE_Hdata',
                         'SPHERE_extraction_Milli.csv'), 'rt') as f:
        readercsv = csv.reader(f)
        for row in readercsv:
def best_model_plot(params_mcmc_yaml, hdr):
    """ Make the best models plot and save fits of
        BestModel
        BestModel_Conv
        BestModel_FM
        BestModel_Res

    Args:
        params_mcmc_yaml: dic, all the parameters of the MCMC and klip
                            read from yaml file
        hdr: the header obtained from create_header

    Returns:
        None
    """

    # I am going to plot the model, I need to define some of the
    # global variables to do so

    # global ALIGNED_CENTER, PIXSCALE_INS, DISTANCE_STAR, WHEREMASK2GENERATEDISK, DIMENSION, SPF_MODEL

    diskfit_mcmc.DISTANCE_STAR = params_mcmc_yaml['DISTANCE_STAR']
    diskfit_mcmc.PIXSCALE_INS = params_mcmc_yaml['PIXSCALE_INS']

    quality_plot = params_mcmc_yaml['QUALITY_PLOT']
    file_prefix = params_mcmc_yaml['FILE_PREFIX']
    band_name = params_mcmc_yaml['BAND_NAME']
    name_h5 = file_prefix + '_backend_file_mcmc'

    numbasis = [params_mcmc_yaml['KLMODE_NUMBER']]

    diskfit_mcmc.ALIGNED_CENTER = params_mcmc_yaml['ALIGNED_CENTER']
    diskfit_mcmc.SPF_MODEL = params_mcmc_yaml[
        'SPF_MODEL']  #Type of description for the SPF

    thin = params_mcmc_yaml['THIN']
    burnin = params_mcmc_yaml['BURNIN']

    reader = backends.HDFBackend(os.path.join(mcmcresultdir, name_h5 + '.h5'))
    chain_flat = reader.get_chain(discard=burnin, thin=thin, flat=True)
    log_prob_samples_flat = reader.get_log_prob(discard=burnin,
                                                flat=True,
                                                thin=thin)

    wheremin = np.where(
        log_prob_samples_flat == np.nanmax(log_prob_samples_flat))
    wheremin0 = np.array(wheremin).flatten()[0]
    theta_ml = chain_flat[wheremin0, :]

    if (diskfit_mcmc.SPF_MODEL == 'spf_fix'):

        # we fix the SPF using a HG parametrization with parameters in the init file
        n_points = 21  # odd number to ensure that scattangl=pi/2 is in the list for normalization
        scatt_angles = np.linspace(0, np.pi, n_points)

        # 2g henyey greenstein, normalized at 1 at 90 degrees
        spf_norm90 = hg_2g(np.degrees(scatt_angles),
                           params_mcmc_yaml['g1_init'],
                           params_mcmc_yaml['g2_init'],
                           params_mcmc_yaml['alpha1_init'], 1)
        #measure fo the spline and save as global value

        diskfit_mcmc.F_SPF = phase_function_spline(scatt_angles, spf_norm90)

        #initial poinr (3g spf fitted to Julien's)
        # theta_ml[3] = 0.99997991
        # theta_ml[4] = 0.05085574
        # theta_ml[5] = 0.04775487
        # theta_ml[6] = 0.99949596
        # theta_ml[7] = -0.03397267

    psf = fits.getdata(os.path.join(klipdir, file_prefix + '_SmallPSF.fits'))

    mask2generatedisk = fits.getdata(
        os.path.join(klipdir, file_prefix + '_mask2generatedisk.fits'))

    mask2generatedisk[np.where(mask2generatedisk == 0.)] = np.nan
    diskfit_mcmc.WHEREMASK2GENERATEDISK = (mask2generatedisk !=
                                           mask2generatedisk)

    instrument = params_mcmc_yaml['INSTRUMENT']

    # load the data
    reduced_data = fits.getdata(
        os.path.join(klipdir, file_prefix + '-klipped-KLmodes-all.fits'))[
            0]  ### we take only the first KL mode

    diskfit_mcmc.DIMENSION = reduced_data.shape[1]

    # load the noise
    noise = fits.getdata(os.path.join(klipdir,
                                      file_prefix + '_noisemap.fits')) / 3.

    disk_ml = diskfit_mcmc.call_gen_disk(theta_ml)

    fits.writeto(os.path.join(mcmcresultdir, name_h5 + '_BestModel.fits'),
                 disk_ml,
                 header=hdr,
                 overwrite=True)

    # find the position of the pericenter in the model
    argpe = hdr['ARGPE_MC']
    pa = hdr['PA_MC']

    model_rot = np.clip(
        rotate(disk_ml, argpe + pa, mode='wrap', reshape=False), 0., None)

    argpe_direction = model_rot[int(diskfit_mcmc.ALIGNED_CENTER[0]):,
                                int(diskfit_mcmc.ALIGNED_CENTER[1])]
    radius_argpe = np.where(argpe_direction == np.nanmax(argpe_direction))[0]

    x_peri_true = radius_argpe * np.cos(
        np.radians(argpe + pa + 90))  # distance to star, in pixel
    y_peri_true = radius_argpe * np.sin(
        np.radians(argpe + pa + 90))  # distance to star, in pixel

    #convolve by the PSF
    disk_ml_convolved = convolve(disk_ml, psf, boundary='wrap')

    fits.writeto(os.path.join(mcmcresultdir, name_h5 + '_BestModel_Conv.fits'),
                 disk_ml_convolved,
                 header=hdr,
                 overwrite=True)

    # load the KL numbers
    diskobj = DiskFM(None,
                     numbasis,
                     None,
                     disk_ml_convolved,
                     basis_filename=os.path.join(klipdir,
                                                 file_prefix + '_klbasis.h5'),
                     load_from_basis=True)

    #do the FM
    diskobj.update_disk(disk_ml_convolved)
    disk_ml_FM = diskobj.fm_parallelized()[0]
    ### we take only the first KL modemode

    fits.writeto(os.path.join(mcmcresultdir, name_h5 + '_BestModel_FM.fits'),
                 disk_ml_FM,
                 header=hdr,
                 overwrite=True)

    fits.writeto(os.path.join(mcmcresultdir, name_h5 + '_BestModel_Res.fits'),
                 np.abs(reduced_data - disk_ml_FM),
                 header=hdr,
                 overwrite=True)

    #Measure the residuals
    residuals = reduced_data - disk_ml_FM
    snr_residuals = (reduced_data - disk_ml_FM) / noise

    #Set the colormap
    vmin = 0.3 * np.min(disk_ml_FM)
    vmax = 0.9 * np.max(disk_ml_FM)

    dim_crop_image = int(4 * params_mcmc_yaml['OWA'] // 2) + 1

    disk_ml_crop = crop_center_odd(disk_ml, dim_crop_image)
    disk_ml_convolved_crop = crop_center_odd(disk_ml_convolved, dim_crop_image)
    disk_ml_FM_crop = crop_center_odd(disk_ml_FM, dim_crop_image)
    reduced_data_crop = crop_center_odd(reduced_data, dim_crop_image)
    residuals_crop = crop_center_odd(residuals, dim_crop_image)
    snr_residuals_crop = crop_center_odd(snr_residuals, dim_crop_image)

    caracsize = 40 * quality_plot / 2.

    fig = plt.figure(figsize=(6.4 * 2 * quality_plot, 4.8 * 2 * quality_plot))
    #The data
    ax1 = fig.add_subplot(235)
    cax = plt.imshow(reduced_data_crop + 0.1,
                     origin='lower',
                     vmin=int(np.round(vmin)),
                     vmax=int(np.round(vmax)),
                     cmap=plt.cm.get_cmap('viridis'))

    if file_prefix == 'Hband_hd48524_fake':
        ax1.set_title("Injected Disk (KLIP)",
                      fontsize=caracsize,
                      pad=caracsize / 3.)
    else:
        ax1.set_title("KLIP reduced data",
                      fontsize=caracsize,
                      pad=caracsize / 3.)
    cbar = fig.colorbar(cax, fraction=0.046, pad=0.04)
    cbar.ax.tick_params(labelsize=caracsize * 3 / 4.)
    plt.axis('off')

    #The residuals
    ax1 = fig.add_subplot(233)
    cax = plt.imshow(residuals_crop,
                     origin='lower',
                     vmin=0,
                     vmax=int(np.round(vmax) // 3),
                     cmap=plt.cm.get_cmap('viridis'))
    ax1.set_title("Residuals", fontsize=caracsize, pad=caracsize / 3.)

    # make the colobar ticks integer only for gpi
    if instrument == 'SPHERE':
        tick_int = list(np.arange(int(np.round(vmax) // 2) + 1))
        tick_int_st = [str(i) for i in tick_int]
        cbar = fig.colorbar(cax, ticks=tick_int, fraction=0.046, pad=0.04)
        cbar.ax.tick_params(labelsize=caracsize * 3 / 4.)
        cbar.ax.set_yticklabels(tick_int_st)
    else:
        cbar = fig.colorbar(cax, fraction=0.046, pad=0.04)
        cbar.ax.tick_params(labelsize=caracsize * 3 / 4.)
    plt.axis('off')

    #The SNR of the residuals
    ax1 = fig.add_subplot(236)
    cax = plt.imshow(snr_residuals_crop,
                     origin='lower',
                     vmin=0,
                     vmax=2,
                     cmap=plt.cm.get_cmap('viridis'))
    ax1.set_title("SNR Residuals", fontsize=caracsize, pad=caracsize / 3.)
    cbar = fig.colorbar(cax, ticks=[0, 1, 2], fraction=0.046, pad=0.04)
    cbar.ax.tick_params(labelsize=caracsize * 3 / 4.)
    cbar.ax.set_yticklabels(['0', '1', '2'])
    plt.axis('off')

    # The model
    ax1 = fig.add_subplot(231)
    vmax_model = int(np.round(np.max(disk_ml_crop) / 1.5))
    if instrument == 'SPHERE':

        vmax_model = 433
    cax = plt.imshow(disk_ml_crop,
                     origin='lower',
                     vmin=-2,
                     vmax=vmax_model,
                     cmap=plt.cm.get_cmap('plasma'))
    ax1.set_title("Best Model", fontsize=caracsize, pad=caracsize / 3.)
    cbar = fig.colorbar(cax, fraction=0.046, pad=0.04)
    cbar.ax.tick_params(labelsize=caracsize * 3 / 4.)

    pos_argperi = plt.Circle(
        (x_peri_true + dim_crop_image // 2, y_peri_true + dim_crop_image // 2),
        3,
        color='g',
        alpha=0.8)
    pos_star = plt.Circle((dim_crop_image // 2, dim_crop_image // 2),
                          2,
                          color='r',
                          alpha=0.8)
    ax1.add_artist(pos_argperi)
    ax1.add_artist(pos_star)
    plt.axis('off')

    rect = Rectangle((9.5, 9.5),
                     psf.shape[0],
                     psf.shape[1],
                     edgecolor='white',
                     facecolor='none',
                     linewidth=2)

    disk_ml_convolved_crop[10:10 + psf.shape[0],
                           10:10 + psf.shape[1]] = 2 * vmax * psf

    ax1 = fig.add_subplot(234)
    cax = plt.imshow(disk_ml_convolved_crop,
                     origin='lower',
                     vmin=int(np.round(vmin)),
                     vmax=int(np.round(vmax * 2)),
                     cmap=plt.cm.get_cmap('viridis'))
    ax1.add_patch(rect)

    ax1.set_title("Model Convolved", fontsize=caracsize, pad=caracsize / 3.)
    cbar = fig.colorbar(cax, fraction=0.046, pad=0.04)
    cbar.ax.tick_params(labelsize=caracsize * 3 / 4.)
    plt.axis('off')

    #The FM convolved model
    ax1 = fig.add_subplot(232)
    cax = plt.imshow(disk_ml_FM_crop,
                     origin='lower',
                     vmin=int(np.round(vmin)),
                     vmax=int(np.round(vmax)),
                     cmap=plt.cm.get_cmap('viridis'))
    ax1.set_title("Model Convolved + FM",
                  fontsize=caracsize,
                  pad=caracsize / 3.)
    cbar = fig.colorbar(cax, fraction=0.046, pad=0.04)
    cbar.ax.tick_params(labelsize=caracsize * 3 / 4.)
    plt.axis('off')

    fig.subplots_adjust(hspace=-0.4, wspace=0.2)

    fig.suptitle(band_name + ': Best Model and Residuals',
                 fontsize=5 / 4. * caracsize,
                 y=0.985)

    fig.tight_layout()

    plt.savefig(os.path.join(mcmcresultdir, name_h5 + '_BestModel_Plot.jpg'))
    plt.close()
def measure_spf_errors(params_mcmc_yaml,
                       Number_rand_mcmc,
                       Norm_90_inplot=1.,
                       median_or_max='median',
                       save=False):
    """
    take a set of scatt angles and a set of HG parameter and return
    the log of a 2g HG SPF (usefull to fit from a set of points)

    Args:
        params_mcmc_yaml: dic, all the parameters of the MCMC and klip
                            read from yaml file
        Number_rand_mcmc: number of randomnly selected psf we use to
                        plot the error bars
        Norm_90_inplot: the value at which you want to normalize the spf
                        in the plot ay 90 degree (to re-measure the error
                        bars properly)
        median_or_max: 'median' or 'max' use 50% percentile 
                        or maximum of likelyhood as "best model". default 'median'
        save: 

    Returns:
        a dic that contains the 'best_spf', 'errorbar_sup',
                                'errorbar_sup', 'errorbar'
    """

    dico_return = dict()

    burnin = params_mcmc_yaml['BURNIN']
    nwalkers = params_mcmc_yaml['NWALKERS']
    DATADIR = os.path.join(basedir, params_mcmc_yaml['BAND_DIR'])
    mcmcresultdir = os.path.join(DATADIR, 'results_MCMC')
    file_prefix = params_mcmc_yaml['FILE_PREFIX']
    diskfit_mcmc.SPF_MODEL = params_mcmc_yaml[
        'SPF_MODEL']  #Type of description for the SPF

    name_h5 = file_prefix + "_backend_file_mcmc"
    chain_name = os.path.join(mcmcresultdir, name_h5 + ".h5")
    reader = backends.HDFBackend(chain_name)

    min_scat = 90 - params_mcmc_yaml['inc_init']
    max_scat = 90 + params_mcmc_yaml['inc_init']

    scattered_angles = np.arange(np.round(max_scat - min_scat)) + np.round(
        np.min(min_scat))

    #we only exctract the last itearations, assuming it converged
    chain_flat = reader.get_chain(discard=burnin, flat=True)

    #if we use the argmax(chi2) as the 'best model' we need to find this maximum
    if median_or_max == 'max':
        log_prob_samples_flat = reader.get_log_prob(discard=burnin, flat=True)
        wheremax = np.where(
            log_prob_samples_flat == np.max(log_prob_samples_flat))
        wheremax0 = np.array(wheremax).flatten()[0]

    if (diskfit_mcmc.SPF_MODEL == 'hg_1g'):
        norm_chain = np.exp(chain_flat[:, 7])
        g1_chain = chain_flat[:, 8]

        if median_or_max == 'median':
            bestmodel_Norm = np.percentile(norm_chain, 50)
            bestmodel_g1 = np.percentile(g1_chain, 50)
        elif median_or_max == 'max':
            bestmodel_Norm = norm_chain[wheremax0]
            bestmodel_g1 = g1_chain[wheremax0]

        Normalization = Norm_90_inplot

        if save == True:
            Normalization = bestmodel_Norm

        best_hg_mcmc = hg_1g(scattered_angles, bestmodel_g1, Normalization)

    elif diskfit_mcmc.SPF_MODEL == 'hg_2g':
        norm_chain = np.exp(chain_flat[:, 7])
        g1_chain = chain_flat[:, 8]
        g2_chain = chain_flat[:, 9]
        alph1_chain = chain_flat[:, 10]

        if median_or_max == 'median':
            bestmodel_Norm = np.percentile(norm_chain, 50)
            bestmodel_g1 = np.percentile(g1_chain, 50)
            bestmodel_g2 = np.percentile(g2_chain, 50)
            bestmodel_alpha1 = np.percentile(alph1_chain, 50)

        elif median_or_max == 'max':
            bestmodel_Norm = norm_chain[wheremax0]
            bestmodel_g1 = g1_chain[wheremax0]
            bestmodel_g2 = g2_chain[wheremax0]
            bestmodel_alpha1 = alph1_chain[wheremax0]

        Normalization = Norm_90_inplot

        if save == True:
            Normalization = bestmodel_Norm

        best_hg_mcmc = hg_2g(scattered_angles, bestmodel_g1, bestmodel_g2,
                             bestmodel_alpha1, Normalization)

    elif diskfit_mcmc.SPF_MODEL == 'hg_3g':

        norm_chain = np.exp(chain_flat[:, 7])
        g1_chain = chain_flat[:, 8]
        g2_chain = chain_flat[:, 9]
        alph1_chain = chain_flat[:, 10]
        g3_chain = chain_flat[:, 11]
        alph2_chain = chain_flat[:, 12]

        if median_or_max == 'median':
            bestmodel_Norm = np.percentile(norm_chain, 50)
            bestmodel_g1 = np.percentile(g1_chain, 50)
            bestmodel_g2 = np.percentile(g2_chain, 50)
            bestmodel_alpha1 = np.percentile(alph1_chain, 50)
            bestmodel_g3 = np.percentile(g3_chain, 50)
            bestmodel_alpha2 = np.percentile(alph2_chain, 50)

        elif median_or_max == 'max':
            bestmodel_Norm = norm_chain[wheremax0]
            bestmodel_g1 = g1_chain[wheremax0]
            bestmodel_g2 = g2_chain[wheremax0]
            bestmodel_alpha1 = alph1_chain[wheremax0]
            bestmodel_g3 = g3_chain[wheremax0]
            bestmodel_alpha2 = alph2_chain[wheremax0]

        Normalization = Norm_90_inplot

        # we normalize the best model at 90 either by the value found
        # by the MCMC if we want to save or by the value in the
        # Norm_90_inplot if we want to plot

        Normalization = Norm_90_inplot
        if save == True:
            Normalization = bestmodel_Norm

        best_hg_mcmc = hg_3g(scattered_angles, bestmodel_g1, bestmodel_g2,
                             bestmodel_g3, bestmodel_alpha1, bestmodel_alpha2,
                             Normalization)

    dico_return['best_spf'] = best_hg_mcmc

    random_param_number = np.random.randint(1,
                                            len(g1_chain) - 1,
                                            Number_rand_mcmc)

    if (diskfit_mcmc.SPF_MODEL
            == 'hg_1g') or (diskfit_mcmc.SPF_MODEL
                            == 'hg_2g') or (diskfit_mcmc.SPF_MODEL == 'hg_3g'):
        g1_rand = g1_chain[random_param_number]
        norm_rand = norm_chain[random_param_number]

        if (diskfit_mcmc.SPF_MODEL == 'hg_2g') or (diskfit_mcmc.SPF_MODEL
                                                   == 'hg_3g'):
            g2_rand = g2_chain[random_param_number]
            alph1_rand = alph1_chain[random_param_number]

            if diskfit_mcmc.SPF_MODEL == 'hg_3g':
                g3_rand = g3_chain[random_param_number]
                alph2_rand = alph2_chain[random_param_number]

    hg_mcmc_rand = np.zeros((len(best_hg_mcmc), len(random_param_number)))

    errorbar_sup = scattered_angles * 0.
    errorbar_inf = scattered_angles * 0.
    errorbar = scattered_angles * 0.

    for num_model in range(Number_rand_mcmc):

        norm_here = norm_rand[num_model]

        # we normalize the random SPF at 90 either by the value of
        # the SPF by the MCMC if we want to save or around the
        # Norm_90_inplot if we want to plot

        Normalization = norm_here * Norm_90_inplot / bestmodel_Norm
        if save == True:
            Normalization = norm_here

        if (diskfit_mcmc.SPF_MODEL == 'hg_1g'):
            g1_here = g1_rand[num_model]

        if (diskfit_mcmc.SPF_MODEL == 'hg_2g'):
            g1_here = g1_rand[num_model]
            g2_here = g2_rand[num_model]
            alph1_here = alph1_rand[num_model]
            hg_mcmc_rand[:,
                         num_model] = hg_2g(scattered_angles, g1_here, g2_here,
                                            alph1_here, Normalization)

        if diskfit_mcmc.SPF_MODEL == 'hg_3g':
            g3_here = g3_rand[num_model]
            alph2_here = alph2_rand[num_model]
            hg_mcmc_rand[:, num_model] = hg_3g(scattered_angles, g1_here,
                                               g2_here, g3_here, alph1_here,
                                               alph2_here, Normalization)

    for anglei in range(len(scattered_angles)):
        errorbar_sup[anglei] = np.max(hg_mcmc_rand[anglei, :])
        errorbar_inf[anglei] = np.min(hg_mcmc_rand[anglei, :])
        errorbar[anglei] = (np.max(hg_mcmc_rand[anglei, :]) -
                            np.min(hg_mcmc_rand[anglei, :])) / 2.

    dico_return['errorbar_sup'] = errorbar_sup
    dico_return['errorbar_inf'] = errorbar_inf
    dico_return['errorbar'] = errorbar

    dico_return['all_rando_spfs'] = hg_mcmc_rand

    dico_return['scattered_angles'] = scattered_angles

    return dico_return
def compare_injected_spfs_plot(params_mcmc_yaml):
    ####################################################################################
    ## injected spf plot
    ####################################################################################

    fill_or_all = 'all'
    Number_rand_mcmc = 50

    color0 = 'black'
    color1 = '#3B73FF'
    color2 = '#ED0052'
    color3 = '#00AF64'
    color4 = '#FFCF0B'

    band_name = params_mcmc_yaml['BAND_NAME']

    file_prefix = params_mcmc_yaml['FILE_PREFIX']
    name_pdf = file_prefix + '_comparison_spf.pdf'
    plt.figure()

    spf_fake_recovered = measure_spf_errors(params_mcmc_yaml,
                                            Number_rand_mcmc,
                                            median_or_max='max')

    scattered_angles = spf_fake_recovered['scattered_angles']

    injected_hg = hg_2g(scattered_angles, params_mcmc_yaml['g1_init'],
                        params_mcmc_yaml['g2_init'],
                        params_mcmc_yaml['alpha1_init'], 1.0)

    if fill_or_all == 'fill':
        plt.fill_between(scattered_angles,
                         spf_fake_recovered['errorbar_sup'],
                         spf_fake_recovered['errorbar_inf'],
                         facecolor=color3,
                         alpha=0.1)
    elif fill_or_all == 'all':

        for num_model in range(Number_rand_mcmc):
            plt.plot(scattered_angles,
                     spf_fake_recovered['all_rando_spfs'][:, num_model],
                     linewidth=1,
                     color=color3,
                     alpha=0.1)

    plt.plot(scattered_angles,
             spf_fake_recovered['best_spf'],
             linewidth=3,
             color=color3,
             label="SPF Recoreved After MCMC")

    plt.plot(scattered_angles,
             injected_hg,
             linewidth=2,
             linestyle='-.',
             color=color2,
             label="Fiducial 'Zodiacal light' SPF (Hong et al. 1985)")

    handles, labels = plt.gca().get_legend_handles_labels()
    order = [1, 0]
    plt.legend([handles[idx] for idx in order], [labels[idx] for idx in order])

    plt.yscale('log')

    plt.ylim(bottom=0.3, top=30)
    plt.xlim(left=0, right=180)
    plt.xlabel('Scattering angles')
    plt.ylabel('Normalized total intensity')
    plt.title(band_name + ' SPF')

    plt.tight_layout()

    if "32297" in band_name:
        # for HD32297 add grey parts on the plots where the disk is hidden
        # behind the FP mask or for the back scattering part
        plt.axvspan(0, 7, alpha=0.2, facecolor='grey')
        plt.text(18, 0.4, 'Behind FPM', fontsize=10)
        plt.arrow(17,
                  0.42,
                  -8,
                  0,
                  head_width=0.03,
                  head_length=2,
                  fc='k',
                  ec='k')

        plt.axvspan(90, 180, alpha=0.2, facecolor='grey')
        plt.text(110, 0.4, 'HD 32297 Back Side', fontsize=10)

    plt.savefig(os.path.join(mcmcresultdir, name_pdf))

    plt.close()