コード例 #1
0
ファイル: predict.py プロジェクト: renedlog/Paraglidable
	def __set_trained(self, dictContent):
		Verbose.print_arguments()

		modelContent = ModelContent()
		for kd,d in dictContent.items():
			modelContent.add(kd, d)

		self.trainedModel.new(modelContent, self.wind_dim, self.other_dim, self.humidity_dim, self.nb_altitudes, self.model_type)
コード例 #2
0
ファイル: train.py プロジェクト: vrana/Paraglidable
    def set_trained(self, cells, super_resolution=1, load_weights=True):
        Verbose.print_arguments()

        model_content = ModelContent()

        if self.model_type == ModelType.CELLS:
            model_content.set_super_resolution(super_resolution)

        for c in cells:
            if self.model_type == ModelType.SPOTS:
                # All the spots of the cells
                model_content.add(c, [s for s in range(len(self.Y_spots[c]))])
            else:
                # All the cell
                model_content.add(c, -1)

        # Maybe the cell has no spots, which will crash at network creation
        if self.model_type == ModelType.SPOTS and model_content.total_nb_spots(
        ) == 0:
            return False

        #===================================================================
        # Create model and load weights
        #===================================================================

        tf.keras.backend.clear_session(
        )  # https://github.com/keras-team/keras/issues/3579

        self.trained_model.new(model_content,
                               wind_dim=self.wind_dim,
                               other_dim=self.X_other[0].shape[-1],
                               humidity_dim=self.X_humidity[0].shape[-1],
                               nb_altitudes=self.X_wind[0].shape[-1] //
                               self.wind_dim,
                               model_type=self.model_type)

        # Re-load shared and specific weights if exists
        if load_weights:
            self.trained_model.load_all_weights()

        #===================================================================
        # Model intputs/outputs
        #===================================================================

        self.all_X = self.__get_X(model_content)
        self.all_Y = self.__get_Y(model_content)

        return True
コード例 #3
0
    def __compute_spots_forecasts(self, models_directory, problem_formulation,
                                  lats, lons, meteo_matrix, filename):
        Verbose.print_arguments()

        predict = Predict(models_directory, ModelType.SPOTS,
                          problem_formulation)
        predict.set_meteo_data(meteo_matrix, GfsData().parameters_vector_all)

        #=============================================================
        # depend de GribReader.get_values_array(self, params, crops):
        forecastCellsLine = {}
        line = 0
        for crop in self.crops:
            for iLat in range(crop[0], crop[1]):
                for iLon in range(crop[2], crop[3]):
                    forecastCellsLine[(iLat, iLon)] = line
                    line += 1

        #=============================================================
        # Compute or load cells_and_spots
        #=============================================================

        filename_cells_and_spots = "Forecast_cellsAndSpots_" + "_".join(
            [str(crop[d]) for crop in self.crops for d in range(4)])

        if not BinObj.exists(filename_cells_and_spots):
            Verbose.print_text(
                0,
                "Generating precomputation file because of new crop, it may crash on the server... To be computed on my computer"
            )
            cells_and_spots = {}
            # find forecast cell for each spot
            # C'est comme le cellsAndSpots de Train sauf que les cellules sont les cells de forecast (32942 cells)
            spots_data = SpotsData()
            spots = spots_data.getSpots(range(80))
            for kc, cell_spots in enumerate(spots):
                for ks, spot in enumerate(cell_spots):
                    iCell = (np.abs(lats - spot.lat).argmin(),
                             np.abs(lons - spot.lon).argmin())
                    cellLine = forecastCellsLine[iCell]
                    kcks_spot = (
                        (kc, ks), spot.toDict())  # (training cell, ks)
                    if not cellLine in cells_and_spots:
                        cells_and_spots[cellLine] = [kcks_spot]
                    else:
                        cells_and_spots[cellLine] += [kcks_spot]
            BinObj.save(cells_and_spots, filename_cells_and_spots)
        else:
            cells_and_spots = BinObj.load(filename_cells_and_spots)

        #=============================================================
        # Create a model with 1 cell of 1 spot
        #=============================================================

        predict.set_trained_spots()

        #=============================================================
        # Compute prediction for each spot, one by one
        #=============================================================

        spots_and_prediction = []
        for kcslst, cslst in enumerate(cells_and_spots.items()):
            meteoLine, spotsLst = cslst
            for cs in spotsLst:
                modelContent = ModelContent()
                modelContent.add(cs[0][0], cs[0][1])
                predict.trainedModel.load_all_weights(
                    modelContent)  # TODO do not reload shared weights
                predict.set_prediction_population()
                NN_X = predict.get_X([meteoLine])
                prediction = predict.trainedModel.model.predict(NN_X)[0][0]
                spots_and_prediction += [
                    Spot((cs[1]['name'], cs[1]['lat'], cs[1]['lon']),
                         cs[1]['id'], cs[1]['nbFlights'], prediction)
                ]  #[cs[1]]

        #=============================================================
        # Export all results
        #=============================================================

        Forecast.__export_spots_forecasts(spots_and_prediction, filename)