コード例 #1
0
ファイル: Matplot.py プロジェクト: aflaxman/gbd
def histogram(data, name, nbins=None, datarange=(None, None), format='png', suffix='', path='./', rows=1, columns=1, num=1, last=True, fontmap = {1:10, 2:8, 3:6, 4:5, 5:4}, verbose=1):

    # Internal histogram specification for handling nested arrays
    try:

        # Stand-alone plot or subplot?
        standalone = rows==1 and columns==1 and num==1
        if standalone:
            if verbose>0:
                print 'Generating histogram of', name
            figure()

        subplot(rows, columns, num)

        #Specify number of bins (10 as default)
        uniquevals = len(unique(data))
        nbins = nbins or uniquevals*(uniquevals<=25) or int(4 + 1.5*log(len(data)))

        # Generate histogram
        hist(data.tolist(), nbins, histtype='stepfilled')

        xlim(datarange)

        # Plot options
        title('\n\n   %s hist'%name, x=0., y=1., ha='left', va='top', fontsize='medium')

        ylabel("Frequency", fontsize='x-small')

        # Plot vertical lines for median and 95% HPD interval
        quant = calc_quantiles(data)
        axvline(x=quant[50], linewidth=2, color='black')
        for q in hpd(data, 0.05):
            axvline(x=q, linewidth=2, color='grey', linestyle='dotted')

        # Smaller tick labels
        tlabels = gca().get_xticklabels()
        setp(tlabels, 'fontsize', fontmap[rows])
        tlabels = gca().get_yticklabels()
        setp(tlabels, 'fontsize', fontmap[rows])

        if standalone:
            if not os.path.exists(path):
                os.mkdir(path)
            if not path.endswith('/'):
                path += '/'
            # Save to file
            savefig("%s%s%s.%s" % (path, name, suffix, format))
            #close()

    except OverflowError:
        print '... cannot generate histogram'
