Exemplo n.º 1
0
def test_model_equality(get_config: Fixture[Callable[[str], dict]], custom_lstm_supported_models: Fixture[str]):
    config = get_config('daily_regression')

    # we only need to test for a single data set, input/output setting and model specifications
    config.force_update({
        'dataset': 'camels_us',
        'data_dir': config.data_dir / 'camels_us',
        'target_variables': 'QObs(mm/d)',
        'forcings': 'daymet',
        'dynamic_inputs': ['prcp(mm/day)', 'tmax(C)'],
        'model': custom_lstm_supported_models
    })

    # create random inputs
    data = {
        'x_d': torch.rand((config.batch_size, 50, len(config.dynamic_inputs))),
        'x_s': torch.rand((config.batch_size, len(config.static_attributes)))
    }

    # initialize two random models
    optimized_model = get_model(config)
    custom_lstm = CustomLSTM(config)

    # copy weights from optimized model into custom model implementation
    custom_lstm.copy_weights(optimized_model)

    # get model predictions
    optimized_model.eval()
    custom_lstm.eval()
    with torch.no_grad():
        pred_custom = custom_lstm(data)
        pred_optimized = optimized_model(data)

    # check for consistency in model outputs
    assert torch.allclose(pred_custom["y_hat"], pred_optimized["y_hat"], atol=1e-6)
Exemplo n.º 2
0
    def __init__(self,
                 cfg: Config,
                 run_dir: Path,
                 period: str = "test",
                 init_model: bool = True):
        self.cfg = cfg
        self.run_dir = run_dir
        self.init_model = init_model
        if period in ["train", "validation", "test"]:
            self.period = period
        else:
            raise ValueError(
                f'Invalid period {period}. Must be one of ["train", "validation", "test"]'
            )

        # determine device
        self._set_device()

        if self.init_model:
            self.model = get_model(cfg).to(self.device)

        # pre-initialize variables, defined in class methods
        self.basins = None
        self.scaler = None
        self.id_to_int = {}
        self.additional_features = []

        # placeholder to store cached validation data
        self.cached_datasets = {}

        self._load_run_data()
Exemplo n.º 3
0
 def _get_model(self) -> torch.nn.Module:
     return get_model(cfg=self.cfg)