Exemplo n.º 1
0
    def __init__(  # pylint: disable=too-many-arguments
        self,
        query: str,
        columns: Optional[List[str]] = None,
        shuffle: bool = False,
        shuffler_capacity: int = 128,
        seed: Optional[int] = None,
        world_size: int = 1,
        rank: int = 0,
    ):
        self.uri = query
        self.columns = columns
        self.shuffle = shuffle
        self.shuffler_capacity = shuffler_capacity
        self.seed = seed
        self.rank = rank
        self.world_size = world_size
        if self.world_size > 1:
            logger.info("Running in distributed mode, world size=%s, rank=%s",
                        world_size, rank)

        # Provide determinstic order between distributed workers.
        self.files = sorted(Resolver.resolve(self.uri))
        logger.info("Loading parquet files: %s", self.files)

        self.spark_row_metadata = Resolver.get_schema(self.uri)
Exemplo n.º 2
0
def init_cb_service(spark: SparkSession, enable_dynamic_port: bool):
    jvm = spark.sparkContext._gateway

    # Set port to 0 to enable dynamic port
    # Re-use the auth-token from the main java/spark process
    if enable_dynamic_port:
        params = CallbackServerParameters(
            daemonize=True,
            daemonize_connections=True,
            # https://www.py4j.org/advanced_topics.html#using-py4j-without-pre-determined-ports-dynamic-port-number
            port=0,
            auth_token=jvm.gateway_parameters.auth_token,
        )
    else:
        params = CallbackServerParameters(
            daemonize=True,
            daemonize_connections=True,
            auth_token=jvm.gateway_parameters.auth_token,
        )

    jvm.start_callback_server(callback_server_parameters=params)

    if enable_dynamic_port:
        python_port = jvm.get_callback_server().get_listening_port()
        jvm.java_gateway_server.resetCallbackClient(
            jvm.java_gateway_server.getCallbackClient().getAddress(),
            python_port,
        )
    else:
        jvm.start_callback_server(callback_server_parameters=params)

    logger.info("Spark callback server started")

    cb = CallbackService(spark)
    cb.register()
Exemplo n.º 3
0
def register_udf(spark: SparkSession, udf: Callable, name: str) -> str:
    """
    Register a given UDF with the give Spark session under the given name.
    """
    func_name = f"{name}_{secrets.token_hex(4)}"
    spark.udf.register(func_name, udf)
    logger.info(f"Created model inference pandas_udf with name {func_name}")
    return func_name
Exemplo n.º 4
0
    def resolve(self, uri: str, name: str, options: Dict[str, str]):
        logger.info(f"Resolving model {name} from {uri}")
        if uri.endswith(".yml") or uri.endswith(".yaml"):
            func_name = codegen_from_yaml(self._spark, uri, name, options)
        else:
            raise ValueError(f"Model URI is not supported: {uri}")

        model = self._jvm.ai.eto.rikai.sql.model.SparkUDFModel(
            name, uri, func_name)
        # TODO: set options
        return model
Exemplo n.º 5
0
 def resolve(self, uri: str, name: str, options: Dict[str, str]):
     logger.info(f"Resolving model {name} from {uri}")
     model_uri, tags, params = _get_model_info(uri, self.tracking_client)
     spec = MlflowModelSpec(
         model_uri, tags, params, self.mlflow_tracking_uri, options=options
     )
     func_name = codegen_from_spec(self._spark, spec, name)
     model = self._jvm.ai.eto.rikai.sql.model.SparkUDFModel(
         name, uri, func_name
     )
     return model
Exemplo n.º 6
0
def init_cb_service(spark: SparkSession):
    jvm = spark.sparkContext._gateway
    params = CallbackServerParameters(
        daemonize=True,
        daemonize_connections=True,
        # Re-use the auth-token from the main java/spark process
        auth_token=jvm.gateway_parameters.auth_token,
    )
    jvm.start_callback_server(callback_server_parameters=params)
    logger.info("Spark callback server started")

    cb = CallbackService(spark)
    cb.register()
