def test_OnlineEmulator_partial_fit(state): config = Config(batch_size=32, learning_rate=0.001, momentum=0.0, levels=63) emulator = get_xarray_emulator(config) emulator.partial_fit(state, state)
def test_checkpointed_model(tmpdir): # dump a model config = Config() emulator = Trainer(config) emulator.dump(tmpdir) # load it emulator = Trainer.load(str(tmpdir)) assert isinstance(emulator, Trainer)
def test_OnlineEmulator_fails_when_accessing_nonexistant_var(state): config = Config( batch_size=32, learning_rate=0.001, momentum=0.0, extra_input_variables=["not a varialbe in any state 332r23r90e9d"], levels=63, ) emulator = get_xarray_emulator(config) with pytest.raises(KeyError): emulator.partial_fit(state, state)
def test_OnlineEmulator_fit_predict(state, extra_inputs): config = Config( batch_size=32, learning_rate=0.001, momentum=0.0, extra_input_variables=extra_inputs, levels=63, ) emulator = get_xarray_emulator(config) emulator.partial_fit(state, state) stateout = emulator.predict(state) assert isinstance(stateout, dict) assert list(stateout["eastward_wind"].dims) == ["z", "y", "x"]
def test_OnlineEmulator_partial_fit_logged(state, tmpdir): config = Config(batch_size=8, learning_rate=0.0001, momentum=0.0, levels=63) time = datetime.datetime.now().isoformat() tf_summary_dir = str(tmpdir.join(time)) emulator = get_xarray_emulator(config) writer = tf.summary.create_file_writer(tf_summary_dir) with writer.as_default(): for i in range(10): emulator.partial_fit(state, state)
def test_dump_load_OnlineEmulator(state, tmpdir, output_exists): if output_exists: path = str(tmpdir) else: path = str(tmpdir.join("model")) n = state["air_temperature"].sizes["z"] config = Config(levels=n) emulator = get_xarray_emulator(config) emulator.partial_fit(state, state) emulator.dump(path) new_emulator = XarrayEmulator.load(path) # assert that the air_temperature output is unchanged field = "air_temperature" np.testing.assert_array_equal( new_emulator.predict(state)[field], emulator.predict(state)[field])
extra_input_variables=extra_inputs, levels=63, ) emulator = get_xarray_emulator(config) emulator.partial_fit(state, state) stateout = emulator.predict(state) assert isinstance(stateout, dict) assert list(stateout["eastward_wind"].dims) == ["z", "y", "x"] @pytest.mark.parametrize("with_validation", [True, False]) @pytest.mark.parametrize( "config", [ Config(batch_size=32, learning_rate=0.001, momentum=0.0, levels=79), Config( batch_size=32, learning_rate=0.001, momentum=0.0, target=QVLossSingleLevel(0), levels=79, ), Config( target=RHLossSingleLevel(50), levels=79, ), ], ) def test_OnlineEmulator_batch_fit(config, with_validation): x = to_dict(_get_argsin(config.levels))