예제 #1
0
def plot_logp(sol, save=False, draw=True, save_as_png=False, dpi=None):
    """
    Plots the model log-likelihood
    """
    ext = ['png' if save_as_png else 'pdf'][0]
    fig, ax = plt.subplots(figsize=(4, 3))
    logp = logp_trace(sol.MDL)
    sampler_state = sol.MDL.get_state()["sampler"]
    x = np.arange(sampler_state["_burn"] + 1, sampler_state["_iter"] + 1,
                  sampler_state["_thin"])
    plt.plot(x, logp, "-")
    plt.xlabel("Iteration")
    plt.ylabel("Log-likelihood")
    plt.grid('on')
    if sampler_state["_burn"] == 0:
        plt.xscale('log')
    else:
        plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
    ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    fig.tight_layout()

    if save:
        fn = 'LOGP-%s-%s.%s' % (sol.model_type_str, sol.filename, ext)
        save_figure(fig, subfolder='LogLikelihood', fname=fn, dpi=dpi)

    plt.close(fig)
    if draw: return fig
    else: return None
예제 #2
0
def plot_deviance(sol, save=False, draw=True, save_as_png=False, dpi=None):
    """
    Plots the model deviance trace
    """
    ext = ['png' if save_as_png else 'pdf'][0]
    fig, ax = plt.subplots(figsize=(4, 3))
    deviance = sol.MDL.trace('deviance')[:]
    sampler_state = sol.MDL.get_state()["sampler"]
    x = np.arange(sampler_state["_burn"] + 1, sampler_state["_iter"] + 1,
                  sampler_state["_thin"])
    plt.plot(x,
             deviance,
             "-",
             color="C3",
             label="DIC = %d\nBPIC = %d" % (sol.MDL.DIC, sol.MDL.BPIC))
    plt.xlabel("Iteration")
    plt.ylabel("Model deviance")
    plt.legend(numpoints=1, loc="best", fontsize=9)
    plt.grid('on')
    if sampler_state["_burn"] == 0:
        plt.xscale('log')
    else:
        plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
    ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    fig.tight_layout()

    if save:
        fn = 'MDEV-%s-%s.%s' % (sol.model_type_str, sol.filename, ext)
        save_figure(fig, subfolder='ModelDeviance', fname=fn, dpi=dpi)

    plt.close(fig)
    if draw: return fig
    else: return None
예제 #3
0
def plot_logp(sol, save=False, draw=True, save_as_png=False, dpi=None):
    """
    Plots the model log-likelihood
    """
    ext = ['png' if save_as_png else 'pdf'][0]
    fig, ax = plt.subplots(figsize=(4,3))
    logp = logp_trace(sol.MDL)
    sampler_state = sol.MDL.get_state()["sampler"]
    x = np.arange(sampler_state["_burn"]+1, sampler_state["_iter"]+1, sampler_state["_thin"])
    plt.plot(x, logp, "-")
    plt.xlabel("Iteration")
    plt.ylabel("Log-likelihood")
    plt.grid('on')
    if sampler_state["_burn"] == 0:
        plt.xscale('log')
    else:
        plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
    ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    fig.tight_layout()
    
    if save: 
        fn = 'LOGP-%s-%s.%s'%(sol.model_type_str,sol.filename,ext)
        save_figure(fig, subfolder='LogLikelihood', fname=fn, dpi=dpi)

    plt.close(fig)        
    if draw:    return fig
    else:       return None
예제 #4
0
def plot_deviance(sol, save=False, draw=True, save_as_png=False, dpi=None):
    """
    Plots the model deviance trace
    """
    ext = ['png' if save_as_png else 'pdf'][0]
    fig, ax = plt.subplots(figsize=(4,3))
    deviance = sol.MDL.trace('deviance')[:]
    sampler_state = sol.MDL.get_state()["sampler"]
    x = np.arange(sampler_state["_burn"]+1, sampler_state["_iter"]+1, sampler_state["_thin"])
    plt.plot(x, deviance, "-", color="C3", label="DIC = %d\nBPIC = %d" %(sol.MDL.DIC,sol.MDL.BPIC))
    plt.xlabel("Iteration")
    plt.ylabel("Model deviance")
    plt.legend(numpoints=1, loc="best", fontsize=9)
    plt.grid('on')
    if sampler_state["_burn"] == 0:
        plt.xscale('log')
    else:
        plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
    ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    fig.tight_layout()
    
    if save: 
        fn = 'MDEV-%s-%s.%s'%(sol.model_type_str,sol.filename,ext)
        save_figure(fig, subfolder='ModelDeviance', fname=fn, dpi=dpi)

    plt.close(fig)        
    if draw:    return fig
    else:       return None
예제 #5
0
def plot_KDE(sol, var1, var2, fig=None, ax=None, draw=True, save=False, save_as_png=False, dpi=None):
    """
    Like the hexbin plot but a 2D KDE
    Pass mcmcinv object and 2 variable names as strings
    """
    ext = ['png' if save_as_png else 'pdf'][0]
    if fig == None or ax == None:
        fig, ax = plt.subplots(figsize=(3,3))
    MDL = sol.MDL
    if var1 == "R0":
        stoc1 = "R0"
    else:
        stoc1 =  ''.join([i for i in var1 if not i.isdigit()])
        stoc_num1 = [int(i) for i in var1 if i.isdigit()]
    try:
        x = MDL.trace(stoc1)[:,stoc_num1[0]-1]
    except:
        x = MDL.trace(stoc1)[:]
    if var2 == "R0":
        stoc2 = "R0"
    else:
        stoc2 =  ''.join([i for i in var2 if not i.isdigit()])
        stoc_num2 = [int(i) for i in var2 if i.isdigit()]
    try:
        y = MDL.trace(stoc2)[:,stoc_num2[0]-1]
    except:
        y = MDL.trace(stoc2)[:]
    xmin, xmax = min(x), max(x)
    ymin, ymax = min(y), max(y) 
    # Peform the kernel density estimate
    xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
    positions = np.vstack([xx.ravel(), yy.ravel()])
    values = np.vstack([x, y])
    kernel = gaussian_kde(values)
    kernel.set_bandwidth(bw_method='silverman')
#        kernel.set_bandwidth(bw_method=kernel.factor * 2.)
    f = np.reshape(kernel(positions).T, xx.shape)

    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)
    plt.sca(ax)
    # Contourf plot
    plt.grid(None)
    plt.ticklabel_format(style='sci', axis='both', scilimits=(0,0))
    plt.xticks(rotation=90)
    plt.locator_params(axis = 'y', nbins = 7)
    plt.locator_params(axis = 'x', nbins = 7)
    ax.contourf(xx, yy, f, cmap=plt.cm.viridis, alpha=0.8)
    ax.scatter(x, y, color='k', s=1, zorder=2)

    plt.ylabel("%s" %var2)
    plt.xlabel("%s" %var1)

    if save: 
        fn = 'KDE-%s-%s.%s'%(sol.model_type_str,sol.filename,ext)
        save_figure(fig, subfolder='2D-KDE', fname=fn, dpi=dpi)
    
    plt.close(fig)        
    if draw:    return fig
    else:       return None
