Exemple #1
0
def smooth_gam(x, y, n_splines=100, lam=10):
    from pygam import ExpectileGAM, LinearGAM, s, f
    gam = LinearGAM(s(0, n_splines=n_splines), lam=lam).fit(x, y)
    # gam = ExpectileGAM(s(0, n_splines=n_splines), expectile=0.5, lam=lam).gridsearch(x.values.reshape((-1,1)), y)
    XX = gam.generate_X_grid(term=0)
    confi = gam.confidence_intervals(XX)
    # confi = gam.prediction_intervals(XX)
    ym = gam.predict_mu(XX)
    return XX[:, 0], ym, confi
        # and a little thinner (0.5 instead of 1)
        plt.rcParams['axes.edgecolor'] = almost_black
        plt.rcParams['axes.labelcolor'] = almost_black

        ax1 = fig.add_subplot(211)
        ax2 = fig.add_subplot(212)

        if plot_type == "GAM":
            nsplines = 20

            lct_1D = np.tile(np.arange(8), 22)
            gam1 = LinearGAM(n_splines=nsplines).fit(
                lct_1D, soilmoist_rn1.reshape(8 * 22))
            x_pred = np.linspace(0, 7, num=100)
            y_pred1 = gam1.predict(x_pred)
            y_int1 = gam1.confidence_intervals(x_pred, width=.95)
            np.savetxt('soilmoist_rn1.out',
                       soilmoist_rn1.reshape(8 * 22),
                       delimiter=',')
            np.savetxt('soilmoist_rn2.out',
                       soilmoist_rn2.reshape(8 * 22),
                       delimiter=',')
            np.savetxt('soilmoist_tdr_rn2.out',
                       soilmoist_tdr_rn2.reshape(8 * 22),
                       delimiter=',')

            gam2 = LinearGAM(n_splines=nsplines).fit(
                lct_1D, soilmoist_rn2.reshape(8 * 22))
            y_pred2 = gam2.predict(x_pred)
            y_int2 = gam2.confidence_intervals(x_pred, width=.95)
