def test_model_forward_no_satellite(configuration_conv3d):

    config_file = "tests/configs/model/conv3d_sat_nwp.yaml"
    config = load_config(config_file)
    config['include_future_satellite'] = False

    # start model
    model = Model(**config)

    dataset_configuration = configuration_conv3d
    dataset_configuration.input_data.nwp.nwp_image_size_pixels = 16

    # create fake data loader
    train_dataset = FakeDataset(configuration=dataset_configuration)
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=None)
    x = next(iter(train_dataloader))

    # run data through model
    y = model(x)

    # check out put is the correct shape
    assert len(y.shape) == 2
    assert y.shape[0] == 2
    assert y.shape[1] == model.forecast_len_30
def test_train(configuration_conv3d):

    config_file = "tests/configs/model/conv3d_gsp.yaml"
    config = load_config(config_file)

    dataset_configuration = configuration_conv3d

    # start model
    model = Model(**config)

    # create fake data loader
    train_dataset = FakeDataset(configuration=dataset_configuration)
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=None)

    # fit model
    trainer = pl.Trainer(gpus=0, max_epochs=1)
    trainer.fit(model, train_dataloader)

    # predict over training set
    _ = trainer.predict(model, train_dataloader)
Beispiel #3
0
def configuration_conv3d():

    config_file = "tests/configs/model/conv3d.yaml"
    config = load_config(config_file)

    dataset_configuration = Configuration()
    dataset_configuration.process.batch_size = 2
    dataset_configuration.input_data.default_history_minutes = config[
        'history_minutes']
    dataset_configuration.input_data.default_forecast_minutes = config[
        'forecast_minutes']
    dataset_configuration.input_data = dataset_configuration.input_data.set_all_to_defaults(
    )
    dataset_configuration.input_data.nwp.nwp_image_size_pixels = 2
    dataset_configuration.input_data.satellite.satellite_image_size_pixels = config[
        'image_size_pixels']
    dataset_configuration.input_data.satellite.forecast_minutes = config[
        'forecast_minutes']
    dataset_configuration.input_data.satellite.history_minutes = config[
        'history_minutes']

    return dataset_configuration
Beispiel #4
0
def test_model_forward(configuration_conv3d):

    config_file = "tests/configs/model/conv3d.yaml"
    config = load_config(config_file)

    dataset_configuration = configuration_conv3d

    # start model
    model = Model(**config)

    # create fake data loader
    train_dataset = FakeDataset(configuration=dataset_configuration)
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=None)
    x = next(iter(train_dataloader))

    # run data through model
    y = model(x)

    # check out put is the correct shape
    assert len(y.shape) == 2
    assert y.shape[0] == 2
    assert y.shape[1] == model.forecast_len_5
def test_init():

    config_file = "tests/configs/model/conv3d_sat_nwp.yaml"
    config = load_config(config_file)

    _ = Model(**config)
Beispiel #6
0
def test_init():

    config_file = "configs/model/conv3d.yaml"
    config = load_config(config_file)

    _ = Model(**config)