コード例 #1
0
    def plot_1d_data_np(self,
                        X,
                        Ys,
                        title='',
                        labels=[],
                        file_name_prefix='',
                        wide_ratio=1):
        """
        Plots multiple lines in a single plot.
        """
        plt.figure(figsize=[16 * wide_ratio, self._figsize])
        create_folder(self._output_folder)
        for i, y in enumerate(Ys):
            style = F'-.{self._COLORS[i%len(self._COLORS)]}'
            if len(labels) > 0:
                assert len(labels) == len(Ys)
                plt.plot(X, y, style, label=labels[i])
            else:
                plt.plot(X, y, style)

        if len(labels) > 0:
            plt.legend()

        plt.grid(True)
        plt.title(title, fontsize=self._font_size)

        file_name = F'{file_name_prefix}'
        pylab.savefig(join(self._output_folder, F'{file_name}.png'),
                      bbox_inches='tight')
        self._close_figure()
コード例 #2
0
    def plot_3d_data_ncds(self,
                          ds,
                          var_names: list,
                          z_levels: list,
                          title='',
                          file_name_prefix='',
                          cmap='viridis',
                          show_color_bar=True):
        """
        This is the main function to plot multiple z_levels (no time) from netCDF files
        """
        create_folder(self._output_folder)
        for c_slice in z_levels:
            fig = plt.subplots(1,
                               len(var_names),
                               squeeze=True,
                               figsize=self.get_proper_size(1, len(var_names)))
            for idx_var, c_var_name in enumerate(var_names):
                cur_var = ds.variables.get(c_var_name)
                ax = plt.subplot(1, len(var_names), idx_var + 1)
                im = self.plot_slice_eoa(cur_var[c_slice, :, :], ax, cmap=cmap)
                c_title = F'{c_var_name} {title} Z-level:{c_slice}'
                plt.title(c_title, fontsize=self._font_size)
                self.add_colorbar(fig, im, ax, show_color_bar)
            plt.title("TEST", fontsize=30)

            file_name = F'{file_name_prefix}_{c_slice:04d}'
            pylab.savefig(join(self._output_folder, F'{file_name}.png'),
                          bbox_inches='tight')
            self._close_figure()
コード例 #3
0
    def plot_4d_data_xarray_map(self,
                                ds,
                                var_names: list,
                                z_levels: list,
                                timesteps: list,
                                title='',
                                file_name_prefix='',
                                cmap='viridis',
                                proj=ccrs.PlateCarree(),
                                lonvar='Longitude',
                                latvar='Latitude',
                                timevar='Time',
                                z_levelvar='Depth'):
        """
        Plots multiple z_levels from a NetCDF file (4D data). It plots the results in a map.
        http://xarray.pydata.org/en/stable/plotting.html#maps
        """
        projection = proj
        create_folder(self._output_folder)
        for c_slice in z_levels:
            for c_time in timesteps:
                fig, axs = plt.subplots(1,
                                        len(var_names),
                                        squeeze=True,
                                        figsize=self.get_proper_size(
                                            1, len(var_names)))
                for idx_var, c_var_name in enumerate(var_names):
                    # Depth selection, not sure if the name is always the same
                    c_depth = ds[z_levelvar].values[c_slice]
                    # Interpolates data to the nearest requested depth
                    cur_var = ds[c_var_name].sel(**{
                        z_levelvar: c_depth,
                        timevar: c_time
                    },
                                                 method='nearest')
                    ax = plt.subplot(1,
                                     len(var_names),
                                     idx_var + 1,
                                     projection=projection)
                    # ------------------ MAP STUFF -------------------------
                    # https://rabernat.github.io/research_computing_2018/maps-with-cartopy.html
                    cur_var.plot(ax=ax, transform=projection)
                    # ax.stock_img()  # Draws a basic topography of the world
                    ax.coastlines(resolution='50m')  # Draws the coastline
                    # ax = self.add_states(ax)
                    # ax = self.add_roads(ax)
                    # ax.add_feature(cartopy.feature.OCEAN)
                    ax.add_feature(cartopy.feature.LAND, edgecolor='black')
                    # ax.add_feature(cartopy.feature.LAKES, edgecolor='black')
                    # ax.add_feature(cartopy.feature.RIVERS)
                    # ------------------ MAP STUFF -------------------------
                    c_title = F'{c_var_name} {title} Z-level:{c_slice} Time:{c_time} '
                    ax.set_title(c_title, fontsize=self._font_size)

                file_name = F'{file_name_prefix}_{c_slice:04d}'
                pylab.savefig(join(self._output_folder, F'{file_name}.png'),
                              bbox_inches='tight')
                self._close_figure()
コード例 #4
0
    def plot_4d_data_ncds_map(self,
                              ds,
                              var_names: list,
                              z_levels: list,
                              timesteps: list,
                              title='',
                              file_name_prefix='',
                              cmap='viridis',
                              proj=ccrs.PlateCarree(),
                              lonvar='Longitude',
                              latvar='Latitude',
                              show_color_bar=True):
        """
        Plots multiple z_levels from a NetCDF file (4D data). It plots the results in a map.
        """
        lon = ds.variables[lonvar][:]
        lat = ds.variables[latvar][:]
        projection = proj

        create_folder(self._output_folder)
        for c_slice in z_levels:
            for c_time in timesteps:
                fig, axs = plt.subplots(1,
                                        len(var_names),
                                        squeeze=True,
                                        figsize=self.get_proper_size(
                                            1, len(var_names)))
                for idx_var, c_var_name in enumerate(var_names):
                    cur_var = ds.variables.get(c_var_name)
                    ax = plt.subplot(1,
                                     len(var_names),
                                     idx_var + 1,
                                     projection=projection)
                    # ------------------ MAP STUFF -------------------------
                    # https://rabernat.github.io/research_computing_2018/maps-with-cartopy.html
                    ax.stock_img()  # Draws a basic topography
                    ax.coastlines(resolution='50m')  # Draws the coastline
                    ax.add_feature(cartopy.feature.OCEAN)
                    ax.add_feature(cartopy.feature.LAND, edgecolor='black')
                    ax.add_feature(cartopy.feature.LAKES, edgecolor='black')
                    ax.add_feature(cartopy.feature.RIVERS)
                    # ------------------ MAP STUFF -------------------------
                    # plt.contourf(lon, lat, cur_var[c_time, c_slice, :, :])
                    im = ax.imshow(cur_var[c_time, c_slice, :, :],
                                   extent=self.getExtent(lat, lon))
                    c_title = F'{c_var_name} {title} Z-level:{c_slice} Time:{c_time} '
                    ax.set_title(c_title, fontsize=self._font_size)
                    self.add_colorbar(fig, im, ax, show_color_bar)

                file_name = F'{file_name_prefix}_{c_slice:04d}'
                pylab.savefig(join(self._output_folder, F'{file_name}.png'),
                              bbox_inches='tight')
                plt.tight_layout()
                self._close_figure()
