示例#1
0
def mcycle_data_linear():
    X, y = mcycle()

    gam = LinearGAM()
    gam.gridsearch(X, y)

    XX = generate_X_grid(gam)
    plt.figure()
    plt.scatter(X, y, facecolor='gray', edgecolors='none')
    plt.plot(XX, gam.predict(XX), 'r--')
    plt.plot(XX, gam.prediction_intervals(XX, width=.95), color='b', ls='--')
    plt.title('95% prediction interval')

    plt.savefig('imgs/pygam_mcycle_data_linear.png', dpi=300)

    m = X.min()
    M = X.max()
    XX = np.linspace(m - 10, M + 10, 500)
    Xl = np.linspace(m - 10, m, 50)
    Xr = np.linspace(M, M + 10, 50)

    plt.figure()

    plt.plot(XX, gam.predict(XX), 'k')
    plt.plot(Xl, gam.confidence_intervals(Xl), color='b', ls='--')
    plt.plot(Xr, gam.confidence_intervals(Xr), color='b', ls='--')
    plt.plot(X, gam.confidence_intervals(X), color='r', ls='--')

    plt.savefig('imgs/pygam_mcycle_data_extrapolation.png', dpi=300)
示例#2
0
    def test_shape_of_random_samples(self, mcycle, mcycle_gam):
        X, y = mcycle
        n_samples = len(X)
        n_draws = 5

        sample_coef = mcycle_gam.sample(X, y, quantity='coef', n_draws=n_draws)
        sample_mu = mcycle_gam.sample(X, y, quantity='mu', n_draws=n_draws)
        sample_y = mcycle_gam.sample(X, y, quantity='y', n_draws=n_draws)
        assert sample_coef.shape == (n_draws, len(mcycle_gam.coef_))
        assert sample_mu.shape == (n_draws, n_samples)
        assert sample_y.shape == (n_draws, n_samples)

        XX = generate_X_grid(mcycle_gam)
        n_samples_in_grid = len(XX)
        sample_coef = mcycle_gam.sample(X,
                                        y,
                                        quantity='coef',
                                        n_draws=n_draws,
                                        sample_at_X=XX)
        sample_mu = mcycle_gam.sample(X,
                                      y,
                                      quantity='mu',
                                      n_draws=n_draws,
                                      sample_at_X=XX)
        sample_y = mcycle_gam.sample(X,
                                     y,
                                     quantity='y',
                                     n_draws=n_draws,
                                     sample_at_X=XX)

        assert sample_coef.shape == (n_draws, len(mcycle_gam.coef_))
        assert sample_mu.shape == (n_draws, n_samples_in_grid)
        assert sample_y.shape == (n_draws, n_samples_in_grid)
示例#3
0
def dependencia_parcial(modelo, X_train):

    x_grid = generate_X_grid(modelo)
    plt.figure(figsize=(10, 5))
    attribute = X_train.columns

    cols = 3
    rows = int(len(attribute) / cols)

    for i, n in enumerate(range(len(attribute))):

        plt.subplot(rows, cols, i + 1)

        partial_dep, confidence_intervals = modelo.partial_dependence(
            x_grid, feature=i + 1, width=.95)

        plt.plot(x_grid[:, n], partial_dep, color='tomato')

        plt.fill_between(x_grid[:, n],
                         confidence_intervals[0][:, 0],
                         confidence_intervals[0][:, 1],
                         color='tomato',
                         alpha=.25)

        plt.title(attribute[n])
        plt.plot(X_train[attribute[n]],
                 [plt.ylim()[0]] * len(X_train[attribute[n]]),
                 '|',
                 color='orange',
                 alpha=.5)
    plt.tight_layout()
示例#4
0
文件: gen_imgs.py 项目: habedi/pyGAM
def mcycle_data_linear():
    X, y = mcycle()

    gam = LinearGAM()
    gam.gridsearch(X, y)

    XX = generate_X_grid(gam)
    plt.figure()
    plt.scatter(X, y, facecolor='gray', edgecolors='none')
    plt.plot(XX, gam.predict(XX), 'r--')
    plt.plot(XX, gam.prediction_intervals(XX, width=.95), color='b', ls='--')
    plt.title('95% prediction interval')

    plt.savefig('imgs/pygam_mcycle_data_linear.png', dpi=300)


    m = X.min()
    M = X.max()
    XX = np.linspace(m - 10, M + 10, 500)
    Xl = np.linspace(m - 10, m, 50)
    Xr = np.linspace(M, M + 10, 50)

    plt.figure()

    plt.plot(XX, gam.predict(XX), 'k')
    plt.plot(Xl, gam.confidence_intervals(Xl), color='b', ls='--')
    plt.plot(Xr, gam.confidence_intervals(Xr), color='b', ls='--')
    plt.plot(X, gam.confidence_intervals(X), color='r', ls='--')

    plt.savefig('imgs/pygam_mcycle_data_extrapolation.png', dpi=300)
