Example #1
0
def test_save_tree_plot():
    fit = get_fit()
    fit2 = get_fit_larger_uncertainties()
    fits = [fit, fit2]

    values_no_error_bars = [{
        "mu": 1.1,
        "tau": 3.1,
    }]

    tree_params = TreePlotParams()

    tree_params.labels = [
        "Normal",
        "Larger uncertainties",
        "Exact",
    ]

    outdir = "tarpan/cmdstanpy/model_info/tree_plot_test"

    if os.path.isdir(outdir):
        shutil.rmtree(outdir)

    save_tree_plot(fits,
                   extra_values=values_no_error_bars,
                   param_names=["mu", "tau"],
                   tree_params=tree_params)

    assert os.path.isfile(os.path.join(outdir, "summary.pdf"))
Example #2
0
def test_save_tree_plot():
    data = {"x": [1, 2, 3, 4, 5, 6], "y": [-1, -2, -3, -4, -5, -6]}

    df = pd.DataFrame(data)

    values_no_error_bars = [{
        "x": 1.1,
        "y": -3.1,
    }]

    tree_params = TreePlotParams()

    tree_params.labels = [
        "Normal",
        "Exact",
    ]

    outdir = "tarpan/shared/model_info/tree_plot_test"

    if os.path.isdir(outdir):
        shutil.rmtree(outdir)

    save_tree_plot([df],
                   extra_values=values_no_error_bars,
                   param_names=["mu", "tau"],
                   tree_params=tree_params)

    assert os.path.isfile(os.path.join(outdir, "summary.pdf"))
Example #3
0
def compare_waic_tree_plot(
    models,
    lpd_column_name=LPD_COLUMN_NAME_DEFAULT,
    tree_plot_params: TreePlotParams = TreePlotParams()):
    """
    Make a plot that compares models using WAIC
    (Widely Aplicable Information Criterion).

    Parameters
    ----------

    models : dict
        key: str
            Model name.
        value: cmdstanpy.stanfit.CmdStanMCMC
            Contains the samples from cmdstanpy to compare.

    lpd_column_name : str
        Prefix of the columns in Stan's output that contain log
        probability density value for each observation. For example,
        if lpd_column_name='possum', when output is expected to have
        columns 'possum.1', 'possum.2', ..., 'possum.33' given 33 observations.

    Returns
    -------
    (fig, ax):
        fig: Matplotlib's figure
        ax: Matplotlib's axis
    """

    compared = compare_waic(models=models, lpd_column_name=lpd_column_name)

    return compare_waic_tree_plot_from_compared(
        compared=compared, tree_plot_params=tree_plot_params)
Example #4
0
def save_compare_psis_tree_plot_from_compared(
    compared,
    tree_plot_params: TreePlotParams = TreePlotParams(),
    info_path=InfoPath()):
    """
    Make a plot that compares models using PSIS
    and save it to a file.

    Parameters
    ----------

    compared : list of WaicModelCompared
        List of compared models.

    info_path : InfoPath
        Determines the location of the output file.
    """

    info_path.set_codefile()
    info_path = InfoPath(**info_path.__dict__)

    fig, ax = compare_psis_tree_plot_from_compared(
        compared=compared, tree_plot_params=tree_plot_params)

    info_path.base_name = info_path.base_name or 'compare_psis'
    info_path.extension = info_path.extension or 'pdf'
    the_path = get_info_path(info_path)
    fig.savefig(the_path, dpi=info_path.dpi)
    plt.close(fig)
Example #5
0
def save_tree_plot(fits,
                   extra_values=[],
                   param_names=None,
                   info_path=InfoPath(),
                   summary_params=SummaryParams(),
                   tree_params=TreePlotParams()):
    """
    Save a tree plot that summarises parameter distributions.
    Can compare summaries from multiple models, when multiple fits are
    supplied. One can also supply additional markers
    to be compared with using `extra_values` parameter.

    Parameters
    ----------

    fits : list of cmdstanpy.stanfit.CmdStanMCMC

        Contains the samples from cmdstanpy.

    extra_values : list of dict
        Additional markers to be shown on tree plot, without error bars:

        [
            {
                "mu": 2.3,
                "sigma": 3.3
            }
        ]

    param_names : list of str

        Names of parameters. Include all if None.

    info_path : InfoPath

        Path information for creating summaries.

    """

    info_path.set_codefile()
    summaries = []

    for fit in fits:
        param_names = filter_param_names(fit.column_names, param_names)
        samples = fit.get_drawset(params=param_names)
        summary, _ = sample_summary(samples, params=summary_params)
        summaries.append(summary)

    for values in extra_values:
        summaries.append(summary_from_dict(values))

    make_comparative_tree_plot(summaries,
                               info_path=info_path,
                               tree_params=tree_params)
