コード例 #1
0
    def _load_run_data(self):
        """Load run specific data from run directory"""

        # get list of basins
        self.basins = load_basin_file(
            getattr(self.cfg, f"{self.period}_basin_file"))

        # load feature scaler
        scaler_file = self.run_dir / "train_data" / "train_data_scaler.p"
        with scaler_file.open('rb') as fp:
            self.scaler = pickle.load(fp)

        # check for old scaler files, where the center/scale parameters had still old names
        if "xarray_means" in self.scaler.keys():
            self.scaler["xarray_feature_center"] = self.scaler.pop(
                "xarray_means")
        if "xarray_stds" in self.scaler.keys():
            self.scaler["xarray_feature_scale"] = self.scaler.pop(
                "xarray_stds")

        # load basin_id to integer dictionary for one-hot-encoding
        if self.cfg.use_basin_id_encoding:
            file_path = self.run_dir / "train_data" / "id_to_int.p"
            with file_path.open("rb") as fp:
                self.id_to_int = pickle.load(fp)

        for file in self.cfg.additional_feature_files:
            with open(file, "rb") as fp:
                self.additional_features.append(pickle.load(fp))
コード例 #2
0
    def __init__(self, cfg: Config):
        super(BaseTrainer, self).__init__()
        self.cfg = cfg
        self.model = None
        self.optimizer = None
        self.loss_obj = None
        self.experiment_logger = None
        self.loader = None
        self.validator = None
        self.noise_sampler_y = None
        self._target_mean = None
        self._target_std = None
        self._allow_subsequent_nan_losses = cfg.allow_subsequent_nan_losses

        # load train basin list and add number of basins to the config
        self.basins = load_basin_file(cfg.train_basin_file)
        self.cfg.number_of_basins = len(self.basins)

        # check at which epoch the training starts
        self._epoch = self._get_start_epoch_number()

        self._create_folder_structure()
        setup_logging(str(self.cfg.run_dir / "output.log"))
        LOGGER.info(f"### Folder structure created at {self.cfg.run_dir}")

        if self.cfg.is_continue_training:
            LOGGER.info(f"### Continue training of run stored in {self.cfg.base_run_dir}")

        LOGGER.info(f"### Run configurations for {self.cfg.experiment_name}")
        for key, val in self.cfg.as_dict().items():
            LOGGER.info(f"{key}: {val}")

        self._set_random_seeds()
        self._set_device()
コード例 #3
0
    def _load_run_data(self):
        """Load run specific data from run directory"""

        # get list of basins
        self.basins = load_basin_file(getattr(self.cfg, f"{self.period}_basin_file"))

        # load feature scaler
        scaler_file = self.run_dir / "train_data" / "train_data_scaler.p"
        with scaler_file.open('rb') as fp:
            self.scaler = pickle.load(fp)

        # load basin_id to integer dictionary for one-hot-encoding
        if self.cfg.use_basin_id_encoding:
            file_path = self.run_dir / "train_data" / "id_to_int.p"
            with file_path.open("rb") as fp:
                self.id_to_int = pickle.load(fp)

        for file in self.cfg.additional_feature_files:
            with open(file, "rb") as fp:
                self.additional_features.append(pickle.load(fp))
コード例 #4
0
    def __init__(self,
                 cfg: Config,
                 is_train: bool,
                 period: str,
                 basin: str = None,
                 additional_features: List[Dict[str, pd.DataFrame]] = [],
                 id_to_int: Dict[str, int] = {},
                 scaler: Dict[str, Union[pd.Series, xarray.DataArray]] = {}):
        super(BaseDataset, self).__init__()
        self.cfg = cfg
        self.is_train = is_train

        if period not in ["train", "validation", "test"]:
            raise ValueError(
                "'period' must be one of 'train', 'validation' or 'test' ")
        else:
            self.period = period

        if period in ["validation", "test"]:
            if not scaler:
                raise ValueError(
                    "During evaluation of validation or test period, scaler dictionary has to be passed"
                )

            if cfg.use_basin_id_encoding and not id_to_int:
                raise ValueError(
                    "For basin id embedding, the id_to_int dictionary has to be passed anything but train"
                )

        if basin is None:
            self.basins = utils.load_basin_file(
                getattr(cfg, f"{period}_basin_file"))
        else:
            self.basins = [basin]
        self.additional_features = additional_features
        self.id_to_int = id_to_int
        self.scaler = scaler
        # don't compute scale when finetuning
        if is_train and not scaler:
            self._compute_scaler = True
        else:
            self._compute_scaler = False

        # check and extract frequency information from config
        self.frequencies = []
        self.seq_len = None
        self._predict_last_n = None
        self._initialize_frequency_configuration()

        # during training we log data processing with progress bars, but not during validation/testing
        self._disable_pbar = cfg.verbose == 0 or not self.is_train

        # initialize class attributes that are filled in the data loading functions
        self.x_d = {}
        self.x_s = {}
        self.attributes = {}
        self.y = {}
        self.per_basin_target_stds = {}
        self.dates = {}
        self.num_samples = 0
        self.one_hot = None
        self.period_starts = {
        }  # needed for restoring date index during evaluation

        # get the start and end date periods for each basin
        self._get_start_and_end_dates()

        # if additional features files are passed in the config, load those files
        if (not additional_features) and cfg.additional_feature_files:
            self._load_additional_features()

        if cfg.use_basin_id_encoding:
            if self.is_train:
                # creates lookup table for the number of basins in the training set
                self._create_id_to_int()

            # create empty tensor of the same length as basins in id to int lookup table
            self.one_hot = torch.zeros(len(self.id_to_int),
                                       dtype=torch.float32)

        # load and preprocess data
        self._load_data()

        if self.is_train:
            self._dump_scaler()