Пример #1
0
def summarize_psd(run, time_dir):
    '''
    Returns a DataFrame with the median and credible intervals for one time.
    Credible intervals are calculated using
    pymc3's highest posterior density (HPD) function, where alpha is the 
    desired probability of type I error (so, 1 - C.I.).
    Uses the same MultiIndex as import_time().
    
    Input
    -----
      time_dir : relative path to the time directory
    '''
    # Import time data
    time_data = import_time(run, time_dir)
    # Grab MultiIndex
    midx = time_data.index
    # Calculate HPDs
    time_data_np = time_data.to_numpy().T
    hpd_50 = hpd(time_data_np, alpha=0.5)
    hpd_90 = hpd(time_data_np, alpha=0.1)
    # Return summary DataFrame
    return pd.DataFrame({
        'MEDIAN'    : time_data.median(axis=1),
        'CI_50_LO'  : pd.Series(hpd_50[:,0], index=midx),
        'CI_50_HI'  : pd.Series(hpd_50[:,1], index=midx),
        'CI_90_LO'  : pd.Series(hpd_90[:,0], index=midx),
        'CI_90_HI'  : pd.Series(hpd_90[:,1], index=midx),
    }, index=midx)
Пример #2
0
def summarize_linechain(run, time_dir, channel, time_counts, log):
    '''
    Returns DataFrame of percentile values for each parameter.
    
    Input
    -----
      time_dir : string, time directory
      channel : str, channel name
      time_counts : histogram of the number of times each model was chosen for 
                    this time and channel
      log : utils.Log object
    '''
    time = run.get_time(time_dir)
    ch_idx = run.get_channel_index(channel)
    log.log(f'\n-- {time} CHANNEL {channel} --')
    # Import linechain
    lc_file = f'{time_dir}linechain_channel{ch_idx}.dat'
    # Get preferred model
    model = time_counts.argmax()

    log.log('Line model histogram:')
    log.log(np.array2string(time_counts, max_line_width=80))
    log.log(f'{model} spectral lines found.')

    # Initialize summary DataFrame
    cols = pd.Series(
        ['MEDIAN', 'CI_50_LO', 'CI_50_HI', 'CI_90_LO', 'CI_90_HI'])
    parameters = ['FREQ', 'AMP', 'QF']
    summary = pd.DataFrame([], columns=cols)

    if model > 0:
        params = import_linechain(lc_file, model)
        # Line model
        model = params.shape[1]
        # Sort
        if model > 1:
            params = sort_params(params, log)

        # HPD
        median = np.median(params, axis=0).flatten()[:, np.newaxis]
        hpd_50 = np.vstack(hpd(params, alpha=0.5))
        hpd_90 = np.vstack(hpd(params, alpha=0.1))
        stats = np.hstack([median, hpd_50, hpd_90])
        midx = pd.MultiIndex.from_product(
            [[channel], [time],
             list(range(model)), parameters],
            names=['CHANNEL', 'TIME', 'LINE', 'PARAMETER'])
        summary = pd.DataFrame(stats, columns=cols, index=midx)

        log.log('Line parameter summary:')
        log.log(summary.to_string(max_cols=80))

    return summary
Пример #3
0
def mcmc_prior_proposal(n_models, calc_posterior, guess_params, guess_sd):
    swap_freq = 0.0
    n_iter = 50000
    tune_freq = 100
    tune_for = 10000

    for i in range(n_models):
        initial_values = guess_params[i, :]
        initial_sd = guess_sd[i, :]
        trace, logp_trace, attempt_matrix, acceptance_matrix, prop_sd, accept_vector = RJMC_outerloop(
            calc_posterior, n_iter, initial_values, initial_sd, n_models,
            swap_freq, tune_freq, tune_for, 1, 1, 1, 1)
        trace_tuned = trace[tune_for:]
        max_ap = np.empty(np.size(trace_tuned, 1))
        map_CI = np.zeros((np.size(trace_tuned, 1), 2))
        parameter_prior_proposal = np.empty((n_models, np.size(trace_tuned,
                                                               1), 2))

        for j in range(np.size(trace_tuned, 2)):
            bins, values = np.histogram(trace_tuned[:, j], bins=100)
            max_ap[j] = (values[np.argmax(bins) + 1] +
                         values[np.argmax(bins)]) / 2
            map_CI[j] = hpd(trace_tuned[:, i], alpha=0.05)
            sigma_hat = map_CI[j, 1] - map_CI[j, 0] / (2 * 1.96)
            parameter_prior_proposal[j, i] = [max_ap, sigma_hat * 1.5]
            support = np.linspace(np.min(trace_tuned[:, j]),
                                  np.max(trace_tuned[:, j]), 100)
            plt.hist(trace_tuned[:, j], density=True)
            plt.plot(support, norm(support, *parameter_prior_proposal))
    return parameter_prior_proposal
Пример #4
0
def calc_credible_intervals(trace_arr, trace_dict, alpha=.05):
    from pymc3.stats import hpd
    hpd_dict = {}
    
    for i, key in enumerate(trace_dict['param_list']):
        hpd_dict[key] = hpd(trace_arr[:,i], alpha)
        
    return hpd_dict