コード例 #2
0
ファイル: bmsplot.py プロジェクト: Jimmy-INL/msd
            ax.acorr(mcmc_trace[-1][ck],
                     color='crimson',
                     detrend=ml.detrend_mean,
                     linestyle='-',
                     linewidth=1.5,
                     maxlags=acorr_maxlags)
            ax.set_xlim(-acorr_maxlags, acorr_maxlags)
            ax.set_ylim(-0.1, 1.1)
            ax.set_ylabel(C_str[ck], rotation='horizontal')
            if (i == 0):
                ax.set_title("Autocorrelation (detrended)")

            ax = AxesArr[i, 2]
            ax.grid(color='lightgrey', linestyle=':')
            # Calculate the median and 95% Highest Probability Density (HPD) or minimum width Bayesian Confidence (BCI) interval
            hist_quant = calc_quantiles(mcmc_trace[-1][ck])
            hist_hpd = calc_hpd(mcmc_trace[-1][ck], hist_hpd_alpha)
            (hist_n, hist_bins, hist_patches) = ax.hist(mcmc_trace[-1][ck],
                                                        bins=hist_num_bins,
                                                        color='steelblue',
                                                        histtype='stepfilled',
                                                        linewidth=0.0,
                                                        normed=True,
                                                        zorder=2)
            ax.set_ylim(0.0, max(hist_n) * 1.1)
            ax.axvspan(hist_hpd[0],
                       hist_hpd[1],
                       alpha=0.25,
                       facecolor='darkslategray',
                       linewidth=1.5)
            ax.axvline(hist_quant[50],
コード例 #3
0
ファイル: Matplot.py プロジェクト: studentmicky/gbd
def histogram(data,
              name,
              nbins=None,
              datarange=(None, None),
              format='png',
              suffix='',
              path='./',
              rows=1,
              columns=1,
              num=1,
              last=True,
              fontmap={
                  1: 10,
                  2: 8,
                  3: 6,
                  4: 5,
                  5: 4
              },
              verbose=1):

    # Internal histogram specification for handling nested arrays
    try:

        # Stand-alone plot or subplot?
        standalone = rows == 1 and columns == 1 and num == 1
        if standalone:
            if verbose > 0:
                print 'Generating histogram of', name
            figure()

        subplot(rows, columns, num)

        #Specify number of bins (10 as default)
        uniquevals = len(unique(data))
        nbins = nbins or uniquevals * (uniquevals <= 25) or int(4 + 1.5 *
                                                                log(len(data)))

        # Generate histogram
        hist(data.tolist(), nbins, histtype='stepfilled')

        xlim(datarange)

        # Plot options
        title('\n\n   %s hist' % name,
              x=0.,
              y=1.,
              ha='left',
              va='top',
              fontsize='medium')

        ylabel("Frequency", fontsize='x-small')

        # Plot vertical lines for median and 95% HPD interval
        quant = calc_quantiles(data)
        axvline(x=quant[50], linewidth=2, color='black')
        for q in hpd(data, 0.05):
            axvline(x=q, linewidth=2, color='grey', linestyle='dotted')

        # Smaller tick labels
        tlabels = gca().get_xticklabels()
        setp(tlabels, 'fontsize', fontmap[rows])
        tlabels = gca().get_yticklabels()
        setp(tlabels, 'fontsize', fontmap[rows])

        if standalone:
            if not os.path.exists(path):
                os.mkdir(path)
            if not path.endswith('/'):
                path += '/'
            # Save to file
            savefig("%s%s%s.%s" % (path, name, suffix, format))
            #close()

    except OverflowError:
        print '... cannot generate histogram'
コード例 #4
0
ファイル: Matplot.py プロジェクト: studentmicky/gbd
def summary_plot(pymc_obj,
                 name='model',
                 format='png',
                 suffix='-summary',
                 path='./',
                 alpha=0.05,
                 quartiles=True,
                 rhat=True,
                 main=None,
                 chain_spacing=0.05,
                 vline_pos=0):
    """
    Model summary plot
    
    :Arguments:
        pymc_obj: PyMC object, trace or array
            A trace from an MCMC sample or a PyMC object with one or more traces.

        name (optional): string
            The name of the object.

        format (optional): string
            Graphic output format (defaults to png).

        suffix (optional): string
            Filename suffix.

        path (optional): string
            Specifies location for saving plots (defaults to local directory).
            
        alpha (optional): float
            Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05).
            
        rhat (optional): bool
            Flag for plotting Gelman-Rubin statistics. Requires 2 or more 
            chains (defaults to True).
            
        main (optional): string
            Title for main plot. Passing False results in titles being 
            suppressed; passing False (default) results in default titles.
            
        chain_spacing (optional): float
            Plot spacing between chains (defaults to 0.05).
            
        vline_pos (optional): numeric
            Location of vertical reference line (defaults to 0).
    
    """

    if not gridspec:
        print '\nYour installation of matplotlib is not recent enough to support summary_plot; this function is disabled until matplotlib is updated.'
        return

    # Quantiles to be calculated
    quantiles = [100 * alpha / 2, 50, 100 * (1 - alpha / 2)]
    if quartiles:
        quantiles = [100 * alpha / 2, 25, 50, 75, 100 * (1 - alpha / 2)]

    # Range for x-axis
    plotrange = None

    # Number of chains
    chains = None

    # Gridspec
    gs = None

    # Subplots
    interval_plot = None
    rhat_plot = None

    try:
        # First try Model type
        vars = pymc_obj._variables_to_tally

    except AttributeError:

        try:

            # Try a database object
            vars = pymc_obj._traces

        except AttributeError:

            # Assume an iterable
            vars = pymc_obj

    # Empty list for y-axis labels
    labels = []
    # Counter for current variable
    var = 1

    # Make sure there is something to print
    if all([v._plot == False for v in vars]):
        print 'No variables to plot'
        return

    for variable in vars:

        # If plot flag is off, do not print
        if variable._plot == False:
            continue

        # Extract name
        varname = variable.__name__

        # Retrieve trace(s)
        i = 0
        traces = []
        while True:
            try:
                #traces.append(pymc_obj.trace(varname, chain=i)[:])
                traces.append(variable.trace(chain=i))
                i += 1
            except (KeyError, IndexError):
                break

        chains = len(traces)

        if gs is None:
            # Initialize plot
            if rhat and chains > 1:
                gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1])

            else:

                gs = gridspec.GridSpec(1, 1)

            # Subplot for confidence intervals
            interval_plot = subplot(gs[0])

        # Get quantiles
        data = [calc_quantiles(d, quantiles) for d in traces]
        data = [[d[q] for q in quantiles] for d in data]

        # Ensure x-axis contains range of current interval
        if plotrange:
            plotrange = [
                min(plotrange[0], nmin(data)),
                max(plotrange[1], nmax(data))
            ]
        else:
            plotrange = [nmin(data), nmax(data)]

        try:
            # First try missing-value stochastic
            value = variable.get_stoch_value()
        except AttributeError:
            # All other variable types
            value = variable.value

        # Number of elements in current variable
        k = size(value)

        # Append variable name(s) to list
        if k > 1:
            names = var_str(varname, shape(value))
            labels += names
        else:
            labels.append(varname)
            #labels.append('\n'.join(varname.split('_')))

        # Add spacing for each chain, if more than one
        e = [0] + [(chain_spacing * ((i + 2) / 2)) * (-1)**i
                   for i in range(chains - 1)]

        # Loop over chains
        for j, quants in enumerate(data):

            # Deal with multivariate nodes
            if k > 1:

                for i, q in enumerate(transpose(quants)):

                    # Y coordinate with jitter
                    y = -(var + i) + e[j]

                    if quartiles:
                        # Plot median
                        pyplot(q[2], y, 'bo', markersize=4)
                        # Plot quartile interval
                        errorbar(x=(q[1], q[3]),
                                 y=(y, y),
                                 linewidth=2,
                                 color="blue")

                    else:
                        # Plot median
                        pyplot(q[1], y, 'bo', markersize=4)

                    # Plot outer interval
                    errorbar(x=(q[0], q[-1]),
                             y=(y, y),
                             linewidth=1,
                             color="blue")

            else:

                # Y coordinate with jitter
                y = -var + e[j]

                if quartiles:
                    # Plot median
                    pyplot(quants[2], y, 'bo', markersize=4)
                    # Plot quartile interval
                    errorbar(x=(quants[1], quants[3]),
                             y=(y, y),
                             linewidth=2,
                             color="blue")
                else:
                    # Plot median
                    pyplot(quants[1], y, 'bo', markersize=4)

                # Plot outer interval
                errorbar(x=(quants[0], quants[-1]),
                         y=(y, y),
                         linewidth=1,
                         color="blue")

        # Increment index
        var += k

    # Update margins
    left_margin = max([len(x) for x in labels]) * 0.015
    gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05)

    # Define range of y-axis
    ylim(-var + 0.5, -0.5)

    datarange = plotrange[1] - plotrange[0]
    xlim(plotrange[0] - 0.05 * datarange, plotrange[1] + 0.05 * datarange)

    # Add variable labels
    ylabels = yticks([-(l + 1) for l in range(len(labels))], labels)

    # Add title
    if main is not False:
        plot_title = main or str(int(
            (1 - alpha) * 100)) + "% Credible Intervals"
        title(plot_title)

    # Remove ticklines on y-axes
    for ticks in interval_plot.yaxis.get_major_ticks():
        ticks.tick1On = False
        ticks.tick2On = False

    for loc, spine in interval_plot.spines.iteritems():
        if loc in ['bottom', 'top']:
            pass
            #spine.set_position(('outward',10)) # outward by 10 points
        elif loc in ['left', 'right']:
            spine.set_color('none')  # don't draw spine

    # Reference line
    axvline(vline_pos, color='k', linestyle='--')

    # Genenerate Gelman-Rubin plot
    if rhat and chains > 1:

        from diagnostics import gelman_rubin

        # If there are multiple chains, calculate R-hat
        rhat_plot = subplot(gs[1])

        if main is not False:
            title("R-hat")

        # Set x range
        xlim(0.9, 2.1)

        # X axis labels
        xticks((1.0, 1.5, 2.0), ("1", "1.5", "2+"))
        yticks([-(l + 1) for l in range(len(labels))], "")

        # Calculate diagnostic
        try:
            R = gelman_rubin(pymc_obj)
        except ValueError:
            R = {}
            for variable in vars:
                R[variable.__name__] = gelman_rubin(variable)

        i = 1
        for variable in vars:

            if variable._plot == False:
                continue

            # Extract name
            varname = variable.__name__

            try:
                value = variable.get_stoch_value()
            except AttributeError:
                value = variable.value

            k = size(value)

            if k > 1:
                pyplot([min(r, 2) for r in R[varname]],
                       [-(j + i) for j in range(k)],
                       'bo',
                       markersize=4)
            else:
                pyplot(min(R[varname], 2), -i, 'bo', markersize=4)

            i += k

        # Define range of y-axis
        ylim(-i + 0.5, -0.5)

        # Remove ticklines on y-axes
        for ticks in rhat_plot.yaxis.get_major_ticks():
            ticks.tick1On = False
            ticks.tick2On = False

        for loc, spine in rhat_plot.spines.iteritems():
            if loc in ['bottom', 'top']:
                pass
                #spine.set_position(('outward',10)) # outward by 10 points
            elif loc in ['left', 'right']:
                spine.set_color('none')  # don't draw spine

    savefig("%s%s%s.%s" % (path, name, suffix, format))