コード例 #5
0
    def plot_2d_data_np(self,
                        np_variables: list,
                        var_names: list,
                        title='',
                        file_name_prefix='',
                        cmap='viridis',
                        flip_data=False,
                        plot_mode=PlotMode.RASTER,
                        show_color_bar=True):
        """
        Plots an array of 2D fields that come as np arrays
        :param np_variables:
        :param var_names:
        :param title:
        :param file_name_prefix:
        :param cmap:
        :param flip_data:
        :param plot_mode:
        :return:
        """
        create_folder(self._output_folder)
        fig, axs = plt.subplots(squeeze=True,
                                figsize=self.get_proper_size(
                                    1, len(var_names)),
                                ncols=len(var_names))

        for idx_var, c_var in enumerate(np_variables):
            ax = plt.subplot(1, len(var_names), idx_var + 1)
            if flip_data:
                im = self.plot_slice_eoa(np.flip(np.flip(c_var), axis=1),
                                         ax,
                                         cmap=cmap,
                                         mode=plot_mode)
            else:
                im = self.plot_slice_eoa(c_var, ax, cmap=cmap, mode=plot_mode)
            self.add_colorbar(fig, im, ax, show_color_bar)

            if var_names != '':
                c_title = F'{var_names[idx_var]} '
            else:
                c_title = F'{idx_var}'
            ax.set_title(c_title, fontsize=self._font_size)

        fig.suptitle(title, fontsize=self._font_size * 1.1)
        file_name = F'{file_name_prefix}'
        pylab.savefig(join(self._output_folder, F'{file_name}.png'),
                      bbox_inches='tight')
        self._close_figure()
コード例 #6
0
    def plot_1d_data_xarray(self,
                            xr_ds,
                            var_names: list,
                            title='',
                            file_name_prefix=''):
        """
        Plots multiple variables from a NetCDF file (1D data). It plots the results in a line
        """
        create_folder(self._output_folder)
        for idx_var, c_var_name in enumerate(var_names):
            xr_ds[c_var_name].to_dataframe().plot()
            c_title = F'{c_var_name} {title}'
            plt.title(c_title, fontsize=self._font_size)

        file_name = F'{file_name_prefix}'
        pylab.savefig(join(self._output_folder, F'{file_name}.png'),
                      bbox_inches='tight')
        self._close_figure()
コード例 #7
0
    def plot_3d_data_singlevar_np(self,
                                  data: list,
                                  z_levels=[],
                                  title='',
                                  file_name_prefix='',
                                  cmap='viridis',
                                  flip_data=False,
                                  show_color_bar=True):
        """
        Plots all the z-layers for a single 3d var
        """
        create_folder(self._output_folder)
        rows = int(np.ceil(len(z_levels) / self._max_imgs_per_row))
        cols = int(min(self._max_imgs_per_row, len(z_levels)))
        if len(z_levels) == 0:
            z_levels = np.arange(data.shape[0])

        fig, axs = plt.subplots(rows,
                                cols,
                                squeeze=True,
                                figsize=self.get_proper_size(rows, cols))

        for slice_idx, c_slice in enumerate(z_levels):
            ax = plt.subplot(rows, cols, slice_idx + 1)
            if flip_data:
                im = self.plot_slice_eoa(np.flip(np.flip(data[c_slice, :, :]),
                                                 axis=1),
                                         ax,
                                         cmap=cmap)
            else:
                im = self.plot_slice_eoa(data[c_slice, :, :], ax, cmap=cmap)
            c_title = F'{title} Z-level:{c_slice}'
            ax.set_title(c_title, fontsize=self._font_size)
            self.add_colorbar(fig, im, ax, show_color_bar)

        file_name = F'{file_name_prefix}_{c_slice:04d}'
        pylab.savefig(join(self._output_folder, F'{file_name}.png'),
                      bbox_inches='tight')
        self._close_figure()
コード例 #8
0
    def plot_3d_data_np(self,
                        np_variables: list,
                        var_names: list,
                        z_levels=[],
                        title='',
                        file_name_prefix='',
                        cmap='viridis',
                        z_lavels_names=[],
                        flip_data=False,
                        show_color_bar=True,
                        plot_mode=PlotMode.RASTER):
        """
        Plot multiple z_levels.
        """
        create_folder(self._output_folder)

        # If the user do not requires any z-leve, then all are plotted
        if len(z_levels) == 0:
            z_levels = range(np_variables[0].shape[0])

        for c_slice in z_levels:
            fig, _axs = plt.subplots(1,
                                     len(var_names),
                                     squeeze=True,
                                     figsize=self.get_proper_size(
                                         1, len(var_names)))

            # Verify the index of the z_levels are the original ones.
            if len(z_lavels_names) != 0:
                c_slice_txt = z_lavels_names[c_slice]
            else:
                c_slice_txt = c_slice

            for idx_var, c_var in enumerate(np_variables):
                ax = _axs[idx_var]
                if flip_data:
                    im = self.plot_slice_eoa(np.flip(np.flip(
                        c_var[c_slice, :, :]),
                                                     axis=1),
                                             ax,
                                             cmap=cmap,
                                             mode=plot_mode)
                else:
                    im = self.plot_slice_eoa(c_var[c_slice, :, :],
                                             ax,
                                             cmap=cmap,
                                             mode=plot_mode)

                if var_names != '':
                    c_title = F'{var_names[idx_var]} {title} Z-level:{c_slice_txt}'
                else:
                    c_title = F'{idx_var} {title} Z-level:{c_slice_txt}'

                ax.set_title(c_title, fontsize=self._font_size)

                self.add_colorbar(fig, im, ax, show_color_bar)

            file_name = F'{file_name_prefix}_{c_slice_txt:04d}'
            pylab.savefig(join(self._output_folder, F'{file_name}.png'),
                          bbox_inches='tight')
            self._close_figure()