Пример #5
0
 def display_hpd():
     hpd_intervals = hpd(trace_values, alpha=alpha_level)
     ax.plot(hpd_intervals, (plot_height * 0.02,
                             plot_height * 0.02), linewidth=4, color='k')
     text_props = dict(size=16, horizontalalignment='center')
     ax.text(hpd_intervals[0], plot_height * 0.07,
             hpd_intervals[0].round(round_to), **text_props)
     ax.text(hpd_intervals[1], plot_height * 0.07,
             hpd_intervals[1].round(round_to), **text_props)
     ax.text((hpd_intervals[0] + hpd_intervals[1]) / 2, plot_height * 0.2,
             format_as_percent(1 - alpha_level) + ' HPD', **text_props)
Пример #6
0
 def display_hpd():
     hpd_intervals = hpd(trace_values, alpha=alpha_level)
     ax.plot(hpd_intervals, (plot_height * 0.02, plot_height * 0.02),
             linewidth=4,
             color='k')
     text_props = dict(size=16, horizontalalignment='center')
     ax.text(hpd_intervals[0], plot_height * 0.07,
             hpd_intervals[0].round(round_to), **text_props)
     ax.text(hpd_intervals[1], plot_height * 0.07,
             hpd_intervals[1].round(round_to), **text_props)
     ax.text((hpd_intervals[0] + hpd_intervals[1]) / 2, plot_height * 0.2,
             format_as_percent(1 - alpha_level) + ' HPD', **text_props)
Пример #7
0
def version_pymc3():

    print("(a) -- independent θ, not modelling exposure")
    with pm.Model():
        θ = pm.Gamma("θ", alpha=1, beta=1)
        pm.Poisson("n_a", mu=θ, observed=observed_n_a)
        trace = pm.sample(5000, tune=1000, cores=4)

    print("(a) 95% credible interval for θ:")
    print(pms.hpd(trace["θ"]))

    print("(b) -- independent θ, modelling exposure")
    with pm.Model():
        θ = pm.Gamma("θ", alpha=1, beta=1)
        pm.Poisson("n_a", mu=miles_e8_estimate * θ, observed=observed_n_a)
        trace = pm.sample(5000, tune=1000, cores=4)

    print("(b) 95% credible interval for θ:")
    print(pms.hpd(trace["θ"]))
    pm.traceplot(trace)
    plt.show()
def calc_hdi(hist, bins, size, random_state=None):
    """
    Estimate HDI from histogram of probability.

    Data is sampled with the given probability and
    HDI is calculated with pymc3.stats.hpd

    Parameters
    ----------
    hist : list of float
        Histogram of probability.
    bins : list of float
        Bins of the histogram.
    size : int
        Number of samples to generate.
    random_state : int, RandomState instance or None, optional (default=None)
        If int, random_state is the seed used by the random number generator;
        If RandomState instance, random_state is the random number generator;
        If None, the random number generator is the RandomState instance used
        by `np.random`.

    Returns
    -------
    hdi_x : [float, float]
        HDI range.
    hdi_y : float
        Probability value corresponding to the HDI.
    """
    rng = sklearn.utils.check_random_state(random_state)
    samples = rng.choice(bins, size=size, p=hist)
    hdi_x = [float(x + np.diff(bins[:2])) for x in hpd(samples)]
    hdi_y = np.mean([
        hist[np.argwhere(np.abs(bins - x) < 1e-6).ravel()[0] + 1]
        for x in hdi_x
    ])
    return hdi_x, hdi_y