示例#5
0
def cake_data_in_one():
    X, y = cake()

    gam = LinearGAM(fit_intercept=True)
    gam.gridsearch(X, y)

    XX = generate_X_grid(gam)

    plt.figure()
    plt.plot(gam.partial_dependence(XX))
    plt.title('LinearGAM')
    plt.savefig('imgs/pygam_cake_data.png', dpi=300)
示例#6
0
文件: gen_imgs.py 项目: habedi/pyGAM
def cake_data_in_one():
    X, y = cake()

    gam = LinearGAM(fit_intercept=True)
    gam.gridsearch(X,y)

    XX = generate_X_grid(gam)

    plt.figure()
    plt.plot(gam.partial_dependence(XX))
    plt.title('LinearGAM')
    plt.savefig('imgs/pygam_cake_data.png', dpi=300)
def test_convex(hepatitis):
    """
    check that convex constraint produces convex function
    """
    X, y = hepatitis

    gam = LinearGAM(constraints='convex')
    gam.fit(X, y)

    XX = generate_X_grid(gam)
    Y = gam.predict(np.sort(XX))
    diffs = np.diff(Y, n=2)
    assert (((diffs >= 0) + np.isclose(diffs, 0.)).all())
def test_monotonic_dec(hepatitis):
    """
    check that monotonic_dec constraint produces monotonic decreasing function
    """
    X, y = hepatitis

    gam = LinearGAM(constraints='monotonic_dec')
    gam.fit(X, y)

    XX = generate_X_grid(gam)
    Y = gam.predict(np.sort(XX))
    diffs = np.diff(Y, n=1)
    assert (((diffs <= 0) + np.isclose(diffs, 0.)).all())
def display_partial_dependence(model, X_train, marker='|', color='dodgerblue'):
    """
    Descripción: Se despliega una grilla de con todos los gráficos de dependencia 
    parcial de los atributos
    Entrada:
        model, modelo entrenado previamente
        X_train, DataFrame con los atributos de entrenamiento
    Salida:
        void
    Version: 1
    Date: 07/27/2019
    """
    x_grid = generate_X_grid(model)
    # Listamos los atributos
    variables = X_train.columns
    # generamos el dimensionado del grid en base a la cantidad de atributos
    n_columns = len(variables)
    # Fijamos un máximo de 4 columnas de gráficos
    max_columns = 4
    # Calculamos en base a la cantodidad maxima de columnas, las filas
    rows = np.ceil(n_columns / max_columns)
    # Fijamos el tamaño del paño de los gráficos
    plt.figure(figsize=(15, 10))
    # Generamos un gráfico de dependencias parciales para cada variable
    for i, name in enumerate(variables):
        plt.subplot(rows, max_columns, i + 1)
        # extraemos la dependencia parcial y sus intervalos de confianza al 95%
        partial_dep, confidence_intervals = model.partial_dependence(
            x_grid, feature=i + 1, width=0.95)
        # Generamos la linea que describe la curva
        plt.plot(x_grid[:, i], partial_dep, color=color)
        # Generamos una visualización de los intervalos de confianza
        plt.fill_between(x_grid[:, i],
                         confidence_intervals[0][:, 0],
                         confidence_intervals[0][:, 1],
                         color=color,
                         alpha=.25)
        x_vect = X_train[name]
        y_vect = [plt.ylim()[0]] * len(X_train[name])
        plt.scatter(X_train[name],
                    y_vect,
                    marker=marker,
                    color='orange',
                    alpha=.5,
                    s=500)
        # agregamos el nombre del atributo
        plt.title(name)
        plt.tight_layout()
示例#10
0
文件: gen_imgs.py 项目: habedi/pyGAM
def gen_basis_fns():
    X, y = hepatitis()
    gam = LinearGAM(lam=.6, fit_intercept=False).fit(X, y)
    XX = generate_X_grid(gam)

    plt.figure()
    fig, ax = plt.subplots(2,1)
    ax[0].plot(XX, gam._modelmat(XX, feature=0).A);
    ax[0].set_title('b-Spline Basis Functions')

    ax[1].scatter(X, y, facecolor='gray', edgecolors='none')
    ax[1].plot(XX, gam._modelmat(XX).A * gam.coef_);
    ax[1].plot(XX, gam.predict(XX), 'k')
    ax[1].set_title('Fitted Model')
    fig.tight_layout()
    plt.savefig('imgs/pygam_basis.png', dpi=300)