コード例 #9
0
    def plot_3d_data_xarray_map(self,
                                xr_ds,
                                var_names: list,
                                timesteps: list,
                                title='',
                                file_name_prefix='',
                                proj=ccrs.PlateCarree(),
                                timevar_name='time'):
        """
        Plots multiple z_levels from a NetCDF file (3D data). It plots the results in a map.
        It is assuming that the 3rd
        http://xarray.pydata.org/en/stable/plotting.html#maps
        """
        projection = proj

        create_folder(self._output_folder)
        for c_time_idx in timesteps:
            print(F"Time: {c_time_idx}")
            plt.subplots(1,
                         len(var_names),
                         squeeze=True,
                         figsize=self.get_proper_size(1, len(var_names)))
            for idx_var, c_var_name in enumerate(var_names):
                print(c_var_name)
                cur_var = xr_ds[c_var_name]
                # TODO. Hardcoded order Assuming the order of the coordinsates is lat, lon, time
                cur_coords_names = list(cur_var.coords.keys())
                # Assuming the order of the dims are time, lat, lon
                cur_dims_names = list(cur_var.dims)
                c_time = xr_ds[timevar_name].values[c_time_idx]
                # Obtains data to the nearest requested time
                try:
                    cur_var = xr_ds[c_var_name].sel(**{timevar_name: c_time},
                                                    method='nearest')
                except Exception as e:
                    print(
                        F"Warning for {c_var_name}!! (couldn't interpolate to the proper 'time' value: {e}"
                    )
                    cur_var = xr_ds[c_var_name].sel(
                        **{cur_dims_names[0]: c_time_idx})
                ax = plt.subplot(1,
                                 len(var_names),
                                 idx_var + 1,
                                 projection=projection)
                # ------------------ MAP STUFF -------------------------
                # https://rabernat.github.io/research_computing_2018/maps-with-cartopy.html
                # cur_var.plot(ax=ax, transform=projection, cbar_kwargs={'shrink': 0.4})
                cur_var.plot(ax=ax, transform=projection)
                # ax.stock_img()  # Draws a basic topography of the world
                ax.coastlines(resolution='50m')  # Draws the coastline
                # ax = self.add_states(ax)
                ax = self.add_roads(ax)
                ax.add_feature(cartopy.feature.OCEAN)
                ax.add_feature(cartopy.feature.LAND, edgecolor='black')
                ax.add_feature(cartopy.feature.LAKES, edgecolor='black')
                ax.add_feature(cartopy.feature.RIVERS)
                ax.gridlines()
                # ------------------ MAP STUFF -------------------------
                c_title = F'{c_var_name} {title} Time:{c_time} '
                plt.title(c_title, fontsize=self._font_size)

            file_name = F'{file_name_prefix}_{c_time}'
            pylab.savefig(join(self._output_folder, F'{file_name}.png'),
                          bbox_inches='tight')
            self._close_figure()
