Exemple #1
0
 def compute_hpd(plotters, iteration=-1, credible_interval=.98):
     x = plotters[0][2].flatten()[:iteration]
     y = plotters[1][2].flatten()[:iteration]
     x_min, x_max = hpd(x, credible_interval=credible_interval)
     y_min, y_max = hpd(y, credible_interval=credible_interval)
     return x_min, x_max, y_min, y_max
Exemple #2
0
def save2npy_point_estimate_by_subj(chains_list,
                                    all_vp_list,
                                    def_args,
                                    credible_interval,
                                    fname,
                                    CI=False,
                                    perc_last_samples=75,
                                    logzeta=False):
    """
    saves point estimates as a dictionary by subject. A point estimate is the
    center of the credible interval. Th final dictionary will have a key for
    each subject.

    Parameters
    ----------
    chains_list : list
        list of pydream estimations (each element is one subject's estimation)
    all_vp_list : list
        list of all subject indexes
    def_args : dict
        dictionary of default arguments (non estimated parameters)
    credible_interval : float
        credible interval (example 0.5 if 50% of datapoints should be in the
        interval)
    fname : str
        file name
    CI : bool
        is the credible interval returned in the dictionary?
    perc_last_samples : int
        percentage of samples that are not considered burn in
    logzeta : bool
        separately convert log zeta in the ouput

    Returns
    -------
    dict
        with subjects as keys and chains as values
    """
    from collections import OrderedDict
    import numpy as np
    from arviz.stats import hpd

    vp_params = OrderedDict({})
    for chains, vp in list(zip(chains_list, all_vp_list)):
        if chains is None:
            continue
        param_dict1 = OrderedDict({})
        param_ix = 0
        for param_name in list(def_args.keys()):
            if def_args[param_name] is None:
                chain_len = chains.shape[1]
                samp_ix = int(chain_len -
                              (chain_len * perc_last_samples / 100))
                chain1 = chains[0, samp_ix:, param_ix]
                chain2 = chains[1, samp_ix:, param_ix]
                chain3 = chains[2, samp_ix:, param_ix]
                allchains = np.hstack((chain1, chain2, chain3))
                # Highst Posterior Density
                if logzeta:
                    if param_name == "zeta":
                        allchains = np.log10(allchains)
                hpd_all = hpd(allchains, credible_interval)
                mpde = (hpd_all[1] + hpd_all[0]) / 2
                if CI:
                    param_dict1[param_name] = [mpde, hpd_all[1] - mpde]
                else:
                    param_dict1[param_name] = mpde
                param_ix += 1
            else:
                param_dict1[param_name] = def_args[param_name]

        vp_params[vp] = param_dict1
    if logzeta:
        np.save("logz" + fname, vp_params)
        return vp_params
    np.save(fname, vp_params)
    return vp_params
Exemple #3
0
def save2pd_overall_point_estimates(chains_list,
                                    all_vp_list,
                                    def_args,
                                    priors,
                                    sw,
                                    credible_interval,
                                    fname,
                                    perc_last_samples=75,
                                    logzeta=False):
    """
    saves point estimates apandas table averaged over subjects. A point
    estimate is the center of the credible interval.

    Parameters
    ----------
    chains_list : list
        list of pydream estimations (each element is one subject's estimation)
    all_vp_list : list
        list of all subject indexes
    def_args : dict
        dictionary of default arguments (non estimated parameters)
    credible_interval : float
        credible interval (example 0.5 if 50% of datapoints should be in the
        interval)
    fname : str
        file name
    CI : bool
        is the credible interval returned in the dictionary?
    perc_last_samples : int
        percentage of samples that are not considered burn in
    logzeta : bool
        separately convert log zeta in the ouput

    Returns
    -------
    pandas table
    """
    from collections import OrderedDict
    import numpy as np
    from arviz.stats import hpd
    import pandas as pd

    param_ix = 0
    rows_list = []
    for param_name in def_args.keys():
        if param_name in list(priors.keys()):
            print(param_name)
            allvps = []
            for vp in all_vp_list:
                #print(vp)
                chains = chains_list[vp]
                if chains is None:
                    continue
                chain_len = chains.shape[1]
                samp_ix = int(chain_len -
                              (chain_len * perc_last_samples / 100))
                chain1 = chains[0, samp_ix:, param_ix]
                chain2 = chains[1, samp_ix:, param_ix]
                chain3 = chains[2, samp_ix:, param_ix]
                allvps.extend(chain1)
                allvps.extend(chain2)
                allvps.extend(chain3)
            allvps = np.array(allvps)
            #print(len(allvps))
            if logzeta:
                if param_name == "zeta":
                    allvps = np.log10(allvps)
                    type(allvps)

            hpd_all = hpd(allvps, credible_interval)
            mpde = (hpd_all[1] + hpd_all[0]) / 2
            #print(hpd_all)
            #break
            dict1 = {
                "param_name": param_name,
                "mpde": mpde,
                "interv": mpde - hpd_all[0],
                "left": hpd_all[0],
                "right": hpd_all[1]
            }
            param_ix += 1
        else:
            dict1 = {
                "param_name": param_name,
                "mpde": def_args[param_name],
                "interv": np.nan,
                "left": np.nan,
                "right": np.nan
            }
        rows_list.append(dict1)
    rows_list.append({
        "param_name": "tau_pre",
        "mpde": sw.tau_pre,
        "interv": np.nan,
        "left": np.nan,
        "right": np.nan
    })
    rows_list.append({
        "param_name": "tau_post",
        "mpde": sw.tau_post,
        "interv": np.nan,
        "left": np.nan,
        "right": np.nan
    })
    rows_list.append({
        "param_name": "foR_size",
        "mpde": sw.foR_size,
        "interv": np.nan,
        "left": np.nan,
        "right": np.nan
    })
    # rows_list.append({"param_name": "chi", "mpde": sw.chii,
    # "interv": np.nan, "left": np.nan, "right": np.nan})
    # rows_list.append({"param_name": "psi", "mpde": sw.ompfactor,
    # interv": np.nan, "left": np.nan, "right": np.nan})

    hpde_df = pd.DataFrame(rows_list)

    if logzeta:
        hpde_df.to_csv("logz" + fname)
        return hpde_df
    hpde_df.to_csv(fname)
    return hpde_df
