def test_trainer_validation(configuration): # start model model = Model( forecast_minutes=configuration.input_data.default_forecast_minutes, history_minutes=configuration.input_data.default_history_minutes, output_variable="gsp_yield", ) # create fake data loader train_dataset = FakeDataset(configuration=configuration) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=None) # set up trainer trainer = pl.Trainer(gpus=0, max_epochs=1) with tempfile.TemporaryDirectory() as tmpdirname: model.results_file_name = f'{tmpdirname}/temp' # test over validation set _ = trainer.validate(model, train_dataloader) # check csv file of validation results has been made results_df = pd.read_csv(f'{model.results_file_name}_0.csv') assert len(results_df) == len( train_dataloader ) * configuration.process.batch_size * model.forecast_len_30 assert 't0_datetime_utc' in results_df.keys() assert 'target_datetime_utc' in results_df.keys() assert 'gsp_id' in results_df.keys() assert "actual_gsp_pv_outturn_mw" in results_df.keys() assert "forecast_gsp_pv_outturn_mw" in results_df.keys()
def test_model_forward_no_satellite(configuration_conv3d): config_file = "tests/configs/model/conv3d_sat_nwp.yaml" config = load_config(config_file) config['include_future_satellite'] = False # start model model = Model(**config) dataset_configuration = configuration_conv3d dataset_configuration.input_data.nwp.nwp_image_size_pixels = 16 # create fake data loader train_dataset = FakeDataset(configuration=dataset_configuration) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=None) x = next(iter(train_dataloader)) # run data through model y = model(x) # check out put is the correct shape assert len(y.shape) == 2 assert y.shape[0] == 2 assert y.shape[1] == model.forecast_len_30
def test_model_forward(configuration): # start model model = Model( forecast_minutes=configuration.input_data.default_forecast_minutes, history_minutes=configuration.input_data.default_history_minutes, output_variable="gsp_yield", ) # create fake data loader train_dataset = FakeDataset(configuration=configuration) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=None) # satellite data x = next(iter(train_dataloader)) # run data through model y = model(x) # check out put is the correct shape assert len(y.shape) == 2 assert y.shape[0] == configuration.process.batch_size assert y.shape[ 1] == configuration.input_data.default_forecast_minutes // 30
def test_model_forward(configuration_perceiver): dataset_configuration = configuration_perceiver dataset_configuration.input_data.nwp.nwp_image_size_pixels = 64 dataset_configuration.input_data.satellite.satellite_image_size_pixels = 16 model = PerceiverModel(history_minutes=30, forecast_minutes=60, nwp_channels=params["nwp_channels"], embedding_dem=2048) # doesnt do anything # set up fake data train_dataset = FakeDataset(configuration=dataset_configuration) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=None) # get data x = next(iter(train_dataloader)) # run data through model y = model(x) # check out put is the correct shape assert len(y.shape) == 2 assert y.shape[0] == dataset_configuration.process.batch_size assert y.shape[1] == 60 // 5
def val_dataloader(self): if self.fake_data: val_dataset = FakeDataset(configuration=self.configuration) else: val_dataset = NetCDFDataset(self.n_val_data, os.path.join(self.data_path, "test"), os.path.join(self.temp_path, "test"), configuration=self.configuration) return torch.utils.data.DataLoader(val_dataset, **self.dataloader_config)
def test_dataloader(self): if self.fake_data: test_dataset = FakeDataset(configuration=self.configuration) else: # TODO need to change this to a test folder test_dataset = NetCDFDataset(self.n_val_data, os.path.join(self.data_path, "test"), os.path.join(self.temp_path, "test"), configuration=self.configuration) return torch.utils.data.DataLoader(test_dataset, **self.dataloader_config)
def test_train(configuration_conv3d): config_file = "tests/configs/model/conv3d.yaml" config = load_config(config_file) dataset_configuration = configuration_conv3d # start model model = Model(**config) # create fake data loader train_dataset = FakeDataset(configuration=dataset_configuration) train_dataset.length = 2 train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=None) # fit model trainer = pl.Trainer(gpus=0, max_epochs=1) trainer.fit(model, train_dataloader) # predict over training set _ = trainer.predict(model, train_dataloader)
def test_trainer(configuration): # start model model = Model( forecast_minutes=configuration.input_data.default_forecast_minutes, history_minutes=configuration.input_data.default_history_minutes, output_variable="gsp_yield", ) # create fake data loader train_dataset = FakeDataset(configuration=configuration) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=None) # set up trainer trainer = pl.Trainer(gpus=0, max_epochs=1) # test over training set _ = trainer.test(model, train_dataloader)
def test_model_validation(configuration): # start model model = Model( forecast_minutes=configuration.input_data.default_forecast_minutes, history_minutes=configuration.input_data.default_history_minutes, output_variable="gsp_yield", ) # create fake data loader train_dataset = FakeDataset(configuration=configuration) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=None) # satellite data x = next(iter(train_dataloader)) # run data through model model.validation_step(x, 0)
def test_model_forward(configuration_conv3d): config_file = "tests/configs/model/conv3d.yaml" config = load_config(config_file) dataset_configuration = configuration_conv3d # start model model = Model(**config) # create fake data loader train_dataset = FakeDataset(configuration=dataset_configuration) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=None) x = next(iter(train_dataloader)) # run data through model y = model(x) # check out put is the correct shape assert len(y.shape) == 2 assert y.shape[0] == 2 assert y.shape[1] == model.forecast_len_5