Пример #1
0
def test_OnlineEmulator_partial_fit(state):
    config = Config(batch_size=32,
                    learning_rate=0.001,
                    momentum=0.0,
                    levels=63)

    emulator = get_xarray_emulator(config)
    emulator.partial_fit(state, state)
Пример #2
0
def test_OnlineEmulator_fails_when_accessing_nonexistant_var(state):
    config = Config(
        batch_size=32,
        learning_rate=0.001,
        momentum=0.0,
        extra_input_variables=["not a varialbe in any state 332r23r90e9d"],
        levels=63,
    )

    emulator = get_xarray_emulator(config)
    with pytest.raises(KeyError):
        emulator.partial_fit(state, state)
Пример #3
0
def test_OnlineEmulator_fit_predict(state, extra_inputs):
    config = Config(
        batch_size=32,
        learning_rate=0.001,
        momentum=0.0,
        extra_input_variables=extra_inputs,
        levels=63,
    )

    emulator = get_xarray_emulator(config)
    emulator.partial_fit(state, state)
    stateout = emulator.predict(state)
    assert isinstance(stateout, dict)
    assert list(stateout["eastward_wind"].dims) == ["z", "y", "x"]
Пример #4
0
def test_OnlineEmulator_partial_fit_logged(state, tmpdir):
    config = Config(batch_size=8,
                    learning_rate=0.0001,
                    momentum=0.0,
                    levels=63)
    time = datetime.datetime.now().isoformat()

    tf_summary_dir = str(tmpdir.join(time))

    emulator = get_xarray_emulator(config)
    writer = tf.summary.create_file_writer(tf_summary_dir)
    with writer.as_default():
        for i in range(10):
            emulator.partial_fit(state, state)
Пример #5
0
def test_dump_load_OnlineEmulator(state, tmpdir, output_exists):
    if output_exists:
        path = str(tmpdir)
    else:
        path = str(tmpdir.join("model"))

    n = state["air_temperature"].sizes["z"]
    config = Config(levels=n)
    emulator = get_xarray_emulator(config)
    emulator.partial_fit(state, state)
    emulator.dump(path)
    new_emulator = XarrayEmulator.load(path)

    # assert that the air_temperature output is unchanged
    field = "air_temperature"
    np.testing.assert_array_equal(
        new_emulator.predict(state)[field],
        emulator.predict(state)[field])
Пример #6
0
 def __post_init__(self: "Adapter"):
     self.emulator = get_xarray_emulator(self.config.emulator)