Exemplo n.º 1
0
def evaluate(user_cfg: Dict):
    """Train model for a single epoch.

    Parameters
    ----------
    user_cfg : Dict
        Dictionary containing the user entered evaluation config
        
    """
    with open(user_cfg["run_dir"] / 'cfg.json', 'r') as fp:
        run_cfg = json.load(fp)

    basins = run_cfg["basins"]

    # get attribute means/stds
    db_path = str(user_cfg["run_dir"] / "attributes.db")
    attributes = load_attributes(db_path=db_path, 
                                 basins=basins,
                                 drop_lat_lon=True)
    means = attributes.mean()
    stds = attributes.std()

    # create model
    input_size_stat = 0 if run_cfg["no_static"] else 27
    input_size_dyn = 5 if (run_cfg["no_static"] or not run_cfg["concat_static"]) else 32
    model = Model(input_size_dyn=input_size_dyn,
                  input_size_stat=input_size_stat,
                  hidden_size=run_cfg["hidden_size"],
                  dropout=run_cfg["dropout"],
                  concat_static=run_cfg["concat_static"],
                  no_static=run_cfg["no_static"]).to(DEVICE)

    # load trained model
    weight_file = user_cfg["run_dir"] / 'model_epoch30.pt'
    model.load_state_dict(torch.load(weight_file, map_location=DEVICE))

    date_range = pd.date_range(start=GLOBAL_SETTINGS["val_start"], end=GLOBAL_SETTINGS["val_end"])
    results = {}
    for basin in tqdm(basins):
        ds_test = CamelsTXT(camels_root=user_cfg["camels_root"],
                            basin=basin,
                            dates=[GLOBAL_SETTINGS["val_start"], GLOBAL_SETTINGS["val_end"]],
                            is_train=False,
                            seq_length=run_cfg["seq_length"],
                            with_attributes=True,
                            attribute_means=means,
                            attribute_stds=stds,
                            concat_static=run_cfg["concat_static"],
                            db_path=db_path)
        loader = DataLoader(ds_test, batch_size=1024, shuffle=False, num_workers=4)

        preds, obs = evaluate_basin(model, loader)

        df = pd.DataFrame(data={'qobs': obs.flatten(), 'qsim': preds.flatten()}, index=date_range)

        results[basin] = df

    _store_results(user_cfg, run_cfg, results)
def evaluate(user_cfg: Dict):
    """Train model for a single epoch.

    Parameters
    ----------
    user_cfg : Dict
        Dictionary containing the user entered evaluation config
        
    """
    with open(user_cfg["run_dir"] / 'cfg.json', 'r') as fp:
        run_cfg = json.load(fp)

    basins = run_cfg["basins"]

    # get attribute means/stds
    db_path = str(user_cfg["run_dir"] / "attributes.db")
    attributes = load_attributes(db_path=db_path,
                                 basins=basins,
                                 drop_lat_lon=True)
    means = attributes.mean()
    stds = attributes.std()

    # load trained model
    model_file = user_cfg["run_dir"] / 'model.pkl'
    model = pickle.load(open(model_file, 'rb'))

    date_range = pd.date_range(start=GLOBAL_SETTINGS["val_start"],
                               end=GLOBAL_SETTINGS["val_end"])
    results = {}
    for basin in tqdm(basins):
        ds_test = CamelsTXT(
            camels_root=user_cfg["camels_root"],
            basin=basin,
            dates=[GLOBAL_SETTINGS["val_start"], GLOBAL_SETTINGS["val_end"]],
            is_train=False,
            seq_length=run_cfg["seq_length"],
            with_attributes=True,
            attribute_means=means,
            attribute_stds=stds,
            concat_static=False,
            db_path=db_path)

        preds, obs = evaluate_basin(model, ds_test, run_cfg["no_static"])

        df = pd.DataFrame(data={
            'qobs': obs.flatten(),
            'qsim': preds.flatten()
        },
                          index=date_range)

        results[basin] = df

    _store_results(user_cfg, run_cfg, results)