示例#11
0
def gen_basis_fns():
    X, y = hepatitis()
    gam = LinearGAM(lam=.6, fit_intercept=False).fit(X, y)
    XX = generate_X_grid(gam)

    plt.figure()
    fig, ax = plt.subplots(2, 1)
    ax[0].plot(XX,
               gam._modelmat(XX, feature=0).A)
    ax[0].set_title('b-Spline Basis Functions')

    ax[1].scatter(X, y, facecolor='gray', edgecolors='none')
    ax[1].plot(XX,
               gam._modelmat(XX).A * gam.coef_)
    ax[1].plot(XX, gam.predict(XX), 'k')
    ax[1].set_title('Fitted Model')
    fig.tight_layout()
    plt.savefig('imgs/pygam_basis.png', dpi=300)
示例#12
0
文件: gen_imgs.py 项目: habedi/pyGAM
def default_data_logistic(n=500):
    X, y = default()

    gam = LogisticGAM()
    gam.gridsearch(X, y)

    XX = generate_X_grid(gam)

    plt.figure()
    fig, axs = plt.subplots(1,3)

    titles = ['student', 'balance', 'income']
    for i, ax in enumerate(axs):
        ax.plot(XX[:, i], gam.partial_dependence(XX, feature=i+1))
        ax.plot(XX[:, i], *gam.partial_dependence(XX, feature=i+1, width=.95)[1],
                c='r', ls='--')
        ax.set_title(titles[i])

    fig.tight_layout()
    plt.savefig('imgs/pygam_default_data_logistic.png', dpi=300)
示例#13
0
def test_prediction_interval_known_scale():
    """
    the prediction intervals should be correct to a few decimal places
    we test at a large sample limit.
    """
    n = 1000000
    X = np.linspace(0,1,n)
    y = np.random.randn(n)

    gam_a = LinearGAM(fit_linear=True, fit_splines=False, scale=1.).fit(X, y)
    gam_b = LinearGAM(n_splines=4, scale=1.).fit(X, y)

    XX = generate_X_grid(gam_a)
    intervals_a = gam_a.prediction_intervals(XX, quantiles=[0.1, .9]).mean(axis=0)
    intervals_b = gam_b.prediction_intervals(XX, quantiles=[0.1, .9]).mean(axis=0)

    assert np.allclose(intervals_a[0], sp.stats.norm.ppf(0.1), atol=0.01)
    assert np.allclose(intervals_a[1], sp.stats.norm.ppf(0.9), atol=0.01)

    assert np.allclose(intervals_b[0], sp.stats.norm.ppf(0.1), atol=0.01)
    assert np.allclose(intervals_b[1], sp.stats.norm.ppf(0.9), atol=0.01)
示例#14
0
文件: gen_imgs.py 项目: habedi/pyGAM
def wage_data_linear():
    X, y = wage()

    gam = LinearGAM(n_splines=10)
    gam.gridsearch(X, y, lam=np.logspace(-5,3,50))

    XX = generate_X_grid(gam)

    plt.figure()
    fig, axs = plt.subplots(1,3)

    titles = ['year', 'age', 'education']
    for i, ax in enumerate(axs):
        ax.plot(XX[:, i], gam.partial_dependence(XX, feature=i+1))
        ax.plot(XX[:, i], *gam.partial_dependence(XX, feature=i+1, width=.95)[1],
                c='r', ls='--')
        if i == 0:
            ax.set_ylim(-30,30);
        ax.set_title(titles[i])

    fig.tight_layout()
    plt.savefig('imgs/pygam_wage_data_linear.png', dpi=300)
示例#15
0
def default_data_logistic(n=500):
    X, y = default()

    gam = LogisticGAM()
    gam.gridsearch(X, y)

    XX = generate_X_grid(gam)

    plt.figure()
    fig, axs = plt.subplots(1, 3)

    titles = ['student', 'balance', 'income']
    for i, ax in enumerate(axs):
        ax.plot(XX[:, i], gam.partial_dependence(XX, feature=i + 1))
        ax.plot(XX[:, i],
                *gam.partial_dependence(XX, feature=i + 1, width=.95)[1],
                c='r',
                ls='--')
        ax.set_title(titles[i])

    fig.tight_layout()
    plt.savefig('imgs/pygam_default_data_logistic.png', dpi=300)
