Example #1
0
def test_save_tree_plot():
    data = {"x": [1, 2, 3, 4, 5, 6], "y": [-1, -2, -3, -4, -5, -6]}

    df = pd.DataFrame(data)

    values_no_error_bars = [{
        "x": 1.1,
        "y": -3.1,
    }]

    tree_params = TreePlotParams()

    tree_params.labels = [
        "Normal",
        "Exact",
    ]

    outdir = "tarpan/shared/model_info/tree_plot_test"

    if os.path.isdir(outdir):
        shutil.rmtree(outdir)

    save_tree_plot([df],
                   extra_values=values_no_error_bars,
                   param_names=["mu", "tau"],
                   tree_params=tree_params)

    assert os.path.isfile(os.path.join(outdir, "summary.pdf"))
Example #2
0
def test_save_tree_plot():
    fit = get_fit()
    fit2 = get_fit_larger_uncertainties()
    fits = [fit, fit2]

    values_no_error_bars = [{
        "mu": 1.1,
        "tau": 3.1,
    }]

    tree_params = TreePlotParams()

    tree_params.labels = [
        "Normal",
        "Larger uncertainties",
        "Exact",
    ]

    outdir = "tarpan/cmdstanpy/model_info/tree_plot_test"

    if os.path.isdir(outdir):
        shutil.rmtree(outdir)

    save_tree_plot(fits,
                   extra_values=values_no_error_bars,
                   param_names=["mu", "tau"],
                   tree_params=tree_params)

    assert os.path.isfile(os.path.join(outdir, "summary.pdf"))
Example #3
0
def run_model():
    model = CmdStanModel(stan_file="eight_schools.stan")

    data = {
        "J": 8,
        "y": [28, 8, -3, 7, -1, 1, 18, 12],
        "sigma": [15, 10, 16, 11, 9, 11, 10, 18]
    }

    fit1 = model.sample(data=data,
                        chains=4,
                        cores=4,
                        seed=1,
                        sampling_iters=1000,
                        warmup_iters=1000)

    # Increase the uncertainties
    data["sigma"] = [i * 2 for i in data["sigma"]]

    fit2 = model.sample(data=data,
                        chains=4,
                        cores=4,
                        seed=1,
                        sampling_iters=1000,
                        warmup_iters=1000)

    # Show extra markers in tree plot for comparison (optional)
    extra_values = [{
        "mu": 2.2,
        "tau": 1.3,
    }]

    # Supply legend labels (optional)
    tree_params = TreePlotParams()
    tree_params.labels = ["Model 1", "Model 2", "Exact"]

    save_tree_plot(fits=[fit1, fit2],
                   extra_values=extra_values,
                   param_names=['mu', 'tau'],
                   tree_params=tree_params)
Example #4
0
def compare_waic_tree_plot_from_compared(
    compared, tree_plot_params: TreePlotParams = TreePlotParams()):
    """
    Make a plot that compares models using WAIC
    (Widely Aplicable Information Criterion).

    Parameters
    ----------

    compared : list of WaicModelCompared
        List of compared models.

    Returns
    -------
    (fig, ax):
        fig: Matplotlib's figure
        ax: Matplotlib's axis
    """

    plot_groups = []
    tree_plot_params = TreePlotParams(**tree_plot_params.__dict__)

    if tree_plot_params.labels is None:
        tree_plot_params.labels = ["dWAIC", "WAIC"]

    if tree_plot_params.xlabel is None:
        tree_plot_params.xlabel = "WAIC (deviance)"

    if tree_plot_params.title is None:
        tree_plot_params.title = "Model comparison (smaller is better)"

    for model in reversed(compared):
        values = []
        waic = model.waic_data
        group = dict(name=model.name, values=values)
        plot_groups.append(group)

        # WAIC difference
        # --------

        value = dict(value=waic.waic, error_bars=[])

        if model.waic_difference_best_std_err is not None:
            error_bars = [
                waic.waic - model.waic_difference_best_std_err,
                waic.waic + model.waic_difference_best_std_err
            ]

            value = dict(value=waic.waic, error_bars=[error_bars])

        values.append(value)

        # WAIC value
        # --------

        error_bars = [
            waic.waic - waic.waic_std_err, waic.waic + waic.waic_std_err
        ]

        value = dict(value=waic.waic, error_bars=[error_bars])

        values.append(value)

    tree_plot_params.draw_zero_line_if_in_range = False
    fig, ax = tree_plot(groups=plot_groups, params=tree_plot_params)

    # Draw a vertical line through the best model
    # ----------

    model = compared[0]

    ax.axvline(x=model.waic_data.waic,
               linestyle='dashed',
               color=tree_plot_params.marker_edge_colors[0])

    return fig, ax
Example #5
0
def compare_psis_tree_plot_from_compared(
    compared, tree_plot_params: TreePlotParams = TreePlotParams()):
    """
    Make a plot that compares models using PSIS.

    Parameters
    ----------

    compared : list of WaicModelCompared
        List of compared models.

    Returns
    -------
    (fig, ax):
        fig: Matplotlib's figure
        ax: Matplotlib's axis
    """

    plot_groups = []
    tree_plot_params = TreePlotParams(**tree_plot_params.__dict__)

    if tree_plot_params.labels is None:
        tree_plot_params.labels = ["dPSIS", "PSIS"]

    if tree_plot_params.xlabel is None:
        tree_plot_params.xlabel = "PSIS (deviance)"

    if tree_plot_params.title is None:
        tree_plot_params.title = "Model comparison (smaller is better)"

    for model in reversed(compared):
        values = []
        psis = model.psis_data
        group = dict(name=model.name, values=values)
        plot_groups.append(group)

        # PSIS difference
        # --------

        value = dict(value=psis.psis, error_bars=[])

        if model.psis_difference_best_std_err is not None:
            error_bars = [
                psis.psis - model.psis_difference_best_std_err,
                psis.psis + model.psis_difference_best_std_err
            ]

            value = dict(value=psis.psis, error_bars=[error_bars])

        values.append(value)

        # PSIS value
        # --------

        error_bars = [
            psis.psis - psis.psis_std_err, psis.psis + psis.psis_std_err
        ]

        value = dict(value=psis.psis, error_bars=[error_bars])

        values.append(value)

    tree_plot_params.draw_zero_line_if_in_range = False
    fig, ax = tree_plot(groups=plot_groups, params=tree_plot_params)

    # Draw a vertical line through the best model
    # ----------

    model = compared[0]

    ax.axvline(x=model.psis_data.psis,
               linestyle='dashed',
               color=tree_plot_params.marker_edge_colors[0])

    return fig, ax