예제 #6
0
def plot_autocorr(
    sol,
    save=False,
    draw=True,
    save_as_png=False,
    dpi=None,
    ignore=subplots_to_ignore,
):
    """
    Plots autocorrelations
    """
    ext = ['png' if save_as_png else 'pdf'][0]
    MDL = sol.MDL

    keys = [k for k in sol.var_dict.keys() if k not in ignore]

    for (i, k) in enumerate(keys):
        vect = old_div((MDL.trace(k)[:].size), (len(MDL.trace(k)[:])))
        if vect > 1:
            keys[i] = [k + "%d" % n for n in range(1, vect + 1)]
    keys = list(flatten(keys))
    ncols = 2
    nrows = int(ceil(len(keys) * 1.0 / ncols))
    fig, ax = plt.subplots(nrows, ncols, figsize=(10, nrows * 2))
    plt.ticklabel_format(style='sci', axis='both', scilimits=(0, 0))
    for (a, k) in zip(ax.flat, keys):
        if k[-1] not in ["%d" % d for d in range(1, 8)] or k == "R0":
            data = sorted(MDL.trace(k)[:].ravel())
        else:
            data = sorted(MDL.trace(k[:-1])[:][:, int(k[-1]) - 1].ravel())
        plt.sca(a)
        plt.gca().get_yaxis().get_major_formatter().set_useOffset(False)
        plt.gca().get_xaxis().get_major_formatter().set_useOffset(False)
        plt.yticks(fontsize=12)
        plt.xticks(fontsize=12)
        plt.ylabel(k, fontsize=12)
        to_thin = old_div(len(data), 50)
        if to_thin != 0: plt.xlabel("Lags / %d" % to_thin, fontsize=12)
        else: plt.xlabel("Lags", fontsize=12)
        max_lags = None
        if len(data) > 50: data = data[::to_thin]
        plt.acorr(data,
                  usevlines=True,
                  maxlags=max_lags,
                  detrend=plt.mlab.detrend_mean)
        plt.grid(None)
    fig.tight_layout()
    for a in ax.flat[ax.size - 1:len(keys) - 1:-1]:
        a.set_visible(False)

    if save:
        fn = 'AC-%s-%s.%s' % (sol.model_type_str, sol.filename, ext)
        save_figure(fig, subfolder='Autocorrelations', fname=fn, dpi=dpi)

    plt.close(fig)
    if draw: return fig
    else: return None
예제 #7
0
def plot_traces(sol, save=False, draw=True, save_as_png=False, dpi=None, 
                ignore=default_ignore,
                ):
    """
    Plots the traces of stochastic and
    deterministic parameters in mcmcinv object (sol)
    Ignores the ones in list argument ignore
    """
    # Get some settings
    ext = ['png' if save_as_png else 'pdf'][0] # get figure format  
    mcmc = sol.mcmc # get MCMC parameters
    
    # Get all variable names from mcmcinv object
    headers = sorted(sol.trace_dict.keys()) # In alphabetical order
    # Remove unwanted headers
    headers = [h for h in headers if h.strip('0123456789') not in ignore]
    # Extract the needed traces
    traces = [sol.trace_dict[h] for h in headers]
    
    # Subplot settings
    ncols = 2
    nrows = int(ceil(len(headers)*1.0 / ncols))
    fig, ax = plt.subplots(nrows, ncols, figsize=(8, nrows*1.5), sharex=True)

    # Plot traces
    for i in range(len(headers)):
        data = traces[i]
        x = np.arange(mcmc["nb_burn"]+1, mcmc["nb_iter"]+1, mcmc["thin"])
        plt.sca(ax.flat[i])
        plt.ylabel(parlbl_dic[headers[i]])    
        plt.plot(x, data,'-', color='0.8', alpha=1)
        av = np.mean(data)*np.ones(len(x))
        sd = np.std(data)*np.ones(len(x))
        plt.plot(x, av, linestyle='--', linewidth=1.5)
        plt.plot(x, av+sd, color='0.2',linestyle=':', linewidth=1)
        plt.plot(x, av-sd, color='0.2',linestyle=':', linewidth=1)
        if x[0] == 1:
            plt.xscale('log')
        else:
            plt.ticklabel_format(style='sci', axis='x', scilimits=(-1,1))    
    
    for a in ax.flat:
        a.grid(False)
    for a in ax[-1]:
        a.set_xlabel("Iteration number")

    plt.tight_layout(pad=0, w_pad=0.5, h_pad=0.5)

    if save: 
        fn = 'TRA-%s-%s.%s'%(sol.model_type_str,sol.filename,ext)
        save_figure(fig, subfolder='Traces', fname=fn, dpi=dpi)

    plt.close(fig)        
    if draw:    return fig
    else:       return None
예제 #8
0
def plot_histo(sol, save=False, draw=True, save_as_png=False, dpi=None, 
               ignore=default_ignore,
               ):    
    """
    Plots the traces of stochastic and
    deterministic parameters in mcmcinv object (sol)
    Ignores the ones in list argument ignore
    """
    # Get some settings
    ext = ['png' if save_as_png else 'pdf'][0] # get figure format  
    
    # Get all variable names from mcmcinv object
    headers = sorted(sol.trace_dict.keys())
    # Remove unwanted headers
    headers = [h for h in headers if h.strip('0123456789') not in ignore]
    # Extract the needed traces
    traces = [sol.trace_dict[h] for h in headers]
    
    # Subplot settings
    ncols = 2
    nrows = int(ceil(len(headers)*1.0 / ncols))
    fig, ax = plt.subplots(nrows, ncols, figsize=(8,nrows*1.8))

    # Plot histograms
    for i in range(len(headers)):
        data = sorted(traces[i])
        plt.sca(ax.flat[i])
        plt.xlabel(parlbl_dic[headers[i]])
        try:
            hist = plt.hist(data, bins=20, histtype='stepfilled', density=False, linewidth=1.0, color='0.95', alpha=1)
            plt.hist(data, bins=20, histtype='step', density=False, linewidth=1.0, alpha=1)
            fit = norm.pdf(data, np.mean(data), np.std(data))                
            xh = [0.5 * (hist[1][r] + hist[1][r+1]) for r in range(len(hist[1])-1)]
            binwidth = (max(xh) - min(xh)) / len(hist[1])
            fit *= len(data) * binwidth
            plt.plot(data, fit, "-", color='k', linewidth=1)
        except:
            print("File %s: failed to plot %s histogram.\nNot enough accepted moves." %(sol.filename,headers[i]))
        plt.grid(False)
        plt.ticklabel_format(style='sci', axis='both', scilimits=(0,0))
        
    for c in range(nrows):
        ax[c][0].set_ylabel("Frequency")
    for a in ax.flat[ax.size - 1:len(headers) - 1:-1]:
        a.set_visible(False)
    plt.tight_layout(pad=1, w_pad=1, h_pad=0)
        
    if save: 
        fn = 'HST-%s-%s.%s'%(sol.model_type_str,sol.filename,ext)
        save_figure(fig, subfolder='Histograms', fname=fn, dpi=dpi)
    
    plt.close(fig)        
    if draw:    return fig
    else:       return None
