def train(cfg): """Train model. Parameters ---------- cfg : Dict Dictionary containing the run config """ # fix random seeds random.seed(cfg["seed"]) np.random.seed(cfg["seed"]) torch.cuda.manual_seed(cfg["seed"]) torch.manual_seed(cfg["seed"]) if cfg["split_file"] is not None: with Path(cfg["split_file"]).open('rb') as fp: splits = pickle.load(fp) basins = splits[cfg["split"]]["train"] else: basins = get_basin_list() #basins = basins[:30] # create folder structure for this run cfg = _setup_run(cfg) # prepare data for training cfg = _prepare_data(cfg=cfg, basins=basins) with open(cfg["scaler_file"], 'rb') as fp: scaler = pickle.load(fp) camels_attr = load_attributes(cfg["db_path"], basins, drop_lat_lon=True, keep_features=cfg["camels_attr"]) scaler["camels_attr_mean"] = camels_attr.mean() scaler["camels_attr_std"] = camels_attr.std() # create model and optimizer if cfg["concat_static"] and not cfg["embedding_hiddens"]: input_size_stat = 0 input_size_dyn = (len(cfg["dynamic_inputs"]) + len(cfg["camels_attr"]) + len(cfg["static_inputs"])) concat_static = True else: input_size_stat = len(cfg["camels_attr"]) + len(cfg["static_inputs"]) input_size_dyn = len(cfg["dynamic_inputs"]) concat_static = False model = Model(input_size_dyn=input_size_dyn, input_size_stat=input_size_stat, hidden_size=cfg["hidden_size"], initial_forget_bias=cfg["initial_forget_gate_bias"], embedding_hiddens=cfg["embedding_hiddens"], dropout=cfg["dropout"], concat_static=cfg["concat_static"]).to(DEVICE) optimizer = torch.optim.Adam(model.parameters(), lr=cfg["learning_rate"]) # prepare PyTorch DataLoader ds = CamelsH5v2(h5_file=cfg["train_file"], basins=basins, db_path=cfg["db_path"], concat_static=concat_static, cache=cfg["cache_data"], camels_attr=cfg["camels_attr"], scaler=scaler) loader = DataLoader(ds, batch_size=cfg["batch_size"], shuffle=True, num_workers=cfg["num_workers"]) # define loss function if cfg["use_mse"]: loss_func = nn.MSELoss() else: loss_func = NSELoss() # reduce learning rates after each 10 epochs learning_rates = {11: 5e-4, 21: 1e-4} for epoch in range(1, cfg["epochs"] + 1): # set new learning rate if epoch in learning_rates.keys(): for param_group in optimizer.param_groups: param_group["lr"] = learning_rates[epoch] train_epoch(model, optimizer, loss_func, loader, cfg, epoch, cfg["use_mse"]) model_path = cfg["run_dir"] / f"model_epoch{epoch}.pt" torch.save(model.state_dict(), str(model_path))