Exemplo n.º 1
0
def plot():
    #for i in range(len(params[0])):
    #    plot_quality_map(params[:,i],name = params_names[i])
    show_co2_change()
    params = prepare_data.x_to_params(x_all)
    specs = prepare_data.y_to_spectra(y_all)
    #plot_quality_map(params[:,co2_pos])
    show_param_hists(x_all, name="whightened")
    show_param_hists(params, name="orig")
    last_specs = specs[:, -20:]
    last_specs[:, 6] /= 12
    show_param_hists(last_specs,
                     name="spec-20",
                     n_parameters=20,
                     elem_names=prepare_data.spectra_names[-20:])
    show_param_hists(y_all[:, -20:],
                     name="y-20",
                     n_parameters=20,
                     elem_names=prepare_data.spectra_names[-20:])
    #show_elementcorrelations()
    #spectra_positions = [0,spec_length,2*spec_length]
    #for i in range(params_in_spectrum):
    #    spectra_positions.append(3*spec_length+i)

    #show_elementcorrelations(spectra[:,spectra_positions],len(prepare_data.spectra_names),"external params",params_names=prepare_data.spectra_names[:])
    #show_one_input_data()
    show_spectra()
Exemplo n.º 2
0
def show_error_correlations(x=x_all, y=y_all, load=False):

    #plot error co2 against: xco2, "albedo_o2","albedo_sco2","albedo_wco2", "tcwv" 4
    #and: "year","xco2_apriori","altitude","psurf","t700","longitude","latitude" 7
    #"snr_wco2","snr_sco2","snr_o2a","aod_bc","aod_dust","aod_ice","aod_oc","aod_seasalt","aod_sulfate","aod_total","aod_water"

    params = prepare_data.x_to_params(x)
    spectra = prepare_data.y_to_spectra(y)

    post = sample_posterior(y)
    diff = torch.mean(torch.FloatTensor(post), dim=1) - x
    #post_params =prepare_data.x_to_params(post)
    _, _, uncert_intervals, _, post_params = compute_calibration(x,
                                                                 y,
                                                                 load=load)
    uncert_error_co2 = uncert_intervals[68, :, c.co2_pos] / 2
    diff = np.mean(post_params, axis=1) - params
    diff_co2 = diff[:, 0]
    error_name = ['error_correlations', 'estimated_error_correlations']
    for l, spectra in enumerate([spectra, np.array(y)]):
        if l == 1:
            params = np.array(x)
        for k, diff in enumerate([diff_co2, uncert_error_co2]):

            plt.figure(error_name[k] + f'_{l}', figsize=(20, 15))
            plt.title(error_name[k] + f'_{l}')
            print(diff.shape)
            horizontal_figures = 4
            vertical_figures = 6
            diff = np.clip(diff, -4, 4)
            for i in range(horizontal_figures):
                ax = plt.subplot(horizontal_figures, vertical_figures,
                                 vertical_figures * i + 1)

                bins = np.linspace(np.min(diff), np.max(diff), 100)

                plt.hist(diff,
                         bins=bins,
                         histtype='step',
                         color="lightskyblue",
                         orientation="horizontal")

                if i > 0:
                    #ax.axis('off')
                    ax.set_xticks([])
                plt.ylabel(f"error of prediction in ppm")
            """
Exemplo n.º 3
0
def plot_quality_map(value, name="co2"):
    """AKA world map

    
    Arguments:
        uncert_intervals {[type]} -- [description]
    
    Keyword Arguments:
        name {str} -- [description] (default: {"Satelite positon"})
        title {str} -- [description] (default: {"Width of uncertanty, depending on position"})
    """

    world = geopandas.read_file(
        geopandas.datasets.get_path('naturalearth_lowres'))

    fig = plt.figure(f"{name}_map", figsize=(20, 10))
    plt.tight_layout()
    #plt.title(title)
    ax = plt.subplot(1, 1, 1)
    #fig, axes = plt.subplots(nrows=2, ncols=2,figsize=(20,15))
    fig.tight_layout()
    #fig.set_title(title)
    #print(np.shape(external_results), np.shape((error)), np.shape(month_masks))
    #ax = plt.subplot(2, 2, 1)
    ax.set_aspect('equal')
    ax.set_title(f"{name}")
    world.plot(ax=ax, color='lightblue', edgecolor='black')
    pos = prepare_data.y_to_spectra(y_all)[:, -2]
    #prepare_data.position
    print(pos)
    plt.scatter(pos[:, 0], pos[:, 1], c=value,
                s=20)  #,cmap=plt.get_cmap("jet"))#"plasma"
    #plt.colorbar(im)
    #plt.legend(title="in ppm")
    plt.xlabel("Longitude")
    plt.ylabel("Latitude")

    plt.colorbar()
Exemplo n.º 4
0
def show_spectra(number=3):
    spec_length = prepare_data.spec_length
    spectra = prepare_data.y_to_spectra(
        y_all)[:, :-prepare_data.params_in_spectrum]
    logspectra = np.log(spectra)  #spectra
    #logspectra = spectra

    for i in [0, 5, 10, 15, 18, 20, 24, 28]:
        i = i * 10
        print(len(spectra), i, y_all.shape, spectra.shape)
        if len(spectra) > i:
            plt.figure(f"spectra {i}", figsize=(15, 15))

            for j in range(number):
                ax = plt.subplot(number, 1, j + 1)
                #logspectra = np.log(spectra)
                plt.plot(range(spec_length),
                         (logspectra[i, spec_length * j:spec_length * (j + 1)]
                          ))  #wavelenth_y[0,:-params_in_spectrum])
                ax.set_xlabel(r"Wavelength in ${\mu}m$")
                ax.set_ylabel(
                    r"Radiance in $log_{10}\;\frac{Ph}{sec\:m^{2}sr\; {\mu}m}$"
                )
                if j == 0:
                    ax.set_title(r"Strong CO$_2$ Band")
                elif j == 1:
                    ax.set_title(r"Weak CO$_2$ Band")
                else:
                    ax.set_title(r"O$_2$ Band")

                print("\n")
                print("spectrum ", i, j)
                #print(wavelenth_y[0,:-params_in_spectrum], np.shape(wavelenth_y[0,:-params_in_spectrum]))
                print(
                    logspectra[i, spec_length * j:spec_length * (j + 1)],
                    np.shape(logspectra[0, spec_length * j:spec_length *
                                        (j + 1)]))
Exemplo n.º 5
0
def show_feature_net_solution():
    print("show_feature_net_solution")
    orig_prior_hists = hists(prepare_data.x_to_params(x_all))

    x = model.fn_func(y_all.to(c.device)).detach().cpu().numpy()
    #print("co2_results",x[:100,0])
    #print("true_results",x_all[:100,0])

    #y_test = y_all.numpy()+prepare_data.mu_y
    #print(y_all[:10,:10], y_all[:10,-10:])
    y_gt = y_all  #[:n_plots]
    x_gt = x_all
    orig_x_gt = prepare_data.x_to_params(x_gt)
    orig_y_gt = prepare_data.y_to_spectra(y_gt)
    #print(x.shape)
    orig_x = prepare_data.x_to_params(x)
    #print(np.shape(orig_x), np.shape(orig_x_gt))
    #print("\n")
    #plot_world.plot_quality_map((np.abs(orig_x-orig_x_gt)[:,0]),(y_all.detach().cpu().numpy()+prepare_data.mu_y)[:,-2:], "Featurenet prediction")
    plot_helpers.plot_quality_map((np.abs(orig_x - orig_x_gt)[:, 0]), position,
                                  "Error of network prediction")

    #print(x)
    #print(x_gt)
    print("shapes:", orig_x.shape, orig_x_gt.shape)
    show_error_stats(orig_x[:, 0] - orig_x_gt[:, 0], show_nice=True)

    print("show_predicted_error")

    x_uncert = model.fn_func(y_all.to(c.device)).detach().cpu().numpy()[:, 1]
    #compute s to sigma, following the paper
    x_uncert = np.sqrt(np.exp(x_uncert))

    uncert = x_uncert * np.linalg.inv(prepare_data.w_x)[0, 0]

    plot_helpers.plot_quality_map(uncert, position, "Predicted Uncertainty")

    #show_error_stats(uncert, "uncertainty")

    #show_error_stats((orig_x[:,0]-orig_x_gt[:,0])/uncert, "normalized")

    plot_helpers.plot_quality_map(
        np.abs((orig_x[:, 0] - orig_x_gt[:, 0]) / uncert), position,
        "Uncertainty quality")

    #orig_x_gt = orig_x_gt[:n_plots]
    #x_gt = x_gt[:n_plots]
    for i in range(n_plots):
        #print(x_gt[0])
        #print("\n",prepare_datax_to_params(x_gt)[0])
        #print(prepare_datax_to_params(x)[0])
        #print(x[0])

        plt.figure(f"orig_{i}", figsize=(20, 15))
        for j in range(n_x):
            plt.subplot(3, n_x / 4 + 1, j + 1)
            if j == 0:
                plt.step(*(orig_prior_hists[j]),
                         where='post',
                         color='grey',
                         label="prior")
                plt.plot([orig_x_gt[i, j], orig_x_gt[i, j]], [0, 1],
                         color='red',
                         label="ground truth")
                plt.plot([orig_x[i, j], orig_x[i, j]], [0, 1],
                         color='blue',
                         label="predicted value")
                plt.legend()
            else:
                plt.step(*(orig_prior_hists[j]), where='post', color='grey')
                #plt.step(*(hist_i[j+offset]), where='post', color='blue')

                #x_low, x_high = np.percentile(orig_posteriors[i][:,j+offset], [q_low, q_high])
                plt.plot([orig_x_gt[i, j], orig_x_gt[i, j]], [0, 1],
                         color='red')
                plt.plot([orig_x[i, j], orig_x[i, j]], [0, 1], color='blue')
            #plt.plot([orig_y_gt [i,j+offset-18], orig_y_gt[i,j+offset-18]], [0,1], color='orange',alpha=0.5) #is on top of red
            #if j+offset == 14:
            #    x_low=dataloader.ret_params[i,0]-dataloader.ret_params[i,1]
            #    x_high=dataloader.ret_params[i,0]+dataloader.ret_params[i,1]
            #    plt.plot([x_low, x_low], [0,1], color='green')
            #    plt.plot([x_high, x_high], [0,1], color='green')
            plt.xlabel(f"{c.param_names[j]}")
Exemplo n.º 6
0
    x_cat, y_cat, ana_cat = torch.cat(x_all,
                                      0), torch.cat(y_all,
                                                    0), torch.cat(ana_all,
                                                                  0)  #ana_all

    if cut:
        x_cat, y_cat, ana_cat = x_cat[:cut], y_cat[:cut], ana_cat[:cut]
    return x_cat, y_cat, ana_cat


x_all, y_all, ana_all = concatenate_test_set(
    c.evaluation_samples)  #concatenate_test_set(1000)#400)#20)#400)

ana_all = ana_all.detach().cpu().numpy()

position = prepare_data.y_to_spectra(y_all)[:, -2:]
#position=dataloader.ret_params[:c.evaluation_samples,[-2,-3]]


def hists(x):

    results = []
    for j in range(N):
        h, b = np.histogram(x[:, j], bins=100, density=True)  #range=(-2,2),
        h /= np.max(h)
        results.append([b[:-1], h])
    return results


def show_error_stats(errors, name="CO2", show_nice=False):
    plt.figure(name, figsize=(8, 2))