예제 #9
0
def plot_hexbin(sol,
                var1,
                var2,
                draw=True,
                save=False,
                save_as_png=False,
                dpi=None):
    """
    Like the 2D KDE plot but a hexbin
    Pass mcmcinv object and 2 variable names as strings
    """
    ext = ['png' if save_as_png else 'pdf'][0]
    MDL = sol.MDL
    if var1 == "R0":
        stoc1 = "R0"
    else:
        stoc1 = ''.join([i for i in var1 if not i.isdigit()])
        stoc_num1 = [int(i) for i in var1 if i.isdigit()]
    try:
        x = MDL.trace(stoc1)[:, stoc_num1[0] - 1]
    except:
        x = MDL.trace(stoc1)[:]
    if var2 == "R0":
        stoc2 = "R0"
    else:
        stoc2 = ''.join([i for i in var2 if not i.isdigit()])
        stoc_num2 = [int(i) for i in var2 if i.isdigit()]
    try:
        y = MDL.trace(stoc2)[:, stoc_num2[0] - 1]
    except:
        y = MDL.trace(stoc2)[:]
    xmin, xmax = min(x), max(x)
    ymin, ymax = min(y), max(y)
    fig, ax = plt.subplots(figsize=(4, 3))
    plt.grid(None)
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)
    plt.hexbin(x, y, gridsize=15, cmap=plt.cm.magma_r)
    plt.ticklabel_format(style='sci', axis='both', scilimits=(0, 0))
    plt.xticks(rotation=90)
    plt.locator_params(axis='y', nbins=5)
    plt.locator_params(axis='x', nbins=5)
    cb = plt.colorbar()
    cb.set_label('Number of observations')
    plt.ylabel("%s" % var2)
    plt.xlabel("%s" % var1)

    if save:
        fn = 'HEX-%s-%s.%s' % (sol.model_type_str, sol.filename, ext)
        save_figure(fig, subfolder='Hexbins', fname=fn, dpi=dpi)

    plt.close(fig)
    if draw: return fig
    else: return None
예제 #10
0
def plot_rtd(sol, save=False, draw=True, save_as_png=False, dpi=None):
    """
    Plots the relaxation time distribution (RTD)
    for a polynomial decomposition or ccdt results
    """
    ext = ['png' if save_as_png else 'pdf'][0]
    fig, ax = plt.subplots(figsize=(4,3))
    try:
        bot95 = 10**sol.MDL.stats()["log_m_i"]['95% HPD interval'][0]
        top95 = 10**sol.MDL.stats()["log_m_i"]['95% HPD interval'][1]
        log_tau = 10**sol.MDL.stats()["log_tau_i"]['mean']
        log_m = 10**sol.MDL.stats()["log_m_i"]['mean']
    except:
        bot95 = sol.MDL.stats()["m_i"]['95% HPD interval'][0]
        top95 = sol.MDL.stats()["m_i"]['95% HPD interval'][1]
        log_tau = 10**sol.MDL.log_tau
        log_m = sol.MDL.stats()["m_i"]['mean']            
    plt.errorbar(log_tau, log_m, None, None, color="C7", linestyle='-', label="RTD")
    try:
        peaks = 10**np.atleast_1d(sol.MDL.stats()["log_peak_tau"]["mean"])
        uncer_peaks = 10**sol.MDL.stats()["log_peak_tau"]['95% HPD interval'].T.reshape(len(np.atleast_1d(sol.MDL.stats()["log_peak_tau"]['mean'])),2)
        m_peaks = log_m[[list(log_tau).index(find_nearest(log_tau, peaks[x])) for x in range(len(peaks))]]
        if len(peaks) >= 1:
            plt.errorbar(peaks, m_peaks*1.2, None, None, color="C3", marker="v", markersize=5, linestyle="", label=r"$\tau_{peak}$")
            for i, u in enumerate(uncer_peaks):
                plt.axvspan(u[0], u[1], alpha=0.2, color="C3")
    except:
        pass
    plt.axvline(10**sol.MDL.stats()["log_half_tau"]['mean'],color="C0",linestyle=':', label=r"$\tau_{50}$")
    plt.axvline(10**sol.MDL.stats()["log_mean_tau"]['mean'],color='C2',linestyle='--', label=r"$\bar{\tau}$")
    inter = 10**sol.MDL.stats()["log_half_tau"]['95% HPD interval']
    plt.axvspan(inter[0], inter[1], alpha=0.2, color="C0")
    inter = 10**sol.MDL.stats()["log_mean_tau"]['95% HPD interval']
    plt.axvspan(inter[0], inter[1], alpha=0.2, color='C2')
    plt.axvspan(min(log_tau), min(log_tau)*10, alpha=0.1, color='C7')
    plt.axvspan(max(log_tau)/10, max(log_tau), alpha=0.1, color='C7')
    plt.fill_between(log_tau, bot95, top95, color="C7", alpha=0.2)
    plt.xlim([10**np.ceil(np.log10(min(log_tau))), 10**np.floor(np.log10(max(log_tau)))])
    ax.set_xlabel(r'$\tau$ (s)')
    ax.set_ylabel(r'$m$')
    plt.grid(False)
    plt.legend(fontsize=9, loc=1,labelspacing=0.2, handlelength=1.5)
    plt.xscale('log')
    plt.yscale('log', nonposy='clip')
    fig.tight_layout()
    if save: 
        fn = 'RTD-%s-%s.%s'%(sol.model_type_str,sol.filename,ext)
        save_figure(fig, subfolder='RTD', fname=fn, dpi=dpi)

    plt.close(fig)        
    if draw:    return fig
    else:       return None
예제 #11
0
def plot_autocorr(sol, save=False, draw=True, save_as_png=False, dpi=None,
                 ignore=default_ignore,
                 ):
    """
    Plots autocorrelations
    """
    ext = ['png' if save_as_png else 'pdf'][0]
    MDL = sol.MDL
    
    keys = [k for k in sol.var_dict.keys() if k not in ignore]

    for (i, k) in enumerate(keys):
        vect = old_div((MDL.trace(k)[:].size),(len(MDL.trace(k)[:])))
        if vect > 1:
         keys[i] = [k+"%d"%n for n in range(1,vect+1)]
    keys = list(flatten(keys))
    ncols = 2
    nrows = int(ceil(len(keys)*1.0 / ncols))
    fig, ax = plt.subplots(nrows, ncols, figsize=(10,nrows*2))
    plt.ticklabel_format(style='sci', axis='both', scilimits=(0,0))
    for (a, k) in zip(ax.flat, keys):
        if k[-1] not in ["%d"%d for d in range(1,8)] or k =="R0":
            data = sorted(MDL.trace(k)[:].ravel())
        else:
            data = sorted(MDL.trace(k[:-1])[:][:,int(k[-1])-1].ravel())
        plt.sca(a)
        plt.gca().get_yaxis().get_major_formatter().set_useOffset(False)
        plt.gca().get_xaxis().get_major_formatter().set_useOffset(False)
        plt.yticks(fontsize=12)
        plt.xticks(fontsize=12)
        plt.ylabel(k, fontsize=12)
        to_thin = old_div(len(data),50)
        if to_thin != 0: plt.xlabel("Lags / %d"%to_thin, fontsize=12)
        else: plt.xlabel("Lags", fontsize=12)
        max_lags = None
        if len(data) > 50: data= data[::to_thin]
        plt.acorr(data, usevlines=True, maxlags=max_lags, detrend=plt.mlab.detrend_mean)
        plt.grid(None)
    fig.tight_layout()
    for a in ax.flat[ax.size - 1:len(keys) - 1:-1]:
        a.set_visible(False)
        
    if save: 
        fn = 'AC-%s-%s.%s'%(sol.model_type_str,sol.filename,ext)
        save_figure(fig, subfolder='Autocorrelations', fname=fn, dpi=dpi)

    plt.close(fig)        
    if draw:    return fig
    else:       return None
