def test_Config_multi_output_levels(levels): levels_set = list(set(levels)) str_levels = ",".join(str(s) for s in levels_set) parser = argparse.ArgumentParser() Config.register_parser(parser) args = parser.parse_args(["--multi-output", "--levels", str_levels]) config = Config.from_args(args) assert config.target.levels == levels_set
def test_Config_register_parser(args, loss_cls): parser = argparse.ArgumentParser() Config.register_parser(parser) args = parser.parse_args(args) config = Config.from_args(args) assert isinstance(config, Config) assert isinstance(config.target, loss_cls) if args.relative_humidity: assert config.relative_humidity else: assert not config.relative_humidity
def test_OnlineEmulator_partial_fit(state): config = Config(batch_size=32, learning_rate=0.001, momentum=0.0, levels=63) emulator = get_xarray_emulator(config) emulator.partial_fit(state, state)
def test_checkpointed_model(tmpdir): # dump a model config = Config() emulator = Trainer(config) emulator.dump(tmpdir) # load it emulator = Trainer.load(str(tmpdir)) assert isinstance(emulator, Trainer)
def test_OnlineEmulator_fails_when_accessing_nonexistant_var(state): config = Config( batch_size=32, learning_rate=0.001, momentum=0.0, extra_input_variables=["not a varialbe in any state 332r23r90e9d"], levels=63, ) emulator = get_xarray_emulator(config) with pytest.raises(KeyError): emulator.partial_fit(state, state)
def test_OnlineEmulator_fit_predict(state, extra_inputs): config = Config( batch_size=32, learning_rate=0.001, momentum=0.0, extra_input_variables=extra_inputs, levels=63, ) emulator = get_xarray_emulator(config) emulator.partial_fit(state, state) stateout = emulator.predict(state) assert isinstance(stateout, dict) assert list(stateout["eastward_wind"].dims) == ["z", "y", "x"]
def test_OnlineEmulator_partial_fit_logged(state, tmpdir): config = Config(batch_size=8, learning_rate=0.0001, momentum=0.0, levels=63) time = datetime.datetime.now().isoformat() tf_summary_dir = str(tmpdir.join(time)) emulator = get_xarray_emulator(config) writer = tf.summary.create_file_writer(tf_summary_dir) with writer.as_default(): for i in range(10): emulator.partial_fit(state, state)
def test_dump_load_OnlineEmulator(state, tmpdir, output_exists): if output_exists: path = str(tmpdir) else: path = str(tmpdir.join("model")) n = state["air_temperature"].sizes["z"] config = Config(levels=n) emulator = get_xarray_emulator(config) emulator.partial_fit(state, state) emulator.dump(path) new_emulator = XarrayEmulator.load(path) # assert that the air_temperature output is unchanged field = "air_temperature" np.testing.assert_array_equal( new_emulator.predict(state)[field], emulator.predict(state)[field])
def test_top_level(): dict_ = {"target": {"level": 10}} config = Config.from_dict(dict_) assert QVLossSingleLevel(10) == config.target
extra_input_variables=extra_inputs, levels=63, ) emulator = get_xarray_emulator(config) emulator.partial_fit(state, state) stateout = emulator.predict(state) assert isinstance(stateout, dict) assert list(stateout["eastward_wind"].dims) == ["z", "y", "x"] @pytest.mark.parametrize("with_validation", [True, False]) @pytest.mark.parametrize( "config", [ Config(batch_size=32, learning_rate=0.001, momentum=0.0, levels=79), Config( batch_size=32, learning_rate=0.001, momentum=0.0, target=QVLossSingleLevel(0), levels=79, ), Config( target=RHLossSingleLevel(50), levels=79, ), ], ) def test_OnlineEmulator_batch_fit(config, with_validation): x = to_dict(_get_argsin(config.levels))
def main(config: Config): tf.random.set_seed(1) logging.info(config) if config.batch is None: raise ValueError("No training dataset detected.") if config.wandb_logger: wandb.init( entity="ai2cm", project=f"emulator-noah", config=args, # type: ignore ) emulator = Trainer(config) prep = partial(compute_in_out, timestep=args.timestep) variables = config.extra_input_variables + all_required_variables() train_dataset = data.netcdf_url_to_dataset( config.batch.training_path, variables, shuffle=True, ).map(prep) test_dataset = data.netcdf_url_to_dataset( config.batch.testing_path, variables, ).map(prep) if args.nfiles: train_dataset = train_dataset.take(args.nfiles) test_dataset = test_dataset.take(args.nfiles) train_dataset = train_dataset.unbatch().cache() test_dataset = test_dataset.unbatch().cache() # detect number of levels sample_ins, _ = next(iter(train_dataset.batch(10).take(1))) config.levels = nz(sample_ins) id_ = pathlib.Path(os.getcwd()).name with tf.summary.create_file_writer(f"data/emulator/{id_}").as_default(): emulator.batch_fit(train_dataset.shuffle(100_000), validation_data=test_dataset) train_scores = emulator.score(train_dataset) test_scores = emulator.score(test_dataset) if config.output_path: os.makedirs(config.output_path, exist_ok=True) with open(os.path.join(config.output_path, "scores.json"), "w") as f: json.dump({"train": train_scores, "test": test_scores}, f) emulator.dump(os.path.join(config.output_path, "model")) if config.wandb_logger: model = wandb.Artifact(f"model", type="model") model.add_dir(os.path.join(config.output_path, "model")) wandb.log_artifact(model)
def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() Config.register_parser(parser) return parser
id_ = pathlib.Path(os.getcwd()).name with tf.summary.create_file_writer(f"data/emulator/{id_}").as_default(): emulator.batch_fit(train_dataset.shuffle(100_000), validation_data=test_dataset) train_scores = emulator.score(train_dataset) test_scores = emulator.score(test_dataset) if config.output_path: os.makedirs(config.output_path, exist_ok=True) with open(os.path.join(config.output_path, "scores.json"), "w") as f: json.dump({"train": train_scores, "test": test_scores}, f) emulator.dump(os.path.join(config.output_path, "model")) if config.wandb_logger: model = wandb.Artifact(f"model", type="model") model.add_dir(os.path.join(config.output_path, "model")) wandb.log_artifact(model) if __name__ == "__main__": parser = get_parser() args = parser.parse_args() config = Config.from_args(args) print(config) main(config)