Пример #9
0
def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.05, quartiles=True,
               rhat=True, main=None, xtitle=None, xlim=None, ylabels=None,
               chain_spacing=0.05, vline=0, gs=None, plot_transformed=False):
    """
    Forest plot (model summary plot).

    Generates a "forest plot" of 100*(1-alpha)% credible intervals for either
    the set of variables in a given model, or a specified set of nodes.

    Parameters
    ----------

    trace_obj: NpTrace or MultiTrace object
        Trace(s) from an MCMC sample.
    varnames: list
        List of variables to plot (defaults to None, which results in all
        variables plotted).
    transform : callable
        Function to transform data (defaults to identity)
    alpha (optional): float
        Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05).
    quartiles (optional): bool
        Flag for plotting the interquartile range, in addition to the
        (1-alpha)*100% intervals (defaults to True).
    rhat (optional): bool
        Flag for plotting Gelman-Rubin statistics. Requires 2 or more chains
        (defaults to True).
    main (optional): string
        Title for main plot. Passing False results in titles being suppressed;
        passing None (default) results in default titles.
    xtitle (optional): string
        Label for x-axis. Defaults to no label
    xlim (optional): list or tuple
        Range for x-axis. Defaults to matplotlib's best guess.
    ylabels (optional): list or array
        User-defined labels for each variable. If not provided, the node
        __name__ attributes are used.
    chain_spacing (optional): float
        Plot spacing between chains (defaults to 0.05).
    vline (optional): numeric
        Location of vertical reference line (defaults to 0).
    gs : GridSpec
        Matplotlib GridSpec object. Defaults to None.
    plot_transformed : bool
        Flag for plotting automatically transformed variables in addition to
        original variables (defaults to False).

    Returns
    -------

    gs : matplotlib GridSpec

    """
    # Quantiles to be calculated
    if quartiles:
        qlist = [100 * alpha / 2, 25, 50, 75, 100 * (1 - alpha / 2)]
    else:
        qlist = [100 * alpha / 2, 50, 100 * (1 - alpha / 2)]

    # Range for x-axis
    plotrange = None

    # Subplots
    interval_plot = None

    nchains = trace_obj.nchains

    if varnames is None:
        varnames = get_default_varnames(trace_obj, plot_transformed)

    plot_rhat = (rhat and nchains > 1)
    # Empty list for y-axis labels
    if gs is None:
        # Initialize plot
        if plot_rhat:
            gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1])
        else:
            gs = gridspec.GridSpec(1, 1)

        # Subplot for confidence intervals
        interval_plot = plt.subplot(gs[0])

    trace_quantiles = quantiles(trace_obj, qlist, transform=transform, squeeze=False)
    hpd_intervals = hpd(trace_obj, alpha, transform=transform, squeeze=False)

    labels = []
    for j, chain in enumerate(trace_obj.chains):
        # Counter for current variable
        var = 0
        for varname in varnames:
            var_quantiles = trace_quantiles[chain][varname]

            quants = [var_quantiles[v] for v in qlist]
            var_hpd = hpd_intervals[chain][varname].T

            # Substitute HPD interval for quantile
            quants[0] = var_hpd[0].T
            quants[-1] = var_hpd[1].T

            # Ensure x-axis contains range of current interval
            if plotrange:
                plotrange = [min(plotrange[0], np.min(quants)),
                             max(plotrange[1], np.max(quants))]
            else:
                plotrange = [np.min(quants), np.max(quants)]

            # Number of elements in current variable
            value = trace_obj.get_values(varname, chains=[chain])[0]
            k = np.size(value)

            # Append variable name(s) to list
            if j == 0:
                if k > 1:
                    names = _var_str(varname, np.shape(value))
                    labels += names
                else:
                    labels.append(varname)

            # Add spacing for each chain, if more than one
            offset = [0] + [(chain_spacing * ((i + 2) / 2)) * (-1) ** i for i in range(nchains - 1)]

            # Y coordinate with offset
            y = -var + offset[j]

            # Deal with multivariate nodes
            if k > 1:
                for q in np.transpose(quants).squeeze():
                    # Multiple y values
                    interval_plot = _plot_tree(interval_plot, y, q, quartiles)
                    y -= 1
            else:
                interval_plot = _plot_tree(interval_plot, y, quants, quartiles)

            # Increment index
            var += k

    labels = ylabels if ylabels is not None else labels

    # Update margins
    left_margin = np.max([len(x) for x in labels]) * 0.015
    gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05)

    # Define range of y-axis
    interval_plot.set_ylim(-var + 0.5, 0.5)

    datarange = plotrange[1] - plotrange[0]
    interval_plot.set_xlim(plotrange[0] - 0.05 * datarange, plotrange[1] + 0.05 * datarange)

    # Add variable labels
    interval_plot.set_yticks([-l for l in range(len(labels))])
    interval_plot.set_yticklabels(labels)

    # Add title
    plot_title = ""
    if main is None:
        plot_title = "{:.0f}% Credible Intervals".format((1 - alpha) * 100)
    elif main:
        plot_title = main
    if plot_title:
        interval_plot.set_title(plot_title)

    # Add x-axis label
    if xtitle is not None:
        interval_plot.set_xlabel(xtitle)

    # Constrain to specified range
    if xlim is not None:
        interval_plot.set_xlim(*xlim)

    # Remove ticklines on y-axes
    for ticks in interval_plot.yaxis.get_major_ticks():
        ticks.tick1On = False
        ticks.tick2On = False

    for loc, spine in interval_plot.spines.items():
        if loc in ['left', 'right']:
            spine.set_color('none')  # don't draw spine

    # Reference line
    interval_plot.axvline(vline, color='k', linestyle='--')

    # Genenerate Gelman-Rubin plot
    if plot_rhat:
        _make_rhat_plot(trace_obj, plt.subplot(gs[1]), "R-hat", labels, varnames, plot_transformed)

    return gs
