def test_yaml_metadata_missing_fields(tmp_path, config_yaml, request, test_case_number): config_yaml = request.getfixturevalue(config_yaml) with open(os.path.join(tmp_path, MODEL_CONFIG_FILENAME), mode="w") as f: f.write(config_yaml) if test_case_number == 1: conf = read_model_metadata_yaml(tmp_path) with pytest.raises( DrumCommonException, match="Missing keys: \['validation', 'environmentID'\]" ): validate_config_fields( conf, ModelMetadataKeys.CUSTOM_PREDICTOR, ModelMetadataKeys.VALIDATION, ModelMetadataKeys.ENVIRONMENT_ID, ) elif test_case_number == 2: with pytest.raises(DrumCommonException, match="Missing keys: \['negativeClassLabel'\]"): read_model_metadata_yaml(tmp_path) elif test_case_number == 3: with pytest.raises( DrumCommonException, match="Error - for multiclass classification, either the class labels or a class labels file must be provided in model-metadata.yaml file", ): read_model_metadata_yaml(tmp_path) elif test_case_number == 4: with pytest.raises( DrumCommonException, match="Error - for multiclass classification, either the class labels or a class labels file should be provided in model-metadata.yaml file, but not both", ): read_model_metadata_yaml(tmp_path) elif test_case_number == 100: read_model_metadata_yaml(tmp_path)
def setup_validation_options(options): model_config = get_metadata(options) validate_config_fields(model_config, ModelMetadataKeys.VALIDATION) if model_config["type"] == "training": return _setup_training_validation(model_config, options) elif model_config["type"] == "inference": return _setup_inference_validation(model_config, options) else: raise DrumCommonException("Unsupported type")
def test_yaml_metadata_missing_fields(custom_predictor_metadata_yaml, tmp_path): with open(os.path.join(tmp_path, MODEL_CONFIG_FILENAME), mode="w") as f: f.write(custom_predictor_metadata_yaml) conf = read_model_metadata_yaml(tmp_path) with pytest.raises( DrumCommonException, match="Missing keys: \['validation', 'environmentID'\]"): validate_config_fields( conf, ModelMetadataKeys.CUSTOM_PREDICTOR, ModelMetadataKeys.VALIDATION, ModelMetadataKeys.ENVIRONMENT_ID, )
def drum_push(options): model_config = get_metadata(options) if model_config["type"] == "training": validate_config_fields(model_config, ModelMetadataKeys.ENVIRONMENT_ID) _push_training(model_config, options.code_dir) elif model_config["type"] == "inference": validate_config_fields(model_config, ModelMetadataKeys.ENVIRONMENT_ID, ModelMetadataKeys.INFERENCE_MODEL) validate_config_fields(model_config[ModelMetadataKeys.INFERENCE_MODEL], "targetName") _push_inference(model_config, options.code_dir) else: raise DrumCommonException("Unsupported type")