def test_model_equality(get_config: Fixture[Callable[[str], dict]], custom_lstm_supported_models: Fixture[str]): config = get_config('daily_regression') # we only need to test for a single data set, input/output setting and model specifications config.force_update({ 'dataset': 'camels_us', 'data_dir': config.data_dir / 'camels_us', 'target_variables': 'QObs(mm/d)', 'forcings': 'daymet', 'dynamic_inputs': ['prcp(mm/day)', 'tmax(C)'], 'model': custom_lstm_supported_models }) # create random inputs data = { 'x_d': torch.rand((config.batch_size, 50, len(config.dynamic_inputs))), 'x_s': torch.rand((config.batch_size, len(config.static_attributes))) } # initialize two random models optimized_model = get_model(config) custom_lstm = CustomLSTM(config) # copy weights from optimized model into custom model implementation custom_lstm.copy_weights(optimized_model) # get model predictions optimized_model.eval() custom_lstm.eval() with torch.no_grad(): pred_custom = custom_lstm(data) pred_optimized = optimized_model(data) # check for consistency in model outputs assert torch.allclose(pred_custom["y_hat"], pred_optimized["y_hat"], atol=1e-6)
def __init__(self, cfg: Config, run_dir: Path, period: str = "test", init_model: bool = True): self.cfg = cfg self.run_dir = run_dir self.init_model = init_model if period in ["train", "validation", "test"]: self.period = period else: raise ValueError( f'Invalid period {period}. Must be one of ["train", "validation", "test"]' ) # determine device self._set_device() if self.init_model: self.model = get_model(cfg).to(self.device) # pre-initialize variables, defined in class methods self.basins = None self.scaler = None self.id_to_int = {} self.additional_features = [] # placeholder to store cached validation data self.cached_datasets = {} self._load_run_data()
def _get_model(self) -> torch.nn.Module: return get_model(cfg=self.cfg)