Пример #10
0
def forestplot(trace_obj,
               varnames=None,
               transform=identity_transform,
               alpha=0.05,
               quartiles=True,
               rhat=True,
               main=None,
               xtitle=None,
               xlim=None,
               ylabels=None,
               chain_spacing=0.05,
               vline=0,
               gs=None,
               plot_transformed=False,
               **plot_kwargs):
    """
    Forest plot (model summary plot).

    Generates a "forest plot" of 100*(1-alpha)% credible intervals for either
    the set of variables in a given model, or a specified set of nodes.

    Parameters
    ----------

    trace_obj: NpTrace or MultiTrace object
        Trace(s) from an MCMC sample.
    varnames: list
        List of variables to plot (defaults to None, which results in all
        variables plotted).
    transform : callable
        Function to transform data (defaults to identity)
    alpha (optional): float
        Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05).
    quartiles (optional): bool
        Flag for plotting the interquartile range, in addition to the
        (1-alpha)*100% intervals (defaults to True).
    rhat (optional): bool
        Flag for plotting Gelman-Rubin statistics. Requires 2 or more chains
        (defaults to True).
    main (optional): string
        Title for main plot. Passing False results in titles being suppressed;
        passing None (default) results in default titles.
    xtitle (optional): string
        Label for x-axis. Defaults to no label
    xlim (optional): list or tuple
        Range for x-axis. Defaults to matplotlib's best guess.
    ylabels (optional): list or array
        User-defined labels for each variable. If not provided, the node
        __name__ attributes are used.
    chain_spacing (optional): float
        Plot spacing between chains (defaults to 0.05).
    vline (optional): numeric
        Location of vertical reference line (defaults to 0).
    gs : GridSpec
        Matplotlib GridSpec object. Defaults to None.
    plot_transformed : bool
        Flag for plotting automatically transformed variables in addition to
        original variables (defaults to False).
    plot_kwargs : dict
        Optional arguments for plot elements. Currently accepts 'fontsize',
        'linewidth', 'color', 'marker', and 'markersize'.

    Returns
    -------

    gs : matplotlib GridSpec

    """
    # Quantiles to be calculated
    if quartiles:
        qlist = [100 * alpha / 2, 25, 50, 75, 100 * (1 - alpha / 2)]
    else:
        qlist = [100 * alpha / 2, 50, 100 * (1 - alpha / 2)]

    # Range for x-axis
    plotrange = None

    # Subplots
    interval_plot = None

    nchains = trace_obj.nchains

    if varnames is None:
        varnames = get_default_varnames(trace_obj.varnames, plot_transformed)

    plot_rhat = (rhat and nchains > 1)
    # Empty list for y-axis labels
    if gs is None:
        # Initialize plot
        if plot_rhat:
            gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1])
        else:
            gs = gridspec.GridSpec(1, 1)

        # Subplot for confidence intervals
        interval_plot = plt.subplot(gs[0])

    trace_quantiles = quantiles(trace_obj,
                                qlist,
                                transform=transform,
                                squeeze=False)
    hpd_intervals = hpd(trace_obj, alpha, transform=transform, squeeze=False)

    labels = []
    for j, chain in enumerate(trace_obj.chains):
        # Counter for current variable
        var = 0
        for varname in varnames:
            var_quantiles = trace_quantiles[chain][varname]

            quants = [var_quantiles[v] for v in qlist]
            var_hpd = hpd_intervals[chain][varname].T

            # Substitute HPD interval for quantile
            quants[0] = var_hpd[0].T
            quants[-1] = var_hpd[1].T

            # Ensure x-axis contains range of current interval
            if plotrange:
                plotrange = [
                    min(plotrange[0], np.min(quants)),
                    max(plotrange[1], np.max(quants))
                ]
            else:
                plotrange = [np.min(quants), np.max(quants)]

            # Number of elements in current variable
            value = trace_obj.get_values(varname, chains=[chain])[0]
            k = np.size(value)

            # Append variable name(s) to list
            if j == 0:
                if k > 1:
                    names = _var_str(varname, np.shape(value))
                    labels += names
                else:
                    labels.append(varname)

            # Add spacing for each chain, if more than one
            offset = [0] + [(chain_spacing * ((i + 2) / 2)) * (-1)**i
                            for i in range(nchains - 1)]

            # Y coordinate with offset
            y = -var + offset[j]

            # Deal with multivariate nodes
            if k > 1:
                for q in np.transpose(quants).squeeze():
                    # Multiple y values
                    interval_plot = _plot_tree(interval_plot, y, q, quartiles,
                                               **plot_kwargs)
                    y -= 1
            else:
                interval_plot = _plot_tree(interval_plot, y, quants, quartiles,
                                           **plot_kwargs)

            # Increment index
            var += k

    labels = ylabels if ylabels is not None else labels

    # Update margins
    left_margin = np.max([len(x) for x in labels]) * 0.015
    gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05)

    # Define range of y-axis
    interval_plot.set_ylim(-var + 0.5, 0.5)

    datarange = plotrange[1] - plotrange[0]
    interval_plot.set_xlim(plotrange[0] - 0.05 * datarange,
                           plotrange[1] + 0.05 * datarange)

    # Add variable labels
    interval_plot.set_yticks([-l for l in range(len(labels))])
    interval_plot.set_yticklabels(labels,
                                  fontsize=plot_kwargs.get('fontsize', None))

    # Add title
    plot_title = ""
    if main is None:
        plot_title = "{:.0f}% Credible Intervals".format((1 - alpha) * 100)
    elif main:
        plot_title = main
    if plot_title:
        interval_plot.set_title(plot_title,
                                fontsize=plot_kwargs.get('fontsize', None))

    # Add x-axis label
    if xtitle is not None:
        interval_plot.set_xlabel(xtitle)

    # Constrain to specified range
    if xlim is not None:
        interval_plot.set_xlim(*xlim)

    # Remove ticklines on y-axes
    for ticks in interval_plot.yaxis.get_major_ticks():
        ticks.tick1On = False
        ticks.tick2On = False

    for loc, spine in interval_plot.spines.items():
        if loc in ['left', 'right']:
            spine.set_color('none')  # don't draw spine

    # Reference line
    interval_plot.axvline(vline, color='k', linestyle=':')

    # Genenerate Gelman-Rubin plot
    if plot_rhat:
        _make_rhat_plot(trace_obj, plt.subplot(gs[1]), "R-hat", labels,
                        varnames, plot_transformed)

    return gs