def main(flux_dir):
    K_TO_C = 273.15
    sites = ["AdelaideRiver","Calperum","CapeTribulation","CowBay",\
             "CumberlandPlains","DalyPasture","DalyUncleared",\
             "DryRiver","Emerald","Gingin","GreatWesternWoodlands",\
             "HowardSprings","Otway","RedDirtMelonFarm","RiggsCreek",\
             "Samford","SturtPlains","Tumbarumba","Whroo",\
             "WombatStateForest","Yanco"]

    pfts = ["SAV","SHB","TRF","TRF","EBF","GRA","SAV",\
            "SAV","NA","EBF","EBF",\
            "SAV","GRA","NA","GRA",\
            "GRA","GRA","EBF","EBF",\
            "EBF","GRA"]

    d = dict(zip(sites, pfts))
    id = dict(zip(sites, pd.factorize(pfts)[0]))

    plot_dir = "plots"
    if not os.path.exists(plot_dir):
        os.makedirs(plot_dir)

    flux_files = sorted(glob.glob(os.path.join(flux_dir, "*_flux.nc")))
    met_files = sorted(glob.glob(os.path.join(flux_dir, "*_met.nc")))

    data_qle = []
    data_qh = []
    data_tair = []
    data_sw = []
    pft_ids = []

    # collect up data
    for flux_fn, met_fn in zip(flux_files, met_files):
        (site, df_flx, df_met) = open_file(flux_fn, met_fn)

        if d[site] != "NA":
            pft = d[site]
            colour_id = id[site]

            # Mask crap stuff
            df_met.where(df_flx.Qle_qc == 1, inplace=True)
            df_met.where(df_flx.Qh_qc == 1, inplace=True)

            df_flx.where(df_flx.Qle_qc == 1, inplace=True)
            df_flx.where(df_flx.Qh_qc == 1, inplace=True)
            #df_flx.where(df_met.Tair_qc == 1, inplace=True)
            #df_flx.where(df_met.SWdown == 1, inplace=True)

            #df_met.where(df_met.SWdown == 1, inplace=True)
            #df_met.where(df_met.Tair_qc == 1, inplace=True)

            # Mask dew
            df_met.where(df_flx.Qle > 0., inplace=True)
            df_flx.where(df_flx.Qle > 0., inplace=True)

            df_flx.dropna(inplace=True)
            df_met.dropna(inplace=True)

            df_flx = df_flx.between_time("09:00", "13:00")
            df_met = df_met.between_time("09:00", "13:00")

            if len(df_flx) > 0 and len(df_met) > 0:
                #data_qle[pft].append(df_flx.Qle.values)
                #data_qh[pft].append(df_flx.Qh.values)
                #data_tair[pft].append(df_met.Tair.values - K_TO_C)
                #data_sw[pft].append(df_met.SWdown.values)

                data_qle.append(df_flx.Qle.values)
                data_qh.append(df_flx.Qh.values)
                data_tair.append(df_met.Tair.values - K_TO_C)
                data_sw.append(df_met.SWdown.values)
                pft_ids.append([pft] * len(df_flx))

    pft_ids = list(itertools.chain(*pft_ids))
    data_qle = list(itertools.chain(*data_qle))
    data_qh = list(itertools.chain(*data_qh))
    data_sw = list(itertools.chain(*data_sw))
    data_tair = list(itertools.chain(*data_tair))

    data_qle = np.asarray(data_qle)
    data_qh = np.asarray(data_qh)
    data_tair = np.asarray(data_tair)
    data_sw = np.asarray(data_sw)
    pft_ids = np.asarray(pft_ids)

    colours = ["red", "green", "blue", "yellow", "pink"]

    fig = plt.figure(figsize=(14, 4))
    fig.subplots_adjust(hspace=0.1)
    fig.subplots_adjust(wspace=0.1)
    plt.rcParams['text.usetex'] = False
    plt.rcParams['font.family'] = "sans-serif"
    plt.rcParams['font.sans-serif'] = "Helvetica"
    plt.rcParams['axes.labelsize'] = 14
    plt.rcParams['font.size'] = 14
    plt.rcParams['legend.fontsize'] = 14
    plt.rcParams['xtick.labelsize'] = 14
    plt.rcParams['ytick.labelsize'] = 14

    almost_black = '#262626'
    # change the tick colors also to the almost black
    plt.rcParams['ytick.color'] = almost_black
    plt.rcParams['xtick.color'] = almost_black

    # change the text colors also to the almost black
    plt.rcParams['text.color'] = almost_black

    # Change the default axis colors from black to a slightly lighter black,
    # and a little thinner (0.5 instead of 1)
    plt.rcParams['axes.edgecolor'] = almost_black
    plt.rcParams['axes.labelcolor'] = almost_black

    ax1 = fig.add_subplot(221)
    ax2 = fig.add_subplot(222)
    ax3 = fig.add_subplot(223)
    ax4 = fig.add_subplot(224)

    colour_id = 0
    for pft in np.unique(pfts):

        if pft != "NA":
            qle = data_qle[np.argwhere(pft_ids == pft)]
            qh = data_qh[np.argwhere(pft_ids == pft)]
            tair = data_tair[np.argwhere(pft_ids == pft)]
            sw = data_sw[np.argwhere(pft_ids == pft)]

            print(pft, len(qle), len(qh), len(tair), len(sw))

            gam = LinearGAM(n_splines=20).gridsearch(sw, qh)
            XX = generate_X_grid(gam)
            CI = gam.confidence_intervals(XX, width=.95)

            ax1.plot(XX,
                     gam.predict(XX),
                     color=colours[colour_id],
                     ls='-',
                     lw=2.0)
            ax1.fill_between(XX[:, 0],
                             CI[:, 0],
                             CI[:, 1],
                             color=colours[colour_id],
                             alpha=0.7)

            gam = LinearGAM(n_splines=20).gridsearch(sw, qle)
            XX = generate_X_grid(gam)
            CI = gam.confidence_intervals(XX, width=.95)

            ax2.plot(XX,
                     gam.predict(XX),
                     color=colours[colour_id],
                     ls='-',
                     lw=2.0)
            ax2.fill_between(XX[:, 0],
                             CI[:, 0],
                             CI[:, 1],
                             color=colours[colour_id],
                             alpha=0.7)

            gam = LinearGAM(n_splines=20).gridsearch(tair, qh)
            XX = generate_X_grid(gam)
            CI = gam.confidence_intervals(XX, width=.95)
            ax3.plot(XX,
                     gam.predict(XX),
                     color=colours[colour_id],
                     ls='-',
                     lw=2.0)
            ax3.fill_between(XX[:, 0],
                             CI[:, 0],
                             CI[:, 1],
                             color=colours[colour_id],
                             alpha=0.7)

            gam = LinearGAM(n_splines=20).gridsearch(tair, qle)
            XX = generate_X_grid(gam)
            CI = gam.confidence_intervals(XX, width=.95)
            ax4.plot(XX,
                     gam.predict(XX),
                     color=colours[colour_id],
                     ls='-',
                     lw=2.0)
            ax4.fill_between(XX[:, 0],
                             CI[:, 0],
                             CI[:, 1],
                             color=colours[colour_id],
                             alpha=0.7)

            colour_id += 1

    plt.setp(ax1.get_xticklabels(), visible=False)
    plt.setp(ax2.get_xticklabels(), visible=False)

    ax1.set_xlim(0, 1300)
    ax1.set_ylim(0, 1000)
    ax2.set_xlim(0, 45)
    ax2.set_ylim(0, 1000)
    ax3.set_xlabel("SW down (W m$^{-2}$)")
    ax4.set_xlabel("Tair ($^\circ$C)")
    ax1.set_ylabel("Qh flux (W m$^{-2}$)")
    ax2.set_ylabel("Qle flux (W m$^{-2}$)")
    #ax1.legend(numpoints=1, loc="best")
    #fig.savefig(os.path.join(plot_dir, "%s.pdf" % (site)),
    #            bbox_inches='tight', pad_inches=0.1)

    fig.savefig(os.path.join(plot_dir, "ozflux_by_pft.png"),
                bbox_inches='tight',
                pad_inches=0.1,
                dpi=150)
