Exemple #1
0
def main(args, unknown_args=None):
    with open(args.training_config, "r") as f:
        config_dict = yaml.safe_load(f)
        if unknown_args is not None:
            config_dict = get_arg_updated_config_dict(args=unknown_args,
                                                      config_dict=config_dict)
        training_config = fv3fit.TrainingConfig.from_dict(config_dict)

    with open(args.training_data_config, "r") as f:
        training_data_config = loaders.BatchesLoader.from_dict(
            yaml.safe_load(f))

    fv3fit.set_random_seed(training_config.random_seed)

    dump_dataclass(training_config, os.path.join(args.output_path,
                                                 "train.yaml"))
    dump_dataclass(training_data_config,
                   os.path.join(args.output_path, "training_data.yaml"))

    train_batches: loaders.typing.Batches = training_data_config.load_batches(
        variables=training_config.variables)
    if args.validation_data_config is not None:
        with open(args.validation_data_config, "r") as f:
            validation_data_config = loaders.BatchesLoader.from_dict(
                yaml.safe_load(f))
        dump_dataclass(
            validation_data_config,
            os.path.join(args.output_path, "validation_data.yaml"),
        )
        val_batches = validation_data_config.load_batches(
            variables=training_config.variables)
    else:
        val_batches: Sequence[xr.Dataset] = []

    if args.local_download_path:
        train_batches = loaders.to_local(
            train_batches, os.path.join(args.local_download_path, "train"))
        val_batches = loaders.to_local(
            val_batches, os.path.join(args.local_download_path, "validation"))

    train = fv3fit.get_training_function(training_config.model_type)
    model = train(
        hyperparameters=training_config.hyperparameters,
        train_batches=train_batches,
        validation_batches=val_batches,
    )
    if len(training_config.derived_output_variables) > 0:
        model = fv3fit.DerivedModel(model,
                                    training_config.derived_output_variables)
    fv3fit.dump(model, args.output_path)
Exemple #2
0
def train_identity_model(model_type, sample_func, hyperparameters=None):
    input_variables, output_variables, train_dataset = get_dataset(
        model_type, sample_func)
    if hyperparameters is None:
        hyperparameters = get_default_hyperparameters(model_type,
                                                      input_variables,
                                                      output_variables)
    train_batches = [train_dataset for _ in range(10)]
    input_variables, output_variables, test_dataset = get_dataset(
        model_type, sample_func)
    val_batches = [test_dataset]
    train = fv3fit.get_training_function(model_type)
    model = train(hyperparameters, train_batches, val_batches)
    return TrainingResult(model, output_variables, test_dataset,
                          hyperparameters)
Exemple #3
0
def mock_train_dense_model():
    original_func = fv3fit.get_training_function("dense")
    train_mock = mock.MagicMock(name="train_dense_model", spec=original_func)
    train_mock.return_value = mock.MagicMock(
        name="train_dense_model_return", spec=fv3fit.Predictor
    )
    register("mock")(train_mock.return_value.__class__)
    try:
        fv3fit._shared.config.register_training_function(
            "dense", fv3fit.DenseHyperparameters
        )(train_mock)
        yield train_mock
    finally:
        fv3fit._shared.config.register_training_function(
            "dense", fv3fit.DenseHyperparameters
        )(original_func)
        register._model_types.pop("mock")
Exemple #4
0
def test_train_dense_model_clipped_inputs_outputs():
    da = xr.DataArray(
        np.arange(1500).reshape(6, 5, 5, 10) * 1.0,
        dims=["tile", "x", "y", "z"],
    )
    train_dataset = xr.Dataset(data_vars={
        "var_in_0": da,
        "var_in_1": da,
        "var_out_0": da,
        "var_out_1": da
    })
    train_batches = [train_dataset for _ in range(2)]
    val_batches = train_batches
    train = fv3fit.get_training_function("dense")

    input_variables = ["var_in_0", "var_in_1"]
    output_variables = ["var_out_0", "var_out_1"]

    hyperparameters = get_default_hyperparameters("dense", input_variables,
                                                  output_variables)
    hyperparameters.clip_config = ClipConfig({
        "var_in_0": {
            "z": SliceConfig(2, 5)
        },
        "var_out_0": {
            "z": SliceConfig(4, 8)
        }
    })
    model = train(
        hyperparameters,
        train_batches,
        val_batches,
    )
    prediction = model.predict(train_dataset)
    assert np.unique(
        prediction["var_out_0"].isel(z=slice(None, 4)).values) == 0.0
    assert np.unique(
        prediction["var_out_0"].isel(z=slice(8, None)).values) == 0.0
Exemple #5
0
def test_train_predict_multiple_stacked_dims(model_type):
    da = xr.DataArray(
        np.full(fill_value=1.0, shape=(5, 10, 15)),
        dims=["x", "y", "z"],
    )
    train_dataset = xr.Dataset(data_vars={
        "var_in_0": da,
        "var_in_1": da,
        "var_out_0": da,
        "var_out_1": da
    })
    train_batches = [train_dataset for _ in range(2)]
    val_batches = []
    train = fv3fit.get_training_function(model_type)
    input_variables = ["var_in_0", "var_in_1"]
    output_variables = ["var_out_0", "var_out_1"]
    hyperparameters = get_default_hyperparameters(model_type, input_variables,
                                                  output_variables)
    model = train(
        hyperparameters,
        train_batches,
        val_batches,
    )
    model.predict(train_dataset)