def _prepare_data(cfg: Dict, basins: List) -> Dict:
    """Preprocess training data.

    Parameters
    ----------
    cfg : dict
        Dictionary containing the run config
    basins : List
        List containing the 8-digit USGS gauge id

    Returns
    -------
    dict
        Dictionary containing the updated run config
    """
    # create database file containing the static basin attributes
    cfg["db_path"] = str(cfg["run_dir"] / "attributes.db")
    add_camels_attributes(cfg["camels_root"], db_path=cfg["db_path"])

    # create .h5 files for train and validation data
    if cfg["train_file"] is None:
        cfg["train_file"] = cfg["train_dir"] / 'train_data.h5'
        cfg["scaler_file"] = cfg["train_dir"] / "scaler.p"

        # get additional static inputs
        file_name = Path(__file__).parent / 'data' / 'dynamic_features_nwm_v2.p'
        with file_name.open("rb") as fp:
            additional_features = pickle.load(fp)
        # ad hoc static climate indices 
        # requres the training period for this experiment
        # overwrites the *_dyn type climate indices in 'additional_features'
        if not cfg['use_dynamic_climate']:
            train_clim_indexes = training_period_climate_indices(
                                      db_path=cfg['db_path'], camels_root=cfg['camels_root'],
                                      basins=basins, 
                                      start_date=cfg['train_start'], end_date=cfg['train_end'])
            for basin in basins:
               for col in train_clim_indexes[basin].columns:
                   additional_features[basin][col] = np.tile(train_clim_indexes[basin][col].values,[additional_features[basin].shape[0],1])

        create_h5_files_v2(
            camels_root=cfg["camels_root"],
            out_file=cfg["train_file"],
            basins=basins,
            dates=[cfg["train_start"], cfg["train_end"]],
            db_path=cfg["db_path"],
            cfg=cfg,
            additional_features=additional_features,
            num_workers=cfg["num_workers"],
            seq_length=cfg["seq_length"])
    
    # copy scaler file into run folder
    else:
        dst = cfg["train_dir"] / "scaler.p"
        shutil.copyfile(cfg["scaler_file"], dst)

    return cfg
示例#2
0
def evaluate(user_cfg: Dict):
    """Train model for a single epoch.

    Parameters
    ----------
    user_cfg : Dict
        Dictionary containing the user entered evaluation config
        
    """
    with open(user_cfg["run_dir"] / 'cfg.json', 'r') as fp:
        run_cfg = json.load(fp)

    if user_cfg["split_file"] is not None:
        with Path(user_cfg["split_file"]).open('rb') as fp:
            splits = pickle.load(fp)
        basins = splits[run_cfg["split"]]["test"]
    else:
        basins = get_basin_list(basin_list_file_evaluate)

    # get attribute means/stds
    db_path = str(user_cfg["run_dir"] / "attributes.db")
    attributes = load_attributes(db_path=db_path,
                                 basins=basins,
                                 drop_lat_lon=True,
                                 keep_features=user_cfg["camels_attr"])

    # get remaining scaler from pickle file
    scaler_file = user_cfg["run_dir"] / "data" / "train" / "scaler.p"
    with open(scaler_file, "rb") as fp:
        scaler = pickle.load(fp)
    scaler["camels_attr_mean"] = attributes.mean()
    scaler["camels_attr_std"] = attributes.std()

    # create model
    if run_cfg["concat_static"] and not run_cfg["embedding_hiddens"]:
        input_size_stat = 0
        input_size_dyn = (len(run_cfg["dynamic_inputs"]) +
                          len(run_cfg["camels_attr"]) +
                          len(run_cfg["static_inputs"]))
        concat_static = True
    else:
        input_size_stat = len(run_cfg["camels_attr"]) + len(
            run_cfg["static_inputs"])
        input_size_dyn = len(run_cfg["dynamic_inputs"])
        concat_static = False
    model = Model(input_size_dyn=input_size_dyn,
                  input_size_stat=input_size_stat,
                  hidden_size=run_cfg["hidden_size"],
                  dropout=run_cfg["dropout"],
                  concat_static=run_cfg["concat_static"],
                  embedding_hiddens=run_cfg["embedding_hiddens"]).to(DEVICE)

    # load trained model
    weight_file = user_cfg["run_dir"] / 'model_epoch30.pt'
    model.load_state_dict(torch.load(weight_file, map_location=DEVICE))

    date_range = pd.date_range(start=user_cfg["val_start"],
                               end=user_cfg["val_end"])
    results = {}
    cell_states = {}
    embeddings = {}
    nses = []

    file_name = Path(__file__).parent / 'data' / 'dynamic_features_nwm_v2.p'
    with file_name.open("rb") as fp:
        additional_features = pickle.load(fp)

    # ad hoc static climate indices
    # requres the training period for this experiment
    # overwrites the *_dyn type climate indices in 'additional_features'
    if not user_cfg['use_dynamic_climate']:
        if user_cfg['static_climate'].lower() == 'test':
            eval_clim_indexes = training_period_climate_indices(
                db_path=db_path,
                camels_root=user_cfg['camels_root'],
                basins=basins,
                start_date=user_cfg['val_start'],
                end_date=user_cfg['val_end'])
        elif user_cfg['static_climate'].lower() == 'train':
            eval_clim_indexes = training_period_climate_indices(
                db_path=db_path,
                camels_root=user_cfg['camels_root'],
                basins=basins,
                start_date=user_cfg['train_start'],
                end_date=user_cfg['train_end'])
        else:
            raise RuntimeError(f"Unknown static_climate variable.")

        for basin in basins:
            for col in eval_clim_indexes[basin].columns:
                additional_features[basin][col] = np.tile(
                    eval_clim_indexes[basin][col].values,
                    [additional_features[basin].shape[0], 1])

    for basin in tqdm(basins):
        ds_test = CamelsTXTv2(
            camels_root=user_cfg["camels_root"],
            basin=basin,
            dates=[user_cfg["val_start"], user_cfg["val_end"]],
            is_train=False,
            dynamic_inputs=user_cfg["dynamic_inputs"],
            camels_attr=user_cfg["camels_attr"],
            static_inputs=user_cfg["static_inputs"],
            additional_features=additional_features[basin],
            scaler=scaler,
            seq_length=run_cfg["seq_length"],
            concat_static=concat_static,
            db_path=db_path)
        loader = DataLoader(ds_test,
                            batch_size=2500,
                            shuffle=False,
                            num_workers=user_cfg["num_workers"])

        preds, obs, cells, embeds = evaluate_basin(model, loader)

        # rescale predictions
        preds = preds * scaler["q_std"] + scaler["q_mean"]

        # store predictions
        # set discharges < 0 to zero
        preds[preds < 0] = 0
        nses.append(calc_nse(obs[obs >= 0], preds[obs >= 0]))
        df = pd.DataFrame(data={
            'qobs': obs.flatten(),
            'qsim': preds.flatten()
        },
                          index=date_range)
        results[basin] = df

        # store cell states and embedding values
        cell_states[basin] = pd.DataFrame(data=cells, index=date_range)
        embeddings[basin] = pd.DataFrame(data=embeds, index=date_range)

    print(f"Mean NSE {np.mean(nses)}, median NSE {np.median(nses)}")

    _store_results(user_cfg, run_cfg, results, cell_states, embeddings)