예제 #12
0
def plot_hexbin(sol, var1, var2, draw=True, save=False, save_as_png=False, dpi=None):
    """
    Like the 2D KDE plot but a hexbin
    Pass mcmcinv object and 2 variable names as strings
    """
    ext = ['png' if save_as_png else 'pdf'][0]
    MDL = sol.MDL
    if var1 == "R0":
        stoc1 = "R0"
    else:
        stoc1 =  ''.join([i for i in var1 if not i.isdigit()])
        stoc_num1 = [int(i) for i in var1 if i.isdigit()]
    try:
        x = MDL.trace(stoc1)[:,stoc_num1[0]-1]
    except:
        x = MDL.trace(stoc1)[:]
    if var2 == "R0":
        stoc2 = "R0"
    else:
        stoc2 =  ''.join([i for i in var2 if not i.isdigit()])
        stoc_num2 = [int(i) for i in var2 if i.isdigit()]
    try:
        y = MDL.trace(stoc2)[:,stoc_num2[0]-1]
    except:
        y = MDL.trace(stoc2)[:]
    xmin, xmax = min(x), max(x)
    ymin, ymax = min(y), max(y)
    fig, ax = plt.subplots(figsize=(4,3))
    plt.grid(None)
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)
    plt.hexbin(x, y, gridsize=15, cmap=plt.cm.magma_r)
    plt.ticklabel_format(style='sci', axis='both', scilimits=(0,0))
    plt.xticks(rotation=90)
    plt.locator_params(axis = 'y', nbins = 5)
    plt.locator_params(axis = 'x', nbins = 5)    
    cb = plt.colorbar()
    cb.set_label('Number of observations')
    plt.ylabel("%s" %var2)
    plt.xlabel("%s" %var1)

    if save: 
        fn = 'HEX-%s-%s.%s'%(sol.model_type_str,sol.filename,ext)
        save_figure(fig, subfolder='Hexbins', fname=fn, dpi=dpi)

    plt.close(fig)        
    if draw:    return fig
    else:       return None
예제 #13
0
def plot_data(filename, headers, ph_units, save=False, 
              save_as_png=False, dpi=None, fig_nb=None):
    """
    Plots data before doing inversion
    Pass full file path, number of headers and phase units
    """
    ext = ['png' if save_as_png else 'pdf'][0]
    data = get_data(filename,headers,ph_units)

    # Graphiques du data
    Z = data["Z"]
    dZ = data["Z_err"]
    f = data["freq"]
    zn_dat = Z
    zn_err = dZ
    Pha_dat = 1000*data["pha"]
    Pha_err = 1000*data["pha_err"]
    Amp_dat = data["amp"]
    Amp_err = data["amp_err"]

    fig, ax = plt.subplots(2, 2, figsize=(8,5), sharex=True)
    # Real-Imag
    plt.axes(ax[0,0])
    plt.errorbar(f, zn_dat.real, zn_err.real, None, fmt='o', mfc='white', markersize=5, label='Data', zorder=0)
    ax[0,0].set_xscale("log")
    plt.ylabel(sym_labels['realrho'])
    
    plt.axes(ax[0,1])
    plt.errorbar(f, -zn_dat.imag, zn_err.imag, None, fmt='o', mfc='white', markersize=5, label='Data', zorder=0)
    ax[0,1].set_xscale("log")
    plt.ylabel(sym_labels['imagrho'])

    # Freq-Phas
    plt.axes(ax[1,1])
    plt.errorbar(f, -Pha_dat, Pha_err, None, fmt='o', mfc='white', markersize=5, label='Data', zorder=0)
    ax[1,1].set_yscale("log", nonposy='clip')
    ax[1,1].set_xscale("log")
    plt.xlabel(sym_labels['freq'])
    plt.ylabel(sym_labels['phas'])

    # Adjust for low or high phase response
    if  (-Pha_dat < 1).any() and (-Pha_dat >= 0.1).any():
        plt.ylim([0.1,10**np.ceil(max(np.log10(-Pha_dat)))])  
    if  (-Pha_dat < 0.1).any() and (-Pha_dat >= 0.01).any():
        plt.ylim([0.01,10**np.ceil(max(np.log10(-Pha_dat)))]) 
    
    # Freq-Ampl
    plt.axes(ax[1,0])
    plt.errorbar(f, Amp_dat, Amp_err, None, fmt='o', mfc='white', markersize=5, label='Data', zorder=0)
    ax[1,0].set_xscale("log")
    plt.xlabel(sym_labels['freq'])
    plt.ylabel(sym_labels['resi'])

    for a in ax.flat:
        a.grid('on')
        
    fig.tight_layout()

    if save: 
        fn = 'DAT-%s.%s'%(filename,ext)
        save_figure(fig, subfolder='Data', fname=fn, dpi=dpi)

    plt.close(fig)        
    return fig
예제 #14
0
def plot_summary(
    sol,
    save=False,
    draw=True,
    save_as_png=False,
    dpi=None,
    ignore=subplots_to_ignore,
    fig_nb="",
):
    """
    Plots a parameter summary and 
    Gelman-Rubin R-hat for multiple chains
    """

    ext = ['png' if save_as_png else 'pdf'][0]
    ch_nb = sol.mcmc["nb_chain"]

    keys = sorted([k for k in sol.var_dict.keys() if k not in ignore])
    trac = [[sol.var_dict[x].trace(chain=n).mean(axis=0) for x in keys]
            for n in range(ch_nb)]
    deps = [var_depth(sol.var_dict[x]) for x in keys]
    lbls = list(
        reversed(
            flatten([[k + '%s' % (x + 1) for x in range(d)] if d > 1 else k
                     for k, d in zip(keys, deps)])))

    if ch_nb >= 2:
        rhat = [
            gelman_rubin([
                sol.MDL.trace(var, -x)[:] for x in range(sol.mcmc['nb_chain'])
            ]) for var in keys
        ]
        R = np.array(flatten(rhat))
        R[R > 5] = 5
    else:
        print(
            "\nTwo or more chains of equal length required for Gelman-Rubin convergence"
        )
        R = len(lbls) * [None]

    fig, axes = plt.subplots(figsize=(6, 4))
    gs2 = gridspec.GridSpec(3, 3)
    ax1 = plt.subplot(gs2[:, :-1])
    ax2 = plt.subplot(gs2[:, -1], sharey=ax1)
    for i in range(len(lbls)):
        for c in range(ch_nb):
            val_m = np.array(flatten(trac[c]))
            ax1.scatter(val_m[i],
                        len(val_m) - (i + 1),
                        color="C0",
                        marker=".",
                        s=50,
                        facecolor='k',
                        edgecolors='k',
                        alpha=1)
        ax2.scatter(R[i], i, color="C3", marker="<", s=50, alpha=1)

    ax1.set_ylim([-1, len(lbls)])
    ax1.set_yticks(list(range(0, len(lbls))))
    ax1.set_yticklabels([parlbl_dic[l] for l in lbls])
    ax1.set_axisbelow(True)
    ax1.yaxis.grid(True)
    ax1.xaxis.grid(False)
    ax1.set_xlim(ax1.get_xlim())
    ax1.set_xlabel(r'Parameter value')

    plt.setp(ax2.get_yticklabels(), visible=False)
    ax2.set_xlim([0.5, 5.5])
    ax2.set_xticklabels(["", "1", "2", "3", "4", "5+"])
    ax2.set_xticks([
        0.5,
        1,
        2,
        3,
        4,
        5,
    ])
    ax2.set_axisbelow(True)
    ax2.yaxis.grid(True)
    ax2.xaxis.grid(False)
    ax2.set_xlabel(r'$\hat{R}$')
    ax2.axvline(1, ls='--', color='C0', zorder=0)

    plt.tight_layout()
    plt.close(fig)

    if save:
        fn = '%sSUM-%s-%s.%s' % (fig_nb, sol.model_type_str, sol.filename, ext)
        save_figure(fig, subfolder='Summaries', fname=fn, dpi=dpi)

    if draw: return fig
    else: return None