def plot_LAI_MAP(df):

    # Clean the df
    df = df.drop(columns=['Potentially_erroneous_data', 'Reference_number',\
                          'Method'])
    df.rename(columns={
        'MAT_(Literature value)': 'MAT',
        'MAP_(Literature value)': 'MAP'
    },
              inplace=True)

    df_eucs = df[(df['PFT'] == "EB") & \
                 (df['MAP'] < 3000.0) & \
                 (df['Vegetation_status'] == "Natural") &\
                 (df['Dominant_species'].str.contains('Eucalyptus'))]

    df_ebf = df[(df['PFT'] == "EB") &\
                (df['MAP'] < 3000.0) & \
                (df['Vegetation_status'] == "Natural") &\
                (~ df['Dominant_species'].str.contains('Eucalyptus'))]

    fig = plt.figure(figsize=(9, 6))
    fig.subplots_adjust(hspace=0.1)
    fig.subplots_adjust(wspace=0.05)
    plt.rcParams['text.usetex'] = False
    plt.rcParams['font.family'] = "sans-serif"
    plt.rcParams['font.sans-serif'] = "Helvetica"
    plt.rcParams['axes.labelsize'] = 14
    plt.rcParams['font.size'] = 14
    plt.rcParams['legend.fontsize'] = 14
    plt.rcParams['xtick.labelsize'] = 14
    plt.rcParams['ytick.labelsize'] = 14

    ax = fig.add_subplot(111)

    colours = sns.color_palette("Set2", 8)

    ax.plot(df_ebf.MAP,
            df_ebf.Total_LAI,
            color=colours[0],
            ls=" ",
            marker="o",
            label="Global EBF")
    nsplines = 7
    x = df_ebf.MAP.values
    y = df_ebf.Total_LAI.values
    gam = LinearGAM(n_splines=nsplines).fit(x, y)
    x_pred = np.linspace(min(x), max(x), num=100)
    y_pred = gam.predict(x_pred)
    y_int = gam.confidence_intervals(x_pred, width=.95)
    ax.plot(x_pred, y_pred, color=colours[0], ls='-', lw=2.0, zorder=10)
    ax.fill_between(x_pred,
                    y_int[:, 0],
                    y_int[:, 1],
                    alpha=0.2,
                    facecolor=colours[0],
                    zorder=10)

    ax.plot(df_eucs.MAP,
            df_eucs.Total_LAI,
            color=colours[1],
            ls=" ",
            marker="o",
            label="Eucalypts")
    nsplines = 4
    x = df_eucs.MAP.values
    y = df_eucs.Total_LAI.values
    gam = LinearGAM(n_splines=nsplines).fit(x, y)
    x_pred = np.linspace(min(x), max(x), num=100)
    y_pred = gam.predict(x_pred)
    y_int = gam.confidence_intervals(x_pred, width=.95)
    ax.plot(x_pred, y_pred, color=colours[1], ls='-', lw=2.0, zorder=10)
    ax.fill_between(x_pred,
                    y_int[:, 0],
                    y_int[:, 1],
                    alpha=0.2,
                    facecolor=colours[1],
                    zorder=10)

    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.legend(numpoints=1, loc="best", frameon=False)
    ax.set_xlabel("Mean annual precipitation (mm)")
    ax.set_ylabel("Leaf area index (m$^{2}$ m$^{-2}$)")

    plt.show()

    odir = "plots"
    fig.savefig(os.path.join(odir, "LAI_vs_MAP_Eucs_vs_EBF.pdf"),
                bbox_inches='tight',
                pad_inches=0.1)

    plt.show()