コード例 #5
0
ファイル: Matplot.py プロジェクト: aflaxman/gbd
def summary_plot(pymc_obj, name='model', format='png',  suffix='-summary', path='./', alpha=0.05, quartiles=True, rhat=True, main=None, chain_spacing=0.05, vline_pos=0):
    """
    Model summary plot
    
    :Arguments:
        pymc_obj: PyMC object, trace or array
            A trace from an MCMC sample or a PyMC object with one or more traces.

        name (optional): string
            The name of the object.

        format (optional): string
            Graphic output format (defaults to png).

        suffix (optional): string
            Filename suffix.

        path (optional): string
            Specifies location for saving plots (defaults to local directory).
            
        alpha (optional): float
            Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05).
            
        rhat (optional): bool
            Flag for plotting Gelman-Rubin statistics. Requires 2 or more 
            chains (defaults to True).
            
        main (optional): string
            Title for main plot. Passing False results in titles being 
            suppressed; passing False (default) results in default titles.
            
        chain_spacing (optional): float
            Plot spacing between chains (defaults to 0.05).
            
        vline_pos (optional): numeric
            Location of vertical reference line (defaults to 0).
    
    """
    
    if not gridspec:
        print '\nYour installation of matplotlib is not recent enough to support summary_plot; this function is disabled until matplotlib is updated.'
        return
    
    # Quantiles to be calculated
    quantiles = [100*alpha/2, 50, 100*(1-alpha/2)]
    if quartiles:
        quantiles = [100*alpha/2, 25, 50, 75, 100*(1-alpha/2)]

    # Range for x-axis
    plotrange = None
    
    # Number of chains
    chains = None
    
    # Gridspec
    gs = None
    
    # Subplots
    interval_plot = None
    rhat_plot = None
    
    try:
        # First try Model type
        vars = pymc_obj._variables_to_tally
        
    except AttributeError:
        
        try:
            
            # Try a database object
            vars = pymc_obj._traces
        
        except AttributeError:
            
            # Assume an iterable
            vars = pymc_obj

    
    # Empty list for y-axis labels
    labels = []
    # Counter for current variable
    var = 1
    
    # Make sure there is something to print
    if all([v._plot==False for v in vars]):
        print 'No variables to plot'
        return
    
    for variable in vars:

        # If plot flag is off, do not print
        if variable._plot==False:
            continue
            
        # Extract name
        varname = variable.__name__

        # Retrieve trace(s)
        i = 0
        traces = []
        while True:
           try:
               #traces.append(pymc_obj.trace(varname, chain=i)[:])
               traces.append(variable.trace(chain=i))
               i+=1
           except (KeyError, IndexError):
               break
               
        chains = len(traces)
        
        if gs is None:
            # Initialize plot
            if rhat and chains>1:
                gs = gridspec.GridSpec(1, 2, width_ratios=[3,1])

            else:
                
                gs = gridspec.GridSpec(1, 1)
                
            # Subplot for confidence intervals
            interval_plot = subplot(gs[0])
                
        # Get quantiles
        data = [calc_quantiles(d, quantiles) for d in traces]
        data = [[d[q] for q in quantiles] for d in data]
        
        # Ensure x-axis contains range of current interval
        if plotrange:
            plotrange = [min(plotrange[0], nmin(data)), max(plotrange[1], nmax(data))]
        else:
            plotrange = [nmin(data), nmax(data)]
        
        try:
            # First try missing-value stochastic
            value = variable.get_stoch_value()
        except AttributeError:
            # All other variable types
            value = variable.value

        # Number of elements in current variable
        k = size(value)
        
        # Append variable name(s) to list
        if k>1:
            names = var_str(varname, shape(value))
            labels += names
        else:
            labels.append(varname)
            #labels.append('\n'.join(varname.split('_')))
            
        # Add spacing for each chain, if more than one
        e = [0] + [(chain_spacing * ((i+2)/2))*(-1)**i for i in range(chains-1)]
        
        # Loop over chains
        for j,quants in enumerate(data):
            
            # Deal with multivariate nodes
            if k>1:

                for i,q in enumerate(transpose(quants)):
                    
                    # Y coordinate with jitter
                    y = -(var+i) + e[j]
                    
                    if quartiles:
                        # Plot median
                        pyplot(q[2], y, 'bo', markersize=4)
                        # Plot quartile interval
                        errorbar(x=(q[1],q[3]), y=(y,y), linewidth=2, color="blue")
                        
                    else:
                        # Plot median
                        pyplot(q[1], y, 'bo', markersize=4)

                    # Plot outer interval
                    errorbar(x=(q[0],q[-1]), y=(y,y), linewidth=1, color="blue")

            else:
                
                # Y coordinate with jitter
                y = -var + e[j]
                
                if quartiles:
                    # Plot median
                    pyplot(quants[2], y, 'bo', markersize=4)
                    # Plot quartile interval
                    errorbar(x=(quants[1],quants[3]), y=(y,y), linewidth=2, color="blue")
                else:
                    # Plot median
                    pyplot(quants[1], y, 'bo', markersize=4)
                
                # Plot outer interval
                errorbar(x=(quants[0],quants[-1]), y=(y,y), linewidth=1, color="blue")
            
        # Increment index
        var += k
        
    # Update margins
    left_margin = max([len(x) for x in labels])*0.015
    gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05)
        
    # Define range of y-axis
    ylim(-var+0.5, -0.5)
    
    datarange = plotrange[1] - plotrange[0]
    xlim(plotrange[0] - 0.05*datarange, plotrange[1] + 0.05*datarange)
    
    # Add variable labels
    ylabels = yticks([-(l+1) for l in range(len(labels))], labels)        
            
    # Add title
    if main is not False:
        plot_title = main or str(int((1-alpha)*100)) + "% Credible Intervals"
        title(plot_title)
    
    # Remove ticklines on y-axes
    for ticks in interval_plot.yaxis.get_major_ticks():
        ticks.tick1On = False
        ticks.tick2On = False
    
    for loc, spine in interval_plot.spines.iteritems():
        if loc in ['bottom','top']:
            pass
            #spine.set_position(('outward',10)) # outward by 10 points
        elif loc in ['left','right']:
            spine.set_color('none') # don't draw spine
      
    # Reference line
    axvline(vline_pos, color='k', linestyle='--')  
        
    # Genenerate Gelman-Rubin plot
    if rhat and chains>1:

        from diagnostics import gelman_rubin
        
        # If there are multiple chains, calculate R-hat
        rhat_plot = subplot(gs[1])
        
        if main is not False:
            title("R-hat")
        
        # Set x range
        xlim(0.9,2.1)
        
        # X axis labels
        xticks((1.0,1.5,2.0), ("1", "1.5", "2+"))
        yticks([-(l+1) for l in range(len(labels))], "")
        
        # Calculate diagnostic
        try:
            R = gelman_rubin(pymc_obj)
        except ValueError:
            R = {}
            for variable in vars:
                R[variable.__name__] = gelman_rubin(variable)
        
        i = 1
        for variable in vars:
            
            if variable._plot==False:
                continue
            
            # Extract name
            varname = variable.__name__
            
            try:
                value = variable.get_stoch_value()
            except AttributeError:
                value = variable.value
                
            k = size(value)
            
            if k>1:
                pyplot([min(r, 2) for r in R[varname]], [-(j+i) for j in range(k)], 'bo', markersize=4)
            else:
                pyplot(min(R[varname], 2), -i, 'bo', markersize=4)
    
            i += k
            
        # Define range of y-axis
        ylim(-i+0.5, -0.5)
        
        # Remove ticklines on y-axes
        for ticks in rhat_plot.yaxis.get_major_ticks():
            ticks.tick1On = False
            ticks.tick2On = False
        
        for loc, spine in rhat_plot.spines.iteritems():
            if loc in ['bottom','top']:
                pass
                #spine.set_position(('outward',10)) # outward by 10 points
            elif loc in ['left','right']:
                spine.set_color('none') # don't draw spine
        
    savefig("%s%s%s.%s" % (path, name, suffix, format))