def save_summary_to_disk(df_summary, txt_summary, info_path=InfoPath()): """ Saves statistical summary of the samples using mean, std, mode, hpdi. Parameters ---------- df_summary : cmdstanpy.stanfit.CmdStanMCMC Panda's dataframe containing the summary for all parameters. txt_summary : list of str Text of the summary table. info_path : InfoPath Path information for creating summaries. Returns -------- Dict: "df" : Dataframe containing summary. "table" : text version of the summary. "path_txt": Path to txt summary file. "path_csv": Path to csv summary file. """ info_path = InfoPath(**info_path.__dict__) info_path.base_name = info_path.base_name or "summary" info_path.extension = 'txt' path_to_summary_txt = get_info_path(info_path) info_path.extension = 'csv' path_to_summary_csv = get_info_path(info_path) with open(path_to_summary_txt, "w") as text_file: print(txt_summary, file=text_file) df_summary.to_csv(path_to_summary_csv, index_label='Name') return { "df": df_summary, "table": txt_summary, "path_txt": path_to_summary_txt, "path_csv": path_to_summary_csv, }
def save_pair_plot(samples, param_names=None, info_path=InfoPath(), pair_plot_params=PairPlotParams()): """ Make histograms for the parameters from posterior destribution. Parameters ----------- samples : Panda's DataFrame Each column contains samples from posterior distribution. param_names : list of str Names of the parameters for plotting. If None, all will be plotted. """ info_path = InfoPath(**info_path.__dict__) info_path.set_codefile() g = make_pair_plot( samples, param_names=param_names, pair_plot_params=pair_plot_params) info_path.base_name = info_path.base_name or "pair_plot" info_path.extension = info_path.extension or 'pdf' plot_path = get_info_path(info_path) g.savefig(plot_path, dpi=info_path.dpi)
def save_compare_waic_txt_from_compared(compared, info_path=InfoPath()): """ Compare models using WAIC (Widely Aplicable Information Criterion) to see which models are more compatible with the data. The result is saved in a text file. Parameters ---------- compared : list of WaicModelCompared List of compared models. info_path : InfoPath Determines the location of the output file. """ info_path.set_codefile() info_path = InfoPath(**info_path.__dict__) info_path.base_name = info_path.base_name or "compare_waic" info_path.extension = 'txt' df = waic_compared_to_df(compared) table = tabulate(df, headers=list(df), floatfmt=".2f", tablefmt="pipe") path = get_info_path(info_path) with open(path, "w") as text_file: print(table, file=text_file)
def save_traceplot(fit, param_names=None, info_path=InfoPath(), traceplot_params=TraceplotParams()): """ Saves traceplots form the fit. Parameters ---------- fit : cmdstanpy.stanfit.CmdStanMCMC Contains the samples from cmdstanpy. param_names : list of str Names of parameters to plot. """ info_path.set_codefile() info_path = InfoPath(**info_path.__dict__) figures_and_axes = traceplot( fit, param_names=param_names, params=traceplot_params) base_name = info_path.base_name or "traceplot" info_path.extension = info_path.extension or 'pdf' for i, figure_and_axis in enumerate(figures_and_axes): info_path.base_name = f'{base_name}_{i + 1:02d}' plot_path = get_info_path(info_path) fig = figure_and_axis[0] fig.savefig(plot_path, dpi=info_path.dpi) plt.close(fig)
def save_compare_psis_tree_plot_from_compared( compared, tree_plot_params: TreePlotParams = TreePlotParams(), info_path=InfoPath()): """ Make a plot that compares models using PSIS and save it to a file. Parameters ---------- compared : list of WaicModelCompared List of compared models. info_path : InfoPath Determines the location of the output file. """ info_path.set_codefile() info_path = InfoPath(**info_path.__dict__) fig, ax = compare_psis_tree_plot_from_compared( compared=compared, tree_plot_params=tree_plot_params) info_path.base_name = info_path.base_name or 'compare_psis' info_path.extension = info_path.extension or 'pdf' the_path = get_info_path(info_path) fig.savefig(the_path, dpi=info_path.dpi) plt.close(fig)
def test_get_info_path(): info_path = InfoPath() info_path.base_name = 'my_basename' info_path.extension = 'test_extension' result = get_info_path(info_path) assert 'tarpan/shared/model_info/info_path_test/my_basename\ .test_extension' in result
def save_scatter_and_kde(values, uncertainties, title=None, xlabel=None, ylabel=None, info_path=InfoPath(), scatter_kde_params=ScatterKdeParams(), legend_labels=None, plot_fn=None): """ Create a scatter plot and a KDE plot under it. The plot is saved to a file. The KDE plot uses uncertainties of each individual observation. Parameters ---------- values: list of lists List of values to plot. Supply more than one list to see distributions shown with different colors and markers. uncertainties: list of lists Uncertainties coresponding to the `values`. plot_fn: [function(fig, axes, params), params] function: A function that can be used to add extra information to the plot before it is saved. Parameters ---------- fig: Matplotlib's figure object axes: list of Matplotlib's axes objects params: custom parameters that are passed to the function params: custom parameters that will be passed to the function Returns -------- fig : Matplotlib's figure object axes : list of Matplotlib's axes """ fig, axes = scatter_and_kde(values=values, uncertainties=uncertainties, title=title, xlabel=xlabel, ylabel=ylabel, scatter_kde_params=scatter_kde_params, legend_labels=legend_labels) if plot_fn is not None: plot_fn[0](fig, axes, params=plot_fn[1]) info_path.set_codefile() info_path.base_name = info_path.base_name or "scatter_kde" info_path.extension = info_path.extension or 'pdf' plot_path = get_info_path(info_path) fig.savefig(plot_path, dpi=info_path.dpi) return fig, axes
def save_compare_parameters( models, labels, extra_values=[], type: CompareParametersType = CompareParametersType.TEXT, param_names=None, info_path=InfoPath(), summary_params=SummaryParams()): """ Saves a text table that compares model parameters Parameters ---------- models : list Panda's data frames List of data frames for each model, containg sample values for multiple parameters (one parameter is one data frame column). Supply multiple data frames to compare parameter distributions. labels : list of str Names of the models in `models` list. extra_values : list of dict Additional values to be shown in the table: [ { "mu": 2.3, "sigma": 3.3 } ] type : CompareParametersType Format of values in the text table. param_names : list of str Names of parameters. Include all if None. info_path : InfoPath Path information for creating summaries. """ info_path.set_codefile() df, table = compare_parameters(models=models, labels=labels, extra_values=extra_values, type=type, param_names=param_names, summary_params=summary_params) info_path = InfoPath(**info_path.__dict__) info_path.base_name = info_path.base_name or "parameters_compared" info_path.extension = 'txt' path_to_txt = get_info_path(info_path) with open(path_to_txt, "w") as text_file: print(table, file=text_file)
def save_diagnostic(fit, info_path=InfoPath()): """ Save diagnostic information from the fit into a text file. """ info_path = InfoPath(**info_path.__dict__) info_path.base_name = info_path.base_name or 'diagnostic' info_path.extension = 'txt' file_path = get_info_path(info_path) with open(file_path, "w") as text_file: print(fit.diagnose(), file=text_file)
def save_psis_pareto_k_plot_from_psis_data( psis_data: PsisData, name, pareto_k_plot_params: ParetoKPlotParams = ParetoKPlotParams(), info_path=InfoPath()): """ Make a plot that shows values of Pareto K index generated by PSIS method. This is used to see if there are values of K higher than 0.7, which means that there are possible outliers and PSIS calculations may not be reliable. Parameters ---------- psis_data : PsisData Calculated PSIS values for the model. name : str Model name. info_path : InfoPath Determines the location of the output file. """ info_path.set_codefile() info_path = InfoPath(**info_path.__dict__) fig, ax = psis_pareto_k_plot_from_psis_data( psis_data=psis_data, name=name, pareto_k_plot_params=pareto_k_plot_params) repalce_characters = [" ", "\\", "/", "?", "+", "*"] model_name_sanitised = name.lower() for character in repalce_characters: model_name_sanitised = model_name_sanitised.replace(character, "_") base_name = f'pareto_k_{model_name_sanitised}' info_path.base_name = info_path.base_name or base_name info_path.extension = info_path.extension or 'pdf' the_path = get_info_path(info_path) fig.savefig(the_path, dpi=info_path.dpi) plt.close(fig)
def make_tree_plot(df_summary, param_names=None, info_path=InfoPath(), tree_params: TreePlotParams = TreePlotParams(), summary_params=SummaryParams()): """ Make tree plot of parameters. """ info_path = InfoPath(**info_path.__dict__) tree_plot_data = extract_tree_plot_data(df_summary, param_names=param_names, summary_params=summary_params) fig, ax = tree_plot(tree_plot_data, params=tree_params) info_path.base_name = info_path.base_name or 'summary' info_path.extension = info_path.extension or 'pdf' the_path = get_info_path(info_path) fig.savefig(the_path, dpi=info_path.dpi) plt.close(fig)
def save_compare_waic_csv_from_compared(compared, info_path=InfoPath()): """ Compare models using WAIC (Widely Aplicable Information Criterion) to see which models are more compatible with the data. The result is saved in a CSV file. Parameters ---------- compared : list of WaicModelCompared List of compared models. info_path : InfoPath Determines the location of the output file. """ info_path.set_codefile() info_path = InfoPath(**info_path.__dict__) info_path.base_name = info_path.base_name or "compare_waic" info_path.extension = 'csv' df = waic_compared_to_df(compared) path = get_info_path(info_path) df.to_csv(path, index_label='Name')
def save_histogram_from_summary(samples, summary, param_names=None, info_path=InfoPath(), histogram_params=HistogramParams(), summary_params=SummaryParams()): """ Make histograms for the parameters from posterior destribution. Parameters ----------- samples : Panda's DataFrame Each column contains samples from posterior distribution. summary : Panda's DataFrame Summary information about each column. param_names : list of str Names of the parameters for plotting. If None, all will be plotted. """ info_path = InfoPath(**info_path.__dict__) figures_and_axes = make_histograms( samples, summary, param_names=param_names, params=histogram_params, summary_params=summary_params) base_name = info_path.base_name or "histogram" info_path.extension = info_path.extension or 'pdf' for i, figure_and_axis in enumerate(figures_and_axes): info_path.base_name = f'{base_name}_{i + 1:02d}' plot_path = get_info_path(info_path) fig = figure_and_axis[0] fig.savefig(plot_path, dpi=info_path.dpi) plt.close(fig)
def make_comparative_tree_plot(summaries, param_names=None, info_path=InfoPath(), tree_params: TreePlotParams = TreePlotParams()): """ Make tree plot that compares summaries of parameters """ info_path = InfoPath(**info_path.__dict__) tree_plot_data = None for df_summary in summaries: tree_plot_data = extract_tree_plot_data(df_summary, param_names=param_names, groups=tree_plot_data) fig, ax = tree_plot(tree_plot_data, params=tree_params) info_path.base_name = info_path.base_name or 'summary' info_path.extension = info_path.extension or 'pdf' the_path = get_info_path(info_path) fig.savefig(the_path, dpi=info_path.dpi) plt.close(fig)
def save_compare_waic_tree_plot_from_compared( compared, tree_plot_params: TreePlotParams = TreePlotParams(), info_path=InfoPath()): """ Make a plot that compares models using WAIC (Widely Aplicable Information Criterion) and save it to a file. Parameters ---------- compared : list of WaicModelCompared List of compared models. lpd_column_name : str Prefix of the columns in Stan's output that contain log probability density value for each observation. For example, if lpd_column_name='possum', when output is expected to have columns 'possum.1', 'possum.2', ..., 'possum.33' given 33 observations. info_path : InfoPath Determines the location of the output file. """ info_path.set_codefile() info_path = InfoPath(**info_path.__dict__) fig, ax = compare_waic_tree_plot_from_compared( compared=compared, tree_plot_params=tree_plot_params) info_path.base_name = info_path.base_name or 'compare_waic' info_path.extension = info_path.extension or 'pdf' the_path = get_info_path(info_path) fig.savefig(the_path, dpi=info_path.dpi) plt.close(fig)