예제 #15
0
def plot_KDE(sol,
             var1,
             var2,
             fig=None,
             ax=None,
             draw=True,
             save=False,
             save_as_png=False,
             dpi=None):
    """
    Like the hexbin plot but a 2D KDE
    Pass mcmcinv object and 2 variable names as strings
    """
    ext = ['png' if save_as_png else 'pdf'][0]
    if fig == None or ax == None:
        fig, ax = plt.subplots(figsize=(3, 3))
    MDL = sol.MDL
    if var1 == "R0":
        stoc1 = "R0"
    else:
        stoc1 = ''.join([i for i in var1 if not i.isdigit()])
        stoc_num1 = [int(i) for i in var1 if i.isdigit()]
    try:
        x = MDL.trace(stoc1)[:, stoc_num1[0] - 1]
    except:
        x = MDL.trace(stoc1)[:]
    if var2 == "R0":
        stoc2 = "R0"
    else:
        stoc2 = ''.join([i for i in var2 if not i.isdigit()])
        stoc_num2 = [int(i) for i in var2 if i.isdigit()]
    try:
        y = MDL.trace(stoc2)[:, stoc_num2[0] - 1]
    except:
        y = MDL.trace(stoc2)[:]
    xmin, xmax = min(x), max(x)
    ymin, ymax = min(y), max(y)
    # Peform the kernel density estimate
    xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
    positions = np.vstack([xx.ravel(), yy.ravel()])
    values = np.vstack([x, y])
    kernel = gaussian_kde(values)
    kernel.set_bandwidth(bw_method='silverman')
    #        kernel.set_bandwidth(bw_method=kernel.factor * 2.)
    f = np.reshape(kernel(positions).T, xx.shape)

    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)
    plt.sca(ax)
    # Contourf plot
    plt.grid(None)
    plt.ticklabel_format(style='sci', axis='both', scilimits=(0, 0))
    plt.xticks(rotation=90)
    plt.locator_params(axis='y', nbins=7)
    plt.locator_params(axis='x', nbins=7)
    ax.contourf(xx, yy, f, cmap=plt.cm.viridis, alpha=0.8)
    ax.scatter(x, y, color='k', s=1, zorder=2)

    plt.ylabel("%s" % var2)
    plt.xlabel("%s" % var1)

    if save:
        fn = 'KDE-%s-%s.%s' % (sol.model_type_str, sol.filename, ext)
        save_figure(fig, subfolder='2D-KDE', fname=fn, dpi=dpi)

    plt.close(fig)
    if draw: return fig
    else: return None
예제 #16
0
def plot_traces(
    sol,
    save=False,
    draw=True,
    save_as_png=False,
    dpi=None,
    ignore=subplots_to_ignore,
):
    """
    Plots the traces of stochastic and
    deterministic parameters in mcmcinv object (sol)
    Ignores the ones in list argument ignore
    """
    ext = ['png' if save_as_png else 'pdf'][0]

    MDL = sol.MDL
    sampler = MDL.get_state()["sampler"]

    keys = [k for k in sol.var_dict.keys() if k not in ignore]

    for (i, k) in enumerate(keys):
        vect = old_div((MDL.trace(k)[:].size), (len(MDL.trace(k)[:])))
        if vect > 1:
            keys[i] = [k + "%d" % n for n in range(1, vect + 1)]

    keys = list(flatten(keys))
    ncols = 2
    nrows = int(ceil(len(keys) * 1.0 / ncols))

    fig, ax = plt.subplots(nrows, ncols, figsize=(8, nrows * 1.5), sharex=True)

    for c, (a, k) in enumerate(zip(ax.flat, keys)):
        if k == "R0":
            stoc = "R0"
        else:
            stoc = ''.join([i for i in k if not i.isdigit()])
            stoc_num = [int(i) for i in k if i.isdigit()]
        try:
            data = MDL.trace(stoc)[:][:, stoc_num[0] - 1]
        except:
            data = MDL.trace(stoc)[:]
        x = np.arange(sampler["_burn"] + 1, sampler["_iter"] + 1,
                      sampler["_thin"])
        plt.sca(a)
        plt.ticklabel_format(style='sci', axis='both', scilimits=(0, 0))
        plt.ylabel(parlbl_dic[k])
        plt.plot(x, data, '-', alpha=0.8)
        plt.plot(x,
                 np.mean(data) * np.ones(len(x)),
                 color='k',
                 linestyle='--',
                 linewidth=2)
        #        plt.plot(x, np.median(data)*np.ones(len(x)), color='k',linestyle=':', linewidth=2)

        if sampler["_burn"] == 0:
            plt.xscale('log')
        else:
            plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))

        plt.grid(False)

    plt.tight_layout(pad=0, w_pad=0.5, h_pad=0)

    for a in ax[-1]:
        a.set_xlabel("Iteration number")

    if save:
        fn = 'TRA-%s-%s.%s' % (sol.model_type_str, sol.filename, ext)
        save_figure(fig, subfolder='Traces', fname=fn, dpi=dpi)

    plt.close(fig)
    if draw: return fig
    else: return None