Example #6
0
def run_model():
    model = CmdStanModel(stan_file="eight_schools.stan")

    data = {
        "J": 8,
        "y": [28, 8, -3, 7, -1, 1, 18, 12],
        "sigma": [15, 10, 16, 11, 9, 11, 10, 18]
    }

    fit1 = model.sample(data=data,
                        chains=4,
                        cores=4,
                        seed=1,
                        sampling_iters=1000,
                        warmup_iters=1000)

    # Increase the uncertainties
    data["sigma"] = [i * 2 for i in data["sigma"]]

    fit2 = model.sample(data=data,
                        chains=4,
                        cores=4,
                        seed=1,
                        sampling_iters=1000,
                        warmup_iters=1000)

    # Show extra markers in tree plot for comparison (optional)
    extra_values = [{
        "mu": 2.2,
        "tau": 1.3,
    }]

    # Supply legend labels (optional)
    tree_params = TreePlotParams()
    tree_params.labels = ["Model 1", "Model 2", "Exact"]

    save_tree_plot(fits=[fit1, fit2],
                   extra_values=extra_values,
                   param_names=['mu', 'tau'],
                   tree_params=tree_params)
Example #7
0
def save_compare_waic_tree_plot(
    models,
    lpd_column_name=LPD_COLUMN_NAME_DEFAULT,
    tree_plot_params: TreePlotParams = TreePlotParams(),
    info_path=InfoPath()):
    """
    Make a plot that compares models using WAIC
    (Widely Aplicable Information Criterion) and save it to a file.

    Parameters
    ----------

    models : list of dict
        List of model samples from cmdstanpy to compare.

        The dictionary has keys:
            name: str
                Model name
            fit: cmdstanpy.stanfit.CmdStanMCMC
                Contains the samples from cmdstanpy.

    lpd_column_name : str
        Prefix of the columns in Stan's output that contain log
        probability density value for each observation. For example,
        if lpd_column_name='possum', when output is expected to have
        columns 'possum.1', 'possum.2', ..., 'possum.33' given 33 observations.


    info_path : InfoPath
        Determines the location of the output file.
    """

    info_path.set_codefile()
    compared = compare_waic(models=models, lpd_column_name=lpd_column_name)

    save_compare_waic_tree_plot_from_compared(
        compared=compared,
        tree_plot_params=tree_plot_params,
        info_path=info_path)
Example #8
0
def save_compare_waic_tree_plot_from_compared(
    compared,
    tree_plot_params: TreePlotParams = TreePlotParams(),
    info_path=InfoPath()):
    """
    Make a plot that compares models using WAIC
    (Widely Aplicable Information Criterion) and save it to a file.

    Parameters
    ----------

    compared : list of WaicModelCompared
        List of compared models.

    lpd_column_name : str
        Prefix of the columns in Stan's output that contain log
        probability density value for each observation. For example,
        if lpd_column_name='possum', when output is expected to have
        columns 'possum.1', 'possum.2', ..., 'possum.33' given 33 observations.


    info_path : InfoPath
        Determines the location of the output file.
    """

    info_path.set_codefile()
    info_path = InfoPath(**info_path.__dict__)

    fig, ax = compare_waic_tree_plot_from_compared(
        compared=compared, tree_plot_params=tree_plot_params)

    info_path.base_name = info_path.base_name or 'compare_waic'
    info_path.extension = info_path.extension or 'pdf'
    the_path = get_info_path(info_path)
    fig.savefig(the_path, dpi=info_path.dpi)
    plt.close(fig)
Example #9
0
def compare_waic_tree_plot_from_compared(
    compared, tree_plot_params: TreePlotParams = TreePlotParams()):
    """
    Make a plot that compares models using WAIC
    (Widely Aplicable Information Criterion).

    Parameters
    ----------

    compared : list of WaicModelCompared
        List of compared models.

    Returns
    -------
    (fig, ax):
        fig: Matplotlib's figure
        ax: Matplotlib's axis
    """

    plot_groups = []
    tree_plot_params = TreePlotParams(**tree_plot_params.__dict__)

    if tree_plot_params.labels is None:
        tree_plot_params.labels = ["dWAIC", "WAIC"]

    if tree_plot_params.xlabel is None:
        tree_plot_params.xlabel = "WAIC (deviance)"

    if tree_plot_params.title is None:
        tree_plot_params.title = "Model comparison (smaller is better)"

    for model in reversed(compared):
        values = []
        waic = model.waic_data
        group = dict(name=model.name, values=values)
        plot_groups.append(group)

        # WAIC difference
        # --------

        value = dict(value=waic.waic, error_bars=[])

        if model.waic_difference_best_std_err is not None:
            error_bars = [
                waic.waic - model.waic_difference_best_std_err,
                waic.waic + model.waic_difference_best_std_err
            ]

            value = dict(value=waic.waic, error_bars=[error_bars])

        values.append(value)

        # WAIC value
        # --------

        error_bars = [
            waic.waic - waic.waic_std_err, waic.waic + waic.waic_std_err
        ]

        value = dict(value=waic.waic, error_bars=[error_bars])

        values.append(value)

    tree_plot_params.draw_zero_line_if_in_range = False
    fig, ax = tree_plot(groups=plot_groups, params=tree_plot_params)

    # Draw a vertical line through the best model
    # ----------

    model = compared[0]

    ax.axvline(x=model.waic_data.waic,
               linestyle='dashed',
               color=tree_plot_params.marker_edge_colors[0])

    return fig, ax