コード例 #10
0
def test_model(config):
    input_folder = config[PredictionParams.input_folder]
    output_folder = config[PredictionParams.output_folder]
    output_fields = config[ProjTrainingParams.output_fields]
    model_weights_file = config[PredictionParams.model_weights_file]
    output_imgs_folder = config[PredictionParams.output_imgs_folder]
    field_names_model = config[ProjTrainingParams.fields_names]
    field_names_obs = config[ProjTrainingParams.fields_names_obs]
    rows = config[ProjTrainingParams.rows]
    cols = config[ProjTrainingParams.cols]
    run_name = config[TrainingParams.config_name]
    norm_type = config[ProjTrainingParams.norm_type]

    output_imgs_folder = join(output_imgs_folder, run_name)
    create_folder(output_imgs_folder)

    # *********** Chooses the proper model ***********
    print('Reading model ....')
    net_type = config[ProjTrainingParams.network_type]
    if net_type == NetworkTypes.UNET or net_type == NetworkTypes.UNET_MultiStream:
        model = select_2d_model(config, last_activation=None)
    if net_type == NetworkTypes.SimpleCNN_2:
        model = simpleCNN(config, nn_type="2d", hid_lay=2, out_lay=2)
    if net_type == NetworkTypes.SimpleCNN_4:
        model = simpleCNN(config, nn_type="2d", hid_lay=4, out_lay=2)
    if net_type == NetworkTypes.SimpleCNN_8:
        model = simpleCNN(config, nn_type="2d", hid_lay=8, out_lay=2)
    if net_type == NetworkTypes.SimpleCNN_16:
        model = simpleCNN(config, nn_type="2d", hid_lay=16, out_lay=2)

    plot_model(model,
               to_file=join(output_folder, F'running.png'),
               show_shapes=True)

    # *********** Reads the weights***********
    print('Reading weights ....')
    model.load_weights(model_weights_file)

    # *********** Read files to predict***********
    all_files = os.listdir(input_folder)
    all_files.sort()
    model_files = np.array([x for x in all_files if x.startswith('model')])

    z_layers = [0]
    var_file = join(input_folder, "cov_mat", "tops_ias_std.nc")
    field_names_std = config[ProjTrainingParams.fields_names_var]
    if len(field_names_std) > 0:
        input_fields_std = read_netcdf(var_file, field_names_std, z_layers)
    else:
        input_fields_std = []

    cmap_out = chooseCMAP(output_fields)
    cmap_model = chooseCMAP(field_names_model)
    cmap_obs = chooseCMAP(field_names_obs)
    cmap_std = chooseCMAP(field_names_std)

    tot_rows = 891
    tot_cols = 1401

    all_whole_mean_times = []
    all_whole_sum_times = []
    all_whole_rmse = []

    # np.random.shuffle(model_files)  # TODO this is only for testing
    for id_file, c_file in enumerate(model_files):
        # Find current and next date
        year = int(c_file.split('_')[1])
        day_of_year = int(c_file.split('_')[2].split('.')[0])

        if day_of_year != 5:
            continue

        model_file = join(input_folder, F'model_{year}_{day_of_year:03d}.nc')
        inc_file = join(input_folder, F'increment_{year}_{day_of_year:03d}.nc')
        obs_file = join(input_folder, F'obs_{year}_{day_of_year:03d}.nc')

        # *********************** Reading files **************************
        input_fields_model = read_netcdf(model_file, field_names_model,
                                         z_layers)
        input_fields_obs = read_netcdf(obs_file, field_names_obs, z_layers)
        output_field_increment = read_netcdf(inc_file, output_fields, z_layers)

        # ******************* Normalizing and Cropping Data *******************
        whole_cnn = np.zeros((891, 1401))
        whole_y = np.zeros((891, 1401))

        this_file_times = []

        start_row = 0
        donerow = False
        while not (donerow):
            donecol = False
            start_col = 0
            while not (donecol):
                # print(F"{start_row}-{start_row+rows} {start_col}-{start_col+cols}")
                # Generate the proper inputs for the NN
                try:
                    perc_ocean = .05
                    input_data, y_data = generateXandY(input_fields_model,
                                                       input_fields_obs,
                                                       input_fields_std,
                                                       output_field_increment,
                                                       field_names_model,
                                                       field_names_obs,
                                                       field_names_std,
                                                       output_fields,
                                                       start_row,
                                                       start_col,
                                                       rows,
                                                       cols,
                                                       norm_type=norm_type,
                                                       perc_ocean=perc_ocean)
                except Exception as e:
                    print(F"Land for {c_file} row:{start_row} col:{start_col}")
                    start_col, donecol = verifyBoundaries(
                        start_col, cols, tot_cols)
                    continue

                # ******************* Replacing nan values *********
                # We set a value of 0.5 on the land. Trying a new loss function that do not takes into account land
                input_data_nans = np.isnan(input_data)
                input_data = np.nan_to_num(input_data, nan=0)
                y_data = np.nan_to_num(y_data, nan=-0.5)

                X = np.expand_dims(input_data, axis=0)
                Y = np.expand_dims(y_data, axis=0)

                # Make the prediction of the network
                start = time.time()
                output_nn_original = model.predict(X, verbose=1)
                toc = time.time() - start
                this_file_times.append(toc)
                # print(F"Time to get prediction {toc:0.3f} seconds")
                # PLOT RAW DATA
                # import matplotlib.pyplot as plt
                # plt.imshow(np.flip(output_nn_original[0,:,:,0], axis=0))
                # plt.imshow(np.flip(Y[0,:,:,0], axis=0))
                # plt.show()
                # Original MSE
                # print(F"MSE: {mean_squared_error(Y[0,:,:,0], output_nn_original[0,:,:,0])}")

                # Make nan all values inside the land
                land_indexes = Y == -0.5
                output_nn_original[land_indexes] = np.nan

                # ====================== PLOTS RAW DATA  NOT NECESSARY =============================
                # viz_obj = EOAImageVisualizer(output_folder=output_imgs_folder, disp_images=False)
                # viz_obj.plot_2d_data_np_raw(np.concatenate((input_data.swapaxes(0,2), Y[0,:,:,:].swapaxes(0,2), output_nn_original[0,:,:,:].swapaxes(0,2))),
                #                             var_names=[F"in_model_{x}" for x in field_names_model] +
                #                                       [F"in_obs_{x}" for x in field_names_obs] +
                #                                       [F"in_var_{x}" for x in field_names_std] +
                #                                       [F"out_inc_{x}" for x in output_fields] +
                #                                       [F"cnn_{x}" for x in output_fields],
                #                             file_name=F"RAW_Input_and_CNN_{c_file}_{start_row:03d}_{start_col:03d}",
                #                             rot_90=True,
                #                             cols_per_row=len(field_names_model),
                #                             title=F"Input data: {field_names_model} and obs {field_names_obs}, increment {output_fields}, cnn {output_fields}")

                # Denormalize the data to the proper units in each field
                denorm_cnn_output = np.zeros(output_nn_original.shape)
                denorm_y = np.zeros(Y.shape)

                # ==== Denormalizingallinput and outputs
                denorm_cnn_output = denormalizeData(output_nn_original,
                                                    output_fields,
                                                    PreprocParams.type_inc,
                                                    norm_type)
                denorm_y = denormalizeData(Y, output_fields,
                                           PreprocParams.type_inc, norm_type)
                input_types = [
                    PreprocParams.type_model for i in input_fields_model
                ] + [PreprocParams.type_obs for i in input_fields_obs
                     ] + [PreprocParams.type_std for i in input_fields_std]
                denorm_input = denormalizeData(
                    input_data,
                    field_names_model + field_names_obs + field_names_std,
                    input_types, norm_type)

                # Recover the original land areas, they are lost after denormalization
                denorm_input[input_data_nans] = np.nan
                denorm_y[land_indexes] = np.nan

                # Remove the 'extra dimension'
                denorm_cnn_output = np.squeeze(denorm_cnn_output)
                denorm_y = np.squeeze(denorm_y)
                whole_cnn[
                    start_row:start_row + rows, start_col:start_col +
                    cols] = denorm_cnn_output  # Add the the 'whole prediction'
                whole_y[start_row:start_row + rows, start_col:start_col +
                        cols] = denorm_y  # Add the the 'whole prediction'

                # if np.random.random() > .99: # Plot 1% of the times
                if True:  # Plot 1% of the times
                    if len(
                            denorm_cnn_output.shape
                    ) == 2:  # In this case we only had one output and we need to make it 'array' to plot
                        denorm_cnn_output = np.expand_dims(denorm_cnn_output,
                                                           axis=2)
                        denorm_y = np.expand_dims(denorm_y, axis=2)

                    # Compute RMSE
                    rmse_cnn = np.zeros(len(output_fields))
                    for i in range(len(output_fields)):
                        ocean_indexes = np.logical_not(
                            np.isnan(denorm_y[:, :, i]))
                        rmse_cnn[i] = np.sqrt(
                            mean_squared_error(
                                denorm_cnn_output[:, :, i][ocean_indexes],
                                denorm_y[:, :, i][ocean_indexes]))

                    # viz_obj = EOAImageVisualizer(output_folder=output_imgs_folder, disp_images=False, mincbar=mincbar, maxcbar=maxcbar)
                    viz_obj = EOAImageVisualizer(
                        output_folder=output_imgs_folder, disp_images=False)

                    # ================== DISPLAYS ALL INPUTS AND OUTPUTS DENORMALIZED ===================
                    # viz_obj.plot_2d_data_np_raw(np.concatenate((input_data.swapaxes(0,2), Y[0,:,:,:].swapaxes(0,2), output_nn_original[0,:,:,:].swapaxes(0,2))),
                    viz_obj.plot_2d_data_np_raw(
                        np.concatenate(
                            (denorm_input.swapaxes(0,
                                                   2), denorm_y.swapaxes(0, 2),
                             denorm_cnn_output.swapaxes(0, 2))),
                        var_names=[F"in_model_{x}"
                                   for x in field_names_model] +
                        [F"in_obs_{x}" for x in field_names_obs] +
                        [F"in_var_{x}" for x in field_names_std] +
                        [F"out_inc_{x}" for x in output_fields] +
                        [F"cnn_{x}" for x in output_fields],
                        file_name=
                        F"Input_and_CNN_{c_file}_{start_row:03d}_{start_col:03d}",
                        cmap=cmap_model + cmap_obs + cmap_std + cmap_out +
                        cmap_out,
                        rot_90=True,
                        cols_per_row=len(field_names_model),
                        title=
                        F"Input data: {field_names_model} and obs {field_names_obs}, increment {output_fields}, cnn {output_fields}"
                    )

                    # =========== Making the same color bar for desired output and the NN =====================
                    mincbar = [
                        np.nanmin(denorm_y[:, :, x])
                        for x in range(denorm_cnn_output.shape[-1])
                    ]
                    maxcbar = [
                        np.nanmax(denorm_y[:, :, x])
                        for x in range(denorm_cnn_output.shape[-1])
                    ]
                    error = (denorm_y - denorm_cnn_output).swapaxes(0, 2)
                    mincbarerror = [
                        np.nanmin(error[i, :, :])
                        for i in range(len(output_fields))
                    ]
                    maxcbarerror = [
                        np.nanmax(error[i, :, :])
                        for i in range(len(output_fields))
                    ]
                    viz_obj = EOAImageVisualizer(
                        output_folder=output_imgs_folder,
                        disp_images=False,
                        mincbar=mincbar + mincbar + mincbarerror,
                        maxcbar=maxcbar + maxcbar + maxcbarerror)

                    # ================== Displays CNN and TSIS with RMSE ================
                    viz_obj.output_folder = join(output_imgs_folder,
                                                 'JoinedErrrorCNN')
                    cmap = chooseCMAP(output_fields)
                    error_cmap = cmocean.cm.diff
                    viz_obj.plot_2d_data_np_raw(
                        np.concatenate((denorm_cnn_output.swapaxes(
                            0, 2), denorm_y.swapaxes(0, 2), error),
                                       axis=0),
                        var_names=[F"CNN INC {x}" for x in output_fields] +
                        [F"TSIS INC {x}" for x in output_fields] +
                        [F'RMSE {c_rmse_cnn:0.4f}' for c_rmse_cnn in rmse_cnn],
                        file_name=
                        F"AllError_{c_file}_{start_row:03d}_{start_col:03d}",
                        rot_90=True,
                        cmap=cmap + cmap + [error_cmap],
                        cols_per_row=len(output_fields),
                        title=F"{output_fields} RMSE: {np.mean(rmse_cnn):0.5f}"
                    )

                start_col, donecol = verifyBoundaries(start_col, cols,
                                                      tot_cols)
                # Column for
            start_row, donerow = verifyBoundaries(start_row, rows, tot_rows)
            # Row for

        # ======= Plots whole output with RMSE
        mincbar = np.nanmin(whole_y) / 2
        maxcbar = np.nanmax(whole_y) / 2
        error = whole_y - whole_cnn
        mincbarerror = np.nanmin(error) / 2
        maxcbarerror = np.nanmax(error) / 2
        no_zero_ids = np.count_nonzero(whole_cnn)

        rmse_cnn = np.sqrt(np.nansum((whole_y - whole_cnn)**2) / no_zero_ids)
        all_whole_rmse.append(rmse_cnn)
        all_whole_mean_times.append(np.mean(np.array(this_file_times)))
        all_whole_sum_times.append(np.sum(np.array(this_file_times)))

        if np.random.random(
        ) > .9 or day_of_year == 353:  # Plot 10% of the times
            viz_obj = EOAImageVisualizer(
                output_folder=output_imgs_folder,
                disp_images=False,
                mincbar=mincbar + mincbar + mincbarerror,
                maxcbar=maxcbar + maxcbar + maxcbarerror)
            # mincbar=[-5, -5, -1],
            # maxcbar=[10, 10, 1])

            # ================== Displays CNN and TSIS with RMSE ================
            viz_obj.output_folder = join(output_imgs_folder,
                                         'WholeOutput_CNN_TSIS')
            viz_obj.plot_2d_data_np_raw(
                [
                    np.flip(whole_cnn, axis=0),
                    np.flip(whole_y, axis=0),
                    np.flip(error, axis=0)
                ],
                var_names=[F"CNN INC {x}" for x in output_fields] +
                [F"TSIS INC {x}"
                 for x in output_fields] + [F'RMSE {rmse_cnn:0.4f}'],
                file_name=F"WholeOutput_CNN_TSIS_{c_file}",
                rot_90=False,
                cols_per_row=3,
                cmap=cmocean.cm.algae,
                title=F"{output_fields} RMSE: {np.mean(rmse_cnn):0.5f}")