def main(fname1, fname2, fname3, fname4):

    df1 = pd.read_csv(fname1)
    df1 = df1[df1.pft == "EBF"]

    df2 = pd.read_csv(fname2)
    df2 = df2[df2.pft == "EBF"]

    # Add in Alice information from FLUXNET
    df3 = pd.read_csv(fname3)
    df3 = df3[df3.pft == "ENF"] # Alice misclassified by fluxnet
    df3 = df3[~np.isnan(df3.gpp_sqrt_d)]
    df1 = df1.append(df3, ignore_index=True)

    df4 = pd.read_csv(fname4)
    df4 = df4[df4.pft == "ENF"] # Alice misclassified by fluxnet
    df4 = df4[~np.isnan(df4.gpp_sqrt_d)]
    df2 = df2.append(df4, ignore_index=True)

    df1.loc[df1.site == "AU-ASM", "site"] = "Alice Springs"
    df2.loc[df2.site == "AU-ASM", "site"] = "Alice Springs"

    width = 14
    height = 9
    fig = plt.figure(figsize=(width, height))
    fig.subplots_adjust(hspace=0.05)
    fig.subplots_adjust(wspace=0.02)
    plt.rcParams['text.usetex'] = False
    plt.rcParams['font.family'] = "sans-serif"
    plt.rcParams['font.sans-serif'] = "Helvetica"
    plt.rcParams['axes.labelsize'] = 16
    plt.rcParams['font.size'] = 16
    plt.rcParams['legend.fontsize'] = 12
    plt.rcParams['xtick.labelsize'] = 16
    plt.rcParams['ytick.labelsize'] = 16

    labels = label_generator('lower', start="(", end=")")

    count=0
    sites = np.unique(df1.site)

    for site in sites:
        df_site1 = df1[df1.site == site]
        df_site2 = df2[df2.site == site]

        site_name = re.sub(r"(\w)([A-Z])", r"\1 \2", site)
        ax = fig.add_subplot(3,2,1+count)

        ax.plot(df_site2.gpp_sqrt_d, df_site2.et,
                label="Heatwave days", color="red", ls="", marker=".",
                lw=2, zorder=100, ms=9, mec="#ffb3b3")
        ax.plot(df_site1.gpp_sqrt_d, df_site1.et,
                label="Non-heatwave days", color="blue", ls="", marker=".",
                lw=2, alpha=0.3, ms=9)

        from scipy import stats

        x = df_site2.gpp_sqrt_d
        y = df_site2.et
        y = y[~np.isnan(x)]
        x = x[~np.isnan(x)]
        x = x[~np.isnan(y)]
        y = y[~np.isnan(y)]

        if site != "Alice Springs":
            nsplines = 20
        else:
            nsplines = 10

        if site != "CumberlandPlains":
            gam = LinearGAM(n_splines=nsplines).gridsearch(x, y)

            x_pred = np.linspace(min(x), max(x), num=100)
            y_pred = gam.predict(x_pred)
            y_int = gam.confidence_intervals(x_pred, width=.95)

            ax.plot(x_pred, y_pred, color="red", ls='-', lw=2.0, zorder=10)
            ax.fill_between(x_pred, y_int[:, 0], y_int[:, 1], alpha=0.2,
                            facecolor='red', zorder=10)

        x = df_site1.gpp_sqrt_d
        y = df_site1.et

        y = y[~np.isnan(x)]
        x = x[~np.isnan(x)]
        x = x[~np.isnan(y)]
        y = y[~np.isnan(y)]

        gam = LinearGAM(n_splines=nsplines).gridsearch(x, y)

        x_pred = np.linspace(min(x), max(x), num=100)
        y_pred = gam.predict(x_pred)
        y_int = gam.confidence_intervals(x_pred, width=.95)

        ax.plot(x_pred, y_pred, color="blue", ls='-', lw=2.0, zorder=10)
        ax.fill_between(x_pred, y_int[:, 0], y_int[:, 1], alpha=0.2,
                        facecolor='blue', zorder=10)

        if count < 4:
            plt.setp(ax.get_xticklabels(), visible=False)

        if count != 0 and count != 2 and count != 4:
            plt.setp(ax.get_yticklabels(), visible=False)

        props = dict(boxstyle='round', facecolor='white', alpha=1.0,
                     ec="white")


        fig_label = "%s %s" % (next(labels), site_name)
        ax.text(0.02, 0.95, fig_label,
                transform=ax.transAxes, fontsize=14, verticalalignment='top',
                bbox=props)
        ax.set_xlim(0, 11)
        ax.set_ylim(0, 4)

        if count == 2:
            ax.set_ylabel('E (mm d$^{-1}$)')
        if count == 4:
            #ax.set_xlabel('Temperature ($^\circ$C)', position=(1.0, 0.5))
            ax.set_xlabel(r"GPP $\times$ D$^{0.5}$ (g C kPa$^{0.5}$ m$^{-2}$ d$^{-1}$)",
                          position=(1.0, 0.5))
        else:
            ax.set_xlabel(" ")
        from matplotlib.ticker import MaxNLocator
        ax.yaxis.set_major_locator(MaxNLocator(3))
        ax.xaxis.set_major_locator(MaxNLocator(3))
        ax.tick_params(direction='in', length=4)

        if site == "Calperum":
            #ax.legend(numpoints=1, ncol=1, frameon=False, loc=(1.6, 0.0))
            ax.legend(numpoints=1, ncol=1, frameon=False, loc="best")

        count+=1


    ofdir = "/Users/mdekauwe/Dropbox/fluxnet_heatwaves_paper/figures/figs"
    ofname = "ozflux_heatwave_vs_non_heatwave_gppsqrtd_et.pdf"
    fig.savefig(os.path.join(ofdir, ofname),
                bbox_inches='tight', pad_inches=0.1)