示例#16
0
def wage_data_linear():
    X, y = wage()

    gam = LinearGAM(n_splines=10)
    gam.gridsearch(X, y, lam=np.logspace(-5, 3, 50))

    XX = generate_X_grid(gam)

    plt.figure()
    fig, axs = plt.subplots(1, 3)

    titles = ['year', 'age', 'education']
    for i, ax in enumerate(axs):
        ax.plot(XX[:, i], gam.partial_dependence(XX, feature=i + 1))
        ax.plot(XX[:, i],
                *gam.partial_dependence(XX, feature=i + 1, width=.95)[1],
                c='r',
                ls='--')
        if i == 0:
            ax.set_ylim(-30, 30)
        ax.set_title(titles[i])

    fig.tight_layout()
    plt.savefig('imgs/pygam_wage_data_linear.png', dpi=300)
示例#17
0
from bokeh.layouts import gridplot, row
import matplotlib.pyplot as plt

# Importing data from the web
path = 'http://www.stat.cmu.edu/~larry/' \
 'all-of-nonpar/=data/rock.dat'

data = pd.read_csv(path, sep=' *', engine='python')

X = data[['peri', 'shape', 'perm']]
y = data['area']

adjy = y - np.mean(y)

gam = LinearGAM(n_splines=10).gridsearch(X, y)
XX = generate_X_grid(gam)

# fig, axs = plt.subplots(1, 3)
titles = ['peri', 'shape', 'perm']

# for i, ax in enumerate(axs):
#     pdep, confi = gam.partial_dependence(XX, feature=i+1, width=.95)

#     ax.scatter(X[X.columns[i]], adjy, color='gray', edgecolors='none')
#     ax.plot(XX[:, i], pdep)
#     ax.plot(XX[:, i], confi[0], c='r', ls='--')
#     ax.set_title(titles[i])

pdep, confi = gam.partial_dependence(XX, width=.95)
p = list()
示例#18
0
def make_plot(plot_dir, site, df_flx, df_met):

    K_TO_C = 273.15

    #golden_mean = 0.6180339887498949
    #width = 6*2*(1/golden_mean)
    #height = width * golden_mean

    fig = plt.figure(figsize=(14, 4))
    fig.subplots_adjust(hspace=0.1)
    fig.subplots_adjust(wspace=0.2)
    plt.rcParams['text.usetex'] = False
    plt.rcParams['font.family'] = "sans-serif"
    plt.rcParams['font.sans-serif'] = "Helvetica"
    plt.rcParams['axes.labelsize'] = 14
    plt.rcParams['font.size'] = 14
    plt.rcParams['legend.fontsize'] = 14
    plt.rcParams['xtick.labelsize'] = 14
    plt.rcParams['ytick.labelsize'] = 14

    almost_black = '#262626'
    # change the tick colors also to the almost black
    plt.rcParams['ytick.color'] = almost_black
    plt.rcParams['xtick.color'] = almost_black

    # change the text colors also to the almost black
    plt.rcParams['text.color'] = almost_black

    # Change the default axis colors from black to a slightly lighter black,
    # and a little thinner (0.5 instead of 1)
    plt.rcParams['axes.edgecolor'] = almost_black
    plt.rcParams['axes.labelcolor'] = almost_black

    ax1 = fig.add_subplot(121)
    ax2 = fig.add_subplot(122)

    colour_list = ["#E69F00", "#56B4E9", "#009E73", "#CC79A7"]

    # Mask crap stuff
    df_flx.where(df_flx.GPP_qc == 1, inplace=True)
    df_flx.where(df_flx.Qle_qc == 1, inplace=True)
    df_flx.where(df_met.Tair_qc == 1, inplace=True)

    df_met.where(df_flx.GPP_qc == 1, inplace=True)
    df_met.where(df_flx.Qle_qc == 1, inplace=True)
    df_met.where(df_met.Tair_qc == 1, inplace=True)

    # Mask dew
    df_met.where(df_flx.Qle > 0., inplace=True)
    df_flx.where(df_flx.Qle > 0., inplace=True)

    # daylight hours
    #df_flx = df_flx.between_time("07:00", "20:00")
    #df_met = df_met.between_time("07:00", "20:00")
    x = df_met.Tair.values - K_TO_C
    y = df_flx.GPP.values

    ax1.plot(df_met.Tair - K_TO_C,
             df_flx.GPP,
             ls=" ",
             marker="o",
             color=colour_list[1],
             alpha=0.01)
    ax2.plot(df_met.Tair - K_TO_C,
             df_flx.Qle,
             ls=" ",
             marker="o",
             color=colour_list[1],
             alpha=0.01)

    x = x[~np.isnan(y)]
    y = y[~np.isnan(y)]
    y = y[~np.isnan(x)]
    x = x[~np.isnan(x)]
    gam = LinearGAM(n_splines=20).gridsearch(x, y)
    XX = generate_X_grid(gam)
    ax1.plot(XX, gam.predict(XX), 'k-', lw=2.0)
    #ax1.plot(XX, gam.prediction_intervals(XX, width=.95), color='k', ls='--')

    x = df_met.Tair.values - K_TO_C
    y = df_flx.Qle.values
    x = x[~np.isnan(y)]
    y = y[~np.isnan(y)]
    y = y[~np.isnan(x)]
    x = x[~np.isnan(x)]
    gam = LinearGAM(n_splines=20).gridsearch(x, y)
    XX = generate_X_grid(gam)
    ax2.plot(XX, gam.predict(XX), 'k-', lw=2.0)
    #ax2.plot(XX, gam.prediction_intervals(XX, width=.95), color='k', ls='--')

    ax1.set_xlim(0, 45)
    ax1.set_ylim(0, 40)
    ax2.set_xlim(0, 45)
    ax2.set_ylim(0, 600)
    ax1.set_xlabel("Tair (deg C)")
    ax1.xaxis.set_label_coords(1.05, -0.1)
    ax1.set_ylabel(r"GPP (umol m$^{-2}$ s$^{-1}$)")
    ax2.set_ylabel(r"LE (W m$^{-2}$)")

    fig.savefig(os.path.join(plot_dir, "%s.pdf" % (site)),
                bbox_inches='tight',
                pad_inches=0.1)
