예제 #1
0
def save_traceplot(fit, param_names=None, info_path=InfoPath(),
                   traceplot_params=TraceplotParams()):
    """
    Saves traceplots form the fit.

    Parameters
    ----------

    fit : cmdstanpy.stanfit.CmdStanMCMC

        Contains the samples from cmdstanpy.

    param_names : list of str

        Names of parameters to plot.
    """

    info_path.set_codefile()
    info_path = InfoPath(**info_path.__dict__)

    figures_and_axes = traceplot(
        fit, param_names=param_names, params=traceplot_params)

    base_name = info_path.base_name or "traceplot"
    info_path.extension = info_path.extension or 'pdf'

    for i, figure_and_axis in enumerate(figures_and_axes):
        info_path.base_name = f'{base_name}_{i + 1:02d}'
        plot_path = get_info_path(info_path)
        fig = figure_and_axis[0]
        fig.savefig(plot_path, dpi=info_path.dpi)
        plt.close(fig)
예제 #2
0
def save_compare_psis_tree_plot_from_compared(
    compared,
    tree_plot_params: TreePlotParams = TreePlotParams(),
    info_path=InfoPath()):
    """
    Make a plot that compares models using PSIS
    and save it to a file.

    Parameters
    ----------

    compared : list of WaicModelCompared
        List of compared models.

    info_path : InfoPath
        Determines the location of the output file.
    """

    info_path.set_codefile()
    info_path = InfoPath(**info_path.__dict__)

    fig, ax = compare_psis_tree_plot_from_compared(
        compared=compared, tree_plot_params=tree_plot_params)

    info_path.base_name = info_path.base_name or 'compare_psis'
    info_path.extension = info_path.extension or 'pdf'
    the_path = get_info_path(info_path)
    fig.savefig(the_path, dpi=info_path.dpi)
    plt.close(fig)
예제 #3
0
파일: waic.py 프로젝트: evgenyneu/tarpan
def save_compare_waic_txt_from_compared(compared, info_path=InfoPath()):
    """
    Compare models using WAIC (Widely Aplicable Information Criterion)
    to see which models are more compatible with the data. The result
    is saved in a text file.

    Parameters
    ----------

    compared : list of WaicModelCompared
        List of compared models.

    info_path : InfoPath
        Determines the location of the output file.
    """

    info_path.set_codefile()
    info_path = InfoPath(**info_path.__dict__)
    info_path.base_name = info_path.base_name or "compare_waic"
    info_path.extension = 'txt'
    df = waic_compared_to_df(compared)
    table = tabulate(df, headers=list(df), floatfmt=".2f", tablefmt="pipe")
    path = get_info_path(info_path)

    with open(path, "w") as text_file:
        print(table, file=text_file)
예제 #4
0
def test_get_info_path():
    info_path = InfoPath()
    info_path.base_name = 'my_basename'
    info_path.extension = 'test_extension'
    result = get_info_path(info_path)

    assert 'tarpan/shared/model_info/info_path_test/my_basename\
.test_extension' in result
예제 #5
0
def save_compare_parameters(
    models,
    labels,
    extra_values=[],
    type: CompareParametersType = CompareParametersType.TEXT,
    param_names=None,
    info_path=InfoPath(),
    summary_params=SummaryParams()):
    """
    Saves a text table that compares model parameters

    Parameters
    ----------

    models : list Panda's data frames
        List of data frames for each model, containg sample values for
        multiple parameters (one parameter is one data frame column).
        Supply multiple data frames to compare parameter distributions.

    labels : list of str
        Names of the models in `models` list.

    extra_values : list of dict
        Additional values to be shown in the table:

        [
            {
                "mu": 2.3,
                "sigma": 3.3
            }
        ]

    type : CompareParametersType
        Format of values in the text table.

    param_names : list of str
        Names of parameters. Include all if None.

    info_path : InfoPath
        Path information for creating summaries.
    """

    info_path.set_codefile()

    df, table = compare_parameters(models=models,
                                   labels=labels,
                                   extra_values=extra_values,
                                   type=type,
                                   param_names=param_names,
                                   summary_params=summary_params)

    info_path = InfoPath(**info_path.__dict__)
    info_path.base_name = info_path.base_name or "parameters_compared"
    info_path.extension = 'txt'
    path_to_txt = get_info_path(info_path)

    with open(path_to_txt, "w") as text_file:
        print(table, file=text_file)
