def get_gluon_train_dataset(self): """ Retrieve the GluonDataset object with training data that was saved in the model folder during training """ gluon_train_dataset_path = f"{self.session_path}/gluon_train_dataset.pk.gz" gluon_train_dataset = read_from_folder(self.folder, gluon_train_dataset_path, ObjectType.PICKLE_GZ) return gluon_train_dataset
def get_model_predictor(self): """ Retrieve the GluonTS Predictor object obtained during training and saved into the model folder """ if not self.manual_selection: self.session_name = self._get_last_session() self.session_path = os.path.join(self.partition_root, self.session_name) self.model_label = self._get_best_model() model_path = os.path.join(self.session_path, self.model_label, "model.pk.gz") try: model = read_from_folder(self.folder, model_path, ObjectType.PICKLE_GZ) except ValueError as e: raise ModelSelectionError( f"Unable to retrieve model '{self.model_label}' from session '{self.session_name}'. " + f"Please make sure that it exists in the Trained model folder. Full error: {e}" ) return model
def _get_best_model(self): """Find the best model according to self.performance_metric based on the aggregated metric rows Returns: Label of the best model. """ available_models_labels = list_available_models_labels() df = read_from_folder(self.folder, f"{self.session_path}/metrics.csv", ObjectType.CSV) try: if (df[METRICS_DATASET.TARGET_COLUMN] == METRICS_DATASET.AGGREGATED_ROW).any(): df = df[df[METRICS_DATASET.TARGET_COLUMN] == METRICS_DATASET.AGGREGATED_ROW] assert df[METRICS_DATASET.MODEL_COLUMN].nunique() == len( df.index), "More than one row per model" model_label = df.loc[df[self.performance_metric].idxmin()][ METRICS_DATASET.MODEL_COLUMN] # or idxmax() if maximize metric assert model_label in available_models_labels, "Best model retrieved is not an available models" except Exception as e: raise ModelSelectionError( f"Unable to find the best model of session '{self.session_name}' with the performance metric '{self.performance_metric}'. Full error: {e}" ) return model_label