示例#1
0
def create_splits(cfg: dict):
    print('create_splits')
    """Create random k-Fold cross validation splits.
    
    Takes a set of basins and randomly creates n splits. The result is stored into a dictionary,
    that contains for each split one key that contains a `train` and a `test` key, which contain
    the list of train and test basins.
    Parameters
    ----------
    cfg : dict
        Dictionary containing the user entered evaluation config
    
    Raises
    ------
    RuntimeError
        If file for the same random seed already exists.
    FileNotFoundError
        If the user defined basin list path does not exist.
    """
    output_file = (Path(__file__).absolute().parent /
                   f'data/kfolds/kfold_splits_seed{cfg["seed"]}.p')
    # check if split file already already exists
    if output_file.is_file():
        raise RuntimeError(f"File '{output_file}' already exists.")

    # set random seed for reproducibility
    np.random.seed(cfg["seed"])

    # read in basin file
    basins = get_basin_list(basin_list_file_train)

    # create folds
    kfold = KFold(n_splits=cfg["n_splits"],
                  shuffle=True,
                  random_state=cfg["seed"])
    kfold.get_n_splits(basins)

    # dict to store the results of all folds
    splits = defaultdict(dict)

    for split, (train_idx, test_idx) in enumerate(kfold.split(basins)):
        # further split train_idx into train/val idx into train and val set

        train_basins = [basins[i] for i in train_idx]
        test_basins = [basins[i] for i in test_idx]

        splits[split] = {'train': train_basins, 'test': test_basins}

    with output_file.open('wb') as fp:
        pickle.dump(splits, fp)

    print(f"Stored dictionary with basin splits at {output_file}")
