def create_splits(cfg: dict): print('create_splits') """Create random k-Fold cross validation splits. Takes a set of basins and randomly creates n splits. The result is stored into a dictionary, that contains for each split one key that contains a `train` and a `test` key, which contain the list of train and test basins. Parameters ---------- cfg : dict Dictionary containing the user entered evaluation config Raises ------ RuntimeError If file for the same random seed already exists. FileNotFoundError If the user defined basin list path does not exist. """ output_file = (Path(__file__).absolute().parent / f'data/kfolds/kfold_splits_seed{cfg["seed"]}.p') # check if split file already already exists if output_file.is_file(): raise RuntimeError(f"File '{output_file}' already exists.") # set random seed for reproducibility np.random.seed(cfg["seed"]) # read in basin file basins = get_basin_list(basin_list_file_train) # create folds kfold = KFold(n_splits=cfg["n_splits"], shuffle=True, random_state=cfg["seed"]) kfold.get_n_splits(basins) # dict to store the results of all folds splits = defaultdict(dict) for split, (train_idx, test_idx) in enumerate(kfold.split(basins)): # further split train_idx into train/val idx into train and val set train_basins = [basins[i] for i in train_idx] test_basins = [basins[i] for i in test_idx] splits[split] = {'train': train_basins, 'test': test_basins} with output_file.open('wb') as fp: pickle.dump(splits, fp) print(f"Stored dictionary with basin splits at {output_file}")
def get_args() -> Dict: """Parse input arguments Returns ------- dict Dictionary containing the run config. """ parser = argparse.ArgumentParser() parser.add_argument('mode', choices=["train", "evaluate", "eval_robustness"]) parser.add_argument('--camels_root', type=str, help="Root directory of CAMELS data set") parser.add_argument('--seed', type=int, required=False, help="Random seed") parser.add_argument('--run_dir', type=str, help="For evaluation mode. Path to run directory.") parser.add_argument('--cache_data', type=bool, default=False, help="If True, loads all data into memory") parser.add_argument('--num_workers', type=int, default=12, help="Number of parallel threads for data loading") parser.add_argument('--no_static', type=bool, default=False, help="If True, trains LSTM without static features") parser.add_argument('--concat_static', type=bool, default=False, help="If True, train LSTM with static feats concatenated at each time step") parser.add_argument('--use_mse', type=bool, default=False, help="If True, uses mean squared error as loss function.") parser.add_argument('--run_dir_base', type=str, default="runs", help="For training mode. Path to store run directories in.") parser.add_argument('--run_name', type=str, required=False, help="For training mode. Name of the run.") parser.add_argument('--train_start', type=str, help="Training start date (ddmmyyyy).") parser.add_argument('--train_end', type=str, help="Training end date (ddmmyyyy).") parser.add_argument('--basins', nargs='+', default=get_basin_list(), help='List of basins') cfg = vars(parser.parse_args()) cfg["train_start"] = pd.to_datetime(cfg["train_start"], format='%d%m%Y') cfg["train_end"] = pd.to_datetime(cfg["train_end"], format='%d%m%Y') # Validation checks if (cfg["mode"] == "train") and (cfg["seed"] is None): # generate random seed for this run cfg["seed"] = int(np.random.uniform(low=0, high=1e6)) if (cfg["mode"] in ["evaluate", "eval_robustness"]) and (cfg["run_dir"] is None): raise ValueError("In evaluation mode a run directory (--run_dir) has to be specified") # combine global settings with user config cfg.update(GLOBAL_SETTINGS) if cfg["mode"] == "train": # print config to terminal for key, val in cfg.items(): print(f"{key}: {val}") # convert path to PosixPath object cfg["camels_root"] = Path(cfg["camels_root"]) if cfg["run_dir"] is not None: cfg["run_dir"] = Path(cfg["run_dir"]) if cfg["run_dir_base"] is not None: cfg["run_dir_base"] = Path(cfg["run_dir_base"]) return cfg
def create_splits(cfg: dict): """Create random k-Fold cross validation splits. Takes a set of basins and randomly creates n splits. The result is stored into a dictionary, that contains for each split one key that contains a `train` and a `test` key, which contain the list of train and test basins. Parameters ---------- cfg : dict Dictionary containing the user entered evaluation config Raises ------ RuntimeError If file for the same random seed already exists. FileNotFoundError If the user defined basin list path does not exist. """ output_file = (Path(__file__).absolute().parent / f'data/kfold_splits_seed{cfg["seed"]}.p') # check if split file already already exists if output_file.is_file(): raise RuntimeError(f"File '{output_file}' already exists.") # set random seed for reproduceability np.random.seed(cfg["seed"]) # read in basin file if cfg["basin_file"] is not None: if not Path(cfg["basin_file"]).is_file(): raise FileNotFoundError(f"Not file found at {cfg['basin_file']}") with open(cfg["basin_file"], 'r') as fp: basins = fp.readlines() basins = [b.strip() for b in basins] """ Delete some basins because of missing data: - '06775500' & '06846500' no attributes - '09535100' no streamflow records """ ignore_basins = ['06775500', '06846500', '09535100'] basins = [b for b in basins if b not in ignore_basins] else: basins = get_basin_list() # create folds kfold = KFold(n_splits=cfg["n_splits"], shuffle=True, random_state=cfg["seed"]) kfold.get_n_splits(basins) # dict to store the results of all folds splits = defaultdict(dict) for split, (train_idx, test_idx) in enumerate(kfold.split(basins)): # further split train_idx into train/val idx into train and val set train_basins = [basins[i] for i in train_idx] test_basins = [basins[i] for i in test_idx] splits[split] = {'train': train_basins, 'test': test_basins} with output_file.open('wb') as fp: pickle.dump(splits, fp) print(f"Stored dictionary with basin splits at {output_file}")
def evaluate(user_cfg: Dict): """Train model for a single epoch. Parameters ---------- user_cfg : Dict Dictionary containing the user entered evaluation config """ with open(user_cfg["run_dir"] / 'cfg.json', 'r') as fp: run_cfg = json.load(fp) if user_cfg["split_file"] is not None: with Path(user_cfg["split_file"]).open('rb') as fp: splits = pickle.load(fp) basins = splits[run_cfg["split"]]["test"] else: basins = get_basin_list() # get attribute means/stds from trainings dataset train_file = user_cfg["run_dir"] / "data/train/train_data.h5" db_path = str(user_cfg["run_dir"] / "attributes.db") ds_train = CamelsH5(h5_file=train_file, db_path=db_path, basins=basins, concat_static=run_cfg["concat_static"]) means = ds_train.get_attribute_means() stds = ds_train.get_attribute_stds() # create model input_size_dyn = 5 if (run_cfg["no_static"] or not run_cfg["concat_static"]) else 32 model = Model(input_size_dyn=input_size_dyn, hidden_size=run_cfg["hidden_size"], dropout=run_cfg["dropout"], concat_static=run_cfg["concat_static"], no_static=run_cfg["no_static"]).to(DEVICE) # load trained model weight_file = user_cfg["run_dir"] / 'model_epoch30.pt' model.load_state_dict(torch.load(weight_file, map_location=DEVICE)) date_range = pd.date_range(start=GLOBAL_SETTINGS["val_start"], end=GLOBAL_SETTINGS["val_end"]) results = {} for basin in tqdm(basins): ds_test = CamelsTXT( camels_root=user_cfg["camels_root"], basin=basin, dates=[GLOBAL_SETTINGS["val_start"], GLOBAL_SETTINGS["val_end"]], is_train=False, seq_length=run_cfg["seq_length"], with_attributes=True, attribute_means=means, attribute_stds=stds, concat_static=run_cfg["concat_static"], db_path=db_path) loader = DataLoader(ds_test, batch_size=1024, shuffle=False, num_workers=4) preds, obs = evaluate_basin(model, loader) df = pd.DataFrame(data={ 'qobs': obs.flatten(), 'qsim': preds.flatten() }, index=date_range) results[basin] = df _store_results(user_cfg, run_cfg, results)
def train(cfg): """Train model. Parameters ---------- cfg : Dict Dictionary containing the run config """ # fix random seeds random.seed(cfg["seed"]) np.random.seed(cfg["seed"]) torch.cuda.manual_seed(cfg["seed"]) torch.manual_seed(cfg["seed"]) if cfg["split_file"] is not None: with Path(cfg["split_file"]).open('rb') as fp: splits = pickle.load(fp) basins = splits[cfg["split"]]["train"] else: basins = get_basin_list() # create folder structure for this run cfg = _setup_run(cfg) # prepare data for training cfg = _prepare_data(cfg=cfg, basins=basins) # prepare PyTorch DataLoader ds = CamelsH5(h5_file=cfg["train_file"], basins=basins, db_path=cfg["db_path"], concat_static=cfg["concat_static"], cache=cfg["cache_data"], no_static=cfg["no_static"]) loader = DataLoader(ds, batch_size=cfg["batch_size"], shuffle=True, num_workers=cfg["num_workers"]) # create model and optimizer input_size_dyn = 5 if (cfg["no_static"] or not cfg["concat_static"]) else 32 model = Model(input_size_dyn=input_size_dyn, hidden_size=cfg["hidden_size"], initial_forget_bias=cfg["initial_forget_gate_bias"], dropout=cfg["dropout"], concat_static=cfg["concat_static"], no_static=cfg["no_static"]).to(DEVICE) optimizer = torch.optim.Adam(model.parameters(), lr=cfg["learning_rate"]) # define loss function if cfg["use_mse"]: loss_func = nn.MSELoss() else: loss_func = NSELoss() # reduce learning rates after each 10 epochs learning_rates = {11: 5e-4, 21: 1e-4} for epoch in range(1, cfg["epochs"] + 1): # set new learning rate if epoch in learning_rates.keys(): for param_group in optimizer.param_groups: param_group["lr"] = learning_rates[epoch] train_epoch(model, optimizer, loss_func, loader, cfg, epoch, cfg["use_mse"]) model_path = cfg["run_dir"] / f"model_epoch{epoch}.pt" torch.save(model.state_dict(), str(model_path))
def eval_robustness(user_cfg: Dict): """Evaluate model robustness of EA-LSTM In this experiment, gaussian noise with increasing scale is added to the static features to evaluate the model robustness against pertubations of the static catchment characteristics. For each scale, 50 noise vectors are drawn. Parameters ---------- user_cfg : Dict Dictionary containing the user entered evaluation config Raises ------ NotImplementedError If the run_dir specified points not to a EA-LSTM model folder. """ random.seed(user_cfg["seed"]) np.random.seed(user_cfg["seed"]) # fixed settings for this analysis n_repetitions = 50 scales = [0.1 * i for i in range(11)] with open(user_cfg["run_dir"] / 'cfg.json', 'r') as fp: run_cfg = json.load(fp) if run_cfg["concat_static"] or run_cfg["no_static"]: raise NotImplementedError( "This function is only implemented for EA-LSTM models") basins = get_basin_list() # get attribute means/stds db_path = str(user_cfg["run_dir"] / "attributes.db") attributes = load_attributes(db_path=db_path, basins=basins, drop_lat_lon=True) means = attributes.mean() stds = attributes.std() # initialize Model model = Model(input_size_dyn=5, input_size_stat=27, hidden_size=run_cfg["hidden_size"], dropout=run_cfg["dropout"]).to(DEVICE) weight_file = user_cfg["run_dir"] / "model_epoch30.pt" model.load_state_dict(torch.load(weight_file, map_location=DEVICE)) overall_results = {} # process bar handle pbar = tqdm(basins, file=sys.stdout) for basin in pbar: ds_test = CamelsTXT( camels_root=user_cfg["camels_root"], basin=basin, dates=[GLOBAL_SETTINGS["val_start"], GLOBAL_SETTINGS["val_end"]], is_train=False, with_attributes=True, attribute_means=means, attribute_stds=stds, db_path=db_path) loader = DataLoader(ds_test, batch_size=len(ds_test), shuffle=False, num_workers=0) basin_results = defaultdict(list) step = 1 for scale in scales: for _ in range(1 if scale == 0.0 else n_repetitions): noise = np.random.normal(loc=0, scale=scale, size=27).astype(np.float32) noise = torch.from_numpy(noise).to(DEVICE) nse = eval_with_added_noise(model, loader, noise) basin_results[scale].append(nse) pbar.set_postfix_str( f"Basin progress: {step}/{(len(scales)-1)*n_repetitions+1}" ) step += 1 overall_results[basin] = basin_results out_file = (Path(__file__).absolute().parent / f'results/{user_cfg["run_dir"].name}_model_robustness.p') if not out_file.parent.is_dir(): out_file.parent.mkdir(parents=True) with out_file.open("wb") as fp: pickle.dump(overall_results, fp)
else: temp_cfg[key] = val json.dump(temp_cfg, fp, sort_keys=True, indent=4) return cfg ################ # Prepare grid # ################ if __name__ == "__main__": cfg = get_args() cfg = _setup_run(cfg) np.random.seed(0) basins = get_basin_list() basin_samples = [] for n_basins in cfg["n_basins"]: for i in range(cfg["basin_samples_per_grid_cell"]): basin_samples.append(np.random.choice(basins, size=n_basins, replace=False)) if n_basins == 531: break # Do the XGB parameter search for one configuration, then reuse these parameters for all others train_start = cfg["xgb_param_search_range"][0] train_end = cfg["xgb_param_search_range"][1] basin_sample_id, basin_sample = [(i, b) for i, b in enumerate(basin_samples) if len(b) == cfg["xgb_param_search_basins"]][0] if cfg["use_params"] is None: param_search_name = f"run_xgb_param_search_{train_start}_{train_end}_basinsample{len(basin_sample)}_{basin_sample_id}_seed111" param_search_model_dir = cfg["run_dir"] / param_search_name
def evaluate(user_cfg: Dict): """Train model for a single epoch. Parameters ---------- user_cfg : Dict Dictionary containing the user entered evaluation config """ with open(user_cfg["run_dir"] / 'cfg.json', 'r') as fp: run_cfg = json.load(fp) if user_cfg["split_file"] is not None: with Path(user_cfg["split_file"]).open('rb') as fp: splits = pickle.load(fp) basins = splits[run_cfg["split"]]["test"] else: basins = get_basin_list(basin_list_file_evaluate) # get attribute means/stds db_path = str(user_cfg["run_dir"] / "attributes.db") attributes = load_attributes(db_path=db_path, basins=basins, drop_lat_lon=True, keep_features=user_cfg["camels_attr"]) # get remaining scaler from pickle file scaler_file = user_cfg["run_dir"] / "data" / "train" / "scaler.p" with open(scaler_file, "rb") as fp: scaler = pickle.load(fp) scaler["camels_attr_mean"] = attributes.mean() scaler["camels_attr_std"] = attributes.std() # create model if run_cfg["concat_static"] and not run_cfg["embedding_hiddens"]: input_size_stat = 0 input_size_dyn = (len(run_cfg["dynamic_inputs"]) + len(run_cfg["camels_attr"]) + len(run_cfg["static_inputs"])) concat_static = True else: input_size_stat = len(run_cfg["camels_attr"]) + len( run_cfg["static_inputs"]) input_size_dyn = len(run_cfg["dynamic_inputs"]) concat_static = False model = Model(input_size_dyn=input_size_dyn, input_size_stat=input_size_stat, hidden_size=run_cfg["hidden_size"], dropout=run_cfg["dropout"], concat_static=run_cfg["concat_static"], embedding_hiddens=run_cfg["embedding_hiddens"]).to(DEVICE) # load trained model weight_file = user_cfg["run_dir"] / 'model_epoch30.pt' model.load_state_dict(torch.load(weight_file, map_location=DEVICE)) date_range = pd.date_range(start=user_cfg["val_start"], end=user_cfg["val_end"]) results = {} cell_states = {} embeddings = {} nses = [] file_name = Path(__file__).parent / 'data' / 'dynamic_features_nwm_v2.p' with file_name.open("rb") as fp: additional_features = pickle.load(fp) # ad hoc static climate indices # requres the training period for this experiment # overwrites the *_dyn type climate indices in 'additional_features' if not user_cfg['use_dynamic_climate']: if user_cfg['static_climate'].lower() == 'test': eval_clim_indexes = training_period_climate_indices( db_path=db_path, camels_root=user_cfg['camels_root'], basins=basins, start_date=user_cfg['val_start'], end_date=user_cfg['val_end']) elif user_cfg['static_climate'].lower() == 'train': eval_clim_indexes = training_period_climate_indices( db_path=db_path, camels_root=user_cfg['camels_root'], basins=basins, start_date=user_cfg['train_start'], end_date=user_cfg['train_end']) else: raise RuntimeError(f"Unknown static_climate variable.") for basin in basins: for col in eval_clim_indexes[basin].columns: additional_features[basin][col] = np.tile( eval_clim_indexes[basin][col].values, [additional_features[basin].shape[0], 1]) for basin in tqdm(basins): ds_test = CamelsTXTv2( camels_root=user_cfg["camels_root"], basin=basin, dates=[user_cfg["val_start"], user_cfg["val_end"]], is_train=False, dynamic_inputs=user_cfg["dynamic_inputs"], camels_attr=user_cfg["camels_attr"], static_inputs=user_cfg["static_inputs"], additional_features=additional_features[basin], scaler=scaler, seq_length=run_cfg["seq_length"], concat_static=concat_static, db_path=db_path) loader = DataLoader(ds_test, batch_size=2500, shuffle=False, num_workers=user_cfg["num_workers"]) preds, obs, cells, embeds = evaluate_basin(model, loader) # rescale predictions preds = preds * scaler["q_std"] + scaler["q_mean"] # store predictions # set discharges < 0 to zero preds[preds < 0] = 0 nses.append(calc_nse(obs[obs >= 0], preds[obs >= 0])) df = pd.DataFrame(data={ 'qobs': obs.flatten(), 'qsim': preds.flatten() }, index=date_range) results[basin] = df # store cell states and embedding values cell_states[basin] = pd.DataFrame(data=cells, index=date_range) embeddings[basin] = pd.DataFrame(data=embeds, index=date_range) print(f"Mean NSE {np.mean(nses)}, median NSE {np.median(nses)}") _store_results(user_cfg, run_cfg, results, cell_states, embeddings)
def train(cfg): """Train model. Parameters ---------- cfg : Dict Dictionary containing the run config """ # fix random seeds random.seed(cfg["seed"]) np.random.seed(cfg["seed"]) torch.cuda.manual_seed(cfg["seed"]) torch.manual_seed(cfg["seed"]) if cfg["split_file"] is not None: with Path(cfg["split_file"]).open('rb') as fp: splits = pickle.load(fp) basins = splits[cfg["split"]]["train"] else: basins = get_basin_list(basin_list_file_train) #basins = basins[:30] # create folder structure for this run cfg = _setup_run(cfg) # prepare data for training cfg = _prepare_data(cfg=cfg, basins=basins) with open(cfg["scaler_file"], 'rb') as fp: scaler = pickle.load(fp) camels_attr = load_attributes(cfg["db_path"], basins, drop_lat_lon=True, keep_features=cfg["camels_attr"]) scaler["camels_attr_mean"] = camels_attr.mean() scaler["camels_attr_std"] = camels_attr.std() # create model and optimizer if cfg["concat_static"] and not cfg["embedding_hiddens"]: input_size_stat = 0 input_size_dyn = (len(cfg["dynamic_inputs"]) + len(cfg["camels_attr"]) + len(cfg["static_inputs"])) concat_static = True else: input_size_stat = len(cfg["camels_attr"]) + len(cfg["static_inputs"]) input_size_dyn = len(cfg["dynamic_inputs"]) concat_static = False model = Model(input_size_dyn=input_size_dyn, input_size_stat=input_size_stat, hidden_size=cfg["hidden_size"], initial_forget_bias=cfg["initial_forget_gate_bias"], embedding_hiddens=cfg["embedding_hiddens"], dropout=cfg["dropout"], concat_static=cfg["concat_static"]).to(DEVICE) optimizer = torch.optim.Adam(model.parameters(), lr=cfg["learning_rate"]) # prepare PyTorch DataLoader ds = CamelsH5v2(h5_file=cfg["train_file"], basins=basins, db_path=cfg["db_path"], concat_static=concat_static, cache=cfg["cache_data"], camels_attr=cfg["camels_attr"], scaler=scaler) loader = DataLoader(ds, batch_size=cfg["batch_size"], shuffle=True, num_workers=cfg["num_workers"]) # define loss function if cfg["use_mse"]: loss_func = nn.MSELoss() else: loss_func = NSELoss() # reduce learning rates after each 10 epochs learning_rates = {11: 5e-4, 21: 1e-4} for epoch in range(1, cfg["epochs"] + 1): # set new learning rate if epoch in learning_rates.keys(): for param_group in optimizer.param_groups: param_group["lr"] = learning_rates[epoch] train_epoch(model, optimizer, loss_func, loader, cfg, epoch, cfg["use_mse"]) model_path = cfg["run_dir"] / f"model_epoch{epoch}.pt" torch.save(model.state_dict(), str(model_path))
def dist_train(rank, world_size, cfg): """Train model. Parameters ---------- cfg : Dict Dictionary containing the run config """ print(f"Running basic DDP example on rank {rank}. {world_size}") setup(rank, world_size) # fix random seeds random.seed(cfg["seed"]) np.random.seed(cfg["seed"]) torch.cuda.manual_seed(cfg["seed"]) torch.manual_seed(cfg["seed"]) basins = get_basin_list() if rank == 0: # create folder structure for this run cfg = _setup_run(cfg) # prepare data for training cfg = _prepare_data(cfg=cfg, basins=basins) with open(str(cfg["camels_root"]) + '/cfg.pkl', 'wb') as f: pickle.dump(cfg, f, pickle.HIGHEST_PROTOCOL) dist.barrier() with open(str(cfg["camels_root"]) + '/cfg.pkl', 'rb') as f: cfg = pickle.load(f) # prepare PyTorch DataLoader ds = CamelsH5(h5_file=cfg["train_file"], basins=basins, db_path=cfg["db_path"], concat_static=cfg["concat_static"], cache=cfg["cache_data"], no_static=cfg["no_static"]) sampler = torch.utils.data.distributed.DistributedSampler( ds, num_replicas=world_size, rank=rank ) loader = DataLoader(ds, batch_size=cfg["batch_size"], shuffle=False, num_workers=cfg["num_workers"], sampler=sampler, pin_memory=True) # create model and optimizer input_size_stat = 0 if cfg["no_static"] else 27 input_size_dyn = 5 if (cfg["no_static"] or not cfg["concat_static"]) else 32 model = Model(input_size_dyn=input_size_dyn, input_size_stat=input_size_stat, hidden_size=cfg["hidden_size"], initial_forget_bias=cfg["initial_forget_gate_bias"], dropout=cfg["dropout"], concat_static=cfg["concat_static"], no_static=cfg["no_static"]) # if cfg["initial_forget_gate_bias"] != 0: # model.bias.shape ddp_model = DDP(model.to(rank), device_ids=[rank]) optimizer = torch.optim.Adam(ddp_model.parameters(), lr=cfg["learning_rate"]) # define loss function if cfg["use_mse"]: loss_func = nn.MSELoss() else: loss_func = NSELoss() # reduce learning rates after each 10 epochs learning_rates = {11: 5e-4, 21: 1e-4} CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint" map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} for epoch in range(1, math.ceil(cfg["epochs"] / world_size) + 1): # set new learning rate if epoch in learning_rates.keys(): for param_group in optimizer.param_groups: param_group["lr"] = learning_rates[epoch] if rank == 0: torch.save(ddp_model.state_dict(), CHECKPOINT_PATH) dist.barrier() ddp_model.load_state_dict( torch.load(CHECKPOINT_PATH, map_location=map_location)) # optimizer.zero_grad() train_epoch(ddp_model, optimizer, loss_func, loader, cfg, epoch, cfg["use_mse"], rank) # model_path = cfg["run_dir"] / f"model_epoch{epoch}.pt" # torch.save(ddp_model.state_dict(), str(model_path)) # if rank == 0: # os.remove(CHECKPOINT_PATH) cleanup()
def predict_basin( basin: str, run_dir: Union[str, Path], camels_dir: Union[str, Path], period: str = "train", epoch: int = 30, ): if isinstance(run_dir, str): run_dir = Path(run_dir) elif not isinstance(run_dir, Path): raise TypeError(f"run_dir must be str or Path, not {type(run_dir)}") if isinstance(camels_dir, str): camels_dir = Path(camels_dir) elif not isinstance(camels_dir, Path): raise TypeError(f"run_dir must be str or Path, not {type(camels_dir)}") with open(run_dir / "cfg.json", "r") as fp: run_cfg = json.load(fp) if not period in ["train", "val"]: raise ValueError("period must be either train or val") basins = get_basin_list() db_path = str(run_dir / "attributes.db") attributes = load_attributes(db_path=db_path, basins=basins, drop_lat_lon=True) means = attributes.mean() stds = attributes.std() attrs_count = len(attributes.columns) timeseries_count = 6 input_size_stat = timeseries_count if run_cfg["no_static"] else attrs_count input_size_dyn = (timeseries_count if (run_cfg["no_static"] or not run_cfg["concat_static"]) else timeseries_count + attrs_count) model = Model( input_size_dyn=input_size_dyn, input_size_stat=input_size_stat, hidden_size=run_cfg["hidden_size"], dropout=run_cfg["dropout"], concat_static=run_cfg["concat_static"], no_static=run_cfg["no_static"], ).to(DEVICE) # load trained model weight_file = run_dir / f"model_epoch{epoch}.pt" model.load_state_dict(torch.load(weight_file, map_location=DEVICE)) ds_test = CamelsTXT( camels_root=camels_dir, basin=basin, dates=[ GLOBAL_SETTINGS[f"{period}_start"], GLOBAL_SETTINGS[f"{period}_end"] ], is_train=False, seq_length=run_cfg["seq_length"], with_attributes=True, attribute_means=means, attribute_stds=stds, concat_static=run_cfg["concat_static"], db_path=db_path, ) date_range = ds_test.dates_index[run_cfg["seq_length"] - 1:] loader = DataLoader(ds_test, batch_size=1024, shuffle=False, num_workers=4) preds, obs = evaluate_basin(model, loader) df = pd.DataFrame(data={ "qobs": obs.flatten(), "qsim": preds.flatten() }, index=date_range) results = df # plt.plot(date_range, results["qobs"], label="Obs") # plt.plot(date_range, results["qsim"], label="Preds") # plt.legend() # plt.savefig(f"{run_dir}/pred_basin_{basin}.pdf") # plt.close() return results, date_range