예제 #6
0
파일: summary.py 프로젝트: evgenyneu/tarpan
def save_summary(fit,
                 param_names=None,
                 info_path=InfoPath(),
                 summary_params=SummaryParams()):
    """
    Saves statistical summary of the samples using mean, std, mode, hpdi.

    Parameters
    ----------

    fit : cmdstanpy.stanfit.CmdStanMCMC

        Contains the samples from cmdstanpy.

    param_names : list of str

        Names of parameters to be included in the summary. Include all if None.

    info_path : InfoPath

        Path information for creating summaries.

    Returns
    -------
    dict:
        df:
            Panda's data frame containing the summary
        table: str
            Summary table in text format.
        samples: Panda's data frame
            Combined samples from all chains
        path_txt: str
            Path to the text summary
        path_csv: str
            Path to summary in CSV format
    """

    info_path.set_codefile()
    info_path = InfoPath(**info_path.__dict__)

    df_summary, summary, samples = make_summary(fit,
                                                param_names=param_names,
                                                summary_params=summary_params)

    output = save_summary_to_disk(df_summary, summary, info_path=info_path)

    return {
        "df": df_summary,
        "table": summary,
        "samples": samples,
        "path_txt": output["path_txt"],
        "path_csv": output["path_csv"]
    }
예제 #7
0
파일: analyse.py 프로젝트: evgenyneu/tarpan
def save_diagnostic(fit, info_path=InfoPath()):
    """
    Save diagnostic information from the fit into a text file.
    """

    info_path = InfoPath(**info_path.__dict__)
    info_path.base_name = info_path.base_name or 'diagnostic'
    info_path.extension = 'txt'
    file_path = get_info_path(info_path)

    with open(file_path, "w") as text_file:
        print(fit.diagnose(), file=text_file)
예제 #8
0
def save_histogram(fit,
                   param_names=None,
                   info_path=InfoPath(),
                   summary_params=SummaryParams(),
                   histogram_params=HistogramParams()):
    """
    Make histograms of parameter distributions.

    Parameters
    ----------

    fit : cmdstanpy.stanfit.CmdStanMCMC
        Samples from cmdstanpy.

    param_names : list of str
        Names of parameters to be included in the summar. Include all if None.

    info_path : InfoPath
        Path information for creating summaries.

    """

    info_path.set_codefile()

    df_summary, summary, samples = make_summary(
        fit, param_names=param_names, summary_params=summary_params)

    save_histogram_from_summary(
        samples, df_summary, param_names=param_names,
        info_path=info_path, summary_params=summary_params,
        histogram_params=histogram_params)
예제 #9
0
def save_analysis(samples, param_names=None, info_path=InfoPath(),
                  summary_params=SummaryParams()):
    """
    Creates all analysis files at once: summary, trace and posterior.

    Parameters
    -----------

    samples : Panda's DataFrame

        Each column contains samples from posterior distribution.

    param_names : list of str

        Names of parameters to plot. Plot all parameters if None.
    """

    info_path.set_codefile()

    summary = save_summary(
        samples, param_names=param_names, info_path=info_path,
        summary_params=summary_params)

    make_tree_plot(summary['df'], param_names=param_names, info_path=info_path,
                   summary_params=summary_params)

    save_histogram_from_summary(
        samples, summary['df'], param_names=param_names,
        info_path=info_path, summary_params=summary_params)

    save_pair_plot(samples, param_names=param_names, info_path=info_path)