示例#2
0
def get_args() -> Dict:
    """Parse input arguments

    Returns
    -------
    dict
        Dictionary containing the run config.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('mode', choices=["train", "evaluate", "eval_robustness"])
    parser.add_argument('--camels_root', type=str, help="Root directory of CAMELS data set")
    parser.add_argument('--seed', type=int, required=False, help="Random seed")
    parser.add_argument('--run_dir', type=str, help="For evaluation mode. Path to run directory.")
    parser.add_argument('--cache_data',
                        type=bool,
                        default=False,
                        help="If True, loads all data into memory")
    parser.add_argument('--num_workers',
                        type=int,
                        default=12,
                        help="Number of parallel threads for data loading")
    parser.add_argument('--no_static',
                        type=bool,
                        default=False,
                        help="If True, trains LSTM without static features")
    parser.add_argument('--concat_static',
                        type=bool,
                        default=False,
                        help="If True, train LSTM with static feats concatenated at each time step")
    parser.add_argument('--use_mse',
                        type=bool,
                        default=False,
                        help="If True, uses mean squared error as loss function.")
    parser.add_argument('--run_dir_base', type=str, default="runs", help="For training mode. Path to store run directories in.")
    parser.add_argument('--run_name', type=str, required=False, help="For training mode. Name of the run.")
    parser.add_argument('--train_start', type=str, help="Training start date (ddmmyyyy).")
    parser.add_argument('--train_end', type=str, help="Training end date (ddmmyyyy).")
    parser.add_argument('--basins', 
                        nargs='+', default=get_basin_list(),
                        help='List of basins')
    cfg = vars(parser.parse_args())
    
    cfg["train_start"] = pd.to_datetime(cfg["train_start"], format='%d%m%Y')
    cfg["train_end"] = pd.to_datetime(cfg["train_end"], format='%d%m%Y')

    # Validation checks
    if (cfg["mode"] == "train") and (cfg["seed"] is None):
        # generate random seed for this run
        cfg["seed"] = int(np.random.uniform(low=0, high=1e6))

    if (cfg["mode"] in ["evaluate", "eval_robustness"]) and (cfg["run_dir"] is None):
        raise ValueError("In evaluation mode a run directory (--run_dir) has to be specified")

    # combine global settings with user config
    cfg.update(GLOBAL_SETTINGS)

    if cfg["mode"] == "train":
        # print config to terminal
        for key, val in cfg.items():
            print(f"{key}: {val}")

    # convert path to PosixPath object
    cfg["camels_root"] = Path(cfg["camels_root"])
    if cfg["run_dir"] is not None:
        cfg["run_dir"] = Path(cfg["run_dir"])
    if cfg["run_dir_base"] is not None:
        cfg["run_dir_base"] = Path(cfg["run_dir_base"])
    return cfg
示例#3
0
def create_splits(cfg: dict):
    """Create random k-Fold cross validation splits.
    
    Takes a set of basins and randomly creates n splits. The result is stored into a dictionary,
    that contains for each split one key that contains a `train` and a `test` key, which contain
    the list of train and test basins.

    Parameters
    ----------
    cfg : dict
        Dictionary containing the user entered evaluation config
    
    Raises
    ------
    RuntimeError
        If file for the same random seed already exists.
    FileNotFoundError
        If the user defined basin list path does not exist.
    """
    output_file = (Path(__file__).absolute().parent /
                   f'data/kfold_splits_seed{cfg["seed"]}.p')
    # check if split file already already exists
    if output_file.is_file():
        raise RuntimeError(f"File '{output_file}' already exists.")

    # set random seed for reproduceability
    np.random.seed(cfg["seed"])

    # read in basin file
    if cfg["basin_file"] is not None:
        if not Path(cfg["basin_file"]).is_file():
            raise FileNotFoundError(f"Not file found at {cfg['basin_file']}")
        with open(cfg["basin_file"], 'r') as fp:
            basins = fp.readlines()
        basins = [b.strip() for b in basins]
        """
        Delete some basins because of missing data:
        - '06775500' & '06846500' no attributes
        - '09535100' no streamflow records
        """
        ignore_basins = ['06775500', '06846500', '09535100']
        basins = [b for b in basins if b not in ignore_basins]
    else:
        basins = get_basin_list()

    # create folds
    kfold = KFold(n_splits=cfg["n_splits"],
                  shuffle=True,
                  random_state=cfg["seed"])
    kfold.get_n_splits(basins)

    # dict to store the results of all folds
    splits = defaultdict(dict)

    for split, (train_idx, test_idx) in enumerate(kfold.split(basins)):
        # further split train_idx into train/val idx into train and val set

        train_basins = [basins[i] for i in train_idx]
        test_basins = [basins[i] for i in test_idx]

        splits[split] = {'train': train_basins, 'test': test_basins}

    with output_file.open('wb') as fp:
        pickle.dump(splits, fp)

    print(f"Stored dictionary with basin splits at {output_file}")
示例#4
0
def evaluate(user_cfg: Dict):
    """Train model for a single epoch.

    Parameters
    ----------
    user_cfg : Dict
        Dictionary containing the user entered evaluation config
        
    """
    with open(user_cfg["run_dir"] / 'cfg.json', 'r') as fp:
        run_cfg = json.load(fp)

    if user_cfg["split_file"] is not None:
        with Path(user_cfg["split_file"]).open('rb') as fp:
            splits = pickle.load(fp)
        basins = splits[run_cfg["split"]]["test"]
    else:
        basins = get_basin_list()

    # get attribute means/stds from trainings dataset
    train_file = user_cfg["run_dir"] / "data/train/train_data.h5"
    db_path = str(user_cfg["run_dir"] / "attributes.db")
    ds_train = CamelsH5(h5_file=train_file,
                        db_path=db_path,
                        basins=basins,
                        concat_static=run_cfg["concat_static"])
    means = ds_train.get_attribute_means()
    stds = ds_train.get_attribute_stds()

    # create model
    input_size_dyn = 5 if (run_cfg["no_static"]
                           or not run_cfg["concat_static"]) else 32
    model = Model(input_size_dyn=input_size_dyn,
                  hidden_size=run_cfg["hidden_size"],
                  dropout=run_cfg["dropout"],
                  concat_static=run_cfg["concat_static"],
                  no_static=run_cfg["no_static"]).to(DEVICE)

    # load trained model
    weight_file = user_cfg["run_dir"] / 'model_epoch30.pt'
    model.load_state_dict(torch.load(weight_file, map_location=DEVICE))

    date_range = pd.date_range(start=GLOBAL_SETTINGS["val_start"],
                               end=GLOBAL_SETTINGS["val_end"])
    results = {}
    for basin in tqdm(basins):
        ds_test = CamelsTXT(
            camels_root=user_cfg["camels_root"],
            basin=basin,
            dates=[GLOBAL_SETTINGS["val_start"], GLOBAL_SETTINGS["val_end"]],
            is_train=False,
            seq_length=run_cfg["seq_length"],
            with_attributes=True,
            attribute_means=means,
            attribute_stds=stds,
            concat_static=run_cfg["concat_static"],
            db_path=db_path)
        loader = DataLoader(ds_test,
                            batch_size=1024,
                            shuffle=False,
                            num_workers=4)

        preds, obs = evaluate_basin(model, loader)

        df = pd.DataFrame(data={
            'qobs': obs.flatten(),
            'qsim': preds.flatten()
        },
                          index=date_range)

        results[basin] = df

    _store_results(user_cfg, run_cfg, results)
示例#5
0
def train(cfg):
    """Train model.

    Parameters
    ----------
    cfg : Dict
        Dictionary containing the run config
    """
    # fix random seeds
    random.seed(cfg["seed"])
    np.random.seed(cfg["seed"])
    torch.cuda.manual_seed(cfg["seed"])
    torch.manual_seed(cfg["seed"])

    if cfg["split_file"] is not None:
        with Path(cfg["split_file"]).open('rb') as fp:
            splits = pickle.load(fp)
        basins = splits[cfg["split"]]["train"]
    else:
        basins = get_basin_list()

    # create folder structure for this run
    cfg = _setup_run(cfg)

    # prepare data for training
    cfg = _prepare_data(cfg=cfg, basins=basins)

    # prepare PyTorch DataLoader
    ds = CamelsH5(h5_file=cfg["train_file"],
                  basins=basins,
                  db_path=cfg["db_path"],
                  concat_static=cfg["concat_static"],
                  cache=cfg["cache_data"],
                  no_static=cfg["no_static"])
    loader = DataLoader(ds,
                        batch_size=cfg["batch_size"],
                        shuffle=True,
                        num_workers=cfg["num_workers"])

    # create model and optimizer
    input_size_dyn = 5 if (cfg["no_static"]
                           or not cfg["concat_static"]) else 32
    model = Model(input_size_dyn=input_size_dyn,
                  hidden_size=cfg["hidden_size"],
                  initial_forget_bias=cfg["initial_forget_gate_bias"],
                  dropout=cfg["dropout"],
                  concat_static=cfg["concat_static"],
                  no_static=cfg["no_static"]).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg["learning_rate"])

    # define loss function
    if cfg["use_mse"]:
        loss_func = nn.MSELoss()
    else:
        loss_func = NSELoss()

    # reduce learning rates after each 10 epochs
    learning_rates = {11: 5e-4, 21: 1e-4}

    for epoch in range(1, cfg["epochs"] + 1):
        # set new learning rate
        if epoch in learning_rates.keys():
            for param_group in optimizer.param_groups:
                param_group["lr"] = learning_rates[epoch]

        train_epoch(model, optimizer, loss_func, loader, cfg, epoch,
                    cfg["use_mse"])

        model_path = cfg["run_dir"] / f"model_epoch{epoch}.pt"
        torch.save(model.state_dict(), str(model_path))
示例#6
0
def eval_robustness(user_cfg: Dict):
    """Evaluate model robustness of EA-LSTM

    In this experiment, gaussian noise with increasing scale is added to the static features to
    evaluate the model robustness against pertubations of the static catchment characteristics.
    For each scale, 50 noise vectors are drawn.
    
    Parameters
    ----------
    user_cfg : Dict
        Dictionary containing the user entered evaluation config
    
    Raises
    ------
    NotImplementedError
        If the run_dir specified points not to a EA-LSTM model folder.
    """
    random.seed(user_cfg["seed"])
    np.random.seed(user_cfg["seed"])

    # fixed settings for this analysis
    n_repetitions = 50
    scales = [0.1 * i for i in range(11)]

    with open(user_cfg["run_dir"] / 'cfg.json', 'r') as fp:
        run_cfg = json.load(fp)

    if run_cfg["concat_static"] or run_cfg["no_static"]:
        raise NotImplementedError(
            "This function is only implemented for EA-LSTM models")

    basins = get_basin_list()

    # get attribute means/stds
    db_path = str(user_cfg["run_dir"] / "attributes.db")
    attributes = load_attributes(db_path=db_path,
                                 basins=basins,
                                 drop_lat_lon=True)
    means = attributes.mean()
    stds = attributes.std()

    # initialize Model
    model = Model(input_size_dyn=5,
                  input_size_stat=27,
                  hidden_size=run_cfg["hidden_size"],
                  dropout=run_cfg["dropout"]).to(DEVICE)
    weight_file = user_cfg["run_dir"] / "model_epoch30.pt"
    model.load_state_dict(torch.load(weight_file, map_location=DEVICE))

    overall_results = {}
    # process bar handle
    pbar = tqdm(basins, file=sys.stdout)
    for basin in pbar:
        ds_test = CamelsTXT(
            camels_root=user_cfg["camels_root"],
            basin=basin,
            dates=[GLOBAL_SETTINGS["val_start"], GLOBAL_SETTINGS["val_end"]],
            is_train=False,
            with_attributes=True,
            attribute_means=means,
            attribute_stds=stds,
            db_path=db_path)
        loader = DataLoader(ds_test,
                            batch_size=len(ds_test),
                            shuffle=False,
                            num_workers=0)
        basin_results = defaultdict(list)
        step = 1
        for scale in scales:
            for _ in range(1 if scale == 0.0 else n_repetitions):
                noise = np.random.normal(loc=0, scale=scale,
                                         size=27).astype(np.float32)
                noise = torch.from_numpy(noise).to(DEVICE)
                nse = eval_with_added_noise(model, loader, noise)
                basin_results[scale].append(nse)
                pbar.set_postfix_str(
                    f"Basin progress: {step}/{(len(scales)-1)*n_repetitions+1}"
                )
                step += 1

        overall_results[basin] = basin_results
    out_file = (Path(__file__).absolute().parent /
                f'results/{user_cfg["run_dir"].name}_model_robustness.p')
    if not out_file.parent.is_dir():
        out_file.parent.mkdir(parents=True)
    with out_file.open("wb") as fp:
        pickle.dump(overall_results, fp)
示例#7
0
            else:
                temp_cfg[key] = val
        json.dump(temp_cfg, fp, sort_keys=True, indent=4)

    return cfg


################
# Prepare grid #
################
if __name__ == "__main__":
    cfg = get_args()
    cfg = _setup_run(cfg)
    
    np.random.seed(0)
    basins = get_basin_list()
    basin_samples = []
    for n_basins in cfg["n_basins"]:
        for i in range(cfg["basin_samples_per_grid_cell"]):
            basin_samples.append(np.random.choice(basins, size=n_basins, replace=False))
            if n_basins == 531:
                break
      
    # Do the XGB parameter search for one configuration, then reuse these parameters for all others
    train_start = cfg["xgb_param_search_range"][0]
    train_end = cfg["xgb_param_search_range"][1]
    basin_sample_id, basin_sample = [(i, b) for i, b in enumerate(basin_samples) if len(b) == cfg["xgb_param_search_basins"]][0]
    
    if cfg["use_params"] is None:
        param_search_name = f"run_xgb_param_search_{train_start}_{train_end}_basinsample{len(basin_sample)}_{basin_sample_id}_seed111"
        param_search_model_dir = cfg["run_dir"] / param_search_name
示例#8
0
def evaluate(user_cfg: Dict):
    """Train model for a single epoch.

    Parameters
    ----------
    user_cfg : Dict
        Dictionary containing the user entered evaluation config
        
    """
    with open(user_cfg["run_dir"] / 'cfg.json', 'r') as fp:
        run_cfg = json.load(fp)

    if user_cfg["split_file"] is not None:
        with Path(user_cfg["split_file"]).open('rb') as fp:
            splits = pickle.load(fp)
        basins = splits[run_cfg["split"]]["test"]
    else:
        basins = get_basin_list(basin_list_file_evaluate)

    # get attribute means/stds
    db_path = str(user_cfg["run_dir"] / "attributes.db")
    attributes = load_attributes(db_path=db_path,
                                 basins=basins,
                                 drop_lat_lon=True,
                                 keep_features=user_cfg["camels_attr"])

    # get remaining scaler from pickle file
    scaler_file = user_cfg["run_dir"] / "data" / "train" / "scaler.p"
    with open(scaler_file, "rb") as fp:
        scaler = pickle.load(fp)
    scaler["camels_attr_mean"] = attributes.mean()
    scaler["camels_attr_std"] = attributes.std()

    # create model
    if run_cfg["concat_static"] and not run_cfg["embedding_hiddens"]:
        input_size_stat = 0
        input_size_dyn = (len(run_cfg["dynamic_inputs"]) +
                          len(run_cfg["camels_attr"]) +
                          len(run_cfg["static_inputs"]))
        concat_static = True
    else:
        input_size_stat = len(run_cfg["camels_attr"]) + len(
            run_cfg["static_inputs"])
        input_size_dyn = len(run_cfg["dynamic_inputs"])
        concat_static = False
    model = Model(input_size_dyn=input_size_dyn,
                  input_size_stat=input_size_stat,
                  hidden_size=run_cfg["hidden_size"],
                  dropout=run_cfg["dropout"],
                  concat_static=run_cfg["concat_static"],
                  embedding_hiddens=run_cfg["embedding_hiddens"]).to(DEVICE)

    # load trained model
    weight_file = user_cfg["run_dir"] / 'model_epoch30.pt'
    model.load_state_dict(torch.load(weight_file, map_location=DEVICE))

    date_range = pd.date_range(start=user_cfg["val_start"],
                               end=user_cfg["val_end"])
    results = {}
    cell_states = {}
    embeddings = {}
    nses = []

    file_name = Path(__file__).parent / 'data' / 'dynamic_features_nwm_v2.p'
    with file_name.open("rb") as fp:
        additional_features = pickle.load(fp)

    # ad hoc static climate indices
    # requres the training period for this experiment
    # overwrites the *_dyn type climate indices in 'additional_features'
    if not user_cfg['use_dynamic_climate']:
        if user_cfg['static_climate'].lower() == 'test':
            eval_clim_indexes = training_period_climate_indices(
                db_path=db_path,
                camels_root=user_cfg['camels_root'],
                basins=basins,
                start_date=user_cfg['val_start'],
                end_date=user_cfg['val_end'])
        elif user_cfg['static_climate'].lower() == 'train':
            eval_clim_indexes = training_period_climate_indices(
                db_path=db_path,
                camels_root=user_cfg['camels_root'],
                basins=basins,
                start_date=user_cfg['train_start'],
                end_date=user_cfg['train_end'])
        else:
            raise RuntimeError(f"Unknown static_climate variable.")

        for basin in basins:
            for col in eval_clim_indexes[basin].columns:
                additional_features[basin][col] = np.tile(
                    eval_clim_indexes[basin][col].values,
                    [additional_features[basin].shape[0], 1])

    for basin in tqdm(basins):
        ds_test = CamelsTXTv2(
            camels_root=user_cfg["camels_root"],
            basin=basin,
            dates=[user_cfg["val_start"], user_cfg["val_end"]],
            is_train=False,
            dynamic_inputs=user_cfg["dynamic_inputs"],
            camels_attr=user_cfg["camels_attr"],
            static_inputs=user_cfg["static_inputs"],
            additional_features=additional_features[basin],
            scaler=scaler,
            seq_length=run_cfg["seq_length"],
            concat_static=concat_static,
            db_path=db_path)
        loader = DataLoader(ds_test,
                            batch_size=2500,
                            shuffle=False,
                            num_workers=user_cfg["num_workers"])

        preds, obs, cells, embeds = evaluate_basin(model, loader)

        # rescale predictions
        preds = preds * scaler["q_std"] + scaler["q_mean"]

        # store predictions
        # set discharges < 0 to zero
        preds[preds < 0] = 0
        nses.append(calc_nse(obs[obs >= 0], preds[obs >= 0]))
        df = pd.DataFrame(data={
            'qobs': obs.flatten(),
            'qsim': preds.flatten()
        },
                          index=date_range)
        results[basin] = df

        # store cell states and embedding values
        cell_states[basin] = pd.DataFrame(data=cells, index=date_range)
        embeddings[basin] = pd.DataFrame(data=embeds, index=date_range)

    print(f"Mean NSE {np.mean(nses)}, median NSE {np.median(nses)}")

    _store_results(user_cfg, run_cfg, results, cell_states, embeddings)
示例#9
0
def train(cfg):
    """Train model.

    Parameters
    ----------
    cfg : Dict
        Dictionary containing the run config
    """
    # fix random seeds
    random.seed(cfg["seed"])
    np.random.seed(cfg["seed"])
    torch.cuda.manual_seed(cfg["seed"])
    torch.manual_seed(cfg["seed"])

    if cfg["split_file"] is not None:
        with Path(cfg["split_file"]).open('rb') as fp:
            splits = pickle.load(fp)
        basins = splits[cfg["split"]]["train"]
    else:
        basins = get_basin_list(basin_list_file_train)
        #basins = basins[:30]

    # create folder structure for this run
    cfg = _setup_run(cfg)

    # prepare data for training
    cfg = _prepare_data(cfg=cfg, basins=basins)

    with open(cfg["scaler_file"], 'rb') as fp:
        scaler = pickle.load(fp)

    camels_attr = load_attributes(cfg["db_path"],
                                  basins,
                                  drop_lat_lon=True,
                                  keep_features=cfg["camels_attr"])
    scaler["camels_attr_mean"] = camels_attr.mean()
    scaler["camels_attr_std"] = camels_attr.std()

    # create model and optimizer
    if cfg["concat_static"] and not cfg["embedding_hiddens"]:
        input_size_stat = 0
        input_size_dyn = (len(cfg["dynamic_inputs"]) +
                          len(cfg["camels_attr"]) + len(cfg["static_inputs"]))
        concat_static = True
    else:
        input_size_stat = len(cfg["camels_attr"]) + len(cfg["static_inputs"])
        input_size_dyn = len(cfg["dynamic_inputs"])
        concat_static = False
    model = Model(input_size_dyn=input_size_dyn,
                  input_size_stat=input_size_stat,
                  hidden_size=cfg["hidden_size"],
                  initial_forget_bias=cfg["initial_forget_gate_bias"],
                  embedding_hiddens=cfg["embedding_hiddens"],
                  dropout=cfg["dropout"],
                  concat_static=cfg["concat_static"]).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg["learning_rate"])

    # prepare PyTorch DataLoader
    ds = CamelsH5v2(h5_file=cfg["train_file"],
                    basins=basins,
                    db_path=cfg["db_path"],
                    concat_static=concat_static,
                    cache=cfg["cache_data"],
                    camels_attr=cfg["camels_attr"],
                    scaler=scaler)
    loader = DataLoader(ds,
                        batch_size=cfg["batch_size"],
                        shuffle=True,
                        num_workers=cfg["num_workers"])

    # define loss function
    if cfg["use_mse"]:
        loss_func = nn.MSELoss()
    else:
        loss_func = NSELoss()

    # reduce learning rates after each 10 epochs
    learning_rates = {11: 5e-4, 21: 1e-4}

    for epoch in range(1, cfg["epochs"] + 1):
        # set new learning rate
        if epoch in learning_rates.keys():
            for param_group in optimizer.param_groups:
                param_group["lr"] = learning_rates[epoch]

        train_epoch(model, optimizer, loss_func, loader, cfg, epoch,
                    cfg["use_mse"])

        model_path = cfg["run_dir"] / f"model_epoch{epoch}.pt"
        torch.save(model.state_dict(), str(model_path))
示例#10
0
def dist_train(rank, world_size, cfg):
    """Train model.

    Parameters
    ----------
    cfg : Dict
        Dictionary containing the run config
    """

    print(f"Running basic DDP example on rank {rank}. {world_size}")
    setup(rank, world_size)

    # fix random seeds
    random.seed(cfg["seed"])
    np.random.seed(cfg["seed"])
    torch.cuda.manual_seed(cfg["seed"])
    torch.manual_seed(cfg["seed"])

    basins = get_basin_list()

    if rank == 0:
        # create folder structure for this run
        cfg = _setup_run(cfg)

        # prepare data for training
        cfg = _prepare_data(cfg=cfg, basins=basins)

        with open(str(cfg["camels_root"]) + '/cfg.pkl', 'wb') as f:
            pickle.dump(cfg, f, pickle.HIGHEST_PROTOCOL)

    dist.barrier()

    with open(str(cfg["camels_root"]) + '/cfg.pkl', 'rb') as f:
        cfg = pickle.load(f)

    # prepare PyTorch DataLoader
    ds = CamelsH5(h5_file=cfg["train_file"],
                  basins=basins,
                  db_path=cfg["db_path"],
                  concat_static=cfg["concat_static"],
                  cache=cfg["cache_data"],
                  no_static=cfg["no_static"])

    sampler = torch.utils.data.distributed.DistributedSampler(
        ds,
        num_replicas=world_size,
        rank=rank
    )

    loader = DataLoader(ds,
                        batch_size=cfg["batch_size"],
                        shuffle=False,
                        num_workers=cfg["num_workers"],
                        sampler=sampler,
                        pin_memory=True)

    # create model and optimizer
    input_size_stat = 0 if cfg["no_static"] else 27
    input_size_dyn = 5 if (cfg["no_static"] or not cfg["concat_static"]) else 32

    model = Model(input_size_dyn=input_size_dyn,
                  input_size_stat=input_size_stat,
                  hidden_size=cfg["hidden_size"],
                  initial_forget_bias=cfg["initial_forget_gate_bias"],
                  dropout=cfg["dropout"],
                  concat_static=cfg["concat_static"],
                  no_static=cfg["no_static"])

    # if cfg["initial_forget_gate_bias"] != 0:
    #     model.bias.shape

    ddp_model = DDP(model.to(rank), device_ids=[rank])

    optimizer = torch.optim.Adam(ddp_model.parameters(),
                                 lr=cfg["learning_rate"])

    # define loss function
    if cfg["use_mse"]:
        loss_func = nn.MSELoss()
    else:
        loss_func = NSELoss()

    # reduce learning rates after each 10 epochs
    learning_rates = {11: 5e-4, 21: 1e-4}

    CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"

    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}

    for epoch in range(1, math.ceil(cfg["epochs"] / world_size) + 1):
        # set new learning rate
        if epoch in learning_rates.keys():
            for param_group in optimizer.param_groups:
                param_group["lr"] = learning_rates[epoch]

        if rank == 0:
            torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)

        dist.barrier()

        ddp_model.load_state_dict(
            torch.load(CHECKPOINT_PATH, map_location=map_location))

        #         optimizer.zero_grad()

        train_epoch(ddp_model, optimizer, loss_func, loader, cfg, epoch,
                    cfg["use_mse"], rank)

    #         model_path = cfg["run_dir"] / f"model_epoch{epoch}.pt"
    #         torch.save(ddp_model.state_dict(), str(model_path))

    #     if rank == 0:
    #         os.remove(CHECKPOINT_PATH)

    cleanup()
def predict_basin(
    basin: str,
    run_dir: Union[str, Path],
    camels_dir: Union[str, Path],
    period: str = "train",
    epoch: int = 30,
):
    if isinstance(run_dir, str):
        run_dir = Path(run_dir)
    elif not isinstance(run_dir, Path):
        raise TypeError(f"run_dir must be str or Path, not {type(run_dir)}")
    if isinstance(camels_dir, str):
        camels_dir = Path(camels_dir)
    elif not isinstance(camels_dir, Path):
        raise TypeError(f"run_dir must be str or Path, not {type(camels_dir)}")

    with open(run_dir / "cfg.json", "r") as fp:
        run_cfg = json.load(fp)

    if not period in ["train", "val"]:
        raise ValueError("period must be either train or val")
    basins = get_basin_list()
    db_path = str(run_dir / "attributes.db")
    attributes = load_attributes(db_path=db_path,
                                 basins=basins,
                                 drop_lat_lon=True)
    means = attributes.mean()
    stds = attributes.std()
    attrs_count = len(attributes.columns)
    timeseries_count = 6
    input_size_stat = timeseries_count if run_cfg["no_static"] else attrs_count
    input_size_dyn = (timeseries_count if
                      (run_cfg["no_static"] or not run_cfg["concat_static"])
                      else timeseries_count + attrs_count)
    model = Model(
        input_size_dyn=input_size_dyn,
        input_size_stat=input_size_stat,
        hidden_size=run_cfg["hidden_size"],
        dropout=run_cfg["dropout"],
        concat_static=run_cfg["concat_static"],
        no_static=run_cfg["no_static"],
    ).to(DEVICE)

    # load trained model
    weight_file = run_dir / f"model_epoch{epoch}.pt"
    model.load_state_dict(torch.load(weight_file, map_location=DEVICE))

    ds_test = CamelsTXT(
        camels_root=camels_dir,
        basin=basin,
        dates=[
            GLOBAL_SETTINGS[f"{period}_start"],
            GLOBAL_SETTINGS[f"{period}_end"]
        ],
        is_train=False,
        seq_length=run_cfg["seq_length"],
        with_attributes=True,
        attribute_means=means,
        attribute_stds=stds,
        concat_static=run_cfg["concat_static"],
        db_path=db_path,
    )
    date_range = ds_test.dates_index[run_cfg["seq_length"] - 1:]
    loader = DataLoader(ds_test, batch_size=1024, shuffle=False, num_workers=4)
    preds, obs = evaluate_basin(model, loader)
    df = pd.DataFrame(data={
        "qobs": obs.flatten(),
        "qsim": preds.flatten()
    },
                      index=date_range)

    results = df
    # plt.plot(date_range, results["qobs"], label="Obs")
    # plt.plot(date_range, results["qsim"], label="Preds")
    # plt.legend()
    # plt.savefig(f"{run_dir}/pred_basin_{basin}.pdf")
    # plt.close()
    return results, date_range