コード例 #11
0
def compute_metrics(gt,
                    nn,
                    metrics,
                    split_info,
                    output_file,
                    column_names=[],
                    by_column=True):
    """
    Compute the received metrics and save the results in a csv file
    :param gt: Dataframe with the values stored by station by column
    :param nn: Result of the NN
    :param metrics:
    :param split_info:
    :param output_file:
    :param column_names:
    :param by_column:
    :return:
    """

    # Eliminate those cases where the original output is unknown

    train_ids = split_info.iloc[:, 0]
    val_ids = split_info.iloc[:, 1]
    test_ids = split_info.iloc[:, 2]
    val_ids = val_ids.drop(pd.isna(val_ids).index.values)
    train_ids = train_ids.drop(pd.isna(train_ids).index.values)
    test_ids = test_ids.drop(pd.isna(test_ids).index.values)

    output_file = output_file.replace('.csv', '')
    create_folder(os.path.dirname(output_file))

    if by_column:
        if len(column_names) == 0:
            column_names = [str(i) for i in range(len(gt[0]))]

        all_metrics = list(metrics.keys())
        all_metrics += [F"{x}_training" for x in metrics.keys()]
        all_metrics += [F"{x}_validation" for x in metrics.keys()]
        all_metrics += [F"{x}_test" for x in metrics.keys()]
        metrics_result = pd.DataFrame(
            {col: np.zeros(len(metrics) * 4)
             for col in column_names},
            index=all_metrics)

        for metric_name, metric_f in metrics.items():
            for cur_col in column_names:
                # All errors
                GT = gt[cur_col].values
                NN = nn[cur_col].values
                error = executeMetric(GT, NN, metric_f)
                metrics_result[cur_col][metric_name] = error
                # Training errors
                if len(train_ids) > 0:
                    GT = gt[cur_col][train_ids].values
                    NN = nn[cur_col][train_ids].values
                    error = executeMetric(GT, NN, metric_f)
                else:
                    error = 0
                metrics_result[cur_col][F"{metric_name}_training"] = error
                # Validation errors
                if len(val_ids) > 0:
                    GT = gt[cur_col][val_ids].values
                    NN = nn[cur_col][val_ids].values
                    error = executeMetric(GT, NN, metric_f)
                else:
                    error = 0
                metrics_result[cur_col][F"{metric_name}_validation"] = error
                # Test errors
                if len(test_ids) > 0:
                    GT = gt[cur_col][test_ids].values
                    NN = nn[cur_col][test_ids].values
                    error = executeMetric(GT, NN, metric_f)
                else:
                    error = 0
                metrics_result[cur_col][F"{metric_name}_test"] = error
                # import matplotlib.pyplot as plt
                # print(metric_f(GT[0:100], NN[0:100]))
                # plt.plot(GT[0:100])
                # plt.plot(NN[0:100])
                # plt.show()

        metrics_result.to_csv(F"{output_file}.csv")
        nn_df = pd.DataFrame(nn, columns=column_names, index=gt.index)
        nn_df.to_csv(F"{output_file}_nnprediction.csv")

    return metrics_result