def calculate_gene_trends(session_ID, list_of_genes, branch_ID):
    n_steps = 2 + len(list_of_genes)

    #uns = cache_adata(session_ID, group="uns")
    obs = cache_adata(session_ID, group="obs")
    cache_progress(session_ID, progress=int(1 / n_steps * 100))

    if (branch_ID == -1):
        branch_probs = None
    else:
        branch_probs = obs["pseudotime_branch_" + str(branch_ID)]

    pseudotime = obs["pseudotime"]
    cache_progress(session_ID, progress=int(2 / n_steps * 100))

    if ((branch_ID == -1) or (branch_probs is None)):
        cells_in_branch = obs.index
    else:
        cells_in_branch = obs[obs["pseudotime_branch_" +
                                  str(branch_ID)] > 0.2].index
    print("[DEBUG] branch: " + str(branch_ID))
    '''
    gene_trends = palantir.presults.compute_gene_trends(pr_res, 
                                                        imp_df.loc[:, genes],
                                                        lineages = [branch],
                                                        n_jobs=1)
    '''
    X_train = pseudotime.to_numpy()

    # reduce the number of data points we fit to save computation time
    max_samples_to_fit = 5000
    if (len(X_train) <= max_samples_to_fit):
        subsample_mask = np.ones_like(X_train)
    else:
        subsample_mask = np.zeros_like(X_train)
        subsample_mask[0:max_samples_to_fit] = 1
        np.random.shuffle(subsample_mask)
    subsample_mask = np.array(subsample_mask, dtype=bool)

    if ((branch_ID != -1) and not (branch_probs is None)):
        weights = branch_probs.to_numpy()
    else:
        weights = np.ones_like(X_train)

    X_train = X_train[subsample_mask]
    weights = weights[subsample_mask]

    X_train = np.reshape(X_train, (len(X_train), 1))
    weights = np.reshape(weights, (len(weights), 1))

    X_plot = np.linspace(np.min(obs["pseudotime"][cells_in_branch]),
                         np.max(obs["pseudotime"][cells_in_branch]), 125)

    gene_trends = pd.DataFrame()
    gene_trends["pseudotime"] = X_plot

    step_number = 3
    for gene in list_of_genes:
        #Y_train = adata.obs_vector(gene, layer="imputed")

        time_0 = datetime.now()
        Y_train = get_obs_vector(session_ID, gene, layer="imputed")
        print("[BENCH] time for get_obs_vector: " +
              str(datetime.now() - time_0))
        Y_train = Y_train[subsample_mask]

        gam = LinearGAM(n_splines=5, spline_order=3)

        time_0 = datetime.now()
        gam.gridsearch(X_train, Y_train, weights=weights, progress=False)
        print("[BENCH] time for gam fit: " + str(datetime.now() - time_0))

        #gam = ExpectileGAM(terms="s(0)", expectile=0.5).gridsearch(X_train, Y_train)
        #lam = gam.lam
        #gam_upper = ExpectileGAM(expectile=0.75, lam=lam).fit(X_train, Y_train)
        #gam_lower = ExpectileGAM(expectile=0.25, lam=lam).fit(X_train, Y_train)
        gene_trends[gene] = gam.predict(X_plot)
        #gene_trends[gene + "_ci_upper"] = gam_upper.predict(X_plot)
        #gene_trends[gene + "_ci_lower"] = gam_lower.predict(X_plot)

        ci = gam.confidence_intervals(X_plot, width=.95)
        gene_trends[gene + "_ci_upper"] = ci[:, 1]
        gene_trends[gene + "_ci_lower"] = ci[:, 0]
        cache_progress(session_ID, progress=int(step_number / n_steps * 100))
        step_number += 1
    gene_trends = gene_trends.clip(lower=0)
    return gene_trends
