Ejemplo n.º 1
0
def test_Config_multi_output_levels(levels):
    levels_set = list(set(levels))
    str_levels = ",".join(str(s) for s in levels_set)
    parser = argparse.ArgumentParser()
    Config.register_parser(parser)
    args = parser.parse_args(["--multi-output", "--levels", str_levels])
    config = Config.from_args(args)
    assert config.target.levels == levels_set
Ejemplo n.º 2
0
def test_Config_register_parser(args, loss_cls):
    parser = argparse.ArgumentParser()
    Config.register_parser(parser)
    args = parser.parse_args(args)
    config = Config.from_args(args)
    assert isinstance(config, Config)
    assert isinstance(config.target, loss_cls)

    if args.relative_humidity:
        assert config.relative_humidity
    else:
        assert not config.relative_humidity
Ejemplo n.º 3
0
def test_OnlineEmulator_partial_fit(state):
    config = Config(batch_size=32,
                    learning_rate=0.001,
                    momentum=0.0,
                    levels=63)

    emulator = get_xarray_emulator(config)
    emulator.partial_fit(state, state)
Ejemplo n.º 4
0
def test_checkpointed_model(tmpdir):

    # dump a model
    config = Config()
    emulator = Trainer(config)
    emulator.dump(tmpdir)

    # load it
    emulator = Trainer.load(str(tmpdir))
    assert isinstance(emulator, Trainer)
Ejemplo n.º 5
0
def test_OnlineEmulator_fails_when_accessing_nonexistant_var(state):
    config = Config(
        batch_size=32,
        learning_rate=0.001,
        momentum=0.0,
        extra_input_variables=["not a varialbe in any state 332r23r90e9d"],
        levels=63,
    )

    emulator = get_xarray_emulator(config)
    with pytest.raises(KeyError):
        emulator.partial_fit(state, state)
Ejemplo n.º 6
0
def test_OnlineEmulator_fit_predict(state, extra_inputs):
    config = Config(
        batch_size=32,
        learning_rate=0.001,
        momentum=0.0,
        extra_input_variables=extra_inputs,
        levels=63,
    )

    emulator = get_xarray_emulator(config)
    emulator.partial_fit(state, state)
    stateout = emulator.predict(state)
    assert isinstance(stateout, dict)
    assert list(stateout["eastward_wind"].dims) == ["z", "y", "x"]
Ejemplo n.º 7
0
def test_OnlineEmulator_partial_fit_logged(state, tmpdir):
    config = Config(batch_size=8,
                    learning_rate=0.0001,
                    momentum=0.0,
                    levels=63)
    time = datetime.datetime.now().isoformat()

    tf_summary_dir = str(tmpdir.join(time))

    emulator = get_xarray_emulator(config)
    writer = tf.summary.create_file_writer(tf_summary_dir)
    with writer.as_default():
        for i in range(10):
            emulator.partial_fit(state, state)
Ejemplo n.º 8
0
def test_dump_load_OnlineEmulator(state, tmpdir, output_exists):
    if output_exists:
        path = str(tmpdir)
    else:
        path = str(tmpdir.join("model"))

    n = state["air_temperature"].sizes["z"]
    config = Config(levels=n)
    emulator = get_xarray_emulator(config)
    emulator.partial_fit(state, state)
    emulator.dump(path)
    new_emulator = XarrayEmulator.load(path)

    # assert that the air_temperature output is unchanged
    field = "air_temperature"
    np.testing.assert_array_equal(
        new_emulator.predict(state)[field],
        emulator.predict(state)[field])
Ejemplo n.º 9
0
def test_top_level():
    dict_ = {"target": {"level": 10}}
    config = Config.from_dict(dict_)
    assert QVLossSingleLevel(10) == config.target
Ejemplo n.º 10
0
        extra_input_variables=extra_inputs,
        levels=63,
    )

    emulator = get_xarray_emulator(config)
    emulator.partial_fit(state, state)
    stateout = emulator.predict(state)
    assert isinstance(stateout, dict)
    assert list(stateout["eastward_wind"].dims) == ["z", "y", "x"]


@pytest.mark.parametrize("with_validation", [True, False])
@pytest.mark.parametrize(
    "config",
    [
        Config(batch_size=32, learning_rate=0.001, momentum=0.0, levels=79),
        Config(
            batch_size=32,
            learning_rate=0.001,
            momentum=0.0,
            target=QVLossSingleLevel(0),
            levels=79,
        ),
        Config(
            target=RHLossSingleLevel(50),
            levels=79,
        ),
    ],
)
def test_OnlineEmulator_batch_fit(config, with_validation):
    x = to_dict(_get_argsin(config.levels))
Ejemplo n.º 11
0
def main(config: Config):

    tf.random.set_seed(1)
    logging.info(config)

    if config.batch is None:
        raise ValueError("No training dataset detected.")

    if config.wandb_logger:
        wandb.init(
            entity="ai2cm",
            project=f"emulator-noah",
            config=args,  # type: ignore
        )

    emulator = Trainer(config)

    prep = partial(compute_in_out, timestep=args.timestep)
    variables = config.extra_input_variables + all_required_variables()

    train_dataset = data.netcdf_url_to_dataset(
        config.batch.training_path,
        variables,
        shuffle=True,
    ).map(prep)

    test_dataset = data.netcdf_url_to_dataset(
        config.batch.testing_path,
        variables,
    ).map(prep)

    if args.nfiles:
        train_dataset = train_dataset.take(args.nfiles)
        test_dataset = test_dataset.take(args.nfiles)

    train_dataset = train_dataset.unbatch().cache()
    test_dataset = test_dataset.unbatch().cache()

    # detect number of levels
    sample_ins, _ = next(iter(train_dataset.batch(10).take(1)))
    config.levels = nz(sample_ins)

    id_ = pathlib.Path(os.getcwd()).name

    with tf.summary.create_file_writer(f"data/emulator/{id_}").as_default():
        emulator.batch_fit(train_dataset.shuffle(100_000),
                           validation_data=test_dataset)

    train_scores = emulator.score(train_dataset)
    test_scores = emulator.score(test_dataset)

    if config.output_path:
        os.makedirs(config.output_path, exist_ok=True)

    with open(os.path.join(config.output_path, "scores.json"), "w") as f:
        json.dump({"train": train_scores, "test": test_scores}, f)

    emulator.dump(os.path.join(config.output_path, "model"))

    if config.wandb_logger:
        model = wandb.Artifact(f"model", type="model")
        model.add_dir(os.path.join(config.output_path, "model"))
        wandb.log_artifact(model)
Ejemplo n.º 12
0
def get_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser()
    Config.register_parser(parser)
    return parser
Ejemplo n.º 13
0
    id_ = pathlib.Path(os.getcwd()).name

    with tf.summary.create_file_writer(f"data/emulator/{id_}").as_default():
        emulator.batch_fit(train_dataset.shuffle(100_000),
                           validation_data=test_dataset)

    train_scores = emulator.score(train_dataset)
    test_scores = emulator.score(test_dataset)

    if config.output_path:
        os.makedirs(config.output_path, exist_ok=True)

    with open(os.path.join(config.output_path, "scores.json"), "w") as f:
        json.dump({"train": train_scores, "test": test_scores}, f)

    emulator.dump(os.path.join(config.output_path, "model"))

    if config.wandb_logger:
        model = wandb.Artifact(f"model", type="model")
        model.add_dir(os.path.join(config.output_path, "model"))
        wandb.log_artifact(model)


if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()
    config = Config.from_args(args)
    print(config)
    main(config)