def _parse_runid_ref(parsed: ParseResult, client: MlflowClient): runid = parsed.hostname run = client.get_run(runid) path = parsed.path.lstrip("/") if path: return ( "runs:/{}/{}".format(runid, path), run.data.tags, run.data.params, ) else: artifacts = client.list_artifacts(runid) if not artifacts: raise SpecError("Run {} has no artifacts".format(runid)) elif len(artifacts) == 1: return ( "runs:/{}/{}".format(runid, artifacts[0].path), run.data.tags, run.data.params, ) else: # TODO allow setting default path from config raise SpecError( ( "Run {} has more than 1 artifact ({})." "Please specify path like " "mlflows://<runid>/path/to/artifact in " "CREATE MODEL or ML_PREDICT" ).format(runid, [x.path for x in artifacts]) )
def codegen_from_yaml( spark: SparkSession, uri: str, name: Optional[str] = None, options: Optional[Dict[str, str]] = None, ) -> str: """Generate code from a YAML file. Parameters ---------- spark : SparkSession A live spark session uri : str the model spec URI name : model name The name of the model. options : dict Optional parameters passed to the model. Returns ------- str Spark UDF function name for the generated data. """ with open_uri(uri) as fobj: spec = ModelSpec(fobj, options=options) if spec.version != 1.0: raise SpecError( f"Only spec version 1.0 is supported, got {spec.version}" ) if spec.flavor == "pytorch": from rikai.spark.sql.codegen.pytorch import generate_udf udf = generate_udf( spec.uri, spec.schema, spec.options, pre_processing=spec.pre_processing, post_processing=spec.post_processing, ) else: raise SpecError(f"Unsupported model flavor: {spec.flavor}") func_name = f"{name}_{secrets.token_hex(4)}" spark.udf.register(func_name, udf) logger.info(f"Created model inference pandas_udf with name {func_name}") return func_name
def load_model(self): if self.flavor == "pytorch": from rikai.spark.sql.codegen.pytorch import load_model_from_uri return load_model_from_uri(self.uri) else: raise SpecError("Unsupported flavor {}".format(self.flavor))
def _parse_model_ref(parsed: ParseResult, client: MlflowClient): model = parsed.hostname path = parsed.path.lstrip("/") if path.isdigit(): mv = client.get_model_version(model, int(path)) run = client.get_run(mv.run_id) return ( "models:/{}/{}".format(model, path), run.data.tags, run.data.params, ) if not path: stage = "none" # TODO allow setting default stage from config else: stage = path.lower() results = client.get_latest_versions(model, stages=[stage]) if not results: raise SpecError( "No versions found for model {} in stage {}".format(model, stage) ) run = client.get_run(results[0].run_id) return ( "models:/{}/{}".format(model, results[0].version), run.data.tags, run.data.params, )
def get_model_version( self, model, stage_or_version=None ) -> (str, mlflow.entities.Run): """ Get the model uri that mlflow model registry understands for loading a model and the corresponding Run with metadata needed for the spec """ # TODO allow default stage from config stage_or_version = stage_or_version or "none" if stage_or_version.isdigit(): # Pegged to version number run_id = self.tracking_client.get_model_version( model, int(stage_or_version) ).run_id version = int(stage_or_version) else: # Latest version in stage results = self.tracking_client.get_latest_versions( model, stages=[stage_or_version.lower()] ) if not results: msg = "No versions found for model {} in stage {}".format( model, stage_or_version ) raise SpecError(msg) run_id, version = results[0].run_id, results[0].version run = self.tracking_client.get_run(run_id) return "models:/{}/{}".format(model, version), run
def validate(self): """Validate model spec Raises ------ SpecError If the spec is not well-formatted. """ logger.debug("Validating spec: %s", self._spec) try: validate(instance=self._spec, schema=MODEL_SPEC_SCHEMA) except ValidationError as e: raise SpecError(e.message) from e
def udf_from_spec(spec: ModelSpec): """Return a UDF from a given ModelSpec Parameters ---------- spec : ModelSpec A model spec Returns ------- str Spark UDF function name for the generated data. """ if spec.version != "1.0": raise SpecError( f"Only spec version 1.0 is supported, got {spec.version}") if spec.flavor == "pytorch": from rikai.spark.sql.codegen.pytorch import generate_udf return generate_udf(spec) else: raise SpecError(f"Unsupported model flavor: {spec.flavor}")