コード例 #12
0
def singleModel(config):
    input_folder = config[PredictionParams.input_folder]
    rows = config[ProjTrainingParams.rows]
    cols = config[ProjTrainingParams.cols]
    model_field_names = config[ProjTrainingParams.fields_names]
    obs_field_names = config[ProjTrainingParams.fields_names_obs]
    output_fields = config[ProjTrainingParams.output_fields]
    run_name = config[TrainingParams.config_name]
    output_folder = join(config[PredictionParams.output_imgs_folder],
                         'MODEL_VISUALIZATION', run_name)
    norm_type = config[ProjTrainingParams.norm_type]

    model_weights_file = config[PredictionParams.model_weights_file]

    net_type = config[ProjTrainingParams.network_type]
    if net_type == NetworkTypes.UNET or net_type == NetworkTypes.UNET_MultiStream:
        model = select_2d_model(config, last_activation=None)
    if net_type == NetworkTypes.SimpleCNN_2:
        model = simpleCNN(config, nn_type="2d", hid_lay=2, out_lay=2)
    if net_type == NetworkTypes.SimpleCNN_4:
        model = simpleCNN(config, nn_type="2d", hid_lay=4, out_lay=2)
    if net_type == NetworkTypes.SimpleCNN_8:
        model = simpleCNN(config, nn_type="2d", hid_lay=8, out_lay=2)
    if net_type == NetworkTypes.SimpleCNN_16:
        model = simpleCNN(config, nn_type="2d", hid_lay=16, out_lay=2)

    create_folder(output_folder)
    plot_model(model,
               to_file=join(output_folder, F'running.png'),
               show_shapes=True)

    print('Reading weights ....')
    model.load_weights(model_weights_file)

    # # All Number of parameters
    print(F' Number of parameters: {model.count_params()}')
    # Number of parameters by layer
    print(F' Number of parameters first CNN: {model.layers[1].count_params()}')

    # Example of plotting the filters of a single layer
    print("Printing layer names:")
    print_layer_names(model)
    # plot_cnn_filters_by_layer(model.layers[1], 'First set of filters')  # The harcoded 1 should change by project

    # *********** Read files to predict***********
    # # ========= Here you need to build your test input different in each project ====
    all_files = os.listdir(input_folder)
    all_files.sort()

    # ========= Here you need to build your test input different in each project ====
    all_files = os.listdir(input_folder)
    all_files.sort()
    model_files = np.array([x for x in all_files if x.startswith('model')])

    z_layers = [0]
    var_file = join(input_folder, "cov_mat", "tops_ias_std.nc")
    var_field_names = config[ProjTrainingParams.fields_names_var]
    if len(var_field_names) > 0:
        input_fields_var = read_netcdf(var_file, var_field_names, z_layers)
    else:
        input_fields_var = []

    np.random.shuffle(model_files)  # TODO this is only for testing
    for id_file, c_file in enumerate(model_files):
        # Find current and next date
        year = int(c_file.split('_')[1])
        day_of_year = int(c_file.split('_')[2].split('.')[0])

        model_file = join(input_folder, F'model_{year}_{day_of_year:03d}.nc')
        inc_file = join(input_folder, F'increment_{year}_{day_of_year:03d}.nc')
        obs_file = join(input_folder, F'obs_{year}_{day_of_year:03d}.nc')

        # *********************** Reading files **************************
        z_layers = [0]
        input_fields_model = read_netcdf(model_file, model_field_names,
                                         z_layers)
        input_fields_obs = read_netcdf(obs_file, obs_field_names, z_layers)
        output_field_increment = read_netcdf(inc_file, output_fields, z_layers)

        # ******************* Normalizing and Cropping Data *******************
        for start_row in np.arange(0, 891 - rows, rows):
            for start_col in np.arange(0, 1401 - cols, cols):
                try:
                    input_data, y_data = generateXandY(input_fields_model,
                                                       input_fields_obs,
                                                       input_fields_var,
                                                       output_field_increment,
                                                       model_field_names,
                                                       obs_field_names,
                                                       var_field_names,
                                                       output_fields,
                                                       start_row,
                                                       start_col,
                                                       rows,
                                                       cols,
                                                       norm_type=norm_type)
                except Exception as e:
                    print(
                        F"Failed for {c_file} row:{start_row} col:{start_col}")
                    continue

                X_nan = np.expand_dims(input_data, axis=0)
                Y_nan = np.expand_dims(y_data, axis=0)

                # ******************* Replacing nan values *********
                # We set a value of 0.5 on the land. Trying a new loss function that do not takes into account land
                X = np.nan_to_num(X_nan, nan=0)
                Y = np.nan_to_num(Y_nan, nan=-0.5)

                output_nn = model.predict(X, verbose=1)
                output_nn[np.isnan(Y_nan)] = np.nan
                # =========== Output from the last layer (should be the same as output_NN
                print("Evaluating all intermediate layers")
                inp = model.input  # input placeholder
                outputs = [
                    layer.output for layer in model.layers[1:]
                    if layer.name.find("conv") != -1
                ]  # Displaying only conv layers
                # All evaluation functions (used to call the model up to each layer)
                functors = [K.function([inp], [out]) for out in outputs]
                # Outputs for every intermediate layer
                layer_outs = [func([X]) for func in functors]

                for layer_to_plot in range(0, len(outputs)):
                    title = F'Layer {layer_to_plot}_{outputs[layer_to_plot].name}. {c_file}_{start_row:03d}_{start_col:03d}'
                    file_name = F'{c_file}_{start_row:03d}_{start_col:03d}_lay_{layer_to_plot}'
                    plot_intermediate_2dcnn_feature_map(
                        layer_outs[layer_to_plot][0],
                        input_data=X_nan,
                        desired_output_data=Y_nan,
                        nn_output_data=output_nn,
                        input_fields=model_field_names + obs_field_names +
                        var_field_names,
                        title=title,
                        output_folder=output_folder,
                        file_name=file_name,
                        disp_images=False)
コード例 #13
0
                print(F"{old_name} \n {new_name} \n")
                # os.rename(old_name, new_name)


if __name__ == '__main__':

    NET = "Network Type"
    OUT = "Output vars"
    IN = "Input vars"
    LOSS = "Loss value"

    # Read folders for all the experiments
    config = get_training_2d()
    trained_models_folder = config[TrainingParams.output_folder]
    output_folder = config[ProjTrainingParams.output_folder_summary_models]
    create_folder(output_folder)

    # fixNames("/data/HYCOM/DA_HYCOM_TSIS/Training")
    # exit()

    all_folders = os.listdir(trained_models_folder)
    all_folders.sort()
    print(all_folders)

    summary = []

    # Iterate over all the experiments
    for experiment in all_folders:
        all_models = os.listdir(
            join(trained_models_folder, experiment, "models"))
        min_loss = 100000.0
