예제 #1
0
def test_OnlineEmulator_partial_fit(state):
    config = Config(batch_size=32,
                    learning_rate=0.001,
                    momentum=0.0,
                    levels=63)

    emulator = get_xarray_emulator(config)
    emulator.partial_fit(state, state)
예제 #2
0
def test_checkpointed_model(tmpdir):

    # dump a model
    config = Config()
    emulator = Trainer(config)
    emulator.dump(tmpdir)

    # load it
    emulator = Trainer.load(str(tmpdir))
    assert isinstance(emulator, Trainer)
예제 #3
0
def test_OnlineEmulator_fails_when_accessing_nonexistant_var(state):
    config = Config(
        batch_size=32,
        learning_rate=0.001,
        momentum=0.0,
        extra_input_variables=["not a varialbe in any state 332r23r90e9d"],
        levels=63,
    )

    emulator = get_xarray_emulator(config)
    with pytest.raises(KeyError):
        emulator.partial_fit(state, state)
예제 #4
0
def test_OnlineEmulator_fit_predict(state, extra_inputs):
    config = Config(
        batch_size=32,
        learning_rate=0.001,
        momentum=0.0,
        extra_input_variables=extra_inputs,
        levels=63,
    )

    emulator = get_xarray_emulator(config)
    emulator.partial_fit(state, state)
    stateout = emulator.predict(state)
    assert isinstance(stateout, dict)
    assert list(stateout["eastward_wind"].dims) == ["z", "y", "x"]
예제 #5
0
def test_OnlineEmulator_partial_fit_logged(state, tmpdir):
    config = Config(batch_size=8,
                    learning_rate=0.0001,
                    momentum=0.0,
                    levels=63)
    time = datetime.datetime.now().isoformat()

    tf_summary_dir = str(tmpdir.join(time))

    emulator = get_xarray_emulator(config)
    writer = tf.summary.create_file_writer(tf_summary_dir)
    with writer.as_default():
        for i in range(10):
            emulator.partial_fit(state, state)
예제 #6
0
def test_dump_load_OnlineEmulator(state, tmpdir, output_exists):
    if output_exists:
        path = str(tmpdir)
    else:
        path = str(tmpdir.join("model"))

    n = state["air_temperature"].sizes["z"]
    config = Config(levels=n)
    emulator = get_xarray_emulator(config)
    emulator.partial_fit(state, state)
    emulator.dump(path)
    new_emulator = XarrayEmulator.load(path)

    # assert that the air_temperature output is unchanged
    field = "air_temperature"
    np.testing.assert_array_equal(
        new_emulator.predict(state)[field],
        emulator.predict(state)[field])
예제 #7
0
        extra_input_variables=extra_inputs,
        levels=63,
    )

    emulator = get_xarray_emulator(config)
    emulator.partial_fit(state, state)
    stateout = emulator.predict(state)
    assert isinstance(stateout, dict)
    assert list(stateout["eastward_wind"].dims) == ["z", "y", "x"]


@pytest.mark.parametrize("with_validation", [True, False])
@pytest.mark.parametrize(
    "config",
    [
        Config(batch_size=32, learning_rate=0.001, momentum=0.0, levels=79),
        Config(
            batch_size=32,
            learning_rate=0.001,
            momentum=0.0,
            target=QVLossSingleLevel(0),
            levels=79,
        ),
        Config(
            target=RHLossSingleLevel(50),
            levels=79,
        ),
    ],
)
def test_OnlineEmulator_batch_fit(config, with_validation):
    x = to_dict(_get_argsin(config.levels))