예제 #17
0
def plot_histo(
    sol,
    save=False,
    draw=True,
    save_as_png=False,
    dpi=None,
    ignore=subplots_to_ignore,
):
    """
    Plots the traces of stochastic and
    deterministic parameters in mcmcinv object (sol)
    Ignores the ones in list argument ignore
    """
    ext = ['png' if save_as_png else 'pdf'][0]
    MDL = sol.MDL

    keys = [k for k in sol.var_dict.keys() if k not in ignore]

    for (i, k) in enumerate(keys):
        vect = old_div((MDL.trace(k)[:].size), (len(MDL.trace(k)[:])))
        if vect > 1:
            keys[i] = [k + "%d" % n for n in range(1, vect + 1)]
    keys = list(flatten(keys))

    ncols = 2
    nrows = int(ceil(len(keys) * 1.0 / ncols))
    fig, ax = plt.subplots(nrows, ncols, figsize=(8, nrows * 1.8))
    for c, (a, k) in enumerate(zip(ax.flat, keys)):
        if k == "R0":
            stoc = "R0"
        else:
            stoc = ''.join([i for i in k if not i.isdigit()])
            stoc_num = [int(i) for i in k if i.isdigit()]
        try:
            data = sorted(MDL.trace(stoc)[:][:, stoc_num[0] - 1])
        except:
            data = sorted(MDL.trace(stoc)[:])
        plt.sca(a)
        plt.xlabel(parlbl_dic[k])
        try:
            hist = plt.hist(data,
                            bins=20,
                            histtype='stepfilled',
                            density=False,
                            linewidth=1.0,
                            color='0.95',
                            alpha=1)
            plt.hist(data,
                     bins=20,
                     histtype='step',
                     density=False,
                     linewidth=1.0,
                     alpha=1)
            fit = norm.pdf(data, np.mean(data), np.std(data))
            xh = [
                0.5 * (hist[1][r] + hist[1][r + 1])
                for r in range(len(hist[1]) - 1)
            ]
            binwidth = old_div((max(xh) - min(xh)), len(hist[1]))
            fit *= len(data) * binwidth
            plt.plot(data, fit, "-", color='k', linewidth=1)
        except:
            print(
                "File %s: failed to plot %s histogram. Parameter not mobile enough (see traces)."
                % (sol.filename, k))
        plt.grid(False)
        plt.ticklabel_format(style='sci', axis='both', scilimits=(0, 0))

    for c in range(nrows):
        ax[c][0].set_ylabel("Frequency")

    plt.tight_layout(pad=1, w_pad=1, h_pad=0)
    for a in ax.flat[ax.size - 1:len(keys) - 1:-1]:
        a.set_visible(False)

    if save:
        fn = 'HST-%s-%s.%s' % (sol.model_type_str, sol.filename, ext)
        save_figure(fig, subfolder='Histograms', fname=fn, dpi=dpi)

    plt.close(fig)
    if draw: return fig
    else: return None
예제 #18
0
def plot_fit(sol,
             save=False,
             draw=True,
             save_as_png=False,
             dpi=None,
             fig_nb=""):
    """
    Plots the average fit and uncertainty
    Pass mcmcinv object (sol)
    """
    ext = ['png' if save_as_png else 'pdf'][0]

    # Prepare data for plotting
    f = sol.data["freq"]
    Zr0 = sol.data["Z_max"]
    zn_dat = sol.data["Z"] / Zr0
    zn_err = sol.data["Z_err"] / Zr0
    zn_fit = sol.fit["best"] / Zr0
    zn_min = sol.fit["lo95"] / Zr0
    zn_max = sol.fit["up95"] / Zr0

    Pha_dat = 1000 * sol.data["pha"]
    Pha_err = 1000 * sol.data["pha_err"]
    Pha_fit = 1000 * np.angle(sol.fit["best"])
    Pha_min = 1000 * np.angle(sol.fit["lo95"])
    Pha_max = 1000 * np.angle(sol.fit["up95"])

    Amp_dat = sol.data["amp"] / Zr0
    Amp_err = sol.data["amp_err"] / Zr0
    Amp_fit = abs(sol.fit["best"]) / Zr0
    Amp_min = abs(sol.fit["lo95"]) / Zr0
    Amp_max = abs(sol.fit["up95"]) / Zr0

    fig, ax = plt.subplots(2, 2, figsize=(8, 5), sharex=True)

    # Freq-Imag
    plt.sca(ax[0, 0])
    plt.errorbar(f,
                 -zn_dat.imag,
                 zn_err.imag,
                 None,
                 color='k',
                 fmt='o',
                 mfc='white',
                 markersize=5,
                 label='Data',
                 zorder=0)
    p = plt.plot(f, -zn_fit.imag, ls='-', label="Model", zorder=2)
    plt.fill_between(f,
                     -zn_max.imag,
                     -zn_min.imag,
                     alpha=0.4,
                     color=p[0].get_color(),
                     zorder=1,
                     label='95% HPD')
    plt.ylabel(sym_labels['imag'])
    plt.legend(loc='best', labelspacing=0.2, handlelength=1, framealpha=1)

    # Freq-Real
    plt.sca(ax[0, 1])
    plt.errorbar(f,
                 zn_dat.real,
                 zn_err.real,
                 None,
                 color='k',
                 fmt='o',
                 mfc='white',
                 markersize=5,
                 label='Data',
                 zorder=0)
    p = plt.plot(f, zn_fit.real, ls='-', label="Model", zorder=2)
    plt.fill_between(f,
                     zn_max.real,
                     zn_min.real,
                     alpha=0.4,
                     color=p[0].get_color(),
                     zorder=1,
                     label='95% HPD')
    plt.ylabel(sym_labels['imag'])
    plt.legend(loc='best', labelspacing=0.2, handlelength=1, framealpha=1)

    # Freq-Phas
    plt.sca(ax[1, 0])
    plt.errorbar(f,
                 -Pha_dat,
                 Pha_err,
                 None,
                 fmt='o',
                 color='k',
                 mfc='white',
                 markersize=5,
                 label='Data',
                 zorder=0)
    p = plt.plot(f, -Pha_fit, ls='-', label='Model', zorder=2)
    ax[1, 0].set_yscale("log", nonposy='clip')
    plt.xscale('log')
    plt.fill_between(f,
                     -Pha_max,
                     -Pha_min,
                     color=p[0].get_color(),
                     alpha=0.4,
                     zorder=1,
                     label='95% HPD')
    plt.xlabel(sym_labels['freq'])
    plt.ylabel(sym_labels['phas'])
    plt.legend(loc='best', labelspacing=0.2, handlelength=1, framealpha=1)

    # Freq-Ampl
    plt.sca(ax[1, 1])
    plt.errorbar(f,
                 Amp_dat,
                 Amp_err,
                 None,
                 fmt='o',
                 color='k',
                 mfc='white',
                 markersize=5,
                 label='Data',
                 zorder=0)
    p = plt.semilogx(f, Amp_fit, ls='-', label='Model', zorder=2)
    plt.fill_between(f,
                     Amp_max,
                     Amp_min,
                     color=p[0].get_color(),
                     alpha=0.4,
                     zorder=1,
                     label='95% HPD')
    plt.xscale('log')
    plt.xlabel(sym_labels['freq'])
    plt.ylabel(sym_labels['ampl'])
    plt.legend(loc='best', labelspacing=0.2, handlelength=1, framealpha=1)

    for a in ax.flat:
        a.grid(True)

    plt.tight_layout(pad=0, h_pad=0.5, w_pad=1)

    if save:
        fn = '%sFIT-%s-%s.%s' % (fig_nb, sol.model_type_str, sol.filename, ext)
        save_figure(fig, subfolder='Fit figures', fname=fn, dpi=dpi)

    plt.close(fig)
    if draw: return fig
    else: return None