예제 #10
0
def save_pair_plot(fit,
                   param_names=None,
                   info_path=InfoPath(),
                   pair_plot_params=PairPlotParams()):
    """
    Save a pair plot of distributions of parameters. It helps
    to see correlations between parameters and spot funnel
    shaped distributions that can result in sampling problems.

    Parameters
    ----------

    fit : cmdstanpy.stanfit.CmdStanMCMC
        Samples from cmdstanpy.

    param_names : list of str
        Names of parameters. Include all if None.

    info_path : InfoPath
        Path information for creating summaries.

    """

    info_path.set_codefile()
    param_names = filter_param_names(fit.column_names, param_names)
    samples = fit.get_drawset(params=param_names)

    shared_save_pair_plot(samples, param_names=param_names,
                          info_path=info_path,
                          pair_plot_params=pair_plot_params)
예제 #11
0
def save_summary(samples, param_names=None, info_path=InfoPath(),
                 summary_params=SummaryParams()):
    """
    Generates and saves statistical summary of the samples using mean, std, mode, hpdi.

    Parameters
    ----------

    samples : Panda's dataframe

        Each column contains samples for a parameter.

    param_names : list of str

        Names of parameters to be included in the summary. Include all if None.

    info_path : InfoPath

        Path information for creating summaries.
    """

    info_path.set_codefile()
    column_names = list(samples)
    param_names = filter_param_names(column_names, param_names)
    samples = samples[param_names]  # Filter by column names
    df_summary, table = sample_summary(samples, params=summary_params)
    return save_summary_to_disk(df_summary, table, info_path)
예제 #12
0
def save_psis_pareto_k_plot(
    fit,
    name,
    lpd_column_name=LPD_COLUMN_NAME_DEFAULT,
    pareto_k_plot_params: ParetoKPlotParams = ParetoKPlotParams(),
    info_path=InfoPath()):
    """
    Make a plot that shows values of Pareto K index generated by PSIS
    method. This is used to see if there are values of K higher than 0.7,
    which means that there are possible outliers and PSIS calculations
    may not be reliable.

    Parameters
    ----------

    fit : cmdstanpy.stanfit.CmdStanMCMC
        Contains the samples from cmdstanpy.

    name : str
        Model name.

    info_path : InfoPath
        Determines the location of the output file.
    """

    info_path.set_codefile()
    psis_data = psis(fit=fit, lpd_column_name=lpd_column_name)

    save_psis_pareto_k_plot_from_psis_data(
        psis_data=psis_data,
        name=name,
        pareto_k_plot_params=pareto_k_plot_params,
        info_path=info_path)
예제 #13
0
def save_psis_pareto_k_plot_from_compared(
    compared,
    pareto_k_plot_params: ParetoKPlotParams = ParetoKPlotParams(),
    info_path=InfoPath()):
    """
    Make multiple plots that show values of Pareto K index generated by PSIS
    method for each compared model.

    Parameters
    ----------

    compared : list of WaicModelCompared
        List of compared models.

    info_path : InfoPath
        Determines the location of the output file.
    """

    info_path.set_codefile()

    for model in compared:
        psis_data = model.psis_data

        save_psis_pareto_k_plot_from_psis_data(
            psis_data=psis_data,
            name=model.name,
            pareto_k_plot_params=pareto_k_plot_params,
            info_path=info_path)
예제 #14
0
파일: waic.py 프로젝트: evgenyneu/tarpan
def save_compare_waic_txt(models,
                          lpd_column_name=LPD_COLUMN_NAME_DEFAULT,
                          info_path=InfoPath()):
    """
    Compare models using WAIC (Widely Aplicable Information Criterion)
    to see which models are more compatible with the data. The result
    is saved in a text file.

    Parameters
    ----------

    models : dict
        key: str
            Model name.
        value: cmdstanpy.stanfit.CmdStanMCMC
            Contains the samples from cmdstanpy to compare.

    lpd_column_name : str
        Prefix of the columns in Stan's output that contain log
        probability density value for each observation. For example,
        if lpd_column_name='possum', when output is expected to have
        columns 'possum.1', 'possum.2', ..., 'possum.33' given 33 observations.

    info_path : InfoPath
        Determines the location of the output file.
    """

    info_path.set_codefile()
    compared = compare_waic(models=models, lpd_column_name=lpd_column_name)
    save_compare_waic_txt_from_compared(compared=compared, info_path=info_path)