示例#19
0
from bokeh.layouts import gridplot, row
import matplotlib.pyplot as plt

# Importing data from the web
path = 'http://www.stat.cmu.edu/~larry/' \
	'all-of-nonpar/=data/rock.dat'

data = pd.read_csv(path, sep=' *', engine='python')

X = data[['peri','shape','perm']]
y = data['area']

adjy = y - np.mean(y)

gam = LinearGAM(n_splines=10).gridsearch(X, y)
XX = generate_X_grid(gam)

# fig, axs = plt.subplots(1, 3)
titles = ['peri', 'shape', 'perm']

# for i, ax in enumerate(axs):
#     pdep, confi = gam.partial_dependence(XX, feature=i+1, width=.95)
    
#     ax.scatter(X[X.columns[i]], adjy, color='gray', edgecolors='none')
#     ax.plot(XX[:, i], pdep)
#     ax.plot(XX[:, i], confi[0], c='r', ls='--')
#     ax.set_title(titles[i])
    
    
pdep, confi = gam.partial_dependence(XX, width=.95)
p = list()
def make_plot(plot_dir, site, df_flx, df_met):

    K_TO_C = 273.15

    #golden_mean = 0.6180339887498949
    #width = 6*2*(1/golden_mean)
    #height = width * golden_mean

    fig = plt.figure(figsize=(14, 4))
    fig.subplots_adjust(hspace=0.1)
    fig.subplots_adjust(wspace=0.1)
    plt.rcParams['text.usetex'] = False
    plt.rcParams['font.family'] = "sans-serif"
    plt.rcParams['font.sans-serif'] = "Helvetica"
    plt.rcParams['axes.labelsize'] = 14
    plt.rcParams['font.size'] = 14
    plt.rcParams['legend.fontsize'] = 14
    plt.rcParams['xtick.labelsize'] = 14
    plt.rcParams['ytick.labelsize'] = 14

    almost_black = '#262626'
    # change the tick colors also to the almost black
    plt.rcParams['ytick.color'] = almost_black
    plt.rcParams['xtick.color'] = almost_black

    # change the text colors also to the almost black
    plt.rcParams['text.color'] = almost_black

    # Change the default axis colors from black to a slightly lighter black,
    # and a little thinner (0.5 instead of 1)
    plt.rcParams['axes.edgecolor'] = almost_black
    plt.rcParams['axes.labelcolor'] = almost_black

    ax1 = fig.add_subplot(121)
    ax2 = fig.add_subplot(122)

    # Mask crap stuff
    df_met.where(df_flx.Qle_qc == 1, inplace=True)
    df_met.where(df_flx.Qh_qc == 1, inplace=True)

    df_flx.where(df_flx.Qle_qc == 1, inplace=True)
    df_flx.where(df_flx.Qh_qc == 1, inplace=True)
    #df_flx.where(df_met.Tair_qc == 1, inplace=True)
    #df_flx.where(df_met.SWdown == 1, inplace=True)

    #df_met.where(df_met.SWdown == 1, inplace=True)
    #df_met.where(df_met.Tair_qc == 1, inplace=True)

    # Mask dew
    df_met.where(df_flx.Qle > 0., inplace=True)
    df_flx.where(df_flx.Qle > 0., inplace=True)

    df_flx.dropna(inplace=True)
    df_met.dropna(inplace=True)

    if len(df_flx) > 0 and len(df_met) > 0:
        print(site, len(df_flx), len(df_met))

        alpha = 0.07

        # < "Midday" data
        df_flx = df_flx.between_time("09:00", "13:00")
        df_met = df_met.between_time("09:00", "13:00")

        ax1.plot(df_met.SWdown,
                 df_flx.Qle,
                 ls=" ",
                 marker=".",
                 mec="#FF0000",
                 color="#FF0000",
                 alpha=alpha)
        ax1.plot(df_met.SWdown,
                 df_flx.Qh,
                 ls=" ",
                 marker=".",
                 mec="royalblue",
                 color="royalblue",
                 alpha=alpha)

        gam = LinearGAM(n_splines=20).gridsearch(df_met.SWdown, df_flx.Qle)
        XX = generate_X_grid(gam)
        CI = gam.confidence_intervals(XX, width=.95)

        ax1.plot(XX,
                 gam.predict(XX),
                 color="#FF0000",
                 ls='-',
                 lw=2.0,
                 label="Qle")
        ax1.fill_between(XX[:, 0],
                         CI[:, 0],
                         CI[:, 1],
                         color='salmon',
                         alpha=0.7)

        gam = LinearGAM(n_splines=20).gridsearch(df_met.SWdown, df_flx.Qh)
        XX = generate_X_grid(gam)
        CI = gam.confidence_intervals(XX, width=.95)
        ax1.plot(XX,
                 gam.predict(XX),
                 color="royalblue",
                 ls='-',
                 lw=2.0,
                 label="Qh")
        ax1.fill_between(XX[:, 0],
                         CI[:, 0],
                         CI[:, 1],
                         color='CornflowerBlue',
                         alpha=0.7)

        ax2.plot(df_met.Tair - K_TO_C,
                 df_flx.Qle,
                 ls=" ",
                 marker=".",
                 color="#FF0000",
                 alpha=alpha,
                 mec="#FF0000")
        ax2.plot(df_met.Tair - K_TO_C,
                 df_flx.Qh,
                 ls=" ",
                 marker=".",
                 color="royalblue",
                 alpha=alpha,
                 mec="royalblue")

        gam = LinearGAM(n_splines=20).gridsearch(df_met.Tair - K_TO_C,
                                                 df_flx.Qle)
        XX = generate_X_grid(gam)
        CI = gam.confidence_intervals(XX, width=.95)
        ax2.plot(XX, gam.predict(XX), color="#FF0000", ls='-', lw=2.0)
        ax2.fill_between(XX[:, 0],
                         CI[:, 0],
                         CI[:, 1],
                         color='salmon',
                         alpha=0.7)

        gam = LinearGAM(n_splines=20).gridsearch(df_met.Tair - K_TO_C,
                                                 df_flx.Qh)
        XX = generate_X_grid(gam)
        CI = gam.confidence_intervals(XX, width=.95)
        ax2.plot(XX, gam.predict(XX), color="royalblue", ls='-', lw=2.0)
        ax2.fill_between(XX[:, 0],
                         CI[:, 0],
                         CI[:, 1],
                         color='CornflowerBlue',
                         alpha=0.7)
        plt.setp(ax2.get_yticklabels(), visible=False)

        ax1.set_xlim(0, 1300)
        ax1.set_ylim(0, 1000)
        ax2.set_xlim(0, 45)
        ax2.set_ylim(0, 1000)
        ax1.set_xlabel("SW down (W m$^{-2}$)")
        ax2.set_xlabel("Tair ($^\circ$C)")
        ax1.set_ylabel("Daytime flux (W m$^{-2}$)")
        #ax1.legend(numpoints=1, loc="best")
        #fig.savefig(os.path.join(plot_dir, "%s.pdf" % (site)),
        #            bbox_inches='tight', pad_inches=0.1)

        fig.savefig(os.path.join(plot_dir, "%s.png" % (site)),
                    bbox_inches='tight',
                    pad_inches=0.1,
                    dpi=100)