Пример #11
0
def forestplot(trace,
               models=None,
               varnames=None,
               transform=identity_transform,
               alpha=0.05,
               quartiles=True,
               rhat=True,
               main=None,
               xtitle=None,
               xlim=None,
               ylabels=None,
               colors='C0',
               chain_spacing=0.1,
               vline=0,
               gs=None,
               plot_transformed=False,
               plot_kwargs=None):
    """
    Forest plot (model summary plot).

    Generates a "forest plot" of 100*(1-alpha)% credible intervals from a trace
    or list of traces.

    Parameters
    ----------

    trace : trace or list of traces
        Trace(s) from an MCMC sample.
    models : list (optional)
        List with names for the models in the list of traces. Useful when
        plotting more that one trace.
    varnames: list
        List of variables to plot (defaults to None, which results in all
        variables plotted).
    transform : callable
        Function to transform data (defaults to identity)
    alpha : float, optional
        Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05).
    quartiles : bool, optional
        Flag for plotting the interquartile range, in addition to the
        (1-alpha)*100% intervals (defaults to True).
    rhat : bool, optional
        Flag for plotting Gelman-Rubin statistics. Requires 2 or more chains
        (defaults to True).
    main : string, optional
        Title for main plot. Passing False results in titles being suppressed;
        passing None (default) results in default titles.
    xtitle : string, optional
        Label for x-axis. Defaults to no label
    xlim : list or tuple, optional
        Range for x-axis. Defaults to matplotlib's best guess.
    ylabels : list or array, optional
        User-defined labels for each variable. If not provided, the node
        __name__ attributes are used.
    colors : list or string, optional
        list with valid matplotlib colors, one color per model. Alternative a
        string can be passed. If the string is `cycle `, it will automatically
        chose a color per model from the matyplolib's cycle. If a single color
        is passed, eg 'k', 'C2', 'red' this color will be used for all models.
        Defauls to 'C0' (blueish in most matplotlib styles)
    chain_spacing : float, optional
        Plot spacing between chains (defaults to 0.1).
    vline : numeric, optional
        Location of vertical reference line (defaults to 0).
    gs : GridSpec
        Matplotlib GridSpec object. Defaults to None.
    plot_transformed : bool
        Flag for plotting automatically transformed variables in addition to
        original variables (defaults to False).
    plot_kwargs : dict
        Optional arguments for plot elements. Currently accepts 'fontsize',
        'linewidth', 'marker', and 'markersize'.

    Returns
    -------

    gs : matplotlib GridSpec

    """
    if plot_kwargs is None:
        plot_kwargs = {}

    if not isinstance(trace, (list, tuple)):
        trace = [trace]

    if models is None:
        if len(trace) > 1:
            models = ['m_{}'.format(i) for i in range(len(trace))]
        else:
            models = ['']
    elif len(models) != len(trace):
        raise ValueError("The number of names for the models does not match "
                         "the number of models")

    if colors == 'cycle':
        colors = ['C{}'.format(i % 10) for i in range(len(models))]
    elif isinstance(colors, str):
        colors = [colors for i in range(len(models))]

    # Quantiles to be calculated
    if quartiles:
        qlist = [100 * alpha / 2, 25, 50, 75, 100 * (1 - alpha / 2)]
    else:
        qlist = [100 * alpha / 2, 50, 100 * (1 - alpha / 2)]

    nchains = [tr.nchains for tr in trace]

    if varnames is None:
        varnames = []
        for idx, tr in enumerate(trace):
            varnames_tmp = get_default_varnames(tr.varnames, plot_transformed)
            for v in varnames_tmp:
                if v not in varnames:
                    varnames.append(v)

    plot_rhat = [rhat and nch > 1 for nch in nchains]
    # Empty list for y-axis labels
    if gs is None:
        # Initialize plot
        if np.any(plot_rhat):
            gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1])
            gr_plot = plt.subplot(gs[1])
            gr_plot.set_xticks((1.0, 1.5, 2.0), ("1", "1.5", "2+"))
            gr_plot.set_xlim(0.9, 2.1)
            gr_plot.set_yticks([])
            gr_plot.set_title('R-hat')
        else:
            gs = gridspec.GridSpec(1, 1)

        # Subplot for confidence intervals
        interval_plot = plt.subplot(gs[0])

    trace_quantiles = []
    hpd_intervals = []
    for tr in trace:
        trace_quantiles.append(
            quantiles(tr, qlist, transform=transform, squeeze=False))
        hpd_intervals.append(hpd(tr, alpha, transform=transform,
                                 squeeze=False))

    labels = []
    var = 0
    all_quants = []
    bands = [0.05, 0] * len(varnames)
    var_old = 0.5
    for v_idx, v in enumerate(varnames):
        for h, tr in enumerate(trace):
            if v not in tr.varnames:
                labels.append(models[h] + ' ' + v)
                var += 1
            else:
                for j, chain in enumerate(tr.chains):
                    var_quantiles = trace_quantiles[h][chain][v]

                    quants = [var_quantiles[vq] for vq in qlist]
                    var_hpd = hpd_intervals[h][chain][v].T

                    # Substitute HPD interval for quantile
                    quants[0] = var_hpd[0].T
                    quants[-1] = var_hpd[1].T

                    # Ensure x-axis contains range of current interval
                    all_quants.extend(np.ravel(quants))

                    # Number of elements in current variable
                    value = tr.get_values(v, chains=[chain])[0]
                    k = np.size(value)

                    # Append variable name(s) to list
                    if j == 0:
                        if k > 1:
                            names = _var_str(v, np.shape(value))
                            names[0] = models[h] + ' ' + names[0]
                            labels += names
                        else:
                            labels.append(models[h] + ' ' + v)

                    # Add spacing for each chain, if more than one
                    offset = [0] + [(chain_spacing * ((i + 2) / 2)) * (-1)**i
                                    for i in range(nchains[h] - 1)]

                    # Y coordinate with offset
                    y = -var + offset[j]

                    # Deal with multivariate nodes

                    if k > 1:
                        qs = np.moveaxis(np.array(quants), 0, -1).squeeze()
                        for q in qs.reshape(-1, len(quants)):
                            # Multiple y values
                            interval_plot = _plot_tree(interval_plot, y, q,
                                                       quartiles, colors[h],
                                                       plot_kwargs)
                            y -= 1
                    else:
                        interval_plot = _plot_tree(interval_plot, y, quants,
                                                   quartiles, colors[h],
                                                   plot_kwargs)

                # Genenerate Gelman-Rubin plot
                if plot_rhat[h] and v in tr.varnames:
                    R = gelman_rubin(tr, [v])
                    if k > 1:
                        Rval = dict2pd(R, 'rhat').values
                        gr_plot.plot([min(r, 2) for r in Rval],
                                     [-(j + var) for j in range(k)],
                                     'o',
                                     color=colors[h],
                                     markersize=4)
                    else:
                        gr_plot.plot(min(R[v], 2),
                                     -var,
                                     'o',
                                     color=colors[h],
                                     markersize=4)
                var += k

        if len(trace) > 1:
            interval_plot.axhspan(var_old,
                                  y - chain_spacing - 0.5,
                                  facecolor='k',
                                  alpha=bands[v_idx])
            gr_plot.axhspan(var_old,
                            y - chain_spacing - 0.5,
                            facecolor='k',
                            alpha=bands[v_idx])
            var_old = y - chain_spacing - 0.5

    if ylabels is not None:
        labels = ylabels

    # Update margins
    left_margin = np.max([len(x) for x in labels]) * 0.015
    gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05)

    # Define range of y-axis for forestplot and R-hat
    interval_plot.set_ylim(-var + 0.5, 0.5)
    if np.any(plot_rhat):
        gr_plot.set_ylim(-var + 0.5, 0.5)

    plotrange = [np.min(all_quants), np.max(all_quants)]
    datarange = plotrange[1] - plotrange[0]
    interval_plot.set_xlim(plotrange[0] - 0.05 * datarange,
                           plotrange[1] + 0.05 * datarange)

    # Add variable labels
    interval_plot.set_yticks([-l for l in range(len(labels))])
    interval_plot.set_yticklabels(labels,
                                  fontsize=plot_kwargs.get('fontsize', None))

    # Add title
    if main is None:
        plot_title = "{:.0f}% Credible Intervals".format((1 - alpha) * 100)
    elif main:
        plot_title = main
    else:
        plot_title = ""

    interval_plot.set_title(plot_title,
                            fontsize=plot_kwargs.get('fontsize', None))

    # Add x-axis label
    if xtitle is not None:
        interval_plot.set_xlabel(xtitle)

    # Constrain to specified range
    if xlim is not None:
        interval_plot.set_xlim(*xlim)

    # Remove ticklines on y-axes
    for ticks in interval_plot.yaxis.get_major_ticks():
        ticks.tick1On = False
        ticks.tick2On = False

    for loc, spine in interval_plot.spines.items():
        if loc in ['left', 'right']:
            spine.set_color('none')  # don't draw spine

    # Reference line
    interval_plot.axvline(vline, color='k', linestyle=':')

    return gs
        trace_model_0.append(trace_tuned[i])
    elif trace_tuned[i,0] == 1:
        trace_model_1.append(trace_tuned[i])
        
        