예제 #15
0
def save_histogram(samples, param_names=None,
                   info_path=InfoPath(),
                   histogram_params=HistogramParams(),
                   summary_params=SummaryParams()):
    """
    Make histograms for the parameters from posterior destribution.

    Parameters
    -----------

    samples : Panda's DataFrame

        Each column contains samples from posterior distribution.

    param_names : list of str

        Names of the parameters for plotting. If None, all will be plotted.
    """

    info_path.set_codefile()
    df_summary, table = sample_summary(df=samples)

    save_histogram_from_summary(samples, df_summary,
                                param_names=param_names,
                                info_path=info_path,
                                histogram_params=histogram_params,
                                summary_params=summary_params)
예제 #16
0
파일: kde.py 프로젝트: evgenyneu/tarpan
def save_scatter_and_kde(values,
                         uncertainties,
                         title=None,
                         xlabel=None,
                         ylabel=None,
                         info_path=InfoPath(),
                         scatter_kde_params=ScatterKdeParams(),
                         legend_labels=None,
                         plot_fn=None):
    """
    Create a scatter plot and a KDE plot under it.
    The plot is saved to a file.
    The KDE plot uses uncertainties of each individual observation.

    Parameters
    ----------
    values: list of lists
        List of values to plot. Supply more than one list to
        see distributions shown with different colors and markers.
    uncertainties: list of lists
        Uncertainties coresponding to the `values`.
    plot_fn: [function(fig, axes, params), params]
        function:
            A function that can be used to add extra information to the
            plot before it is saved.

            Parameters
            ----------

            fig: Matplotlib's figure object
            axes: list of Matplotlib's axes objects
            params: custom parameters that are passed to the function
        params: custom parameters that will be passed to the function

    Returns
    --------
    fig : Matplotlib's figure object
    axes : list of Matplotlib's axes
    """

    fig, axes = scatter_and_kde(values=values,
                                uncertainties=uncertainties,
                                title=title,
                                xlabel=xlabel,
                                ylabel=ylabel,
                                scatter_kde_params=scatter_kde_params,
                                legend_labels=legend_labels)

    if plot_fn is not None:
        plot_fn[0](fig, axes, params=plot_fn[1])

    info_path.set_codefile()
    info_path.base_name = info_path.base_name or "scatter_kde"
    info_path.extension = info_path.extension or 'pdf'
    plot_path = get_info_path(info_path)
    fig.savefig(plot_path, dpi=info_path.dpi)

    return fig, axes
예제 #17
0
def save_pair_plot(samples, param_names=None,
                   info_path=InfoPath(),
                   pair_plot_params=PairPlotParams()):
    """
    Make histograms for the parameters from posterior destribution.

    Parameters
    -----------

    samples : Panda's DataFrame

        Each column contains samples from posterior distribution.

    param_names : list of str

        Names of the parameters for plotting. If None, all will be plotted.
    """

    info_path = InfoPath(**info_path.__dict__)
    info_path.set_codefile()

    g = make_pair_plot(
        samples, param_names=param_names,
        pair_plot_params=pair_plot_params)

    info_path.base_name = info_path.base_name or "pair_plot"
    info_path.extension = info_path.extension or 'pdf'
    plot_path = get_info_path(info_path)
    g.savefig(plot_path, dpi=info_path.dpi)