Example #10
0
def save_compare(
        models, lpd_column_name=LPD_COLUMN_NAME_DEFAULT,
        tree_plot_params: TreePlotParams = TreePlotParams(),
        info_path=InfoPath(),
        pareto_k_plot_params: ParetoKPlotParams = ParetoKPlotParams()):
    """
    Compare multiple models using WAIC (Widely Aplicable Information Criterion)
    and PSIS (Pareto-smoothed importance sampling) methods. Saves the analysis
    data and plots.

    Parameters
    ----------

    models : list of dict
        List of model samples from cmdstanpy to compare.

        The dictionary has keys:
            name: str
                Model name
            fit: cmdstanpy.stanfit.CmdStanMCMC
                Contains the samples from cmdstanpy.

    lpd_column_name : str
        Prefix of the columns in Stan's output that contain log
        probability density value for each observation. For example,
        if lpd_column_name='possum', when output is expected to have
        columns 'possum.1', 'possum.2', ..., 'possum.33' given 33 observations.


    info_path : InfoPath
        Determines the location of the output files.
    """

    info_path.set_codefile()

    # Compare with WAIC
    # --------

    compared = compare_waic(models=models, lpd_column_name=lpd_column_name)
    save_compare_waic_csv_from_compared(compared=compared, info_path=info_path)
    save_compare_waic_txt_from_compared(compared=compared, info_path=info_path)

    save_compare_waic_tree_plot_from_compared(
        compared=compared,
        tree_plot_params=tree_plot_params,
        info_path=info_path)

    # Compare with PSIS
    # --------

    compared = compare_psis(models=models, lpd_column_name=lpd_column_name)
    save_compare_psis_csv_from_compared(compared=compared, info_path=info_path)
    save_compare_psis_txt_from_compared(compared=compared, info_path=info_path)

    save_compare_psis_tree_plot_from_compared(
        compared=compared,
        tree_plot_params=tree_plot_params,
        info_path=info_path)

    save_psis_pareto_k_plot_from_compared(
        compared=compared,
        info_path=info_path,
        pareto_k_plot_params=pareto_k_plot_params)
Example #11
0
def compare_psis_tree_plot_from_compared(
    compared, tree_plot_params: TreePlotParams = TreePlotParams()):
    """
    Make a plot that compares models using PSIS.

    Parameters
    ----------

    compared : list of WaicModelCompared
        List of compared models.

    Returns
    -------
    (fig, ax):
        fig: Matplotlib's figure
        ax: Matplotlib's axis
    """

    plot_groups = []
    tree_plot_params = TreePlotParams(**tree_plot_params.__dict__)

    if tree_plot_params.labels is None:
        tree_plot_params.labels = ["dPSIS", "PSIS"]

    if tree_plot_params.xlabel is None:
        tree_plot_params.xlabel = "PSIS (deviance)"

    if tree_plot_params.title is None:
        tree_plot_params.title = "Model comparison (smaller is better)"

    for model in reversed(compared):
        values = []
        psis = model.psis_data
        group = dict(name=model.name, values=values)
        plot_groups.append(group)

        # PSIS difference
        # --------

        value = dict(value=psis.psis, error_bars=[])

        if model.psis_difference_best_std_err is not None:
            error_bars = [
                psis.psis - model.psis_difference_best_std_err,
                psis.psis + model.psis_difference_best_std_err
            ]

            value = dict(value=psis.psis, error_bars=[error_bars])

        values.append(value)

        # PSIS value
        # --------

        error_bars = [
            psis.psis - psis.psis_std_err, psis.psis + psis.psis_std_err
        ]

        value = dict(value=psis.psis, error_bars=[error_bars])

        values.append(value)

    tree_plot_params.draw_zero_line_if_in_range = False
    fig, ax = tree_plot(groups=plot_groups, params=tree_plot_params)

    # Draw a vertical line through the best model
    # ----------

    model = compared[0]

    ax.axvline(x=model.psis_data.psis,
               linestyle='dashed',
               color=tree_plot_params.marker_edge_colors[0])

    return fig, ax