def make_plot(plot_dir, site, df_flx, df_met):

    K_TO_C = 273.15

    #golden_mean = 0.6180339887498949
    #width = 6*2*(1/golden_mean)
    #height = width * golden_mean

    fig = plt.figure(figsize=(14, 4))
    fig.subplots_adjust(hspace=0.1)
    fig.subplots_adjust(wspace=0.1)
    plt.rcParams['text.usetex'] = False
    plt.rcParams['font.family'] = "sans-serif"
    plt.rcParams['font.sans-serif'] = "Helvetica"
    plt.rcParams['axes.labelsize'] = 14
    plt.rcParams['font.size'] = 14
    plt.rcParams['legend.fontsize'] = 14
    plt.rcParams['xtick.labelsize'] = 14
    plt.rcParams['ytick.labelsize'] = 14

    almost_black = '#262626'
    # change the tick colors also to the almost black
    plt.rcParams['ytick.color'] = almost_black
    plt.rcParams['xtick.color'] = almost_black

    # change the text colors also to the almost black
    plt.rcParams['text.color'] = almost_black

    # Change the default axis colors from black to a slightly lighter black,
    # and a little thinner (0.5 instead of 1)
    plt.rcParams['axes.edgecolor'] = almost_black
    plt.rcParams['axes.labelcolor'] = almost_black

    ax1 = fig.add_subplot(121)
    ax2 = fig.add_subplot(122)

    # Mask crap stuff
    df_met.where(df_flx.Qle_qc == 1, inplace=True)
    df_met.where(df_flx.Qh_qc == 1, inplace=True)

    df_flx.where(df_flx.Qle_qc == 1, inplace=True)
    df_flx.where(df_flx.Qh_qc == 1, inplace=True)
    #df_flx.where(df_met.Tair_qc == 1, inplace=True)
    #df_flx.where(df_met.SWdown == 1, inplace=True)

    #df_met.where(df_met.SWdown == 1, inplace=True)
    #df_met.where(df_met.Tair_qc == 1, inplace=True)

    # Mask dew
    df_met.where(df_flx.Qle > 0., inplace=True)
    df_flx.where(df_flx.Qle > 0., inplace=True)

    df_flx.dropna(inplace=True)
    df_met.dropna(inplace=True)

    if len(df_flx) > 0 and len(df_met) > 0:
        print(site, len(df_flx), len(df_met))

        alpha = 0.07

        # < "Midday" data
        df_flx = df_flx.between_time("09:00", "13:00")
        df_met = df_met.between_time("09:00", "13:00")

        ax1.plot(df_met.SWdown,
                 df_flx.Qle,
                 ls=" ",
                 marker=".",
                 mec="#FF0000",
                 color="#FF0000",
                 alpha=alpha)
        ax1.plot(df_met.SWdown,
                 df_flx.Qh,
                 ls=" ",
                 marker=".",
                 mec="royalblue",
                 color="royalblue",
                 alpha=alpha)

        gam = LinearGAM(n_splines=20).gridsearch(df_met.SWdown, df_flx.Qle)
        XX = generate_X_grid(gam)
        CI = gam.confidence_intervals(XX, width=.95)

        ax1.plot(XX,
                 gam.predict(XX),
                 color="#FF0000",
                 ls='-',
                 lw=2.0,
                 label="Qle")
        ax1.fill_between(XX[:, 0],
                         CI[:, 0],
                         CI[:, 1],
                         color='salmon',
                         alpha=0.7)

        gam = LinearGAM(n_splines=20).gridsearch(df_met.SWdown, df_flx.Qh)
        XX = generate_X_grid(gam)
        CI = gam.confidence_intervals(XX, width=.95)
        ax1.plot(XX,
                 gam.predict(XX),
                 color="royalblue",
                 ls='-',
                 lw=2.0,
                 label="Qh")
        ax1.fill_between(XX[:, 0],
                         CI[:, 0],
                         CI[:, 1],
                         color='CornflowerBlue',
                         alpha=0.7)

        ax2.plot(df_met.Tair - K_TO_C,
                 df_flx.Qle,
                 ls=" ",
                 marker=".",
                 color="#FF0000",
                 alpha=alpha,
                 mec="#FF0000")
        ax2.plot(df_met.Tair - K_TO_C,
                 df_flx.Qh,
                 ls=" ",
                 marker=".",
                 color="royalblue",
                 alpha=alpha,
                 mec="royalblue")

        gam = LinearGAM(n_splines=20).gridsearch(df_met.Tair - K_TO_C,
                                                 df_flx.Qle)
        XX = generate_X_grid(gam)
        CI = gam.confidence_intervals(XX, width=.95)
        ax2.plot(XX, gam.predict(XX), color="#FF0000", ls='-', lw=2.0)
        ax2.fill_between(XX[:, 0],
                         CI[:, 0],
                         CI[:, 1],
                         color='salmon',
                         alpha=0.7)

        gam = LinearGAM(n_splines=20).gridsearch(df_met.Tair - K_TO_C,
                                                 df_flx.Qh)
        XX = generate_X_grid(gam)
        CI = gam.confidence_intervals(XX, width=.95)
        ax2.plot(XX, gam.predict(XX), color="royalblue", ls='-', lw=2.0)
        ax2.fill_between(XX[:, 0],
                         CI[:, 0],
                         CI[:, 1],
                         color='CornflowerBlue',
                         alpha=0.7)
        plt.setp(ax2.get_yticklabels(), visible=False)

        ax1.set_xlim(0, 1300)
        ax1.set_ylim(0, 1000)
        ax2.set_xlim(0, 45)
        ax2.set_ylim(0, 1000)
        ax1.set_xlabel("SW down (W m$^{-2}$)")
        ax2.set_xlabel("Tair ($^\circ$C)")
        ax1.set_ylabel("Daytime flux (W m$^{-2}$)")
        #ax1.legend(numpoints=1, loc="best")
        #fig.savefig(os.path.join(plot_dir, "%s.pdf" % (site)),
        #            bbox_inches='tight', pad_inches=0.1)

        fig.savefig(os.path.join(plot_dir, "%s.png" % (site)),
                    bbox_inches='tight',
                    pad_inches=0.1,
                    dpi=100)
             discrete=(False, False),
             cbar=False,
             cbar_kws=dict(shrink=.75),
             ax=ax)
# sns.displot(df, x="WTD", y="deltaT",ax = ax)
x = df['WTD (m)'].values
y = df['ΔT (°C)'].values
xx = x.reshape(x.shape[0], 1)
yy = y.reshape(y.shape[0], 1)

print("I am OK 2")
# reshape for gam
gam = LinearGAM(n_splines=4).gridsearch(xx, yy)  # n_splines=22
x_pred = np.linspace(min(x), max(x), num=100)
y_pred = gam.predict(x_pred)
y_int = gam.confidence_intervals(x_pred, width=.95)
ax.plot(x_pred, y_pred, color="red", ls='-', lw=2.0, zorder=10)
ax.fill_between(x_pred,
                y_int[:, 0],
                y_int[:, 1],
                alpha=0.2,
                facecolor='red',
                zorder=10)
# ax.text(0.03, 0.95, '(f)', transform=ax.transAxes, fontsize=18, verticalalignment='top', bbox=props)

print("I am OK 3")
# plt.setp(ax.get_xticklabels(), visible=False)
# ax.set(xticks=xtickslocs, xticklabels=cleaner_dates) ####
ax.yaxis.tick_left()
ax.yaxis.set_label_position("left")
bwith = 2