def main(flux_dir):
    K_TO_C = 273.15
    sites = ["AdelaideRiver","Calperum","CapeTribulation","CowBay",\
             "CumberlandPlains","DalyPasture","DalyUncleared",\
             "DryRiver","Emerald","Gingin","GreatWesternWoodlands",\
             "HowardSprings","Otway","RedDirtMelonFarm","RiggsCreek",\
             "Samford","SturtPlains","Tumbarumba","Whroo",\
             "WombatStateForest","Yanco"]

    pfts = ["SAV","SHB","TRF","TRF","EBF","GRA","SAV",\
            "SAV","NA","EBF","EBF",\
            "SAV","GRA","NA","GRA",\
            "GRA","GRA","EBF","EBF",\
            "EBF","GRA"]

    d = dict(zip(sites, pfts))
    id = dict(zip(sites, pd.factorize(pfts)[0]))

    plot_dir = "plots"
    if not os.path.exists(plot_dir):
        os.makedirs(plot_dir)

    flux_files = sorted(glob.glob(os.path.join(flux_dir, "*_flux.nc")))
    met_files = sorted(glob.glob(os.path.join(flux_dir, "*_met.nc")))

    data_qle = []
    data_qh = []
    data_tair = []
    data_sw = []
    pft_ids = []

    # collect up data
    for flux_fn, met_fn in zip(flux_files, met_files):
        (site, df_flx, df_met) = open_file(flux_fn, met_fn)

        if d[site] != "NA":
            pft = d[site]
            colour_id = id[site]

            # Mask crap stuff
            df_met.where(df_flx.Qle_qc == 1, inplace=True)
            df_met.where(df_flx.Qh_qc == 1, inplace=True)

            df_flx.where(df_flx.Qle_qc == 1, inplace=True)
            df_flx.where(df_flx.Qh_qc == 1, inplace=True)
            #df_flx.where(df_met.Tair_qc == 1, inplace=True)
            #df_flx.where(df_met.SWdown == 1, inplace=True)

            #df_met.where(df_met.SWdown == 1, inplace=True)
            #df_met.where(df_met.Tair_qc == 1, inplace=True)

            # Mask dew
            df_met.where(df_flx.Qle > 0., inplace=True)
            df_flx.where(df_flx.Qle > 0., inplace=True)

            df_flx.dropna(inplace=True)
            df_met.dropna(inplace=True)

            df_flx = df_flx.between_time("09:00", "13:00")
            df_met = df_met.between_time("09:00", "13:00")

            if len(df_flx) > 0 and len(df_met) > 0:
                #data_qle[pft].append(df_flx.Qle.values)
                #data_qh[pft].append(df_flx.Qh.values)
                #data_tair[pft].append(df_met.Tair.values - K_TO_C)
                #data_sw[pft].append(df_met.SWdown.values)

                data_qle.append(df_flx.Qle.values)
                data_qh.append(df_flx.Qh.values)
                data_tair.append(df_met.Tair.values - K_TO_C)
                data_sw.append(df_met.SWdown.values)
                pft_ids.append([pft] * len(df_flx))

    pft_ids = list(itertools.chain(*pft_ids))
    data_qle = list(itertools.chain(*data_qle))
    data_qh = list(itertools.chain(*data_qh))
    data_sw = list(itertools.chain(*data_sw))
    data_tair = list(itertools.chain(*data_tair))

    data_qle = np.asarray(data_qle)
    data_qh = np.asarray(data_qh)
    data_tair = np.asarray(data_tair)
    data_sw = np.asarray(data_sw)
    pft_ids = np.asarray(pft_ids)

    colours = ["red", "green", "blue", "yellow", "pink"]

    fig = plt.figure(figsize=(14, 4))
    fig.subplots_adjust(hspace=0.1)
    fig.subplots_adjust(wspace=0.1)
    plt.rcParams['text.usetex'] = False
    plt.rcParams['font.family'] = "sans-serif"
    plt.rcParams['font.sans-serif'] = "Helvetica"
    plt.rcParams['axes.labelsize'] = 14
    plt.rcParams['font.size'] = 14
    plt.rcParams['legend.fontsize'] = 14
    plt.rcParams['xtick.labelsize'] = 14
    plt.rcParams['ytick.labelsize'] = 14

    almost_black = '#262626'
    # change the tick colors also to the almost black
    plt.rcParams['ytick.color'] = almost_black
    plt.rcParams['xtick.color'] = almost_black

    # change the text colors also to the almost black
    plt.rcParams['text.color'] = almost_black

    # Change the default axis colors from black to a slightly lighter black,
    # and a little thinner (0.5 instead of 1)
    plt.rcParams['axes.edgecolor'] = almost_black
    plt.rcParams['axes.labelcolor'] = almost_black

    ax1 = fig.add_subplot(221)
    ax2 = fig.add_subplot(222)
    ax3 = fig.add_subplot(223)
    ax4 = fig.add_subplot(224)

    colour_id = 0
    for pft in np.unique(pfts):

        if pft != "NA":
            qle = data_qle[np.argwhere(pft_ids == pft)]
            qh = data_qh[np.argwhere(pft_ids == pft)]
            tair = data_tair[np.argwhere(pft_ids == pft)]
            sw = data_sw[np.argwhere(pft_ids == pft)]

            print(pft, len(qle), len(qh), len(tair), len(sw))

            gam = LinearGAM(n_splines=20).gridsearch(sw, qh)
            XX = generate_X_grid(gam)
            CI = gam.confidence_intervals(XX, width=.95)

            ax1.plot(XX,
                     gam.predict(XX),
                     color=colours[colour_id],
                     ls='-',
                     lw=2.0)
            ax1.fill_between(XX[:, 0],
                             CI[:, 0],
                             CI[:, 1],
                             color=colours[colour_id],
                             alpha=0.7)

            gam = LinearGAM(n_splines=20).gridsearch(sw, qle)
            XX = generate_X_grid(gam)
            CI = gam.confidence_intervals(XX, width=.95)

            ax2.plot(XX,
                     gam.predict(XX),
                     color=colours[colour_id],
                     ls='-',
                     lw=2.0)
            ax2.fill_between(XX[:, 0],
                             CI[:, 0],
                             CI[:, 1],
                             color=colours[colour_id],
                             alpha=0.7)

            gam = LinearGAM(n_splines=20).gridsearch(tair, qh)
            XX = generate_X_grid(gam)
            CI = gam.confidence_intervals(XX, width=.95)
            ax3.plot(XX,
                     gam.predict(XX),
                     color=colours[colour_id],
                     ls='-',
                     lw=2.0)
            ax3.fill_between(XX[:, 0],
                             CI[:, 0],
                             CI[:, 1],
                             color=colours[colour_id],
                             alpha=0.7)

            gam = LinearGAM(n_splines=20).gridsearch(tair, qle)
            XX = generate_X_grid(gam)
            CI = gam.confidence_intervals(XX, width=.95)
            ax4.plot(XX,
                     gam.predict(XX),
                     color=colours[colour_id],
                     ls='-',
                     lw=2.0)
            ax4.fill_between(XX[:, 0],
                             CI[:, 0],
                             CI[:, 1],
                             color=colours[colour_id],
                             alpha=0.7)

            colour_id += 1

    plt.setp(ax1.get_xticklabels(), visible=False)
    plt.setp(ax2.get_xticklabels(), visible=False)

    ax1.set_xlim(0, 1300)
    ax1.set_ylim(0, 1000)
    ax2.set_xlim(0, 45)
    ax2.set_ylim(0, 1000)
    ax3.set_xlabel("SW down (W m$^{-2}$)")
    ax4.set_xlabel("Tair ($^\circ$C)")
    ax1.set_ylabel("Qh flux (W m$^{-2}$)")
    ax2.set_ylabel("Qle flux (W m$^{-2}$)")
    #ax1.legend(numpoints=1, loc="best")
    #fig.savefig(os.path.join(plot_dir, "%s.pdf" % (site)),
    #            bbox_inches='tight', pad_inches=0.1)

    fig.savefig(os.path.join(plot_dir, "ozflux_by_pft.png"),
                bbox_inches='tight',
                pad_inches=0.1,
                dpi=150)
示例#22
0
# -*- coding: utf-8 -*-
"""
@author: Christian Winkler
"""
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from pygam import LinearGAM
import pygam

from pygam.utils import generate_X_grid

example_data = pd.read_csv("example_data.csv")
y = example_data['head'].values
X = example_data['age'].values

gam = LinearGAM(n_splines=4).fit(X, y)  # your fitted model

# change resolution of X grid
XX = generate_X_grid(gam, n=20)

plt.figure(figsize=(10, 8))
plt.scatter(X, y)
plt.plot(XX,
         gam.prediction_intervals(XX, quantiles=[.025, .5, .975]),
         color="k")
plt.savefig("pygam_example_2.png")
plt.show()