trace_model_0=np.asarray(trace_model_0)
trace_model_1=np.asarray(trace_model_1)
f = plt.figure()
plt.scatter(trace_model_0[::10,1],trace_model_0[::10,2],s=1,label='Model 0',marker=',',alpha=0.7)
plt.scatter(trace_model_1[::10,1],trace_model_1[::10,2],s=1,label='Model 1',marker=',',alpha=0.7)
plt.legend()
plt.show()

plt.plot(trace[::500,1],trace[::500,2])
plt.show()
map_x_0=hpd(trace_model_0[:,1],alpha=0.05)
map_x_1=hpd(trace_model_1[:,1],alpha=0.05)
map_y_0=hpd(trace_model_0[:,2],alpha=0.05)
map_y_1=hpd(trace_model_1[:,2],alpha=0.05)
CI_x_0=map_x_0[1]-map_x_0[0]
CI_x_1=map_x_1[1]-map_x_1[0]
CI_y_0=map_y_0[1]-map_y_0[0]
CI_y_1=map_y_1[1]-map_y_1[0]

#trace_model_0_subsample=trace_model_0[::1000]
#trace_model_1_subsample=trace_model_1[::1000]
#trace_subsample=trace_tuned[::1000]
#Try subsampling to make the graphs look better.
plt.hist(trace_model_0[:,1],bins=100,label='x values Model 0',density=True)

