예제 #1
0
def completed_rundir(configuration, tmpdir_factory):

    tendency_dataset_path = tmpdir_factory.mktemp("tendencies")

    if configuration == ConfigEnum.predictor:
        model = get_mock_predictor()
        model_path = str(tmpdir_factory.mktemp("model"))
        fv3fit.dump(model, str(model_path))
        config = get_ml_config(model_path)
    elif configuration == ConfigEnum.nudging:
        tendencies = _tendency_dataset()
        tendencies.to_zarr(
            str(tendency_dataset_path.join("ds.zarr")), consolidated=True
        )
        config = get_nudging_config(str(tendency_dataset_path.join("ds.zarr")))
    elif configuration == ConfigEnum.microphys_emulation:
        model_path = str(tmpdir_factory.mktemp("model").join("model.tf"))
        model = create_emulation_model()
        model.save(model_path)
        config = get_emulation_config(model_path)
    else:
        raise NotImplementedError()

    rundir = tmpdir_factory.mktemp("rundir").join("subdir")
    run_native(config, str(rundir))
    return rundir
예제 #2
0
def test_loaded_DenseModel_predicts_with_clipped_inputs(tmpdir):
    hyperparameters = DenseHyperparameters(
        ["a", "b"],
        ["c"],
        clip_config=PackerConfig({"a": {
            "z": SliceConfig(None, 3)
        }}),
    )
    model = DenseModel(["a", "b"], ["c"], hyperparameters)

    nz = 5
    dims = ["x", "y", "z"]
    shape = (2, 2, nz)
    arr = np.arange(np.prod(shape)).reshape(shape).astype(float)
    input_data = xr.Dataset({
        "a": (dims, arr),
        "b": (dims, arr),
        "c": (dims, arr + 1)
    })
    model.fit([input_data])
    prediction = model.predict(input_data)
    output_path = str(tmpdir.join("trained_model"))
    fv3fit.dump(model, output_path)
    model_loaded = fv3fit.load(output_path)
    loaded_prediction = model_loaded.predict(input_data)
    xr.testing.assert_allclose(prediction, loaded_prediction)
예제 #3
0
def test_adapter_regression(state, regtest, tmpdir_factory):
    model_path = str(tmpdir_factory.mktemp("model"))
    mock = get_mock_predictor(dQ1_tendency=1 / 86400)
    fv3fit.dump(mock, model_path)

    adapted_model = Adapter(
        Config(model_path, {
            "air_temperature": "dQ1",
            "specific_humidity": "dQ2"
        }), 900)
    transform = StepTransformer(
        adapted_model,
        MockDerivedState(state),
        "machine_learning",
        diagnostic_variables={
            "tendency_of_specific_humidity_due_to_machine_learning",
            "tendency_of_air_temperature_due_to_machine_learning",
            "tendency_of_internal_energy_due_to_machine_learning",
        },
        timestep=900,
    )

    def add_one_to_temperature():
        state["air_temperature"] += 1
        return {"some_diag": state["specific_humidity"]}

    out = transform(add_one_to_temperature)()

    # ensure tendency of internal energy is non-zero somewhere (GH#1433)
    max = abs(out["tendency_of_internal_energy_due_to_machine_learning"]).max()
    assert max.values.item() > 1e-6

    # sort to make the check deterministic
    regression_state(out, regtest)
예제 #4
0
def test_dump_and_load_default_maintains_prediction(model_type):
    n_sample, n_tile, nx, ny, n_feature = 1, 6, 12, 12, 2
    sample_func = get_uniform_sample_func(size=(n_sample, n_tile, nx, ny,
                                                n_feature))
    result = train_identity_model(model_type, sample_func=sample_func)

    original_result = result.model.predict(result.test_dataset)
    with tempfile.TemporaryDirectory() as tmpdir:
        fv3fit.dump(result.model, tmpdir)
        loaded_model = fv3fit.load(tmpdir)
    loaded_result = loaded_model.predict(result.test_dataset)
    xr.testing.assert_equal(loaded_result, original_result)