예제 #18
0
def save_tree_plot(models,
                   extra_values=[],
                   param_names=None,
                   info_path=InfoPath(),
                   summary_params=SummaryParams(),
                   tree_params=TreePlotParams()):
    """
    Save a tree plot that summarises parameter distributions.
    Can compare summaries from multiple models, when multiple samples are
    supplied. One can also supply additional markers
    to be compared with using `extra_values` parameter.

    Parameters
    ----------

    models : list Panda's data frames

        List of data frames for each model, containg sample values for
        multiple parameters (one parameter is one data frame column).
        Supply multiple data frames to see their distribution summaries
        compared on the tree plot.

    extra_values : list of dict
        Additional markers to be shown on tree plot, without error bars:

        [
            {
                "mu": 2.3,
                "sigma": 3.3
            }
        ]

    param_names : list of str

        Names of parameters. Include all if None.

    info_path : InfoPath

        Path information for creating summaries.

    """

    info_path.set_codefile()
    summaries = []

    for samples in models:
        column_names = list(samples)
        param_names = filter_param_names(column_names, param_names)
        summary, _ = sample_summary(samples, params=summary_params)
        summaries.append(summary)

    for values in extra_values:
        summaries.append(summary_from_dict(values))

    make_comparative_tree_plot(summaries,
                               info_path=info_path,
                               tree_params=tree_params)
예제 #19
0
def save_psis_pareto_k_plot_from_psis_data(
    psis_data: PsisData,
    name,
    pareto_k_plot_params: ParetoKPlotParams = ParetoKPlotParams(),
    info_path=InfoPath()):
    """
    Make a plot that shows values of Pareto K index generated by PSIS
    method. This is used to see if there are values of K higher than 0.7,
    which means that there are possible outliers and PSIS calculations
    may not be reliable.

    Parameters
    ----------

    psis_data : PsisData
        Calculated PSIS values for the model.

    name : str
        Model name.

    info_path : InfoPath
        Determines the location of the output file.
    """

    info_path.set_codefile()
    info_path = InfoPath(**info_path.__dict__)

    fig, ax = psis_pareto_k_plot_from_psis_data(
        psis_data=psis_data,
        name=name,
        pareto_k_plot_params=pareto_k_plot_params)

    repalce_characters = [" ", "\\", "/", "?", "+", "*"]
    model_name_sanitised = name.lower()

    for character in repalce_characters:
        model_name_sanitised = model_name_sanitised.replace(character, "_")

    base_name = f'pareto_k_{model_name_sanitised}'
    info_path.base_name = info_path.base_name or base_name
    info_path.extension = info_path.extension or 'pdf'
    the_path = get_info_path(info_path)
    fig.savefig(the_path, dpi=info_path.dpi)
    plt.close(fig)
예제 #20
0
def get_fit2():
    """
    Returns fit file for unit tests, uses data with larger uncertainties.
    """

    info_path = InfoPath(path='temp_data',
                         dir_name="a02_gaussian_mixture2",
                         sub_dir_name=InfoPath.DO_NOT_CREATE)

    return run(info_path=info_path, func=run_model, data=get_data2())
예제 #21
0
def get_fit():
    """
    Returns fit file for unit tests.
    """

    info_path = InfoPath(path='temp_data',
                         dir_name="a01_eight_schools",
                         sub_dir_name=InfoPath.DO_NOT_CREATE)

    return run(info_path=info_path, func=run_model, data=get_data())
예제 #22
0
def make_tree_plot(df_summary,
                   param_names=None,
                   info_path=InfoPath(),
                   tree_params: TreePlotParams = TreePlotParams(),
                   summary_params=SummaryParams()):
    """
    Make tree plot of parameters.
    """

    info_path = InfoPath(**info_path.__dict__)
    tree_plot_data = extract_tree_plot_data(df_summary,
                                            param_names=param_names,
                                            summary_params=summary_params)

    fig, ax = tree_plot(tree_plot_data, params=tree_params)
    info_path.base_name = info_path.base_name or 'summary'
    info_path.extension = info_path.extension or 'pdf'
    the_path = get_info_path(info_path)
    fig.savefig(the_path, dpi=info_path.dpi)
    plt.close(fig)