Exemple #4
0
def save2pd_subj_point_estimates(chains_list,
                                 all_vp_list,
                                 priors,
                                 credible_interval,
                                 fname,
                                 perc_last_samples=75):
    """
    saves point estimates apandas table with separate fits for each subject. A
    point estimate is the center of the credible interval.

    Parameters
    ----------
    chains_list : list
        list of pydream estimations (each element is one subject's estimation)
    all_vp_list : list
        list of all subject indexes
    def_args : dict
        dictionary of default arguments (non estimated parameters)
    credible_interval : float
        credible interval (example 0.5 if 50% of datapoints should be in the
        interval)
    fname : str
        file name
    CI : bool
        is the credible interval returned in the dictionary?
    perc_last_samples : int
        percentage of samples that are not considered burn in

    Returns
    -------
    pandas table
    """
    from collections import OrderedDict
    import numpy as np
    from arviz.stats import hpd
    import pandas as pd
    rows_list = []
    #vp_id = 0
    for vp in all_vp_list:
        chains = chains_list[vp]
        if chains is None:
            continue
        tmp_df = pd.DataFrame()
        param_ix = 0
        for param_name in list(priors.keys()):
            chain_len = chains.shape[1]
            samp_ix = int(chain_len - (chain_len * perc_last_samples / 100))
            chain1 = chains[0, samp_ix:, param_ix]
            chain2 = chains[1, samp_ix:, param_ix]
            chain3 = chains[2, samp_ix:, param_ix]
            allchains = []
            allchains = np.hstack((chain1, chain2, chain3))
            hpd_all = hpd(allchains, credible_interval)
            mpde = (hpd_all[1] + hpd_all[0]) / 2

            dict1 = {
                "vp": vp,
                "param_name": param_name,
                "mpde": mpde,
                "interv": mpde - hpd_all[0],
                "left": hpd_all[0],
                "right": hpd_all[1]
            }
            rows_list.append(dict1)

            param_ix += 1
        tmp_df['vp'] = vp
    # vp_id +=1
    hpde_df = pd.DataFrame(rows_list)
    hpde_df.to_csv(fname)
    return hpde_df
Exemple #5
0
def save2dict_overall_point_estimates(chains_list,
                                      all_vp_list,
                                      def_args,
                                      priors,
                                      sw,
                                      credible_interval,
                                      fname,
                                      perc_last_samples=75):
    """
    saves point estimates as a dictionary averaged over subjects. A point
    estimate is the center of the credible interval. The final dictionary has
    a key for each parameter.

    Parameters
    ----------
    chains_list : list
        list of pydream estimations (each element is one subject's estimation)
    all_vp_list : list
        list of all subject indexes
    def_args : dict
        dictionary of default arguments (non estimated parameters)
    credible_interval : float
        credible interval (example 0.5 if 50% of datapoints should be in the
        interval)
    fname : str
        file name
    CI : bool
        is the credible interval returned in the dictionary?
    perc_last_samples : int
        percentage of samples that are not considered burn in

    Returns
    -------
    dict
        with subjects as keys and chains as values
    """
    from collections import OrderedDict
    import numpy as np
    from arviz.stats import hpd
    import pandas as pd

    param_ix = 0
    dict1 = {}
    for param_name in def_args.keys():
        if param_name in list(priors.keys()):
            print(param_name)
            allvps = []
            for vp in all_vp_list:
                #print(vp)
                chains = chains_list[vp]
                if chains is None:
                    continue
                chain_len = chains.shape[1]
                samp_ix = int(chain_len -
                              (chain_len * perc_last_samples / 100))
                chain1 = chains[0, samp_ix:, param_ix]
                chain2 = chains[1, samp_ix:, param_ix]
                chain3 = chains[2, samp_ix:, param_ix]
                allvps.extend(chain1)
                allvps.extend(chain2)
                allvps.extend(chain3)
            allvps = np.array(allvps)
            #print(len(allvps))
            hpd_all = hpd(allvps, credible_interval)
            mpde = (hpd_all[1] + hpd_all[0]) / 2
            #print(hpd_all)
            #break
            dict1[param_name] = mpde
            param_ix += 1
        else:
            dict1[param_name] = def_args[param_name]
    np.save(fname, dict1)
    return dict1