Ejemplo n.º 1
0
def test_model(config):
    input_folder = config[PredictionParams.input_folder]
    output_folder = config[PredictionParams.output_folder]
    output_fields = config[ProjTrainingParams.output_fields]
    model_weights_file = config[PredictionParams.model_weights_file]
    output_imgs_folder = config[PredictionParams.output_imgs_folder]
    field_names_model = config[ProjTrainingParams.fields_names]
    field_names_obs = config[ProjTrainingParams.fields_names_obs]
    rows = config[ProjTrainingParams.rows]
    cols = config[ProjTrainingParams.cols]
    run_name = config[TrainingParams.config_name]
    norm_type = config[ProjTrainingParams.norm_type]

    output_imgs_folder = join(output_imgs_folder, run_name)
    create_folder(output_imgs_folder)

    # *********** Chooses the proper model ***********
    print('Reading model ....')
    net_type = config[ProjTrainingParams.network_type]
    if net_type == NetworkTypes.UNET or net_type == NetworkTypes.UNET_MultiStream:
        model = select_2d_model(config, last_activation=None)
    if net_type == NetworkTypes.SimpleCNN_2:
        model = simpleCNN(config, nn_type="2d", hid_lay=2, out_lay=2)
    if net_type == NetworkTypes.SimpleCNN_4:
        model = simpleCNN(config, nn_type="2d", hid_lay=4, out_lay=2)
    if net_type == NetworkTypes.SimpleCNN_8:
        model = simpleCNN(config, nn_type="2d", hid_lay=8, out_lay=2)
    if net_type == NetworkTypes.SimpleCNN_16:
        model = simpleCNN(config, nn_type="2d", hid_lay=16, out_lay=2)

    plot_model(model,
               to_file=join(output_folder, F'running.png'),
               show_shapes=True)

    # *********** Reads the weights***********
    print('Reading weights ....')
    model.load_weights(model_weights_file)

    # *********** Read files to predict***********
    all_files = os.listdir(input_folder)
    all_files.sort()
    model_files = np.array([x for x in all_files if x.startswith('model')])

    z_layers = [0]
    var_file = join(input_folder, "cov_mat", "tops_ias_std.nc")
    field_names_std = config[ProjTrainingParams.fields_names_var]
    if len(field_names_std) > 0:
        input_fields_std = read_netcdf(var_file, field_names_std, z_layers)
    else:
        input_fields_std = []

    cmap_out = chooseCMAP(output_fields)
    cmap_model = chooseCMAP(field_names_model)
    cmap_obs = chooseCMAP(field_names_obs)
    cmap_std = chooseCMAP(field_names_std)

    tot_rows = 891
    tot_cols = 1401

    all_whole_mean_times = []
    all_whole_sum_times = []
    all_whole_rmse = []

    # np.random.shuffle(model_files)  # TODO this is only for testing
    for id_file, c_file in enumerate(model_files):
        # Find current and next date
        year = int(c_file.split('_')[1])
        day_of_year = int(c_file.split('_')[2].split('.')[0])

        if day_of_year != 5:
            continue

        model_file = join(input_folder, F'model_{year}_{day_of_year:03d}.nc')
        inc_file = join(input_folder, F'increment_{year}_{day_of_year:03d}.nc')
        obs_file = join(input_folder, F'obs_{year}_{day_of_year:03d}.nc')

        # *********************** Reading files **************************
        input_fields_model = read_netcdf(model_file, field_names_model,
                                         z_layers)
        input_fields_obs = read_netcdf(obs_file, field_names_obs, z_layers)
        output_field_increment = read_netcdf(inc_file, output_fields, z_layers)

        # ******************* Normalizing and Cropping Data *******************
        whole_cnn = np.zeros((891, 1401))
        whole_y = np.zeros((891, 1401))

        this_file_times = []

        start_row = 0
        donerow = False
        while not (donerow):
            donecol = False
            start_col = 0
            while not (donecol):
                # print(F"{start_row}-{start_row+rows} {start_col}-{start_col+cols}")
                # Generate the proper inputs for the NN
                try:
                    perc_ocean = .05
                    input_data, y_data = generateXandY(input_fields_model,
                                                       input_fields_obs,
                                                       input_fields_std,
                                                       output_field_increment,
                                                       field_names_model,
                                                       field_names_obs,
                                                       field_names_std,
                                                       output_fields,
                                                       start_row,
                                                       start_col,
                                                       rows,
                                                       cols,
                                                       norm_type=norm_type,
                                                       perc_ocean=perc_ocean)
                except Exception as e:
                    print(F"Land for {c_file} row:{start_row} col:{start_col}")
                    start_col, donecol = verifyBoundaries(
                        start_col, cols, tot_cols)
                    continue

                # ******************* Replacing nan values *********
                # We set a value of 0.5 on the land. Trying a new loss function that do not takes into account land
                input_data_nans = np.isnan(input_data)
                input_data = np.nan_to_num(input_data, nan=0)
                y_data = np.nan_to_num(y_data, nan=-0.5)

                X = np.expand_dims(input_data, axis=0)
                Y = np.expand_dims(y_data, axis=0)

                # Make the prediction of the network
                start = time.time()
                output_nn_original = model.predict(X, verbose=1)
                toc = time.time() - start
                this_file_times.append(toc)
                # print(F"Time to get prediction {toc:0.3f} seconds")
                # PLOT RAW DATA
                # import matplotlib.pyplot as plt
                # plt.imshow(np.flip(output_nn_original[0,:,:,0], axis=0))
                # plt.imshow(np.flip(Y[0,:,:,0], axis=0))
                # plt.show()
                # Original MSE
                # print(F"MSE: {mean_squared_error(Y[0,:,:,0], output_nn_original[0,:,:,0])}")

                # Make nan all values inside the land
                land_indexes = Y == -0.5
                output_nn_original[land_indexes] = np.nan

                # ====================== PLOTS RAW DATA  NOT NECESSARY =============================
                # viz_obj = EOAImageVisualizer(output_folder=output_imgs_folder, disp_images=False)
                # viz_obj.plot_2d_data_np_raw(np.concatenate((input_data.swapaxes(0,2), Y[0,:,:,:].swapaxes(0,2), output_nn_original[0,:,:,:].swapaxes(0,2))),
                #                             var_names=[F"in_model_{x}" for x in field_names_model] +
                #                                       [F"in_obs_{x}" for x in field_names_obs] +
                #                                       [F"in_var_{x}" for x in field_names_std] +
                #                                       [F"out_inc_{x}" for x in output_fields] +
                #                                       [F"cnn_{x}" for x in output_fields],
                #                             file_name=F"RAW_Input_and_CNN_{c_file}_{start_row:03d}_{start_col:03d}",
                #                             rot_90=True,
                #                             cols_per_row=len(field_names_model),
                #                             title=F"Input data: {field_names_model} and obs {field_names_obs}, increment {output_fields}, cnn {output_fields}")

                # Denormalize the data to the proper units in each field
                denorm_cnn_output = np.zeros(output_nn_original.shape)
                denorm_y = np.zeros(Y.shape)

                # ==== Denormalizingallinput and outputs
                denorm_cnn_output = denormalizeData(output_nn_original,
                                                    output_fields,
                                                    PreprocParams.type_inc,
                                                    norm_type)
                denorm_y = denormalizeData(Y, output_fields,
                                           PreprocParams.type_inc, norm_type)
                input_types = [
                    PreprocParams.type_model for i in input_fields_model
                ] + [PreprocParams.type_obs for i in input_fields_obs
                     ] + [PreprocParams.type_std for i in input_fields_std]
                denorm_input = denormalizeData(
                    input_data,
                    field_names_model + field_names_obs + field_names_std,
                    input_types, norm_type)

                # Recover the original land areas, they are lost after denormalization
                denorm_input[input_data_nans] = np.nan
                denorm_y[land_indexes] = np.nan

                # Remove the 'extra dimension'
                denorm_cnn_output = np.squeeze(denorm_cnn_output)
                denorm_y = np.squeeze(denorm_y)
                whole_cnn[
                    start_row:start_row + rows, start_col:start_col +
                    cols] = denorm_cnn_output  # Add the the 'whole prediction'
                whole_y[start_row:start_row + rows, start_col:start_col +
                        cols] = denorm_y  # Add the the 'whole prediction'

                # if np.random.random() > .99: # Plot 1% of the times
                if True:  # Plot 1% of the times
                    if len(
                            denorm_cnn_output.shape
                    ) == 2:  # In this case we only had one output and we need to make it 'array' to plot
                        denorm_cnn_output = np.expand_dims(denorm_cnn_output,
                                                           axis=2)
                        denorm_y = np.expand_dims(denorm_y, axis=2)

                    # Compute RMSE
                    rmse_cnn = np.zeros(len(output_fields))
                    for i in range(len(output_fields)):
                        ocean_indexes = np.logical_not(
                            np.isnan(denorm_y[:, :, i]))
                        rmse_cnn[i] = np.sqrt(
                            mean_squared_error(
                                denorm_cnn_output[:, :, i][ocean_indexes],
                                denorm_y[:, :, i][ocean_indexes]))

                    # viz_obj = EOAImageVisualizer(output_folder=output_imgs_folder, disp_images=False, mincbar=mincbar, maxcbar=maxcbar)
                    viz_obj = EOAImageVisualizer(
                        output_folder=output_imgs_folder, disp_images=False)

                    # ================== DISPLAYS ALL INPUTS AND OUTPUTS DENORMALIZED ===================
                    # viz_obj.plot_2d_data_np_raw(np.concatenate((input_data.swapaxes(0,2), Y[0,:,:,:].swapaxes(0,2), output_nn_original[0,:,:,:].swapaxes(0,2))),
                    viz_obj.plot_2d_data_np_raw(
                        np.concatenate(
                            (denorm_input.swapaxes(0,
                                                   2), denorm_y.swapaxes(0, 2),
                             denorm_cnn_output.swapaxes(0, 2))),
                        var_names=[F"in_model_{x}"
                                   for x in field_names_model] +
                        [F"in_obs_{x}" for x in field_names_obs] +
                        [F"in_var_{x}" for x in field_names_std] +
                        [F"out_inc_{x}" for x in output_fields] +
                        [F"cnn_{x}" for x in output_fields],
                        file_name=
                        F"Input_and_CNN_{c_file}_{start_row:03d}_{start_col:03d}",
                        cmap=cmap_model + cmap_obs + cmap_std + cmap_out +
                        cmap_out,
                        rot_90=True,
                        cols_per_row=len(field_names_model),
                        title=
                        F"Input data: {field_names_model} and obs {field_names_obs}, increment {output_fields}, cnn {output_fields}"
                    )

                    # =========== Making the same color bar for desired output and the NN =====================
                    mincbar = [
                        np.nanmin(denorm_y[:, :, x])
                        for x in range(denorm_cnn_output.shape[-1])
                    ]
                    maxcbar = [
                        np.nanmax(denorm_y[:, :, x])
                        for x in range(denorm_cnn_output.shape[-1])
                    ]
                    error = (denorm_y - denorm_cnn_output).swapaxes(0, 2)
                    mincbarerror = [
                        np.nanmin(error[i, :, :])
                        for i in range(len(output_fields))
                    ]
                    maxcbarerror = [
                        np.nanmax(error[i, :, :])
                        for i in range(len(output_fields))
                    ]
                    viz_obj = EOAImageVisualizer(
                        output_folder=output_imgs_folder,
                        disp_images=False,
                        mincbar=mincbar + mincbar + mincbarerror,
                        maxcbar=maxcbar + maxcbar + maxcbarerror)

                    # ================== Displays CNN and TSIS with RMSE ================
                    viz_obj.output_folder = join(output_imgs_folder,
                                                 'JoinedErrrorCNN')
                    cmap = chooseCMAP(output_fields)
                    error_cmap = cmocean.cm.diff
                    viz_obj.plot_2d_data_np_raw(
                        np.concatenate((denorm_cnn_output.swapaxes(
                            0, 2), denorm_y.swapaxes(0, 2), error),
                                       axis=0),
                        var_names=[F"CNN INC {x}" for x in output_fields] +
                        [F"TSIS INC {x}" for x in output_fields] +
                        [F'RMSE {c_rmse_cnn:0.4f}' for c_rmse_cnn in rmse_cnn],
                        file_name=
                        F"AllError_{c_file}_{start_row:03d}_{start_col:03d}",
                        rot_90=True,
                        cmap=cmap + cmap + [error_cmap],
                        cols_per_row=len(output_fields),
                        title=F"{output_fields} RMSE: {np.mean(rmse_cnn):0.5f}"
                    )

                start_col, donecol = verifyBoundaries(start_col, cols,
                                                      tot_cols)
                # Column for
            start_row, donerow = verifyBoundaries(start_row, rows, tot_rows)
            # Row for

        # ======= Plots whole output with RMSE
        mincbar = np.nanmin(whole_y) / 2
        maxcbar = np.nanmax(whole_y) / 2
        error = whole_y - whole_cnn
        mincbarerror = np.nanmin(error) / 2
        maxcbarerror = np.nanmax(error) / 2
        no_zero_ids = np.count_nonzero(whole_cnn)

        rmse_cnn = np.sqrt(np.nansum((whole_y - whole_cnn)**2) / no_zero_ids)
        all_whole_rmse.append(rmse_cnn)
        all_whole_mean_times.append(np.mean(np.array(this_file_times)))
        all_whole_sum_times.append(np.sum(np.array(this_file_times)))

        if np.random.random(
        ) > .9 or day_of_year == 353:  # Plot 10% of the times
            viz_obj = EOAImageVisualizer(
                output_folder=output_imgs_folder,
                disp_images=False,
                mincbar=mincbar + mincbar + mincbarerror,
                maxcbar=maxcbar + maxcbar + maxcbarerror)
            # mincbar=[-5, -5, -1],
            # maxcbar=[10, 10, 1])

            # ================== Displays CNN and TSIS with RMSE ================
            viz_obj.output_folder = join(output_imgs_folder,
                                         'WholeOutput_CNN_TSIS')
            viz_obj.plot_2d_data_np_raw(
                [
                    np.flip(whole_cnn, axis=0),
                    np.flip(whole_y, axis=0),
                    np.flip(error, axis=0)
                ],
                var_names=[F"CNN INC {x}" for x in output_fields] +
                [F"TSIS INC {x}"
                 for x in output_fields] + [F'RMSE {rmse_cnn:0.4f}'],
                file_name=F"WholeOutput_CNN_TSIS_{c_file}",
                rot_90=False,
                cols_per_row=3,
                cmap=cmocean.cm.algae,
                title=F"{output_fields} RMSE: {np.mean(rmse_cnn):0.5f}")