예제 #23
0
def save_tree_plot(fits,
                   extra_values=[],
                   param_names=None,
                   info_path=InfoPath(),
                   summary_params=SummaryParams(),
                   tree_params=TreePlotParams()):
    """
    Save a tree plot that summarises parameter distributions.
    Can compare summaries from multiple models, when multiple fits are
    supplied. One can also supply additional markers
    to be compared with using `extra_values` parameter.

    Parameters
    ----------

    fits : list of cmdstanpy.stanfit.CmdStanMCMC

        Contains the samples from cmdstanpy.

    extra_values : list of dict
        Additional markers to be shown on tree plot, without error bars:

        [
            {
                "mu": 2.3,
                "sigma": 3.3
            }
        ]

    param_names : list of str

        Names of parameters. Include all if None.

    info_path : InfoPath

        Path information for creating summaries.

    """

    info_path.set_codefile()
    summaries = []

    for fit in fits:
        param_names = filter_param_names(fit.column_names, param_names)
        samples = fit.get_drawset(params=param_names)
        summary, _ = sample_summary(samples, params=summary_params)
        summaries.append(summary)

    for values in extra_values:
        summaries.append(summary_from_dict(values))

    make_comparative_tree_plot(summaries,
                               info_path=info_path,
                               tree_params=tree_params)
예제 #24
0
class AnalysisSettings:
    # Data for Stan model (dictionary)
    data = None

    csv_path: str = "data/time_series_19-covid-Confirmed.csv"

    # URL to the data
    data_url: str = "https://raw.githubusercontent.com/CSSEGISandData/\
COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/\
time_series_covid19_confirmed_global.csv"

    # Path to the .stan model file
    stan_model_path: str = "code/stan_model/logistic.stan"

    # Location of plots and summaries
    info_path: InfoPath = InfoPath()

    plots_dir: str = "plots"

    # Stan's sampling parameter
    max_treedepth: float = 10

    # Number of hours to wait before downloading the data from the Web
    max_hours_diff = 12

    # Width of HPDI (highest posterior density interval) that is used
    # to plot the shaded region around the predicted mean line.
    hpdi_width: float = 0.95

    # Maximum number of people that can be infected
    population_size: float = 2_900_000

    # Difference between the maximum number of confirmed cases
    # and the actual number of confirmed cases at which we consider
    # all people to be reported
    tolerance_cases = 1000

    marker_color: str = "#F4A92800"
    marker_edgecolor: str = "#F4A928"

    mu_line_color: str = "#28ADF4"
    mu_hpdi_color: str = "#6ef48688"
    cases_hpdi_color: str = "#e8e8f455"

    # Plot's background color
    background_color = '#023D45'

    marker: str = "o"

    grid_color: str = "#aaaaaa"

    grid_alpha: float = 0.2
예제 #25
0
파일: waic.py 프로젝트: evgenyneu/tarpan
def save_compare_waic_csv_from_compared(compared, info_path=InfoPath()):
    """
    Compare models using WAIC (Widely Aplicable Information Criterion)
    to see which models are more compatible with the data. The result
    is saved in a CSV file.

    Parameters
    ----------

    compared : list of WaicModelCompared
        List of compared models.
    info_path : InfoPath
        Determines the location of the output file.
    """

    info_path.set_codefile()
    info_path = InfoPath(**info_path.__dict__)
    info_path.base_name = info_path.base_name or "compare_waic"
    info_path.extension = 'csv'
    df = waic_compared_to_df(compared)
    path = get_info_path(info_path)
    df.to_csv(path, index_label='Name')
예제 #26
0
def get_fit1_divorse_age():
    info_path = InfoPath(path='temp_data',
                         dir_name="a05_divorse1_divorse_age",
                         sub_dir_name=InfoPath.DO_NOT_CREATE)

    iters = get_iters()
    data = get_data1_divorse_age()

    return run(info_path=info_path,
               func=run_model1_divorse_age,
               data=data,
               sampling_iters=iters["sampling_iters"],
               warmup_iters=iters["warmup_iters"])