Exemplo n.º 7
0
    def setUpClass(cls) -> None:
        import rikai

        jar_dir = os.path.join(os.path.dirname(rikai.__file__), "jars")
        rikai_jar = glob.glob(os.path.join(jar_dir, "*.jar"))
        if len(rikai_jar) == 0:
            raise ValueError(
                f"Rikai Jar is not found on {jar_dir}, please run 'sbt package' first"
            )
        jars = ":".join(rikai_jar)
        logger.info("loading jars for spark testing: %s", jars)
        cls.spark = (SparkSession.builder.appName("spark-test").config(
            "spark.jars", jars).master("local[2]").getOrCreate())
Exemplo n.º 8
0
Arquivo: fs.py Projeto: eto-ai/rikai
    def resolve(self, spec):
        name = spec.getName()
        uri = spec.getUri()
        options = spec.getOptions()
        logger.info(f"Resolving model {name} from {uri}")
        if uri.endswith(".yml") or uri.endswith(".yaml"):
            func_name = codegen_from_yaml(self._spark, uri, name, options)
        else:
            raise ValueError(f"Model URI is not supported: {uri}")

        model = self._jvm.ai.eto.rikai.sql.model.SparkUDFModel(
            name, uri, func_name)
        # TODO: set options
        return model
Exemplo n.º 9
0
def codegen_from_yaml(
    spark: SparkSession,
    uri: str,
    name: Optional[str] = None,
    options: Optional[Dict[str, str]] = None,
) -> str:
    """Generate code from a YAML file.

    Parameters
    ----------
    spark : SparkSession
        A live spark session
    uri : str
        the model spec URI
    name : model name
        The name of the model.
    options : dict
        Optional parameters passed to the model.

    Returns
    -------
    str
        Spark UDF function name for the generated data.
    """
    with open_uri(uri) as fobj:
        spec = ModelSpec(fobj, options=options)

    if spec.version != 1.0:
        raise SpecError(
            f"Only spec version 1.0 is supported, got {spec.version}"
        )

    if spec.flavor == "pytorch":
        from rikai.spark.sql.codegen.pytorch import generate_udf

        udf = generate_udf(
            spec.uri,
            spec.schema,
            spec.options,
            pre_processing=spec.pre_processing,
            post_processing=spec.post_processing,
        )
    else:
        raise SpecError(f"Unsupported model flavor: {spec.flavor}")

    func_name = f"{name}_{secrets.token_hex(4)}"
    spark.udf.register(func_name, udf)
    logger.info(f"Created model inference pandas_udf with name {func_name}")
    return func_name
Exemplo n.º 10
0
def image_copy(img: Image, uri: str) -> Image:
    """Copy the image to a new destination, specified by the URI.

    Parameters
    ----------
    img : Image
        An image object
    uri : str
        The base directory to copy the image to.

    Return
    ------
    Image
        Return a new image pointed to the new URI
    """
    logger.info("Copying image src=%s dest=%s", img.uri, uri)
    return Image(_copy(img.uri, uri))
Exemplo n.º 11
0
 def resolve(self, raw_spec):
     name = raw_spec.getName()
     uri = raw_spec.getUri()
     logger.info(f"Resolving model {name} from {uri}")
     parsed = urlparse(uri)
     if not parsed.scheme:
         raise ValueError("Scheme must be mlflow. How did you get here?")
     parts = parsed.path.strip("/").split("/", 1)
     model_uri, run = self.get_model_version(*parts)
     spec = MlflowModelSpec(
         model_uri,
         self.get_model_conf(raw_spec, run),
         self.mlflow_tracking_uri,
         options=self.get_options(raw_spec, run),
     )
     func_name = codegen_from_spec(self._spark, spec, name)
     model = self._jvm.ai.eto.rikai.sql.model.SparkUDFModel(
         name, uri, func_name
     )
     return model
Exemplo n.º 12
0
 def register(self):
     """Register this :py:class:`CallbackService` to SparkSession's JVM."""
     jvm = self.spark.sparkContext._jvm
     jvm.ai.eto.rikai.sql.spark.Python.register(self)
     logger.info(
         "Rikai Python CallbackService is registered to SparkSession")