예제 #1
0
def train(cfg):
    """Train model.

    Parameters
    ----------
    cfg : Dict
        Dictionary containing the run config
    """
    # fix random seeds
    random.seed(cfg["seed"])
    np.random.seed(cfg["seed"])
    torch.cuda.manual_seed(cfg["seed"])
    torch.manual_seed(cfg["seed"])

    if cfg["split_file"] is not None:
        with Path(cfg["split_file"]).open('rb') as fp:
            splits = pickle.load(fp)
        basins = splits[cfg["split"]]["train"]
    else:
        basins = get_basin_list()
        #basins = basins[:30]

    # create folder structure for this run
    cfg = _setup_run(cfg)

    # prepare data for training
    cfg = _prepare_data(cfg=cfg, basins=basins)

    with open(cfg["scaler_file"], 'rb') as fp:
        scaler = pickle.load(fp)

    camels_attr = load_attributes(cfg["db_path"],
                                  basins,
                                  drop_lat_lon=True,
                                  keep_features=cfg["camels_attr"])
    scaler["camels_attr_mean"] = camels_attr.mean()
    scaler["camels_attr_std"] = camels_attr.std()

    # create model and optimizer
    if cfg["concat_static"] and not cfg["embedding_hiddens"]:
        input_size_stat = 0
        input_size_dyn = (len(cfg["dynamic_inputs"]) +
                          len(cfg["camels_attr"]) + len(cfg["static_inputs"]))
        concat_static = True
    else:
        input_size_stat = len(cfg["camels_attr"]) + len(cfg["static_inputs"])
        input_size_dyn = len(cfg["dynamic_inputs"])
        concat_static = False
    model = Model(input_size_dyn=input_size_dyn,
                  input_size_stat=input_size_stat,
                  hidden_size=cfg["hidden_size"],
                  initial_forget_bias=cfg["initial_forget_gate_bias"],
                  embedding_hiddens=cfg["embedding_hiddens"],
                  dropout=cfg["dropout"],
                  concat_static=cfg["concat_static"]).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg["learning_rate"])

    # prepare PyTorch DataLoader
    ds = CamelsH5v2(h5_file=cfg["train_file"],
                    basins=basins,
                    db_path=cfg["db_path"],
                    concat_static=concat_static,
                    cache=cfg["cache_data"],
                    camels_attr=cfg["camels_attr"],
                    scaler=scaler)
    loader = DataLoader(ds,
                        batch_size=cfg["batch_size"],
                        shuffle=True,
                        num_workers=cfg["num_workers"])

    # define loss function
    if cfg["use_mse"]:
        loss_func = nn.MSELoss()
    else:
        loss_func = NSELoss()

    # reduce learning rates after each 10 epochs
    learning_rates = {11: 5e-4, 21: 1e-4}

    for epoch in range(1, cfg["epochs"] + 1):
        # set new learning rate
        if epoch in learning_rates.keys():
            for param_group in optimizer.param_groups:
                param_group["lr"] = learning_rates[epoch]

        train_epoch(model, optimizer, loss_func, loader, cfg, epoch,
                    cfg["use_mse"])

        model_path = cfg["run_dir"] / f"model_epoch{epoch}.pt"
        torch.save(model.state_dict(), str(model_path))