Пример #13
0
def forestplot(trace, models=None, varnames=None, transform=identity_transform,
               alpha=0.05, quartiles=True, rhat=True, main=None, xtitle=None,
               xlim=None, ylabels=None, colors='C0', chain_spacing=0.1, vline=0,
               gs=None, plot_transformed=False, plot_kwargs=None):
    """
    Forest plot (model summary plot).

    Generates a "forest plot" of 100*(1-alpha)% credible intervals from a trace
    or list of traces.

    Parameters
    ----------

    trace : trace or list of traces
        Trace(s) from an MCMC sample.
    models : list (optional)
        List with names for the models in the list of traces. Useful when
        plotting more that one trace.
    varnames: list
        List of variables to plot (defaults to None, which results in all
        variables plotted).
    transform : callable
        Function to transform data (defaults to identity)
    alpha : float, optional
        Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05).
    quartiles : bool, optional
        Flag for plotting the interquartile range, in addition to the
        (1-alpha)*100% intervals (defaults to True).
    rhat : bool, optional
        Flag for plotting Gelman-Rubin statistics. Requires 2 or more chains
        (defaults to True).
    main : string, optional
        Title for main plot. Passing False results in titles being suppressed;
        passing None (default) results in default titles.
    xtitle : string, optional
        Label for x-axis. Defaults to no label
    xlim : list or tuple, optional
        Range for x-axis. Defaults to matplotlib's best guess.
    ylabels : list or array, optional
        User-defined labels for each variable. If not provided, the node
        __name__ attributes are used.
    colors : list or string, optional
        list with valid matplotlib colors, one color per model. Alternative a
        string can be passed. If the string is `cycle `, it will automatically
        chose a color per model from the matyplolib's cycle. If a single color
        is passed, eg 'k', 'C2', 'red' this color will be used for all models.
        Defauls to 'C0' (blueish in most matplotlib styles)
    chain_spacing : float, optional
        Plot spacing between chains (defaults to 0.1).
    vline : numeric, optional
        Location of vertical reference line (defaults to 0).
    gs : GridSpec
        Matplotlib GridSpec object. Defaults to None.
    plot_transformed : bool
        Flag for plotting automatically transformed variables in addition to
        original variables (defaults to False).
    plot_kwargs : dict
        Optional arguments for plot elements. Currently accepts 'fontsize',
        'linewidth', 'marker', and 'markersize'.

    Returns
    -------

    gs : matplotlib GridSpec

    """
    if plot_kwargs is None:
        plot_kwargs = {}

    if not isinstance(trace, (list, tuple)):
        trace = [trace]

    if models is None:
        if len(trace) > 1:
            models = ['m_{}'.format(i) for i in range(len(trace))]
        else:
            models = ['']
    elif len(models) != len(trace):
        raise ValueError("The number of names for the models does not match "
                         "the number of models")

    if colors == 'cycle':
        colors = ['C{}'.format(i % 10) for i in range(len(models))]
    elif isinstance(colors, str):
        colors = [colors for i in range(len(models))]

    # Quantiles to be calculated
    if quartiles:
        qlist = [100 * alpha / 2, 25, 50, 75, 100 * (1 - alpha / 2)]
    else:
        qlist = [100 * alpha / 2, 50, 100 * (1 - alpha / 2)]

    nchains = [tr.nchains for tr in trace]

    if varnames is None:
        varnames = []
        for idx, tr in enumerate(trace):
            varnames_tmp = get_default_varnames(tr.varnames, plot_transformed)
            for v in varnames_tmp:
                if v not in varnames:
                    varnames.append(v)

    plot_rhat = [rhat and nch > 1 for nch in nchains]
    # Empty list for y-axis labels
    if gs is None:
        # Initialize plot
        if np.any(plot_rhat):
            gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1])
            gr_plot = plt.subplot(gs[1])
            gr_plot.set_xticks((1.0, 1.5, 2.0), ("1", "1.5", "2+"))
            gr_plot.set_xlim(0.9, 2.1)
            gr_plot.set_yticks([])
            gr_plot.set_title('R-hat')
        else:
            gs = gridspec.GridSpec(1, 1)

    # Subplot for confidence intervals
    interval_plot = plt.subplot(gs[0])

    trace_quantiles = []
    hpd_intervals = []
    for tr in trace:
        trace_quantiles.append(quantiles(tr, qlist, transform=transform,
                                         squeeze=False))
        hpd_intervals.append(hpd(tr, alpha, transform=transform,
                                 squeeze=False))

    labels = []
    var = 0
    all_quants = []
    bands = [0.05, 0] * len(varnames)
    var_old = 0.5
    for v_idx, v in enumerate(varnames):
        for h, tr in enumerate(trace):
            if v not in tr.varnames:
                labels.append(models[h] + ' ' + v)
                var += 1
            else:
                for j, chain in enumerate(tr.chains):
                    var_quantiles = trace_quantiles[h][chain][v]

                    quants = [var_quantiles[vq] for vq in qlist]
                    var_hpd = hpd_intervals[h][chain][v].T

                    # Substitute HPD interval for quantile
                    quants[0] = var_hpd[0].T
                    quants[-1] = var_hpd[1].T

                    # Ensure x-axis contains range of current interval
                    all_quants.extend(np.ravel(quants))

                    # Number of elements in current variable
                    value = tr.get_values(v, chains=[chain])[0]
                    k = np.size(value)

                    # Append variable name(s) to list
                    if j == 0:
                        if k > 1:
                            names = _var_str(v, np.shape(value))
                            names[0] = models[h] + ' ' + names[0]
                            labels += names
                        else:
                            labels.append(models[h] + ' ' + v)

                    # Add spacing for each chain, if more than one
                    offset = [0] + [(chain_spacing * ((i + 2) / 2)) *
                                    (-1) ** i for i in range(nchains[h] - 1)]

                    # Y coordinate with offset
                    y = - var + offset[j]

                    # Deal with multivariate nodes

                    if k > 1:
                        qs = np.moveaxis(np.array(quants), 0, -1).squeeze()
                        for q in qs.reshape(-1, len(quants)):
                            # Multiple y values
                            interval_plot = _plot_tree(interval_plot, y, q,
                                                       quartiles, colors[h],
                                                       plot_kwargs)
                            y -= 1
                    else:
                        interval_plot = _plot_tree(interval_plot, y, quants,
                                                   quartiles, colors[h],
                                                   plot_kwargs)

                # Genenerate Gelman-Rubin plot
                if plot_rhat[h] and v in tr.varnames:
                    R = gelman_rubin(tr, [v])
                    if k > 1:
                        Rval = dict2pd(R, 'rhat').values
                        gr_plot.plot([min(r, 2) for r in Rval],
                                     [-(j + var) for j in range(k)], 'o',
                                     color=colors[h], markersize=4)
                    else:
                        gr_plot.plot(min(R[v], 2), -var, 'o', color=colors[h],
                                     markersize=4)
                var += k

        if len(trace) > 1:
            interval_plot.axhspan(var_old, y - chain_spacing - 0.5,
                                  facecolor='k', alpha=bands[v_idx])
            if np.any(plot_rhat):
                gr_plot.axhspan(var_old, y - chain_spacing - 0.5,
                                facecolor='k', alpha=bands[v_idx])
            var_old = y - chain_spacing - 0.5

    if ylabels is not None:
        labels = ylabels

    # Update margins
    left_margin = np.max([len(x) for x in labels]) * 0.015
    gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05)

    # Define range of y-axis for forestplot and R-hat
    interval_plot.set_ylim(- var + 0.5, 0.5)
    if np.any(plot_rhat):
        gr_plot.set_ylim(- var + 0.5, 0.5)

    plotrange = [np.min(all_quants), np.max(all_quants)]
    datarange = plotrange[1] - plotrange[0]
    interval_plot.set_xlim(plotrange[0] - 0.05 * datarange,
                           plotrange[1] + 0.05 * datarange)

    # Add variable labels
    interval_plot.set_yticks([- l for l in range(len(labels))])
    interval_plot.set_yticklabels(labels,
                                  fontsize=plot_kwargs.get('fontsize', None))

    # Add title
    if main is None:
        plot_title = "{:.0f}% Credible Intervals".format((1 - alpha) * 100)
    elif main:
        plot_title = main
    else:
        plot_title = ""

    interval_plot.set_title(plot_title,
                            fontsize=plot_kwargs.get('fontsize', None))

    # Add x-axis label
    if xtitle is not None:
        interval_plot.set_xlabel(xtitle)

    # Constrain to specified range
    if xlim is not None:
        interval_plot.set_xlim(*xlim)

    # Remove ticklines on y-axes
    for ticks in interval_plot.yaxis.get_major_ticks():
        ticks.tick1On = False
        ticks.tick2On = False

    for loc, spine in interval_plot.spines.items():
        if loc in ['left', 'right']:
            spine.set_color('none')  # don't draw spine

    # Reference line
    interval_plot.axvline(vline, color='k', linestyle=':')

    return gs