Ejemplo n.º 2
0
def test_model(config):
    input_folder = config[PredictionParams.input_folder]
    output_folder = config[PredictionParams.output_folder]
    output_fields = config[ProjTrainingParams.output_fields]
    model_weights_file = config[PredictionParams.model_weights_file]
    output_imgs_folder = config[PredictionParams.output_imgs_folder]
    field_names_model = config[ProjTrainingParams.fields_names]
    field_names_obs = config[ProjTrainingParams.fields_names_obs]
    rows = config[ProjTrainingParams.rows]
    cols = config[ProjTrainingParams.cols]
    run_name = config[TrainingParams.config_name]
    norm_type = config[ProjTrainingParams.norm_type]

    output_imgs_folder = join(output_imgs_folder, run_name)
    create_folder(output_imgs_folder)

    # *********** Chooses the proper model ***********
    print('Reading model ....')

    net_type = config[ProjTrainingParams.network_type]
    if net_type == NetworkTypes.UNET or net_type == NetworkTypes.UNET_MultiStream:
        model = select_2d_model(config, last_activation=None)
    if net_type == NetworkTypes.SimpleCNN_2:
        model = simpleCNN(config, nn_type="2d", hid_lay=2, out_lay=2)
    if net_type == NetworkTypes.SimpleCNN_4:
        model = simpleCNN(config, nn_type="2d", hid_lay=4, out_lay=2)
    if net_type == NetworkTypes.SimpleCNN_8:
        model = simpleCNN(config, nn_type="2d", hid_lay=8, out_lay=2)
    if net_type == NetworkTypes.SimpleCNN_16:
        model = simpleCNN(config, nn_type="2d", hid_lay=16, out_lay=2)

    plot_model(model,
               to_file=join(output_folder, F'running.png'),
               show_shapes=True)

    # *********** Reads the weights***********
    print('Reading weights ....')
    model.load_weights(model_weights_file)

    # *********** Read files to predict***********
    all_files = os.listdir(input_folder)
    all_files.sort()
    model_files = np.array([x for x in all_files if x.startswith('model')])

    z_layers = [0]
    var_file = join(input_folder, "cov_mat", "tops_ias_std.nc")
    field_names_std = config[ProjTrainingParams.fields_names_var]
    if len(field_names_std) > 0:
        input_fields_std = read_netcdf(var_file, field_names_std, z_layers)
    else:
        input_fields_std = []

    cmap_out = chooseCMAP(output_fields)
    cmap_model = chooseCMAP(field_names_model)
    cmap_obs = chooseCMAP(field_names_obs)
    cmap_std = chooseCMAP(field_names_std)

    tot_rows = 891
    tot_cols = 1401

    all_whole_mean_times = []
    all_whole_sum_times = []
    all_whole_rmse = []

    # np.random.shuffle(model_files)  # TODO this is only for testing
    for id_file, c_file in enumerate(model_files):
        # Find current and next date
        year = int(c_file.split('_')[1])
        day_of_year = int(c_file.split('_')[2].split('.')[0])

        model_file = join(input_folder, F'model_{year}_{day_of_year:03d}.nc')
        inc_file = join(input_folder, F'increment_{year}_{day_of_year:03d}.nc')
        obs_file = join(input_folder, F'obs_{year}_{day_of_year:03d}.nc')

        # *********************** Reading files **************************
        input_fields_model = read_netcdf(model_file, field_names_model,
                                         z_layers)
        input_fields_obs = read_netcdf(obs_file, field_names_obs, z_layers)
        output_field_increment = read_netcdf(inc_file, output_fields, z_layers)

        # ******************* Normalizing and Cropping Data *******************
        this_file_times = []

        try:
            perc_ocean = .01
            input_data, y_data = generateXandY(input_fields_model,
                                               input_fields_obs,
                                               input_fields_std,
                                               output_field_increment,
                                               field_names_model,
                                               field_names_obs,
                                               field_names_std,
                                               output_fields,
                                               0,
                                               0,
                                               grows,
                                               gcols,
                                               norm_type=norm_type,
                                               perc_ocean=perc_ocean)
        except Exception as e:
            print(F"Exception {e}")

        # ******************* Replacing nan values *********
        # We set a value of 0.5 on the land. Trying a new loss function that do not takes into account land
        input_data_nans = np.isnan(input_data)
        input_data = np.nan_to_num(input_data, nan=0)
        y_data = np.nan_to_num(y_data, nan=-0.5)

        X = np.expand_dims(input_data, axis=0)
        Y = np.expand_dims(y_data, axis=0)

        # Make the prediction of the network
        start = time.time()
        output_nn_original = model.predict(X, verbose=1)
        toc = time.time() - start
        this_file_times.append(toc)

        # Make nan all values inside the land
        land_indexes = Y == -0.5
        output_nn_original[land_indexes] = np.nan

        # ==== Denormalizingallinput and outputs
        denorm_cnn_output = denormalizeData(output_nn_original, output_fields,
                                            PreprocParams.type_inc, norm_type)
        denorm_y = denormalizeData(Y, output_fields, PreprocParams.type_inc,
                                   norm_type)
        input_types = [PreprocParams.type_model
                       for i in input_fields_model] + [
                           PreprocParams.type_obs for i in input_fields_obs
                       ] + [PreprocParams.type_std for i in input_fields_std]
        denorm_input = denormalizeData(
            input_data, field_names_model + field_names_obs + field_names_std,
            input_types, norm_type)

        # Recover the original land areas, they are lost after denormalization
        denorm_y[land_indexes] = np.nan

        # Remove the 'extra dimension'
        denorm_cnn_output = np.squeeze(denorm_cnn_output)
        denorm_y = np.squeeze(denorm_y)
        whole_cnn = denorm_cnn_output  # Add the the 'whole prediction'
        whole_y = denorm_y  # Add the the 'whole prediction'

        if len(
                denorm_cnn_output.shape
        ) == 2:  # In this case we only had one output and we need to make it 'array' to plot
            denorm_cnn_output = np.expand_dims(denorm_cnn_output, axis=2)
            denorm_y = np.expand_dims(denorm_y, axis=2)

        # Compute RMSE
        # rmse_cnn = np.zeros(len(output_fields))
        # for i in range(len(output_fields)):
        #     ocean_indexes = np.logical_not(np.isnan(denorm_y[:,:,i]))
        #     rmse_cnn[i] = np.sqrt(mean_squared_error(denorm_cnn_output[:,:,i][ocean_indexes], denorm_y[:,:,i][ocean_indexes]))

        # ================== DISPLAYS ALL INPUTS AND OUTPUTS DENORMALIZED ===================
        # Adding back mask to all the input variables
        denorm_input[input_data_nans] = np.nan

        # ======= Plots whole output with RMSE
        mincbar = np.nanmin(whole_y)
        maxcbar = np.nanmax(whole_y)
        error = whole_y - whole_cnn
        mincbarerror = np.nanmin(error)
        maxcbarerror = np.nanmax(error)
        no_zero_ids = np.count_nonzero(whole_cnn)

        if output_fields[
                0] == 'srfhgt':  # This should only be for SSH to adjust the units
            whole_cnn /= 9.81
            whole_y = np.array(whole_y) / 9.81

        rmse_cnn = np.sqrt(np.nansum((whole_y - whole_cnn)**2) / no_zero_ids)

        all_whole_rmse.append(rmse_cnn)
        all_whole_mean_times.append(np.mean(np.array(this_file_times)))
        all_whole_sum_times.append(np.sum(np.array(this_file_times)))

        # if day_of_year == 353: # Plot 10% of the times
        if True:  # Plot 10% of the times

            # viz_obj = EOAImageVisualizer(output_folder=output_imgs_folder, disp_images=False, mincbar=mincbar, maxcbar=maxcbar)
            viz_obj = EOAImageVisualizer(output_folder=output_imgs_folder,
                                         disp_images=False)

            # viz_obj.plot_2d_data_np_raw(np.concatenate((input_data.swapaxes(0,2), Y[0,:,:,:].swapaxes(0,2), output_nn_original[0,:,:,:].swapaxes(0,2))),
            viz_obj.plot_2d_data_np_raw(
                np.concatenate(
                    (denorm_input.swapaxes(0, 2), denorm_y.swapaxes(0, 2),
                     denorm_cnn_output.swapaxes(0, 2))),
                var_names=[F"in_model_{x}" for x in field_names_model] +
                [F"in_obs_{x}" for x in field_names_obs] +
                [F"in_var_{x}" for x in field_names_std] +
                [F"out_inc_{x}"
                 for x in output_fields] + [F"cnn_{x}" for x in output_fields],
                file_name=F"Global_Input_and_CNN_{c_file}",
                rot_90=True,
                cmap=cmap_model + cmap_obs + cmap_std + cmap_out + cmap_out,
                cols_per_row=len(field_names_model),
                title=
                F"Input data: {field_names_model} and obs {field_names_obs}, increment {output_fields}, cnn {output_fields}"
            )

            minmax = getMinMaxPlot(output_fields)[0]
            viz_obj = EOAImageVisualizer(
                output_folder=output_imgs_folder,
                disp_images=False,
                # mincbar=mincbar + mincbar + mincbarerror,
                # maxcbar=maxcbar + maxcbar + maxcbarerror)
                # mincbar=[minmax[0], minmax[0], max(minmax[0],-1)],
                # maxcbar=[minmax[1], minmax[1], min(minmax[1],1)])
                mincbar=[minmax[0], minmax[0], -1],
                maxcbar=[minmax[1], minmax[1], 1])

            # ================== Displays CNN and TSIS with RMSE ================
            error_cmap = cmocean.cm.diff
            viz_obj.output_folder = join(output_imgs_folder,
                                         'WholeOutput_CNN_TSIS')
            viz_obj.plot_2d_data_np_raw(
                [
                    np.flip(whole_cnn, axis=0),
                    np.flip(whole_y, axis=0),
                    np.flip(error, axis=0)
                ],
                # var_names=[F"CNN INC {x}" for x in output_fields] + [F"TSIS INC {x}" for x in output_fields] + [F'TSIS - CNN (Mean RMSE {rmse_cnn:0.4f} m)'],
                var_names=[F"CNN increment SSH" for x in output_fields] +
                [F"TSIS increment SSH" for x in output_fields] +
                [F'TSIS - CNN \n (Mean RMSE {rmse_cnn:0.4f} m)'],
                file_name=F"Global_WholeOutput_CNN_TSIS_{c_file}",
                rot_90=False,
                cmap=cmap_out + cmap_out + [error_cmap],
                cols_per_row=3,
                # title=F"{output_fields[0]} RMSE: {np.mean(rmse_cnn):0.5f} m.")
                title=F"SSH RMSE: {np.mean(rmse_cnn):0.5f} m.")

            print("DONE ALL FILES!!!!!!!!!!!!!")
    dic_summary = {
        "File": model_files,
        "rmse": all_whole_rmse,
        "times mean": all_whole_mean_times,
        "times sum": all_whole_sum_times,
    }
    df = pd.DataFrame.from_dict(dic_summary)
    df.to_csv(join(output_imgs_folder, "Global_RMSE_and_times.csv"))