예제 #19
0
def plot_rtd(sol, save=False, draw=True, save_as_png=False, dpi=None):
    """
    Plots the relaxation time distribution (RTD)
    for a polynomial decomposition or ccdt results
    """
    ext = ['png' if save_as_png else 'pdf'][0]
    fig, ax = plt.subplots(figsize=(4, 3))
    try:
        bot95 = 10**sol.MDL.stats()["log_m_i"]['95% HPD interval'][0]
        top95 = 10**sol.MDL.stats()["log_m_i"]['95% HPD interval'][1]
        log_tau = 10**sol.MDL.stats()["log_tau_i"]['mean']
        log_m = 10**sol.MDL.stats()["log_m_i"]['mean']
    except:
        bot95 = sol.MDL.stats()["m_i"]['95% HPD interval'][0]
        top95 = sol.MDL.stats()["m_i"]['95% HPD interval'][1]
        log_tau = 10**sol.MDL.log_tau
        log_m = sol.MDL.stats()["m_i"]['mean']
    plt.errorbar(log_tau,
                 log_m,
                 None,
                 None,
                 color="C7",
                 linestyle='-',
                 label="RTD")
    try:
        peaks = 10**np.atleast_1d(sol.MDL.stats()["log_peak_tau"]["mean"])
        uncer_peaks = 10**sol.MDL.stats(
        )["log_peak_tau"]['95% HPD interval'].T.reshape(
            len(np.atleast_1d(sol.MDL.stats()["log_peak_tau"]['mean'])), 2)
        m_peaks = log_m[[
            list(log_tau).index(find_nearest(log_tau, peaks[x]))
            for x in range(len(peaks))
        ]]
        if len(peaks) >= 1:
            plt.errorbar(peaks,
                         m_peaks * 1.2,
                         None,
                         None,
                         color="C3",
                         marker="v",
                         markersize=5,
                         linestyle="",
                         label=r"$\tau_{peak}$")
            for i, u in enumerate(uncer_peaks):
                plt.axvspan(u[0], u[1], alpha=0.2, color="C3")
    except:
        pass
    plt.axvline(10**sol.MDL.stats()["log_half_tau"]['mean'],
                color="C0",
                linestyle=':',
                label=r"$\tau_{50}$")
    plt.axvline(10**sol.MDL.stats()["log_mean_tau"]['mean'],
                color='C2',
                linestyle='--',
                label=r"$\bar{\tau}$")
    inter = 10**sol.MDL.stats()["log_half_tau"]['95% HPD interval']
    plt.axvspan(inter[0], inter[1], alpha=0.2, color="C0")
    inter = 10**sol.MDL.stats()["log_mean_tau"]['95% HPD interval']
    plt.axvspan(inter[0], inter[1], alpha=0.2, color='C2')
    plt.axvspan(min(log_tau), min(log_tau) * 10, alpha=0.1, color='C7')
    plt.axvspan(max(log_tau) / 10, max(log_tau), alpha=0.1, color='C7')
    plt.fill_between(log_tau, bot95, top95, color="C7", alpha=0.2)
    plt.xlim([
        10**np.ceil(np.log10(min(log_tau))),
        10**np.floor(np.log10(max(log_tau)))
    ])
    ax.set_xlabel(r'$\tau$ (s)')
    ax.set_ylabel(r'$m$')
    plt.grid(False)
    plt.legend(fontsize=9, loc=1, labelspacing=0.2, handlelength=1.5)
    plt.xscale('log')
    plt.yscale('log', nonposy='clip')
    fig.tight_layout()
    if save:
        fn = 'RTD-%s-%s.%s' % (sol.model_type_str, sol.filename, ext)
        save_figure(fig, subfolder='RTD', fname=fn, dpi=dpi)

    plt.close(fig)
    if draw: return fig
    else: return None
예제 #20
0
def plot_summary(sol, save=False, draw=True, save_as_png=False, dpi=None,
                 ignore=default_ignore,
                 fig_nb="",
                 ):
    """
    Plots a parameter summary and 
    Gelman-Rubin R-hat for multiple chains
    """
    
    ext = ['png' if save_as_png else 'pdf'][0]
    ch_nb = sol.mcmc["nb_chain"]

    keys = sorted([k for k in sol.var_dict.keys() if k not in ignore])        
    trac = [[sol.var_dict[x].trace(chain=n).mean(axis=0) for x in keys] for n in range(ch_nb)]
    deps = [var_depth(sol.var_dict[x]) for x in keys]
    lbls = list(reversed(flatten([[k+'%s'%(x+1) for x in range(d)] if d > 1 else k for k, d in zip(keys,deps)])))
    
    if ch_nb >= 2:
        rhat = [gelman_rubin([sol.MDL.trace(var, -x)[:] for x in range(sol.mcmc['nb_chain'])]) for var in keys]
        R = np.array(flatten(rhat))
        R[R > 5] = 5 
    else:
        print("\nTwo or more chains of equal length required for Gelman-Rubin convergence")
        R = len(lbls)*[None]
        
    fig, axes = plt.subplots(figsize=(6,4))
    gs2 = gridspec.GridSpec(3, 3)
    ax1 = plt.subplot(gs2[:, :-1])
    ax2 = plt.subplot(gs2[:, -1], sharey = ax1)
    for i in range(len(lbls)):
        for c in range(ch_nb):
            val_m = np.array(flatten(trac[c]))
            ax1.scatter(val_m[i], len(val_m)-(i+1) , color="C0", marker=".", 
                        s=50, facecolor='k', edgecolors='k',alpha=1)
        ax2.scatter(R[i], i, color="C3", marker="<", s=50, alpha=1)

    ax1.set_ylim([-1, len(lbls)])
    ax1.set_yticks(list(range(0,len(lbls))))
    ax1.set_yticklabels([parlbl_dic[l] for l in lbls])
    ax1.set_axisbelow(True)
    ax1.yaxis.grid(True)
    ax1.xaxis.grid(False)
    ax1.set_xlim(ax1.get_xlim())
    ax1.set_xlabel(r'Parameter value')

    plt.setp(ax2.get_yticklabels(), visible=False)
    ax2.set_xlim([0.5, 5.5])
    ax2.set_xticklabels(["","1","2","3","4","5+"])
    ax2.set_xticks([0.5, 1, 2, 3, 4, 5, ])
    ax2.set_axisbelow(True)
    ax2.yaxis.grid(True)
    ax2.xaxis.grid(False)
    ax2.set_xlabel(r'$\hat{R}$')
    ax2.axvline(1, ls='--', color='C0', zorder=0)

    plt.tight_layout()
    plt.close(fig)        

    if save: 
        fn = '%sSUM-%s-%s.%s'%(fig_nb,sol.model_type_str,sol.filename,ext)
        save_figure(fig, subfolder='Summaries', fname=fn, dpi=dpi)

    if draw:    return fig
    else:       return None