コード例 #14
0
def test_model(config):
    input_folder = config[PredictionParams.input_folder]
    output_folder = config[PredictionParams.output_folder]
    output_fields = config[ProjTrainingParams.output_fields]
    model_weights_file = config[PredictionParams.model_weights_file]
    output_imgs_folder = config[PredictionParams.output_imgs_folder]
    field_names_model = config[ProjTrainingParams.fields_names]
    field_names_obs = config[ProjTrainingParams.fields_names_obs]
    rows = config[ProjTrainingParams.rows]
    cols = config[ProjTrainingParams.cols]
    run_name = config[TrainingParams.config_name]
    norm_type = config[ProjTrainingParams.norm_type]

    output_imgs_folder = join(output_imgs_folder, run_name)
    create_folder(output_imgs_folder)

    # *********** Chooses the proper model ***********
    print('Reading model ....')

    net_type = config[ProjTrainingParams.network_type]
    if net_type == NetworkTypes.UNET or net_type == NetworkTypes.UNET_MultiStream:
        model = select_2d_model(config, last_activation=None)
    if net_type == NetworkTypes.SimpleCNN_2:
        model = simpleCNN(config, nn_type="2d", hid_lay=2, out_lay=2)
    if net_type == NetworkTypes.SimpleCNN_4:
        model = simpleCNN(config, nn_type="2d", hid_lay=4, out_lay=2)
    if net_type == NetworkTypes.SimpleCNN_8:
        model = simpleCNN(config, nn_type="2d", hid_lay=8, out_lay=2)
    if net_type == NetworkTypes.SimpleCNN_16:
        model = simpleCNN(config, nn_type="2d", hid_lay=16, out_lay=2)

    plot_model(model,
               to_file=join(output_folder, F'running.png'),
               show_shapes=True)

    # *********** Reads the weights***********
    print('Reading weights ....')
    model.load_weights(model_weights_file)

    # *********** Read files to predict***********
    all_files = os.listdir(input_folder)
    all_files.sort()
    model_files = np.array([x for x in all_files if x.startswith('model')])

    z_layers = [0]
    var_file = join(input_folder, "cov_mat", "tops_ias_std.nc")
    field_names_std = config[ProjTrainingParams.fields_names_var]
    if len(field_names_std) > 0:
        input_fields_std = read_netcdf(var_file, field_names_std, z_layers)
    else:
        input_fields_std = []

    cmap_out = chooseCMAP(output_fields)
    cmap_model = chooseCMAP(field_names_model)
    cmap_obs = chooseCMAP(field_names_obs)
    cmap_std = chooseCMAP(field_names_std)

    tot_rows = 891
    tot_cols = 1401

    all_whole_mean_times = []
    all_whole_sum_times = []
    all_whole_rmse = []

    # np.random.shuffle(model_files)  # TODO this is only for testing
    for id_file, c_file in enumerate(model_files):
        # Find current and next date
        year = int(c_file.split('_')[1])
        day_of_year = int(c_file.split('_')[2].split('.')[0])

        model_file = join(input_folder, F'model_{year}_{day_of_year:03d}.nc')
        inc_file = join(input_folder, F'increment_{year}_{day_of_year:03d}.nc')
        obs_file = join(input_folder, F'obs_{year}_{day_of_year:03d}.nc')

        # *********************** Reading files **************************
        input_fields_model = read_netcdf(model_file, field_names_model,
                                         z_layers)
        input_fields_obs = read_netcdf(obs_file, field_names_obs, z_layers)
        output_field_increment = read_netcdf(inc_file, output_fields, z_layers)

        # ******************* Normalizing and Cropping Data *******************
        this_file_times = []

        try:
            perc_ocean = .01
            input_data, y_data = generateXandY(input_fields_model,
                                               input_fields_obs,
                                               input_fields_std,
                                               output_field_increment,
                                               field_names_model,
                                               field_names_obs,
                                               field_names_std,
                                               output_fields,
                                               0,
                                               0,
                                               grows,
                                               gcols,
                                               norm_type=norm_type,
                                               perc_ocean=perc_ocean)
        except Exception as e:
            print(F"Exception {e}")

        # ******************* Replacing nan values *********
        # We set a value of 0.5 on the land. Trying a new loss function that do not takes into account land
        input_data_nans = np.isnan(input_data)
        input_data = np.nan_to_num(input_data, nan=0)
        y_data = np.nan_to_num(y_data, nan=-0.5)

        X = np.expand_dims(input_data, axis=0)
        Y = np.expand_dims(y_data, axis=0)

        # Make the prediction of the network
        start = time.time()
        output_nn_original = model.predict(X, verbose=1)
        toc = time.time() - start
        this_file_times.append(toc)

        # Make nan all values inside the land
        land_indexes = Y == -0.5
        output_nn_original[land_indexes] = np.nan

        # ==== Denormalizingallinput and outputs
        denorm_cnn_output = denormalizeData(output_nn_original, output_fields,
                                            PreprocParams.type_inc, norm_type)
        denorm_y = denormalizeData(Y, output_fields, PreprocParams.type_inc,
                                   norm_type)
        input_types = [PreprocParams.type_model
                       for i in input_fields_model] + [
                           PreprocParams.type_obs for i in input_fields_obs
                       ] + [PreprocParams.type_std for i in input_fields_std]
        denorm_input = denormalizeData(
            input_data, field_names_model + field_names_obs + field_names_std,
            input_types, norm_type)

        # Recover the original land areas, they are lost after denormalization
        denorm_y[land_indexes] = np.nan

        # Remove the 'extra dimension'
        denorm_cnn_output = np.squeeze(denorm_cnn_output)
        denorm_y = np.squeeze(denorm_y)
        whole_cnn = denorm_cnn_output  # Add the the 'whole prediction'
        whole_y = denorm_y  # Add the the 'whole prediction'

        if len(
                denorm_cnn_output.shape
        ) == 2:  # In this case we only had one output and we need to make it 'array' to plot
            denorm_cnn_output = np.expand_dims(denorm_cnn_output, axis=2)
            denorm_y = np.expand_dims(denorm_y, axis=2)

        # Compute RMSE
        # rmse_cnn = np.zeros(len(output_fields))
        # for i in range(len(output_fields)):
        #     ocean_indexes = np.logical_not(np.isnan(denorm_y[:,:,i]))
        #     rmse_cnn[i] = np.sqrt(mean_squared_error(denorm_cnn_output[:,:,i][ocean_indexes], denorm_y[:,:,i][ocean_indexes]))

        # ================== DISPLAYS ALL INPUTS AND OUTPUTS DENORMALIZED ===================
        # Adding back mask to all the input variables
        denorm_input[input_data_nans] = np.nan

        # ======= Plots whole output with RMSE
        mincbar = np.nanmin(whole_y)
        maxcbar = np.nanmax(whole_y)
        error = whole_y - whole_cnn
        mincbarerror = np.nanmin(error)
        maxcbarerror = np.nanmax(error)
        no_zero_ids = np.count_nonzero(whole_cnn)

        if output_fields[
                0] == 'srfhgt':  # This should only be for SSH to adjust the units
            whole_cnn /= 9.81
            whole_y = np.array(whole_y) / 9.81

        rmse_cnn = np.sqrt(np.nansum((whole_y - whole_cnn)**2) / no_zero_ids)

        all_whole_rmse.append(rmse_cnn)
        all_whole_mean_times.append(np.mean(np.array(this_file_times)))
        all_whole_sum_times.append(np.sum(np.array(this_file_times)))

        # if day_of_year == 353: # Plot 10% of the times
        if True:  # Plot 10% of the times

            # viz_obj = EOAImageVisualizer(output_folder=output_imgs_folder, disp_images=False, mincbar=mincbar, maxcbar=maxcbar)
            viz_obj = EOAImageVisualizer(output_folder=output_imgs_folder,
                                         disp_images=False)

            # viz_obj.plot_2d_data_np_raw(np.concatenate((input_data.swapaxes(0,2), Y[0,:,:,:].swapaxes(0,2), output_nn_original[0,:,:,:].swapaxes(0,2))),
            viz_obj.plot_2d_data_np_raw(
                np.concatenate(
                    (denorm_input.swapaxes(0, 2), denorm_y.swapaxes(0, 2),
                     denorm_cnn_output.swapaxes(0, 2))),
                var_names=[F"in_model_{x}" for x in field_names_model] +
                [F"in_obs_{x}" for x in field_names_obs] +
                [F"in_var_{x}" for x in field_names_std] +
                [F"out_inc_{x}"
                 for x in output_fields] + [F"cnn_{x}" for x in output_fields],
                file_name=F"Global_Input_and_CNN_{c_file}",
                rot_90=True,
                cmap=cmap_model + cmap_obs + cmap_std + cmap_out + cmap_out,
                cols_per_row=len(field_names_model),
                title=
                F"Input data: {field_names_model} and obs {field_names_obs}, increment {output_fields}, cnn {output_fields}"
            )

            minmax = getMinMaxPlot(output_fields)[0]
            viz_obj = EOAImageVisualizer(
                output_folder=output_imgs_folder,
                disp_images=False,
                # mincbar=mincbar + mincbar + mincbarerror,
                # maxcbar=maxcbar + maxcbar + maxcbarerror)
                # mincbar=[minmax[0], minmax[0], max(minmax[0],-1)],
                # maxcbar=[minmax[1], minmax[1], min(minmax[1],1)])
                mincbar=[minmax[0], minmax[0], -1],
                maxcbar=[minmax[1], minmax[1], 1])

            # ================== Displays CNN and TSIS with RMSE ================
            error_cmap = cmocean.cm.diff
            viz_obj.output_folder = join(output_imgs_folder,
                                         'WholeOutput_CNN_TSIS')
            viz_obj.plot_2d_data_np_raw(
                [
                    np.flip(whole_cnn, axis=0),
                    np.flip(whole_y, axis=0),
                    np.flip(error, axis=0)
                ],
                # var_names=[F"CNN INC {x}" for x in output_fields] + [F"TSIS INC {x}" for x in output_fields] + [F'TSIS - CNN (Mean RMSE {rmse_cnn:0.4f} m)'],
                var_names=[F"CNN increment SSH" for x in output_fields] +
                [F"TSIS increment SSH" for x in output_fields] +
                [F'TSIS - CNN \n (Mean RMSE {rmse_cnn:0.4f} m)'],
                file_name=F"Global_WholeOutput_CNN_TSIS_{c_file}",
                rot_90=False,
                cmap=cmap_out + cmap_out + [error_cmap],
                cols_per_row=3,
                # title=F"{output_fields[0]} RMSE: {np.mean(rmse_cnn):0.5f} m.")
                title=F"SSH RMSE: {np.mean(rmse_cnn):0.5f} m.")

            print("DONE ALL FILES!!!!!!!!!!!!!")
    dic_summary = {
        "File": model_files,
        "rmse": all_whole_rmse,
        "times mean": all_whole_mean_times,
        "times sum": all_whole_sum_times,
    }
    df = pd.DataFrame.from_dict(dic_summary)
    df.to_csv(join(output_imgs_folder, "Global_RMSE_and_times.csv"))