예제 #5
0
def main(args, unknown_args=None):
    with open(args.training_config, "r") as f:
        config_dict = yaml.safe_load(f)
        if unknown_args is not None:
            config_dict = get_arg_updated_config_dict(args=unknown_args,
                                                      config_dict=config_dict)
        training_config = fv3fit.TrainingConfig.from_dict(config_dict)

    with open(args.training_data_config, "r") as f:
        training_data_config = loaders.BatchesLoader.from_dict(
            yaml.safe_load(f))

    fv3fit.set_random_seed(training_config.random_seed)

    dump_dataclass(training_config, os.path.join(args.output_path,
                                                 "train.yaml"))
    dump_dataclass(training_data_config,
                   os.path.join(args.output_path, "training_data.yaml"))

    train_batches: loaders.typing.Batches = training_data_config.load_batches(
        variables=training_config.variables)
    if args.validation_data_config is not None:
        with open(args.validation_data_config, "r") as f:
            validation_data_config = loaders.BatchesLoader.from_dict(
                yaml.safe_load(f))
        dump_dataclass(
            validation_data_config,
            os.path.join(args.output_path, "validation_data.yaml"),
        )
        val_batches = validation_data_config.load_batches(
            variables=training_config.variables)
    else:
        val_batches: Sequence[xr.Dataset] = []

    if args.local_download_path:
        train_batches = loaders.to_local(
            train_batches, os.path.join(args.local_download_path, "train"))
        val_batches = loaders.to_local(
            val_batches, os.path.join(args.local_download_path, "validation"))

    train = fv3fit.get_training_function(training_config.model_type)
    model = train(
        hyperparameters=training_config.hyperparameters,
        train_batches=train_batches,
        validation_batches=val_batches,
    )
    if len(training_config.derived_output_variables) > 0:
        model = fv3fit.DerivedModel(model,
                                    training_config.derived_output_variables)
    fv3fit.dump(model, args.output_path)
예제 #6
0
def test_offline_diags_integration(data_path, grid_dataset_path):  # noqa: F811
    """
    Test the bash endpoint for computing offline diagnostics
    """

    batches_kwargs = {
        "needs_grid": False,
        "res": "c8_random_values",
        "timesteps_per_batch": 1,
        "timesteps": ["20160801.001500"],
    }
    trained_model = fv3fit.testing.ConstantOutputPredictor(
        input_variables=["air_temperature", "specific_humidity"],
        output_variables=["dQ1", "dQ2"],
    )
    trained_model.set_outputs(dQ1=np.zeros([19]), dQ2=np.zeros([19]))
    data_config = loaders.BatchesFromMapperConfig(
        loaders.MapperConfig(
            function="open_zarr",
            kwargs={"data_path": data_path},
        ),
        function="batches_from_mapper",
        kwargs=batches_kwargs,
    )
    with tempfile.TemporaryDirectory() as tmpdir:
        model_dir = os.path.join(tmpdir, "trained_model")
        fv3fit.dump(trained_model, model_dir)
        data_config_filename = os.path.join(tmpdir, "data_config.yaml")
        with open(data_config_filename, "w") as f:
            yaml.safe_dump(dataclasses.asdict(data_config), f)
        compute_diags_args = ComputeDiagsArgs(
            model_path=model_dir,
            output_path=os.path.join(tmpdir, "offline_diags"),
            data_yaml=data_config_filename,
            grid=grid_dataset_path,
        )
        compute.main(compute_diags_args)
        if isinstance(data_config, loaders.BatchesFromMapperConfig):
            assert "transect_lon0.nc" in os.listdir(
                os.path.join(tmpdir, "offline_diags"))
        create_report_args = CreateReportArgs(
            input_path=os.path.join(tmpdir, "offline_diags"),
            output_path=os.path.join(tmpdir, "report"),
        )
        create_report(create_report_args)
        with open(os.path.join(tmpdir, "report/index.html")) as f:
            report = f.read()
        if isinstance(data_config, loaders.BatchesFromMapperConfig):
            assert "Transect snapshot at" in report