Exemplo n.º 3
0
def eval_robustness(user_cfg: Dict):
    """Evaluate model robustness of EA-LSTM

    In this experiment, gaussian noise with increasing scale is added to the static features to
    evaluate the model robustness against pertubations of the static catchment characteristics.
    For each scale, 50 noise vectors are drawn.
    
    Parameters
    ----------
    user_cfg : Dict
        Dictionary containing the user entered evaluation config
    
    Raises
    ------
    NotImplementedError
        If the run_dir specified points not to a EA-LSTM model folder.
    """
    random.seed(user_cfg["seed"])
    np.random.seed(user_cfg["seed"])

    # fixed settings for this analysis
    n_repetitions = 50
    scales = [0.1 * i for i in range(11)]

    with open(user_cfg["run_dir"] / 'cfg.json', 'r') as fp:
        run_cfg = json.load(fp)

    if run_cfg["concat_static"] or run_cfg["no_static"]:
        raise NotImplementedError("This function is only implemented for EA-LSTM models")

    basins = run_cfg["basins"]

    # get attribute means/stds
    db_path = str(user_cfg["run_dir"] / "attributes.db")
    attributes = load_attributes(db_path=db_path, 
                                 basins=basins,
                                 drop_lat_lon=True)
    means = attributes.mean()
    stds = attributes.std()

    # initialize Model
    model = Model(input_size_dyn=5,
                  input_size_stat=27,
                  hidden_size=run_cfg["hidden_size"],
                  dropout=run_cfg["dropout"]).to(DEVICE)
    weight_file = user_cfg["run_dir"] / "model_epoch30.pt"
    model.load_state_dict(torch.load(weight_file, map_location=DEVICE))

    overall_results = {}
    # process bar handle
    pbar = tqdm(basins, file=sys.stdout)
    for basin in pbar:
        ds_test = CamelsTXT(camels_root=user_cfg["camels_root"],
                            basin=basin,
                            dates=[GLOBAL_SETTINGS["val_start"], GLOBAL_SETTINGS["val_end"]],
                            is_train=False,
                            with_attributes=True,
                            attribute_means=means,
                            attribute_stds=stds,
                            db_path=db_path)
        loader = DataLoader(ds_test, batch_size=len(ds_test), shuffle=False, num_workers=0)
        basin_results = defaultdict(list)
        step = 1
        for scale in scales:
            for _ in range(1 if scale == 0.0 else n_repetitions):
                noise = np.random.normal(loc=0, scale=scale, size=27).astype(np.float32)
                noise = torch.from_numpy(noise).to(DEVICE)
                nse = eval_with_added_noise(model, loader, noise)
                basin_results[scale].append(nse)
                pbar.set_postfix_str(f"Basin progress: {step}/{(len(scales)-1)*n_repetitions+1}")
                step += 1

        overall_results[basin] = basin_results
    out_file = (Path(__file__).absolute().parent /
                f'results/{user_cfg["run_dir"].name}_model_robustness.p')
    if not out_file.parent.is_dir():
        out_file.parent.mkdir(parents=True)
    with out_file.open("wb") as fp:
        pickle.dump(overall_results, fp)
def predict_basin(
    basin: str,
    run_dir: Union[str, Path],
    camels_dir: Union[str, Path],
    period: str = "train",
    epoch: int = 30,
):
    if isinstance(run_dir, str):
        run_dir = Path(run_dir)
    elif not isinstance(run_dir, Path):
        raise TypeError(f"run_dir must be str or Path, not {type(run_dir)}")
    if isinstance(camels_dir, str):
        camels_dir = Path(camels_dir)
    elif not isinstance(camels_dir, Path):
        raise TypeError(f"run_dir must be str or Path, not {type(camels_dir)}")

    with open(run_dir / "cfg.json", "r") as fp:
        run_cfg = json.load(fp)

    if not period in ["train", "val"]:
        raise ValueError("period must be either train or val")
    basins = get_basin_list()
    db_path = str(run_dir / "attributes.db")
    attributes = load_attributes(db_path=db_path,
                                 basins=basins,
                                 drop_lat_lon=True)
    means = attributes.mean()
    stds = attributes.std()
    attrs_count = len(attributes.columns)
    timeseries_count = 6
    input_size_stat = timeseries_count if run_cfg["no_static"] else attrs_count
    input_size_dyn = (timeseries_count if
                      (run_cfg["no_static"] or not run_cfg["concat_static"])
                      else timeseries_count + attrs_count)
    model = Model(
        input_size_dyn=input_size_dyn,
        input_size_stat=input_size_stat,
        hidden_size=run_cfg["hidden_size"],
        dropout=run_cfg["dropout"],
        concat_static=run_cfg["concat_static"],
        no_static=run_cfg["no_static"],
    ).to(DEVICE)

    # load trained model
    weight_file = run_dir / f"model_epoch{epoch}.pt"
    model.load_state_dict(torch.load(weight_file, map_location=DEVICE))

    ds_test = CamelsTXT(
        camels_root=camels_dir,
        basin=basin,
        dates=[
            GLOBAL_SETTINGS[f"{period}_start"],
            GLOBAL_SETTINGS[f"{period}_end"]
        ],
        is_train=False,
        seq_length=run_cfg["seq_length"],
        with_attributes=True,
        attribute_means=means,
        attribute_stds=stds,
        concat_static=run_cfg["concat_static"],
        db_path=db_path,
    )
    date_range = ds_test.dates_index[run_cfg["seq_length"] - 1:]
    loader = DataLoader(ds_test, batch_size=1024, shuffle=False, num_workers=4)
    preds, obs = evaluate_basin(model, loader)
    df = pd.DataFrame(data={
        "qobs": obs.flatten(),
        "qsim": preds.flatten()
    },
                      index=date_range)

    results = df
    # plt.plot(date_range, results["qobs"], label="Obs")
    # plt.plot(date_range, results["qsim"], label="Preds")
    # plt.legend()
    # plt.savefig(f"{run_dir}/pred_basin_{basin}.pdf")
    # plt.close()
    return results, date_range