コード例 #15
0
ファイル: 3_Train_2D.py プロジェクト: olmozavala/da_hycom
def doTraining(conf):
    input_folder_preproc = config[ProjTrainingParams.input_folder_preproc]
    # input_folder_obs = config[ProjTrainingParams.input_folder_obs]
    years = config[ProjTrainingParams.YEARS]
    fields = config[ProjTrainingParams.fields_names]
    fields_obs = config[ProjTrainingParams.fields_names_obs]
    output_field = config[ProjTrainingParams.output_fields]
    # day_to_predict = config[ProjTrainingParams.prediction_time]

    output_folder = config[TrainingParams.output_folder]
    val_perc = config[TrainingParams.validation_percentage]
    test_perc = config[TrainingParams.test_percentage]
    eval_metrics = config[TrainingParams.evaluation_metrics]
    loss_func = config[TrainingParams.loss_function]
    batch_size = config[TrainingParams.batch_size]
    epochs = config[TrainingParams.epochs]
    run_name = config[TrainingParams.config_name]
    optimizer = config[TrainingParams.optimizer]

    output_folder = join(output_folder, run_name)
    split_info_folder = join(output_folder, 'Splits')
    parameters_folder = join(output_folder, 'Parameters')
    weights_folder = join(output_folder, 'models')
    logs_folder = join(output_folder, 'logs')
    create_folder(split_info_folder)
    create_folder(parameters_folder)
    create_folder(weights_folder)
    create_folder(logs_folder)

    # Compute how many cases
    files_to_read, paths_to_read = get_preproc_increment_files(
        input_folder_preproc)
    tot_examples = len(files_to_read)

    # ================ Split definition =================
    [train_ids, val_ids, test_ids
     ] = utilsNN.split_train_validation_and_test(tot_examples,
                                                 val_percentage=val_perc,
                                                 test_percentage=test_perc,
                                                 shuffle_ids=False)

    print(
        F"Train examples (total:{len(train_ids)}) :{files_to_read[train_ids]}")
    print(
        F"Validation examples (total:{len(val_ids)}) :{files_to_read[val_ids]}:"
    )
    print(F"Test examples (total:{len(test_ids)}) :{files_to_read[test_ids]}")

    print("Selecting and generating the model....")
    now = datetime.utcnow().strftime("%Y_%m_%d_%H_%M")
    model_name = F'{run_name}_{now}'

    # ******************* Selecting the model **********************
    net_type = config[ProjTrainingParams.network_type]
    if net_type == NetworkTypes.UNET or net_type == NetworkTypes.UNET_MultiStream:
        model = select_2d_model(config, last_activation=None)
    if net_type == NetworkTypes.SimpleCNN_2:
        model = simpleCNN(config, nn_type="2d", hid_lay=2, out_lay=2)
    if net_type == NetworkTypes.SimpleCNN_4:
        model = simpleCNN(config, nn_type="2d", hid_lay=4, out_lay=2)
    if net_type == NetworkTypes.SimpleCNN_8:
        model = simpleCNN(config, nn_type="2d", hid_lay=8, out_lay=2)
    if net_type == NetworkTypes.SimpleCNN_16:
        model = simpleCNN(config, nn_type="2d", hid_lay=16, out_lay=2)

    plot_model(model,
               to_file=join(output_folder, F'{model_name}.png'),
               show_shapes=True)

    print("Saving split information...")
    file_name_splits = join(split_info_folder, F'{model_name}.txt')
    utilsNN.save_splits(file_name=file_name_splits,
                        train_ids=train_ids,
                        val_ids=val_ids,
                        test_ids=test_ids)

    print("Compiling model ...")
    model.compile(loss=loss_func, optimizer=optimizer, metrics=eval_metrics)

    print("Getting callbacks ...")

    [logger, save_callback, stop_callback] = utilsNN.get_all_callbacks(
        model_name=model_name,
        early_stopping_func=F'val_{eval_metrics[0].__name__}',
        weights_folder=weights_folder,
        logs_folder=logs_folder)

    print("Training ...")
    # ----------- Using preprocessed data -------------------
    generator_train = data_gen_from_preproc(input_folder_preproc, config,
                                            train_ids, fields, fields_obs,
                                            output_field)
    generator_val = data_gen_from_preproc(input_folder_preproc, config,
                                          val_ids, fields, fields_obs,
                                          output_field)

    # Decide which generator to use
    data_augmentation = config[TrainingParams.data_augmentation]

    model.fit_generator(
        generator_train,
        steps_per_epoch=1000,
        validation_data=generator_val,
        # validation_steps=min(100, len(val_ids)),
        validation_steps=100,
        use_multiprocessing=False,
        workers=1,
        # validation_freq=10, # How often to compute the validation loss
        epochs=epochs,
        callbacks=[logger, save_callback, stop_callback])