def main(args, unknown_args=None): with open(args.training_config, "r") as f: config_dict = yaml.safe_load(f) if unknown_args is not None: config_dict = get_arg_updated_config_dict(args=unknown_args, config_dict=config_dict) training_config = fv3fit.TrainingConfig.from_dict(config_dict) with open(args.training_data_config, "r") as f: training_data_config = loaders.BatchesLoader.from_dict( yaml.safe_load(f)) fv3fit.set_random_seed(training_config.random_seed) dump_dataclass(training_config, os.path.join(args.output_path, "train.yaml")) dump_dataclass(training_data_config, os.path.join(args.output_path, "training_data.yaml")) train_batches: loaders.typing.Batches = training_data_config.load_batches( variables=training_config.variables) if args.validation_data_config is not None: with open(args.validation_data_config, "r") as f: validation_data_config = loaders.BatchesLoader.from_dict( yaml.safe_load(f)) dump_dataclass( validation_data_config, os.path.join(args.output_path, "validation_data.yaml"), ) val_batches = validation_data_config.load_batches( variables=training_config.variables) else: val_batches: Sequence[xr.Dataset] = [] if args.local_download_path: train_batches = loaders.to_local( train_batches, os.path.join(args.local_download_path, "train")) val_batches = loaders.to_local( val_batches, os.path.join(args.local_download_path, "validation")) train = fv3fit.get_training_function(training_config.model_type) model = train( hyperparameters=training_config.hyperparameters, train_batches=train_batches, validation_batches=val_batches, ) if len(training_config.derived_output_variables) > 0: model = fv3fit.DerivedModel(model, training_config.derived_output_variables) fv3fit.dump(model, args.output_path)
def train_identity_model(model_type, sample_func, hyperparameters=None): input_variables, output_variables, train_dataset = get_dataset( model_type, sample_func) if hyperparameters is None: hyperparameters = get_default_hyperparameters(model_type, input_variables, output_variables) train_batches = [train_dataset for _ in range(10)] input_variables, output_variables, test_dataset = get_dataset( model_type, sample_func) val_batches = [test_dataset] train = fv3fit.get_training_function(model_type) model = train(hyperparameters, train_batches, val_batches) return TrainingResult(model, output_variables, test_dataset, hyperparameters)
def mock_train_dense_model(): original_func = fv3fit.get_training_function("dense") train_mock = mock.MagicMock(name="train_dense_model", spec=original_func) train_mock.return_value = mock.MagicMock( name="train_dense_model_return", spec=fv3fit.Predictor ) register("mock")(train_mock.return_value.__class__) try: fv3fit._shared.config.register_training_function( "dense", fv3fit.DenseHyperparameters )(train_mock) yield train_mock finally: fv3fit._shared.config.register_training_function( "dense", fv3fit.DenseHyperparameters )(original_func) register._model_types.pop("mock")
def test_train_dense_model_clipped_inputs_outputs(): da = xr.DataArray( np.arange(1500).reshape(6, 5, 5, 10) * 1.0, dims=["tile", "x", "y", "z"], ) train_dataset = xr.Dataset(data_vars={ "var_in_0": da, "var_in_1": da, "var_out_0": da, "var_out_1": da }) train_batches = [train_dataset for _ in range(2)] val_batches = train_batches train = fv3fit.get_training_function("dense") input_variables = ["var_in_0", "var_in_1"] output_variables = ["var_out_0", "var_out_1"] hyperparameters = get_default_hyperparameters("dense", input_variables, output_variables) hyperparameters.clip_config = ClipConfig({ "var_in_0": { "z": SliceConfig(2, 5) }, "var_out_0": { "z": SliceConfig(4, 8) } }) model = train( hyperparameters, train_batches, val_batches, ) prediction = model.predict(train_dataset) assert np.unique( prediction["var_out_0"].isel(z=slice(None, 4)).values) == 0.0 assert np.unique( prediction["var_out_0"].isel(z=slice(8, None)).values) == 0.0
def test_train_predict_multiple_stacked_dims(model_type): da = xr.DataArray( np.full(fill_value=1.0, shape=(5, 10, 15)), dims=["x", "y", "z"], ) train_dataset = xr.Dataset(data_vars={ "var_in_0": da, "var_in_1": da, "var_out_0": da, "var_out_1": da }) train_batches = [train_dataset for _ in range(2)] val_batches = [] train = fv3fit.get_training_function(model_type) input_variables = ["var_in_0", "var_in_1"] output_variables = ["var_out_0", "var_out_1"] hyperparameters = get_default_hyperparameters(model_type, input_variables, output_variables) model = train( hyperparameters, train_batches, val_batches, ) model.predict(train_dataset)