예제 #21
0
def plot_data(filename,
              headers,
              ph_units,
              save=False,
              save_as_png=False,
              dpi=None,
              fig_nb=None):
    """
    Plots data before doing inversion
    Pass full file path, number of headers and phase units
    """
    ext = ['png' if save_as_png else 'pdf'][0]
    data = get_data(filename, headers, ph_units)

    # Graphiques du data
    Z = data["Z"]
    dZ = data["Z_err"]
    f = data["freq"]
    zn_dat = Z
    zn_err = dZ
    Pha_dat = 1000 * data["pha"]
    Pha_err = 1000 * data["pha_err"]
    Amp_dat = data["amp"]
    Amp_err = data["amp_err"]

    fig, ax = plt.subplots(2, 2, figsize=(8, 5), sharex=True)
    # Real-Imag
    plt.axes(ax[0, 0])
    plt.errorbar(f,
                 zn_dat.real,
                 zn_err.real,
                 None,
                 fmt='o',
                 mfc='white',
                 markersize=5,
                 label='Data',
                 zorder=0)
    ax[0, 0].set_xscale("log")
    plt.ylabel(sym_labels['realrho'])

    plt.axes(ax[0, 1])
    plt.errorbar(f,
                 -zn_dat.imag,
                 zn_err.imag,
                 None,
                 fmt='o',
                 mfc='white',
                 markersize=5,
                 label='Data',
                 zorder=0)
    ax[0, 1].set_xscale("log")
    plt.ylabel(sym_labels['imagrho'])

    # Freq-Phas
    plt.axes(ax[1, 1])
    plt.errorbar(f,
                 -Pha_dat,
                 Pha_err,
                 None,
                 fmt='o',
                 mfc='white',
                 markersize=5,
                 label='Data',
                 zorder=0)
    ax[1, 1].set_yscale("log", nonposy='clip')
    ax[1, 1].set_xscale("log")
    plt.xlabel(sym_labels['freq'])
    plt.ylabel(sym_labels['phas'])

    # Adjust for low or high phase response
    if (-Pha_dat < 1).any() and (-Pha_dat >= 0.1).any():
        plt.ylim([0.1, 10**np.ceil(max(np.log10(-Pha_dat)))])
    if (-Pha_dat < 0.1).any() and (-Pha_dat >= 0.01).any():
        plt.ylim([0.01, 10**np.ceil(max(np.log10(-Pha_dat)))])

    # Freq-Ampl
    plt.axes(ax[1, 0])
    plt.errorbar(f,
                 Amp_dat,
                 Amp_err,
                 None,
                 fmt='o',
                 mfc='white',
                 markersize=5,
                 label='Data',
                 zorder=0)
    ax[1, 0].set_xscale("log")
    plt.xlabel(sym_labels['freq'])
    plt.ylabel(sym_labels['resi'])

    for a in ax.flat:
        a.grid('on')

    fig.tight_layout()

    if save:
        fn = 'DAT-%s.%s' % (filename, ext)
        save_figure(fig, subfolder='Data', fname=fn, dpi=dpi)

    plt.close(fig)
    return fig
예제 #22
0
def plot_fit(sol, save=False, draw=True, 
             save_as_png=False, dpi=None, fig_nb=""):
    """
    Plots the average fit and uncertainty
    Pass mcmcinv object (sol)
    """
    ext = ['png' if save_as_png else 'pdf'][0]

    # Prepare data for plotting
    f = sol.data["freq"]
    Zr0 = sol.data["Z_max"]
    zn_dat = sol.data["Z"]/Zr0
    zn_err = sol.data["Z_err"]/Zr0
    zn_fit = sol.fit["best"]/Zr0
    zn_min = sol.fit["lo95"]/Zr0
    zn_max = sol.fit["up95"]/Zr0
    
    Pha_dat = 1000*sol.data["pha"]
    Pha_err = 1000*sol.data["pha_err"]
    Pha_fit = 1000*np.angle(sol.fit["best"])
    Pha_min = 1000*np.angle(sol.fit["lo95"])
    Pha_max = 1000*np.angle(sol.fit["up95"])
    
    Amp_dat = sol.data["amp"]/Zr0
    Amp_err = sol.data["amp_err"]/Zr0
    Amp_fit = abs(sol.fit["best"])/Zr0
    Amp_min = abs(sol.fit["lo95"])/Zr0
    Amp_max = abs(sol.fit["up95"])/Zr0
    
    fig, ax = plt.subplots(2, 2, figsize=(8,5), sharex=True)
    
    # Freq-Imag
    plt.sca(ax[0,0])
    plt.errorbar(f, -zn_dat.imag, zn_err.imag, None, color='k', fmt='o', mfc='white', markersize=5, label='Data', zorder=0)
    p=plt.plot(f, -zn_fit.imag, ls='-', label="Model",zorder=2)
    plt.fill_between(f, -zn_max.imag, -zn_min.imag, alpha=0.4, color=p[0].get_color(), zorder=1, label='95% HPD')
    plt.ylabel(sym_labels['imag'])
    plt.legend(loc='best', labelspacing=0.2, handlelength=1, framealpha=1)
    
    # Freq-Real
    plt.sca(ax[0,1])
    plt.errorbar(f, zn_dat.real, zn_err.real, None, color='k', fmt='o', mfc='white', markersize=5, label='Data', zorder=0)
    p=plt.plot(f, zn_fit.real, ls='-', label="Model",zorder=2)
    plt.fill_between(f, zn_max.real, zn_min.real, alpha=0.4, color=p[0].get_color(), zorder=1, label='95% HPD')
    plt.ylabel(sym_labels['imag'])
    plt.legend(loc='best', labelspacing=0.2, handlelength=1, framealpha=1)
    
    # Freq-Phas
    plt.sca(ax[1,0])
    plt.errorbar(f, -Pha_dat, Pha_err, None, fmt='o', color='k', mfc='white', markersize=5, label='Data', zorder=0)
    p=plt.plot(f, -Pha_fit, ls='-', label='Model', zorder=2)
    ax[1,0].set_yscale("log", nonposy='clip')
    plt.xscale('log')
    plt.fill_between(f, -Pha_max, -Pha_min, color=p[0].get_color(), alpha=0.4, zorder=1, label='95% HPD')
    plt.xlabel(sym_labels['freq'])
    plt.ylabel(sym_labels['phas'])
    plt.legend(loc='best', labelspacing=0.2, handlelength=1, framealpha=1)

    # Freq-Ampl
    plt.sca(ax[1,1])
    plt.errorbar(f, Amp_dat, Amp_err, None, fmt='o', color='k', mfc='white', markersize=5, label='Data', zorder=0)
    p=plt.semilogx(f, Amp_fit, ls='-', label='Model', zorder=2)
    plt.fill_between(f, Amp_max, Amp_min, color=p[0].get_color(), alpha=0.4, zorder=1, label='95% HPD')
    plt.xscale('log')
    plt.xlabel(sym_labels['freq'])
    plt.ylabel(sym_labels['ampl'])
    plt.legend(loc='best', labelspacing=0.2, handlelength=1, framealpha=1)

    for a in ax.flat:
        a.grid(True)

    plt.tight_layout(pad=0, h_pad=0.5, w_pad=1)
        
    if save:
        fn = '%sFIT-%s-%s.%s'%(fig_nb,sol.model_type_str,sol.filename,ext)
        save_figure(fig, subfolder='Fit figures', fname=fn, dpi=dpi)

    plt.close(fig)        
    if draw:    return fig
    else:       return None