def summarize_psd(run, time_dir): ''' Returns a DataFrame with the median and credible intervals for one time. Credible intervals are calculated using pymc3's highest posterior density (HPD) function, where alpha is the desired probability of type I error (so, 1 - C.I.). Uses the same MultiIndex as import_time(). Input ----- time_dir : relative path to the time directory ''' # Import time data time_data = import_time(run, time_dir) # Grab MultiIndex midx = time_data.index # Calculate HPDs time_data_np = time_data.to_numpy().T hpd_50 = hpd(time_data_np, alpha=0.5) hpd_90 = hpd(time_data_np, alpha=0.1) # Return summary DataFrame return pd.DataFrame({ 'MEDIAN' : time_data.median(axis=1), 'CI_50_LO' : pd.Series(hpd_50[:,0], index=midx), 'CI_50_HI' : pd.Series(hpd_50[:,1], index=midx), 'CI_90_LO' : pd.Series(hpd_90[:,0], index=midx), 'CI_90_HI' : pd.Series(hpd_90[:,1], index=midx), }, index=midx)
def summarize_linechain(run, time_dir, channel, time_counts, log): ''' Returns DataFrame of percentile values for each parameter. Input ----- time_dir : string, time directory channel : str, channel name time_counts : histogram of the number of times each model was chosen for this time and channel log : utils.Log object ''' time = run.get_time(time_dir) ch_idx = run.get_channel_index(channel) log.log(f'\n-- {time} CHANNEL {channel} --') # Import linechain lc_file = f'{time_dir}linechain_channel{ch_idx}.dat' # Get preferred model model = time_counts.argmax() log.log('Line model histogram:') log.log(np.array2string(time_counts, max_line_width=80)) log.log(f'{model} spectral lines found.') # Initialize summary DataFrame cols = pd.Series( ['MEDIAN', 'CI_50_LO', 'CI_50_HI', 'CI_90_LO', 'CI_90_HI']) parameters = ['FREQ', 'AMP', 'QF'] summary = pd.DataFrame([], columns=cols) if model > 0: params = import_linechain(lc_file, model) # Line model model = params.shape[1] # Sort if model > 1: params = sort_params(params, log) # HPD median = np.median(params, axis=0).flatten()[:, np.newaxis] hpd_50 = np.vstack(hpd(params, alpha=0.5)) hpd_90 = np.vstack(hpd(params, alpha=0.1)) stats = np.hstack([median, hpd_50, hpd_90]) midx = pd.MultiIndex.from_product( [[channel], [time], list(range(model)), parameters], names=['CHANNEL', 'TIME', 'LINE', 'PARAMETER']) summary = pd.DataFrame(stats, columns=cols, index=midx) log.log('Line parameter summary:') log.log(summary.to_string(max_cols=80)) return summary
def mcmc_prior_proposal(n_models, calc_posterior, guess_params, guess_sd): swap_freq = 0.0 n_iter = 50000 tune_freq = 100 tune_for = 10000 for i in range(n_models): initial_values = guess_params[i, :] initial_sd = guess_sd[i, :] trace, logp_trace, attempt_matrix, acceptance_matrix, prop_sd, accept_vector = RJMC_outerloop( calc_posterior, n_iter, initial_values, initial_sd, n_models, swap_freq, tune_freq, tune_for, 1, 1, 1, 1) trace_tuned = trace[tune_for:] max_ap = np.empty(np.size(trace_tuned, 1)) map_CI = np.zeros((np.size(trace_tuned, 1), 2)) parameter_prior_proposal = np.empty((n_models, np.size(trace_tuned, 1), 2)) for j in range(np.size(trace_tuned, 2)): bins, values = np.histogram(trace_tuned[:, j], bins=100) max_ap[j] = (values[np.argmax(bins) + 1] + values[np.argmax(bins)]) / 2 map_CI[j] = hpd(trace_tuned[:, i], alpha=0.05) sigma_hat = map_CI[j, 1] - map_CI[j, 0] / (2 * 1.96) parameter_prior_proposal[j, i] = [max_ap, sigma_hat * 1.5] support = np.linspace(np.min(trace_tuned[:, j]), np.max(trace_tuned[:, j]), 100) plt.hist(trace_tuned[:, j], density=True) plt.plot(support, norm(support, *parameter_prior_proposal)) return parameter_prior_proposal
def calc_credible_intervals(trace_arr, trace_dict, alpha=.05): from pymc3.stats import hpd hpd_dict = {} for i, key in enumerate(trace_dict['param_list']): hpd_dict[key] = hpd(trace_arr[:,i], alpha) return hpd_dict
def display_hpd(): hpd_intervals = hpd(trace_values, alpha=alpha_level) ax.plot(hpd_intervals, (plot_height * 0.02, plot_height * 0.02), linewidth=4, color='k') text_props = dict(size=16, horizontalalignment='center') ax.text(hpd_intervals[0], plot_height * 0.07, hpd_intervals[0].round(round_to), **text_props) ax.text(hpd_intervals[1], plot_height * 0.07, hpd_intervals[1].round(round_to), **text_props) ax.text((hpd_intervals[0] + hpd_intervals[1]) / 2, plot_height * 0.2, format_as_percent(1 - alpha_level) + ' HPD', **text_props)
def version_pymc3(): print("(a) -- independent θ, not modelling exposure") with pm.Model(): θ = pm.Gamma("θ", alpha=1, beta=1) pm.Poisson("n_a", mu=θ, observed=observed_n_a) trace = pm.sample(5000, tune=1000, cores=4) print("(a) 95% credible interval for θ:") print(pms.hpd(trace["θ"])) print("(b) -- independent θ, modelling exposure") with pm.Model(): θ = pm.Gamma("θ", alpha=1, beta=1) pm.Poisson("n_a", mu=miles_e8_estimate * θ, observed=observed_n_a) trace = pm.sample(5000, tune=1000, cores=4) print("(b) 95% credible interval for θ:") print(pms.hpd(trace["θ"])) pm.traceplot(trace) plt.show()
def calc_hdi(hist, bins, size, random_state=None): """ Estimate HDI from histogram of probability. Data is sampled with the given probability and HDI is calculated with pymc3.stats.hpd Parameters ---------- hist : list of float Histogram of probability. bins : list of float Bins of the histogram. size : int Number of samples to generate. random_state : int, RandomState instance or None, optional (default=None) If int, random_state is the seed used by the random number generator; If RandomState instance, random_state is the random number generator; If None, the random number generator is the RandomState instance used by `np.random`. Returns ------- hdi_x : [float, float] HDI range. hdi_y : float Probability value corresponding to the HDI. """ rng = sklearn.utils.check_random_state(random_state) samples = rng.choice(bins, size=size, p=hist) hdi_x = [float(x + np.diff(bins[:2])) for x in hpd(samples)] hdi_y = np.mean([ hist[np.argwhere(np.abs(bins - x) < 1e-6).ravel()[0] + 1] for x in hdi_x ]) return hdi_x, hdi_y
def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.05, quartiles=True, rhat=True, main=None, xtitle=None, xlim=None, ylabels=None, chain_spacing=0.05, vline=0, gs=None, plot_transformed=False): """ Forest plot (model summary plot). Generates a "forest plot" of 100*(1-alpha)% credible intervals for either the set of variables in a given model, or a specified set of nodes. Parameters ---------- trace_obj: NpTrace or MultiTrace object Trace(s) from an MCMC sample. varnames: list List of variables to plot (defaults to None, which results in all variables plotted). transform : callable Function to transform data (defaults to identity) alpha (optional): float Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05). quartiles (optional): bool Flag for plotting the interquartile range, in addition to the (1-alpha)*100% intervals (defaults to True). rhat (optional): bool Flag for plotting Gelman-Rubin statistics. Requires 2 or more chains (defaults to True). main (optional): string Title for main plot. Passing False results in titles being suppressed; passing None (default) results in default titles. xtitle (optional): string Label for x-axis. Defaults to no label xlim (optional): list or tuple Range for x-axis. Defaults to matplotlib's best guess. ylabels (optional): list or array User-defined labels for each variable. If not provided, the node __name__ attributes are used. chain_spacing (optional): float Plot spacing between chains (defaults to 0.05). vline (optional): numeric Location of vertical reference line (defaults to 0). gs : GridSpec Matplotlib GridSpec object. Defaults to None. plot_transformed : bool Flag for plotting automatically transformed variables in addition to original variables (defaults to False). Returns ------- gs : matplotlib GridSpec """ # Quantiles to be calculated if quartiles: qlist = [100 * alpha / 2, 25, 50, 75, 100 * (1 - alpha / 2)] else: qlist = [100 * alpha / 2, 50, 100 * (1 - alpha / 2)] # Range for x-axis plotrange = None # Subplots interval_plot = None nchains = trace_obj.nchains if varnames is None: varnames = get_default_varnames(trace_obj, plot_transformed) plot_rhat = (rhat and nchains > 1) # Empty list for y-axis labels if gs is None: # Initialize plot if plot_rhat: gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1]) else: gs = gridspec.GridSpec(1, 1) # Subplot for confidence intervals interval_plot = plt.subplot(gs[0]) trace_quantiles = quantiles(trace_obj, qlist, transform=transform, squeeze=False) hpd_intervals = hpd(trace_obj, alpha, transform=transform, squeeze=False) labels = [] for j, chain in enumerate(trace_obj.chains): # Counter for current variable var = 0 for varname in varnames: var_quantiles = trace_quantiles[chain][varname] quants = [var_quantiles[v] for v in qlist] var_hpd = hpd_intervals[chain][varname].T # Substitute HPD interval for quantile quants[0] = var_hpd[0].T quants[-1] = var_hpd[1].T # Ensure x-axis contains range of current interval if plotrange: plotrange = [min(plotrange[0], np.min(quants)), max(plotrange[1], np.max(quants))] else: plotrange = [np.min(quants), np.max(quants)] # Number of elements in current variable value = trace_obj.get_values(varname, chains=[chain])[0] k = np.size(value) # Append variable name(s) to list if j == 0: if k > 1: names = _var_str(varname, np.shape(value)) labels += names else: labels.append(varname) # Add spacing for each chain, if more than one offset = [0] + [(chain_spacing * ((i + 2) / 2)) * (-1) ** i for i in range(nchains - 1)] # Y coordinate with offset y = -var + offset[j] # Deal with multivariate nodes if k > 1: for q in np.transpose(quants).squeeze(): # Multiple y values interval_plot = _plot_tree(interval_plot, y, q, quartiles) y -= 1 else: interval_plot = _plot_tree(interval_plot, y, quants, quartiles) # Increment index var += k labels = ylabels if ylabels is not None else labels # Update margins left_margin = np.max([len(x) for x in labels]) * 0.015 gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05) # Define range of y-axis interval_plot.set_ylim(-var + 0.5, 0.5) datarange = plotrange[1] - plotrange[0] interval_plot.set_xlim(plotrange[0] - 0.05 * datarange, plotrange[1] + 0.05 * datarange) # Add variable labels interval_plot.set_yticks([-l for l in range(len(labels))]) interval_plot.set_yticklabels(labels) # Add title plot_title = "" if main is None: plot_title = "{:.0f}% Credible Intervals".format((1 - alpha) * 100) elif main: plot_title = main if plot_title: interval_plot.set_title(plot_title) # Add x-axis label if xtitle is not None: interval_plot.set_xlabel(xtitle) # Constrain to specified range if xlim is not None: interval_plot.set_xlim(*xlim) # Remove ticklines on y-axes for ticks in interval_plot.yaxis.get_major_ticks(): ticks.tick1On = False ticks.tick2On = False for loc, spine in interval_plot.spines.items(): if loc in ['left', 'right']: spine.set_color('none') # don't draw spine # Reference line interval_plot.axvline(vline, color='k', linestyle='--') # Genenerate Gelman-Rubin plot if plot_rhat: _make_rhat_plot(trace_obj, plt.subplot(gs[1]), "R-hat", labels, varnames, plot_transformed) return gs
def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.05, quartiles=True, rhat=True, main=None, xtitle=None, xlim=None, ylabels=None, chain_spacing=0.05, vline=0, gs=None, plot_transformed=False, **plot_kwargs): """ Forest plot (model summary plot). Generates a "forest plot" of 100*(1-alpha)% credible intervals for either the set of variables in a given model, or a specified set of nodes. Parameters ---------- trace_obj: NpTrace or MultiTrace object Trace(s) from an MCMC sample. varnames: list List of variables to plot (defaults to None, which results in all variables plotted). transform : callable Function to transform data (defaults to identity) alpha (optional): float Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05). quartiles (optional): bool Flag for plotting the interquartile range, in addition to the (1-alpha)*100% intervals (defaults to True). rhat (optional): bool Flag for plotting Gelman-Rubin statistics. Requires 2 or more chains (defaults to True). main (optional): string Title for main plot. Passing False results in titles being suppressed; passing None (default) results in default titles. xtitle (optional): string Label for x-axis. Defaults to no label xlim (optional): list or tuple Range for x-axis. Defaults to matplotlib's best guess. ylabels (optional): list or array User-defined labels for each variable. If not provided, the node __name__ attributes are used. chain_spacing (optional): float Plot spacing between chains (defaults to 0.05). vline (optional): numeric Location of vertical reference line (defaults to 0). gs : GridSpec Matplotlib GridSpec object. Defaults to None. plot_transformed : bool Flag for plotting automatically transformed variables in addition to original variables (defaults to False). plot_kwargs : dict Optional arguments for plot elements. Currently accepts 'fontsize', 'linewidth', 'color', 'marker', and 'markersize'. Returns ------- gs : matplotlib GridSpec """ # Quantiles to be calculated if quartiles: qlist = [100 * alpha / 2, 25, 50, 75, 100 * (1 - alpha / 2)] else: qlist = [100 * alpha / 2, 50, 100 * (1 - alpha / 2)] # Range for x-axis plotrange = None # Subplots interval_plot = None nchains = trace_obj.nchains if varnames is None: varnames = get_default_varnames(trace_obj.varnames, plot_transformed) plot_rhat = (rhat and nchains > 1) # Empty list for y-axis labels if gs is None: # Initialize plot if plot_rhat: gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1]) else: gs = gridspec.GridSpec(1, 1) # Subplot for confidence intervals interval_plot = plt.subplot(gs[0]) trace_quantiles = quantiles(trace_obj, qlist, transform=transform, squeeze=False) hpd_intervals = hpd(trace_obj, alpha, transform=transform, squeeze=False) labels = [] for j, chain in enumerate(trace_obj.chains): # Counter for current variable var = 0 for varname in varnames: var_quantiles = trace_quantiles[chain][varname] quants = [var_quantiles[v] for v in qlist] var_hpd = hpd_intervals[chain][varname].T # Substitute HPD interval for quantile quants[0] = var_hpd[0].T quants[-1] = var_hpd[1].T # Ensure x-axis contains range of current interval if plotrange: plotrange = [ min(plotrange[0], np.min(quants)), max(plotrange[1], np.max(quants)) ] else: plotrange = [np.min(quants), np.max(quants)] # Number of elements in current variable value = trace_obj.get_values(varname, chains=[chain])[0] k = np.size(value) # Append variable name(s) to list if j == 0: if k > 1: names = _var_str(varname, np.shape(value)) labels += names else: labels.append(varname) # Add spacing for each chain, if more than one offset = [0] + [(chain_spacing * ((i + 2) / 2)) * (-1)**i for i in range(nchains - 1)] # Y coordinate with offset y = -var + offset[j] # Deal with multivariate nodes if k > 1: for q in np.transpose(quants).squeeze(): # Multiple y values interval_plot = _plot_tree(interval_plot, y, q, quartiles, **plot_kwargs) y -= 1 else: interval_plot = _plot_tree(interval_plot, y, quants, quartiles, **plot_kwargs) # Increment index var += k labels = ylabels if ylabels is not None else labels # Update margins left_margin = np.max([len(x) for x in labels]) * 0.015 gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05) # Define range of y-axis interval_plot.set_ylim(-var + 0.5, 0.5) datarange = plotrange[1] - plotrange[0] interval_plot.set_xlim(plotrange[0] - 0.05 * datarange, plotrange[1] + 0.05 * datarange) # Add variable labels interval_plot.set_yticks([-l for l in range(len(labels))]) interval_plot.set_yticklabels(labels, fontsize=plot_kwargs.get('fontsize', None)) # Add title plot_title = "" if main is None: plot_title = "{:.0f}% Credible Intervals".format((1 - alpha) * 100) elif main: plot_title = main if plot_title: interval_plot.set_title(plot_title, fontsize=plot_kwargs.get('fontsize', None)) # Add x-axis label if xtitle is not None: interval_plot.set_xlabel(xtitle) # Constrain to specified range if xlim is not None: interval_plot.set_xlim(*xlim) # Remove ticklines on y-axes for ticks in interval_plot.yaxis.get_major_ticks(): ticks.tick1On = False ticks.tick2On = False for loc, spine in interval_plot.spines.items(): if loc in ['left', 'right']: spine.set_color('none') # don't draw spine # Reference line interval_plot.axvline(vline, color='k', linestyle=':') # Genenerate Gelman-Rubin plot if plot_rhat: _make_rhat_plot(trace_obj, plt.subplot(gs[1]), "R-hat", labels, varnames, plot_transformed) return gs
def forestplot(trace, models=None, varnames=None, transform=identity_transform, alpha=0.05, quartiles=True, rhat=True, main=None, xtitle=None, xlim=None, ylabels=None, colors='C0', chain_spacing=0.1, vline=0, gs=None, plot_transformed=False, plot_kwargs=None): """ Forest plot (model summary plot). Generates a "forest plot" of 100*(1-alpha)% credible intervals from a trace or list of traces. Parameters ---------- trace : trace or list of traces Trace(s) from an MCMC sample. models : list (optional) List with names for the models in the list of traces. Useful when plotting more that one trace. varnames: list List of variables to plot (defaults to None, which results in all variables plotted). transform : callable Function to transform data (defaults to identity) alpha : float, optional Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05). quartiles : bool, optional Flag for plotting the interquartile range, in addition to the (1-alpha)*100% intervals (defaults to True). rhat : bool, optional Flag for plotting Gelman-Rubin statistics. Requires 2 or more chains (defaults to True). main : string, optional Title for main plot. Passing False results in titles being suppressed; passing None (default) results in default titles. xtitle : string, optional Label for x-axis. Defaults to no label xlim : list or tuple, optional Range for x-axis. Defaults to matplotlib's best guess. ylabels : list or array, optional User-defined labels for each variable. If not provided, the node __name__ attributes are used. colors : list or string, optional list with valid matplotlib colors, one color per model. Alternative a string can be passed. If the string is `cycle `, it will automatically chose a color per model from the matyplolib's cycle. If a single color is passed, eg 'k', 'C2', 'red' this color will be used for all models. Defauls to 'C0' (blueish in most matplotlib styles) chain_spacing : float, optional Plot spacing between chains (defaults to 0.1). vline : numeric, optional Location of vertical reference line (defaults to 0). gs : GridSpec Matplotlib GridSpec object. Defaults to None. plot_transformed : bool Flag for plotting automatically transformed variables in addition to original variables (defaults to False). plot_kwargs : dict Optional arguments for plot elements. Currently accepts 'fontsize', 'linewidth', 'marker', and 'markersize'. Returns ------- gs : matplotlib GridSpec """ if plot_kwargs is None: plot_kwargs = {} if not isinstance(trace, (list, tuple)): trace = [trace] if models is None: if len(trace) > 1: models = ['m_{}'.format(i) for i in range(len(trace))] else: models = [''] elif len(models) != len(trace): raise ValueError("The number of names for the models does not match " "the number of models") if colors == 'cycle': colors = ['C{}'.format(i % 10) for i in range(len(models))] elif isinstance(colors, str): colors = [colors for i in range(len(models))] # Quantiles to be calculated if quartiles: qlist = [100 * alpha / 2, 25, 50, 75, 100 * (1 - alpha / 2)] else: qlist = [100 * alpha / 2, 50, 100 * (1 - alpha / 2)] nchains = [tr.nchains for tr in trace] if varnames is None: varnames = [] for idx, tr in enumerate(trace): varnames_tmp = get_default_varnames(tr.varnames, plot_transformed) for v in varnames_tmp: if v not in varnames: varnames.append(v) plot_rhat = [rhat and nch > 1 for nch in nchains] # Empty list for y-axis labels if gs is None: # Initialize plot if np.any(plot_rhat): gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1]) gr_plot = plt.subplot(gs[1]) gr_plot.set_xticks((1.0, 1.5, 2.0), ("1", "1.5", "2+")) gr_plot.set_xlim(0.9, 2.1) gr_plot.set_yticks([]) gr_plot.set_title('R-hat') else: gs = gridspec.GridSpec(1, 1) # Subplot for confidence intervals interval_plot = plt.subplot(gs[0]) trace_quantiles = [] hpd_intervals = [] for tr in trace: trace_quantiles.append( quantiles(tr, qlist, transform=transform, squeeze=False)) hpd_intervals.append(hpd(tr, alpha, transform=transform, squeeze=False)) labels = [] var = 0 all_quants = [] bands = [0.05, 0] * len(varnames) var_old = 0.5 for v_idx, v in enumerate(varnames): for h, tr in enumerate(trace): if v not in tr.varnames: labels.append(models[h] + ' ' + v) var += 1 else: for j, chain in enumerate(tr.chains): var_quantiles = trace_quantiles[h][chain][v] quants = [var_quantiles[vq] for vq in qlist] var_hpd = hpd_intervals[h][chain][v].T # Substitute HPD interval for quantile quants[0] = var_hpd[0].T quants[-1] = var_hpd[1].T # Ensure x-axis contains range of current interval all_quants.extend(np.ravel(quants)) # Number of elements in current variable value = tr.get_values(v, chains=[chain])[0] k = np.size(value) # Append variable name(s) to list if j == 0: if k > 1: names = _var_str(v, np.shape(value)) names[0] = models[h] + ' ' + names[0] labels += names else: labels.append(models[h] + ' ' + v) # Add spacing for each chain, if more than one offset = [0] + [(chain_spacing * ((i + 2) / 2)) * (-1)**i for i in range(nchains[h] - 1)] # Y coordinate with offset y = -var + offset[j] # Deal with multivariate nodes if k > 1: qs = np.moveaxis(np.array(quants), 0, -1).squeeze() for q in qs.reshape(-1, len(quants)): # Multiple y values interval_plot = _plot_tree(interval_plot, y, q, quartiles, colors[h], plot_kwargs) y -= 1 else: interval_plot = _plot_tree(interval_plot, y, quants, quartiles, colors[h], plot_kwargs) # Genenerate Gelman-Rubin plot if plot_rhat[h] and v in tr.varnames: R = gelman_rubin(tr, [v]) if k > 1: Rval = dict2pd(R, 'rhat').values gr_plot.plot([min(r, 2) for r in Rval], [-(j + var) for j in range(k)], 'o', color=colors[h], markersize=4) else: gr_plot.plot(min(R[v], 2), -var, 'o', color=colors[h], markersize=4) var += k if len(trace) > 1: interval_plot.axhspan(var_old, y - chain_spacing - 0.5, facecolor='k', alpha=bands[v_idx]) gr_plot.axhspan(var_old, y - chain_spacing - 0.5, facecolor='k', alpha=bands[v_idx]) var_old = y - chain_spacing - 0.5 if ylabels is not None: labels = ylabels # Update margins left_margin = np.max([len(x) for x in labels]) * 0.015 gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05) # Define range of y-axis for forestplot and R-hat interval_plot.set_ylim(-var + 0.5, 0.5) if np.any(plot_rhat): gr_plot.set_ylim(-var + 0.5, 0.5) plotrange = [np.min(all_quants), np.max(all_quants)] datarange = plotrange[1] - plotrange[0] interval_plot.set_xlim(plotrange[0] - 0.05 * datarange, plotrange[1] + 0.05 * datarange) # Add variable labels interval_plot.set_yticks([-l for l in range(len(labels))]) interval_plot.set_yticklabels(labels, fontsize=plot_kwargs.get('fontsize', None)) # Add title if main is None: plot_title = "{:.0f}% Credible Intervals".format((1 - alpha) * 100) elif main: plot_title = main else: plot_title = "" interval_plot.set_title(plot_title, fontsize=plot_kwargs.get('fontsize', None)) # Add x-axis label if xtitle is not None: interval_plot.set_xlabel(xtitle) # Constrain to specified range if xlim is not None: interval_plot.set_xlim(*xlim) # Remove ticklines on y-axes for ticks in interval_plot.yaxis.get_major_ticks(): ticks.tick1On = False ticks.tick2On = False for loc, spine in interval_plot.spines.items(): if loc in ['left', 'right']: spine.set_color('none') # don't draw spine # Reference line interval_plot.axvline(vline, color='k', linestyle=':') return gs
trace_model_0.append(trace_tuned[i]) elif trace_tuned[i,0] == 1: trace_model_1.append(trace_tuned[i]) trace_model_0=np.asarray(trace_model_0) trace_model_1=np.asarray(trace_model_1) f = plt.figure() plt.scatter(trace_model_0[::10,1],trace_model_0[::10,2],s=1,label='Model 0',marker=',',alpha=0.7) plt.scatter(trace_model_1[::10,1],trace_model_1[::10,2],s=1,label='Model 1',marker=',',alpha=0.7) plt.legend() plt.show() plt.plot(trace[::500,1],trace[::500,2]) plt.show() map_x_0=hpd(trace_model_0[:,1],alpha=0.05) map_x_1=hpd(trace_model_1[:,1],alpha=0.05) map_y_0=hpd(trace_model_0[:,2],alpha=0.05) map_y_1=hpd(trace_model_1[:,2],alpha=0.05) CI_x_0=map_x_0[1]-map_x_0[0] CI_x_1=map_x_1[1]-map_x_1[0] CI_y_0=map_y_0[1]-map_y_0[0] CI_y_1=map_y_1[1]-map_y_1[0] #trace_model_0_subsample=trace_model_0[::1000] #trace_model_1_subsample=trace_model_1[::1000] #trace_subsample=trace_tuned[::1000] #Try subsampling to make the graphs look better. plt.hist(trace_model_0[:,1],bins=100,label='x values Model 0',density=True)
def forestplot(trace, models=None, varnames=None, transform=identity_transform, alpha=0.05, quartiles=True, rhat=True, main=None, xtitle=None, xlim=None, ylabels=None, colors='C0', chain_spacing=0.1, vline=0, gs=None, plot_transformed=False, plot_kwargs=None): """ Forest plot (model summary plot). Generates a "forest plot" of 100*(1-alpha)% credible intervals from a trace or list of traces. Parameters ---------- trace : trace or list of traces Trace(s) from an MCMC sample. models : list (optional) List with names for the models in the list of traces. Useful when plotting more that one trace. varnames: list List of variables to plot (defaults to None, which results in all variables plotted). transform : callable Function to transform data (defaults to identity) alpha : float, optional Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05). quartiles : bool, optional Flag for plotting the interquartile range, in addition to the (1-alpha)*100% intervals (defaults to True). rhat : bool, optional Flag for plotting Gelman-Rubin statistics. Requires 2 or more chains (defaults to True). main : string, optional Title for main plot. Passing False results in titles being suppressed; passing None (default) results in default titles. xtitle : string, optional Label for x-axis. Defaults to no label xlim : list or tuple, optional Range for x-axis. Defaults to matplotlib's best guess. ylabels : list or array, optional User-defined labels for each variable. If not provided, the node __name__ attributes are used. colors : list or string, optional list with valid matplotlib colors, one color per model. Alternative a string can be passed. If the string is `cycle `, it will automatically chose a color per model from the matyplolib's cycle. If a single color is passed, eg 'k', 'C2', 'red' this color will be used for all models. Defauls to 'C0' (blueish in most matplotlib styles) chain_spacing : float, optional Plot spacing between chains (defaults to 0.1). vline : numeric, optional Location of vertical reference line (defaults to 0). gs : GridSpec Matplotlib GridSpec object. Defaults to None. plot_transformed : bool Flag for plotting automatically transformed variables in addition to original variables (defaults to False). plot_kwargs : dict Optional arguments for plot elements. Currently accepts 'fontsize', 'linewidth', 'marker', and 'markersize'. Returns ------- gs : matplotlib GridSpec """ if plot_kwargs is None: plot_kwargs = {} if not isinstance(trace, (list, tuple)): trace = [trace] if models is None: if len(trace) > 1: models = ['m_{}'.format(i) for i in range(len(trace))] else: models = [''] elif len(models) != len(trace): raise ValueError("The number of names for the models does not match " "the number of models") if colors == 'cycle': colors = ['C{}'.format(i % 10) for i in range(len(models))] elif isinstance(colors, str): colors = [colors for i in range(len(models))] # Quantiles to be calculated if quartiles: qlist = [100 * alpha / 2, 25, 50, 75, 100 * (1 - alpha / 2)] else: qlist = [100 * alpha / 2, 50, 100 * (1 - alpha / 2)] nchains = [tr.nchains for tr in trace] if varnames is None: varnames = [] for idx, tr in enumerate(trace): varnames_tmp = get_default_varnames(tr.varnames, plot_transformed) for v in varnames_tmp: if v not in varnames: varnames.append(v) plot_rhat = [rhat and nch > 1 for nch in nchains] # Empty list for y-axis labels if gs is None: # Initialize plot if np.any(plot_rhat): gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1]) gr_plot = plt.subplot(gs[1]) gr_plot.set_xticks((1.0, 1.5, 2.0), ("1", "1.5", "2+")) gr_plot.set_xlim(0.9, 2.1) gr_plot.set_yticks([]) gr_plot.set_title('R-hat') else: gs = gridspec.GridSpec(1, 1) # Subplot for confidence intervals interval_plot = plt.subplot(gs[0]) trace_quantiles = [] hpd_intervals = [] for tr in trace: trace_quantiles.append(quantiles(tr, qlist, transform=transform, squeeze=False)) hpd_intervals.append(hpd(tr, alpha, transform=transform, squeeze=False)) labels = [] var = 0 all_quants = [] bands = [0.05, 0] * len(varnames) var_old = 0.5 for v_idx, v in enumerate(varnames): for h, tr in enumerate(trace): if v not in tr.varnames: labels.append(models[h] + ' ' + v) var += 1 else: for j, chain in enumerate(tr.chains): var_quantiles = trace_quantiles[h][chain][v] quants = [var_quantiles[vq] for vq in qlist] var_hpd = hpd_intervals[h][chain][v].T # Substitute HPD interval for quantile quants[0] = var_hpd[0].T quants[-1] = var_hpd[1].T # Ensure x-axis contains range of current interval all_quants.extend(np.ravel(quants)) # Number of elements in current variable value = tr.get_values(v, chains=[chain])[0] k = np.size(value) # Append variable name(s) to list if j == 0: if k > 1: names = _var_str(v, np.shape(value)) names[0] = models[h] + ' ' + names[0] labels += names else: labels.append(models[h] + ' ' + v) # Add spacing for each chain, if more than one offset = [0] + [(chain_spacing * ((i + 2) / 2)) * (-1) ** i for i in range(nchains[h] - 1)] # Y coordinate with offset y = - var + offset[j] # Deal with multivariate nodes if k > 1: qs = np.moveaxis(np.array(quants), 0, -1).squeeze() for q in qs.reshape(-1, len(quants)): # Multiple y values interval_plot = _plot_tree(interval_plot, y, q, quartiles, colors[h], plot_kwargs) y -= 1 else: interval_plot = _plot_tree(interval_plot, y, quants, quartiles, colors[h], plot_kwargs) # Genenerate Gelman-Rubin plot if plot_rhat[h] and v in tr.varnames: R = gelman_rubin(tr, [v]) if k > 1: Rval = dict2pd(R, 'rhat').values gr_plot.plot([min(r, 2) for r in Rval], [-(j + var) for j in range(k)], 'o', color=colors[h], markersize=4) else: gr_plot.plot(min(R[v], 2), -var, 'o', color=colors[h], markersize=4) var += k if len(trace) > 1: interval_plot.axhspan(var_old, y - chain_spacing - 0.5, facecolor='k', alpha=bands[v_idx]) if np.any(plot_rhat): gr_plot.axhspan(var_old, y - chain_spacing - 0.5, facecolor='k', alpha=bands[v_idx]) var_old = y - chain_spacing - 0.5 if ylabels is not None: labels = ylabels # Update margins left_margin = np.max([len(x) for x in labels]) * 0.015 gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05) # Define range of y-axis for forestplot and R-hat interval_plot.set_ylim(- var + 0.5, 0.5) if np.any(plot_rhat): gr_plot.set_ylim(- var + 0.5, 0.5) plotrange = [np.min(all_quants), np.max(all_quants)] datarange = plotrange[1] - plotrange[0] interval_plot.set_xlim(plotrange[0] - 0.05 * datarange, plotrange[1] + 0.05 * datarange) # Add variable labels interval_plot.set_yticks([- l for l in range(len(labels))]) interval_plot.set_yticklabels(labels, fontsize=plot_kwargs.get('fontsize', None)) # Add title if main is None: plot_title = "{:.0f}% Credible Intervals".format((1 - alpha) * 100) elif main: plot_title = main else: plot_title = "" interval_plot.set_title(plot_title, fontsize=plot_kwargs.get('fontsize', None)) # Add x-axis label if xtitle is not None: interval_plot.set_xlabel(xtitle) # Constrain to specified range if xlim is not None: interval_plot.set_xlim(*xlim) # Remove ticklines on y-axes for ticks in interval_plot.yaxis.get_major_ticks(): ticks.tick1On = False ticks.tick2On = False for loc, spine in interval_plot.spines.items(): if loc in ['left', 'right']: spine.set_color('none') # don't draw spine # Reference line interval_plot.axvline(vline, color='k', linestyle=':') return gs