예제 #7
0
def test_constant_model_predict_after_dump_and_load(input_variables,
                                                    output_variables, nz):
    gridded_dataset = get_gridded_dataset(nz)
    outputs = get_first_columns(gridded_dataset, output_variables)
    predictor = get_predictor(input_variables, output_variables, outputs)
    with tempfile.TemporaryDirectory() as tempdir:
        fv3fit.dump(predictor, tempdir)
        predictor = fv3fit.load(tempdir)

    ds_pred = predictor.predict(gridded_dataset)

    assert sorted(list(ds_pred.data_vars.keys())) == sorted(output_variables)
    for name in output_variables:
        assert np.all(
            stack_non_vertical(ds_pred[name]).values == outputs[name][None, :])
예제 #8
0
def test_dump_and_load(tmpdir):
    derived_model = DerivedModel(
        base_model, derived_output_variables=["net_shortwave_sfc_flux_derived"],
    )
    ds_in = xr.Dataset(
        data_vars={
            "input": xr.DataArray(np.zeros([3, 3, 5]), dims=["x", "y", "z"],),
            "surface_diffused_shortwave_albedo": xr.DataArray(
                np.zeros([3, 3]), dims=["x", "y"],
            ),
        }
    )
    prediction = derived_model.predict(ds_in)

    fv3fit.dump(derived_model, str(tmpdir))
    loaded_model = fv3fit.load(str(tmpdir))

    prediction_after_load = loaded_model.predict(ds_in)
    assert prediction_after_load.identical(prediction)
예제 #9
0
def test_reloaded_model_gives_same_outputs(sample_dim_name, dt):
    train_dataset = get_train_dataset(sample_dim_name, dt)
    model = _BPTTTrainer(
        sample_dim_name,
        ["a", "b"],
        n_units=32,
        n_hidden_layers=4,
        kernel_regularizer=None,
        train_batch_size=48,
        optimizer="adam",
    )
    model.fit_statistics(train_dataset)
    model.fit(train_dataset)

    with tempfile.TemporaryDirectory() as tmpdir:
        fv3fit.dump(model.predictor_model, tmpdir)
        loaded = fv3fit.load(tmpdir)

    first_timestep = train_dataset.isel(time=0)
    reference_output = model.predictor_model.predict(first_timestep)
    # test that loaded model gives the same predictions
    loaded_output = loaded.predict(first_timestep)
    xr.testing.assert_equal(reference_output, loaded_output)
예제 #10
0
        )
        model.fit_statistics(ds)

    train_filenames = filenames[:-1]
    validation = xr.open_dataset(filenames[-1])

    base_epoch = 0
    for i_epoch in range(config["total_epochs"]):
        epoch = base_epoch + i_epoch
        print(f"starting epoch {epoch}")
        for i, ds in enumerate(
            loaders.OneAheadIterator(shuffled(train_filenames), function=load_dataset)
        ):
            model.fit(ds, epochs=1)
        val_loss = model.loss(validation)
        print(f"val_loss: {val_loss}")
        dirname = os.path.join(
            args.model_output_dir, f"model-epoch_{epoch:03d}-loss_{val_loss:.04f}"
        )
        os.makedirs(dirname, exist_ok=True)
        # dump doesn't need an estimator's fit methods, it only needs .dump to exist
        # which predictor_model has defined
        fv3fit.dump(model.predictor_model, dirname)  # type: ignore
        if i_epoch == config["decrease_learning_rate_epoch"] - 1:
            optimizer_kwargs["lr"] = config["decreased_learning_rate"]
            # train_keras_model will not be None because we call fit above
            # (assuming there is any training data)
            model.train_keras_model.compile(  # type: ignore
                optimizer=optimizer_class(**optimizer_kwargs), loss=model.losses
            )