예제 #27
0
def make_comparative_tree_plot(summaries,
                               param_names=None,
                               info_path=InfoPath(),
                               tree_params: TreePlotParams = TreePlotParams()):
    """
    Make tree plot that compares summaries of parameters
    """

    info_path = InfoPath(**info_path.__dict__)
    tree_plot_data = None

    for df_summary in summaries:
        tree_plot_data = extract_tree_plot_data(df_summary,
                                                param_names=param_names,
                                                groups=tree_plot_data)

    fig, ax = tree_plot(tree_plot_data, params=tree_params)
    info_path.base_name = info_path.base_name or 'summary'
    info_path.extension = info_path.extension or 'pdf'
    the_path = get_info_path(info_path)
    fig.savefig(the_path, dpi=info_path.dpi)
    plt.close(fig)
예제 #28
0
def save_histogram_from_summary(samples, summary, param_names=None,
                                info_path=InfoPath(),
                                histogram_params=HistogramParams(),
                                summary_params=SummaryParams()):
    """
    Make histograms for the parameters from posterior destribution.

    Parameters
    -----------

    samples : Panda's DataFrame

        Each column contains samples from posterior distribution.

    summary : Panda's DataFrame

        Summary information about each column.

    param_names : list of str

        Names of the parameters for plotting. If None, all will be plotted.
    """

    info_path = InfoPath(**info_path.__dict__)

    figures_and_axes = make_histograms(
        samples, summary, param_names=param_names,
        params=histogram_params,
        summary_params=summary_params)

    base_name = info_path.base_name or "histogram"
    info_path.extension = info_path.extension or 'pdf'

    for i, figure_and_axis in enumerate(figures_and_axes):
        info_path.base_name = f'{base_name}_{i + 1:02d}'
        plot_path = get_info_path(info_path)
        fig = figure_and_axis[0]
        fig.savefig(plot_path, dpi=info_path.dpi)
        plt.close(fig)
예제 #29
0
파일: waic.py 프로젝트: evgenyneu/tarpan
def save_compare_waic_tree_plot_from_compared(
    compared,
    tree_plot_params: TreePlotParams = TreePlotParams(),
    info_path=InfoPath()):
    """
    Make a plot that compares models using WAIC
    (Widely Aplicable Information Criterion) and save it to a file.

    Parameters
    ----------

    compared : list of WaicModelCompared
        List of compared models.

    lpd_column_name : str
        Prefix of the columns in Stan's output that contain log
        probability density value for each observation. For example,
        if lpd_column_name='possum', when output is expected to have
        columns 'possum.1', 'possum.2', ..., 'possum.33' given 33 observations.


    info_path : InfoPath
        Determines the location of the output file.
    """

    info_path.set_codefile()
    info_path = InfoPath(**info_path.__dict__)

    fig, ax = compare_waic_tree_plot_from_compared(
        compared=compared, tree_plot_params=tree_plot_params)

    info_path.base_name = info_path.base_name or 'compare_waic'
    info_path.extension = info_path.extension or 'pdf'
    the_path = get_info_path(info_path)
    fig.savefig(the_path, dpi=info_path.dpi)
    plt.close(fig)
예제 #30
0
def get_fit_larger_uncertainties():
    """
    Returns fit file for unit tests, uses data with larger uncertainties.
    """

    info_path = InfoPath(path='temp_data',
                         dir_name="a01_eight_schools_large_uncert",
                         sub_dir_name=InfoPath.DO_NOT_CREATE)

    run(info_path=info_path, func=run_model, data=get_data())

    # Use data with increased uncertainties
    data = get_data()
    data["sigma"] = [u * 2 for u in data["sigma"]]

    return run(info_path=info_path, func=run_model, data=data)