Beispiel #1
0
    def test_model_cache(self):
        archive_path = SparkModelCache.add_local_model(self.spark,
                                                       self._model_path)
        assert archive_path != self._model_path

        # Ensure we can use the model locally.
        local_model = SparkModelCache.get_or_load(archive_path)
        assert local_model.__name__ == "ConstPyfunc"

        # Request the model on all executors, and see how many times we got cache hits.
        def get_model(_):
            model = SparkModelCache.get_or_load(archive_path)
            # NB: Can not use instanceof test as remote does not know about ConstPyfunc class
            assert model.__name__ == "ConstPyfunc"
            return SparkModelCache._cache_hits

        # This will run 30 distinct tasks, and we expect most to reuse an already-loaded model.
        # Note that we can't necessarily expect an even split, or even that there were only
        # exactly 2 python processes launched, due to Spark and its mysterious ways, but we do
        # expect significant reuse.
        results = self.spark.sparkContext.parallelize(range(
            0, 100), 30).map(get_model).collect()

        # TODO(tomas): Looks like spark does not reuse python workers with python==3.x
        assert sys.version[0] == '3' or max(results) > 10
        # Running again should see no newly-loaded models.
        results2 = self.spark.sparkContext.parallelize(range(
            0, 100), 30).map(get_model).collect()
        assert sys.version[0] == '3' or min(results2) > 0
Beispiel #2
0
def spark_udf(spark, path, run_id=None, result_type="double"):
    """Returns a Spark UDF that can be used to invoke the python-function formatted model.

    Note that parameters passed to the UDF will be forwarded to the model as a DataFrame
    where the names are simply ordinals (0, 1, ...).

    Example:
        predict = mlflow.pyfunc.spark_udf(spark, "/my/local/model")
        df.withColumn("prediction", predict("name", "age")).show()

    Args:
        spark (SparkSession): a SparkSession object
        path (str): A path containing a pyfunc model.
        result_type (str): Spark UDF type returned by the model's prediction method. Default double
    """

    # Scope Spark import to this method so users don't need pyspark to use non-Spark-related
    # functionality.
    from mlflow.pyfunc.spark_model_cache import SparkModelCache
    from pyspark.sql.functions import pandas_udf

    if run_id:
        path = tracking._get_model_log_dir(path, run_id)

    archive_path = SparkModelCache.add_local_model(spark, path)

    def predict(*args):
        model = SparkModelCache.get_or_load(archive_path)
        schema = {str(i): arg for i, arg in enumerate(args)}
        pdf = pandas.DataFrame(schema)
        result = model.predict(pdf)
        return pandas.Series(result)

    return pandas_udf(predict, result_type)
Beispiel #3
0
def test_model_cache(spark, model_path):
    mlflow.pyfunc.save_model(
        path=model_path,
        loader_module=__name__,
        code_path=[os.path.dirname(tests.__file__)],
    )

    archive_path = SparkModelCache.add_local_model(spark, model_path)
    assert archive_path != model_path

    # Define the model class name as a string so that each Spark executor can reference it
    # without attempting to resolve ConstantPyfuncWrapper, which is only available on the driver.
    constant_model_name = ConstantPyfuncWrapper.__name__

    def check_get_or_load_return_value(model_from_cache,
                                       model_path_from_cache):
        assert model_path_from_cache != model_path
        assert os.path.isdir(model_path_from_cache)
        model2 = mlflow.pyfunc.load_model(model_path_from_cache)
        for model in [model_from_cache, model2]:
            assert isinstance(model, PyFuncModel)
            # NB: Can not use instanceof test as remote does not know about ConstantPyfuncWrapper
            # class.
            assert type(model._model_impl).__name__ == constant_model_name

    # Ensure we can use the model locally.
    local_model, local_model_path = SparkModelCache.get_or_load(archive_path)

    check_get_or_load_return_value(local_model, local_model_path)

    # Request the model on all executors, and see how many times we got cache hits.
    def get_model(_):
        executor_model, executor_model_path = SparkModelCache.get_or_load(
            archive_path)
        check_get_or_load_return_value(executor_model, executor_model_path)
        return SparkModelCache._cache_hits

    # This will run 30 distinct tasks, and we expect most to reuse an already-loaded model.
    # Note that we can't necessarily expect an even split, or even that there were only
    # exactly 2 python processes launched, due to Spark and its mysterious ways, but we do
    # expect significant reuse.
    results = spark.sparkContext.parallelize(range(100),
                                             30).map(get_model).collect()
    assert max(results) > 10
    # Running again should see no newly-loaded models.
    results2 = spark.sparkContext.parallelize(range(100),
                                              30).map(get_model).collect()
    assert min(results2) > 0
Beispiel #4
0
def spark_udf(spark, path, run_id=None, result_type="double"):
    """
    Return a Spark UDF that can be used to invoke the Python function formatted model.

    Parameters passed to the UDF are forwarded to the model as a DataFrame where the names are
    simply ordinals (0, 1, ...).

    Example:

    .. code:: python

        predict = mlflow.pyfunc.spark_udf(spark, "/my/local/model")
        df.withColumn("prediction", predict("name", "age")).show()

    :param spark: A SparkSession object.
    :param path: A path containing a pyfunc model.
    :param run_id: ID of the run that produced this model. If provided, ``run_id`` is used to
                   retrieve the model logged with MLflow.
    :param result_type: Spark UDF type returned by the model's prediction method. Default double.

    """

    # Scope Spark import to this method so users don't need pyspark to use non-Spark-related
    # functionality.
    from mlflow.pyfunc.spark_model_cache import SparkModelCache
    from pyspark.sql.functions import pandas_udf

    if run_id:
        path = tracking.utils._get_model_log_dir(path, run_id)

    archive_path = SparkModelCache.add_local_model(spark, path)

    def predict(*args):
        model = SparkModelCache.get_or_load(archive_path)
        schema = {str(i): arg for i, arg in enumerate(args)}
        # Explicitly pass order of columns to avoid lexicographic ordering (i.e., 10 < 2)
        columns = [str(i) for i, _ in enumerate(args)]
        pdf = pandas.DataFrame(schema, columns=columns)
        result = model.predict(pdf)
        return pandas.Series(result)

    return pandas_udf(predict, result_type)
Beispiel #5
0
def test_model_cache(spark, model_path):
    mlflow.pyfunc.save_model(
        dst_path=model_path,
        loader_module=__name__,
        code_path=[os.path.dirname(tests.__file__)],
    )

    archive_path = SparkModelCache.add_local_model(spark, model_path)
    assert archive_path != model_path

    # Ensure we can use the model locally.
    local_model = SparkModelCache.get_or_load(archive_path)
    assert isinstance(local_model, ConstantPyfuncWrapper)

    # Define the model class name as a string so that each Spark executor can reference it
    # without attempting to resolve ConstantPyfuncWrapper, which is only available on the driver.
    constant_model_name = ConstantPyfuncWrapper.__name__

    # Request the model on all executors, and see how many times we got cache hits.
    def get_model(_):
        model = SparkModelCache.get_or_load(archive_path)
        # NB: Can not use instanceof test as remote does not know about ConstantPyfuncWrapper class.
        assert type(model).__name__ == constant_model_name
        return SparkModelCache._cache_hits

    # This will run 30 distinct tasks, and we expect most to reuse an already-loaded model.
    # Note that we can't necessarily expect an even split, or even that there were only
    # exactly 2 python processes launched, due to Spark and its mysterious ways, but we do
    # expect significant reuse.
    results = spark.sparkContext.parallelize(range(0, 100),
                                             30).map(get_model).collect()

    # TODO(tomas): Looks like spark does not reuse python workers with python==3.x
    assert sys.version[0] == '3' or max(results) > 10
    # Running again should see no newly-loaded models.
    results2 = spark.sparkContext.parallelize(range(0, 100),
                                              30).map(get_model).collect()
    assert sys.version[0] == '3' or min(results2) > 0
Beispiel #6
0
def spark_udf(spark, model_uri, result_type="double"):
    """
    A Spark UDF that can be used to invoke the Python function formatted model.

    Parameters passed to the UDF are forwarded to the model as a DataFrame where the column names
    are ordinals (0, 1, ...). On some versions of Spark, it is also possible to wrap the input in a
    struct. In that case, the data will be passed as a DataFrame with column names given by the
    struct definition (e.g. when invoked as my_udf(struct('x', 'y'), the model will ge the data as a
    pandas DataFrame with 2 columns 'x' and 'y').

    The predictions are filtered to contain only the columns that can be represented as the
    ``result_type``. If the ``result_type`` is string or array of strings, all predictions are
    converted to string. If the result type is not an array type, the left most column with
    matching type is returned.

    .. code-block:: python
        :caption: Example

        predict = mlflow.pyfunc.spark_udf(spark, "/my/local/model")
        df.withColumn("prediction", predict("name", "age")).show()

    :param spark: A SparkSession object.
    :param model_uri: The location, in URI format, of the MLflow model with the
                      :py:mod:`mlflow.pyfunc` flavor. For example:

                      - ``/Users/me/path/to/local/model``
                      - ``relative/path/to/local/model``
                      - ``s3://my_bucket/path/to/model``
                      - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
                      - ``models:/<model_name>/<model_version>``
                      - ``models:/<model_name>/<stage>``

                      For more information about supported URI schemes, see
                      `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
                      artifact-locations>`_.

    :param result_type: the return type of the user-defined function. The value can be either a
        ``pyspark.sql.types.DataType`` object or a DDL-formatted type string. Only a primitive
        type or an array ``pyspark.sql.types.ArrayType`` of primitive type are allowed.
        The following classes of result type are supported:

        - "int" or ``pyspark.sql.types.IntegerType``: The leftmost integer that can fit in an
          ``int32`` or an exception if there is none.

        - "long" or ``pyspark.sql.types.LongType``: The leftmost long integer that can fit in an
          ``int64`` or an exception if there is none.

        - ``ArrayType(IntegerType|LongType)``: All integer columns that can fit into the requested
          size.

        - "float" or ``pyspark.sql.types.FloatType``: The leftmost numeric result cast to
          ``float32`` or an exception if there is none.

        - "double" or ``pyspark.sql.types.DoubleType``: The leftmost numeric result cast to
          ``double`` or an exception if there is none.

        - ``ArrayType(FloatType|DoubleType)``: All numeric columns cast to the requested type or
          an exception if there are no numeric columns.

        - "string" or ``pyspark.sql.types.StringType``: The leftmost column converted to ``string``.

        - ``ArrayType(StringType)``: All columns converted to ``string``.

    :return: Spark UDF that applies the model's ``predict`` method to the data and returns a
             type specified by ``result_type``, which by default is a double.
    """

    # Scope Spark import to this method so users don't need pyspark to use non-Spark-related
    # functionality.
    from mlflow.pyfunc.spark_model_cache import SparkModelCache
    from pyspark.sql.functions import pandas_udf
    from pyspark.sql.types import _parse_datatype_string
    from pyspark.sql.types import ArrayType, DataType as SparkDataType
    from pyspark.sql.types import DoubleType, IntegerType, FloatType, LongType, StringType

    if not isinstance(result_type, SparkDataType):
        result_type = _parse_datatype_string(result_type)

    elem_type = result_type
    if isinstance(elem_type, ArrayType):
        elem_type = elem_type.elementType

    supported_types = [
        IntegerType, LongType, FloatType, DoubleType, StringType
    ]

    if not any([isinstance(elem_type, x) for x in supported_types]):
        raise MlflowException(
            message=
            "Invalid result_type '{}'. Result type can only be one of or an array of one "
            "of the following types types: {}".format(str(elem_type),
                                                      str(supported_types)),
            error_code=INVALID_PARAMETER_VALUE,
        )

    with TempDir() as local_tmpdir:
        local_model_path = _download_artifact_from_uri(
            artifact_uri=model_uri, output_path=local_tmpdir.path())
        archive_path = SparkModelCache.add_local_model(spark, local_model_path)

    def predict(*args):
        model = SparkModelCache.get_or_load(archive_path)
        input_schema = model.metadata.get_input_schema()
        pdf = None

        for x in args:
            if type(x) == pandas.DataFrame:
                if len(args) != 1:
                    raise Exception(
                        "If passing a StructType column, there should be only one "
                        "input column, but got %d" % len(args))
                pdf = x
        if pdf is None:
            args = list(args)
            if input_schema is None:
                names = [str(i) for i in range(len(args))]
            else:
                names = input_schema.column_names()
                if len(args) > len(names):
                    args = args[:len(names)]
                if len(args) < len(names):
                    message = (
                        "Model input is missing columns. Expected {0} input columns {1},"
                        " but the model received only {2} unnamed input columns"
                        " (Since the columns were passed unnamed they are expected to be in"
                        " the order specified by the schema).".format(
                            len(names), names, len(args)))
                    raise MlflowException(message)
            pdf = pandas.DataFrame(
                data={names[i]: x
                      for i, x in enumerate(args)}, columns=names)

        result = model.predict(pdf)

        if not isinstance(result, pandas.DataFrame):
            result = pandas.DataFrame(data=result)

        elem_type = result_type.elementType if isinstance(
            result_type, ArrayType) else result_type

        if type(elem_type) == IntegerType:
            result = result.select_dtypes(
                [np.byte, np.ubyte, np.short, np.ushort,
                 np.int32]).astype(np.int32)
        elif type(elem_type) == LongType:
            result = result.select_dtypes(
                [np.byte, np.ubyte, np.short, np.ushort, np.int, np.long])

        elif type(elem_type) == FloatType:
            result = result.select_dtypes(include=(np.number, )).astype(
                np.float32)

        elif type(elem_type) == DoubleType:
            result = result.select_dtypes(include=(np.number, )).astype(
                np.float64)

        if len(result.columns) == 0:
            raise MlflowException(
                message=
                "The the model did not produce any values compatible with the requested "
                "type '{}'. Consider requesting udf with StringType or "
                "Arraytype(StringType).".format(str(elem_type)),
                error_code=INVALID_PARAMETER_VALUE,
            )

        if type(elem_type) == StringType:
            result = result.applymap(str)

        if type(result_type) == ArrayType:
            return pandas.Series(result.to_numpy().tolist())
        else:
            return result[result.columns[0]]

    return pandas_udf(predict, result_type)
Beispiel #7
0
def spark_udf(spark, path, run_id=None, result_type="double"):
    """
    A Spark UDF that can be used to invoke the Python function formatted model.

    Parameters passed to the UDF are forwarded to the model as a DataFrame where the names are
    ordinals (0, 1, ...).

    The predictions are filtered to contain only the columns that can be represented as the
    ``result_type``. If the ``result_type`` is string or array of strings, all predictions are
    converted to string. If the result type is not an array type, the left most column with
    matching type will be returned.

    >>> predict = mlflow.pyfunc.spark_udf(spark, "/my/local/model")
    >>> df.withColumn("prediction", predict("name", "age")).show()

    :param spark: A SparkSession object.
    :param path: A path containing a :py:mod:`mlflow.pyfunc` model.
    :param run_id: ID of the run that produced this model. If provided, ``run_id`` is used to
                   retrieve the model logged with MLflow.
    :param result_type: the return type of the user-defined function. The value can be either a
                        :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
                        Only a primitive type or an array (pyspark.sql.types.ArrayType) of primitive
                        types are allowed. The following classes of result type are supported:
                        - "int" or pyspark.sql.types.IntegerType: The leftmost integer that can fit
                          in int32 result is returned or exception is raised if there is none.
                        - "long" or pyspark.sql.types.LongType: The leftmost long integer that can
                          fit in int64 result is returned or exception is raised if there is none.
                        - ArrayType(IntegerType|LongType): Return all integer columns that can fit
                          into the requested size.
                        - "float" or pyspark.sql.types.FloatType: The leftmost numeric result cast
                          to float32 is returned or exception is raised if there is none.
                        - "double" or pyspark.sql.types.DoubleType: The leftmost numeric result cast
                          to double is returned or exception is raised if there is none..
                        - ArrayType(FloatType|DoubleType): Return all numeric columns cast to the
                          requested type. Exception is raised if there are no numeric columns.
                        - "string" or pyspark.sql.types.StringType: Result is the leftmost column
                          converted to string.
                        - ArrayType(StringType): Return all columns converted to string.

    :return: Spark UDF which will apply model's prediction method to the data. Default double.
    """

    # Scope Spark import to this method so users don't need pyspark to use non-Spark-related
    # functionality.
    from mlflow.pyfunc.spark_model_cache import SparkModelCache
    from pyspark.sql.functions import pandas_udf
    from pyspark.sql.types import _parse_datatype_string
    from pyspark.sql.types import ArrayType, DataType
    from pyspark.sql.types import DoubleType, IntegerType, FloatType, LongType, StringType

    if not isinstance(result_type, DataType):
        result_type = _parse_datatype_string(result_type)

    elem_type = result_type
    if isinstance(elem_type, ArrayType):
        elem_type = elem_type.elementType

    supported_types = [
        IntegerType, LongType, FloatType, DoubleType, StringType
    ]

    if not any([isinstance(elem_type, x) for x in supported_types]):
        raise MlflowException(
            message=
            "Invalid result_type '{}'. Result type can only be one of or an array of one "
            "of the following types types: {}".format(str(elem_type),
                                                      str(supported_types)),
            error_code=INVALID_PARAMETER_VALUE)

    if run_id:
        path = tracking.artifact_utils._get_model_log_dir(path, run_id)

    archive_path = SparkModelCache.add_local_model(spark, path)

    def predict(*args):
        model = SparkModelCache.get_or_load(archive_path)
        schema = {str(i): arg for i, arg in enumerate(args)}
        # Explicitly pass order of columns to avoid lexicographic ordering (i.e., 10 < 2)
        columns = [str(i) for i, _ in enumerate(args)]
        pdf = pandas.DataFrame(schema, columns=columns)
        result = model.predict(pdf)
        if not isinstance(result, pandas.DataFrame):
            result = pandas.DataFrame(data=result)

        elif type(elem_type) == IntegerType:
            result = result.select_dtypes(
                [np.byte, np.ubyte, np.short, np.ushort,
                 np.int32]).astype(np.int32)

        elif type(elem_type) == LongType:
            result = result.select_dtypes(
                [np.byte, np.ubyte, np.short, np.ushort, np.int, np.long])

        elif type(elem_type) == FloatType:
            result = result.select_dtypes(include=np.number).astype(np.float32)

        elif type(elem_type) == DoubleType:
            result = result.select_dtypes(include=np.number).astype(np.float64)

        if len(result.columns) == 0:
            raise MlflowException(
                message=
                "The the model did not produce any values compatible with the requested "
                "type '{}'. Consider requesting udf with StringType or "
                "Arraytype(StringType).".format(str(elem_type)),
                error_code=INVALID_PARAMETER_VALUE)

        if type(elem_type) == StringType:
            result = result.applymap(str)

        if type(result_type) == ArrayType:
            return pandas.Series([row[1].values for row in result.iterrows()])
        else:
            return result[result.columns[0]]

    return pandas_udf(predict, result_type)
Beispiel #8
0
def spark_udf(spark, model_uri, result_type="double"):
    """
    A Spark UDF that can be used to invoke the Python function formatted model.

    Parameters passed to the UDF are forwarded to the model as a DataFrame where the names are
    ordinals (0, 1, ...).

    The predictions are filtered to contain only the columns that can be represented as the
    ``result_type``. If the ``result_type`` is string or array of strings, all predictions are
    converted to string. If the result type is not an array type, the left most column with
    matching type is returned.

    >>> predict = mlflow.pyfunc.spark_udf(spark, "/my/local/model")
    >>> df.withColumn("prediction", predict("name", "age")).show()

    :param spark: A SparkSession object.
    :param model_uri: The location, in URI format, of the MLflow model with the
                      :py:mod:`mlflow.pyfunc` flavor. For example:

                      - ``/Users/me/path/to/local/model``
                      - ``relative/path/to/local/model``
                      - ``s3://my_bucket/path/to/model``
                      - ``runs:/<mlflow_run_id>/run-relative/path/to/model``

                      For more information about supported URI schemes, see
                      `Referencing Artifacts <https://www.mlflow.org/docs/latest/tracking.html#
                      artifact-locations>`_.

    :param result_type: the return type of the user-defined function. The value can be either a
        :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. Only a primitive
        type or an array ``pyspark.sql.types.ArrayType`` of primitive type are allowed.
        The following classes of result type are supported:

        - "int" or ``pyspark.sql.types.IntegerType``: The leftmost integer that can fit in an
          ``int32`` or an exception if there is none.

        - "long" or ``pyspark.sql.types.LongType``: The leftmost long integer that can fit in an
          ``int64`` or an exception if there is none.

        - ``ArrayType(IntegerType|LongType)``: All integer columns that can fit into the requested
          size.

        - "float" or ``pyspark.sql.types.FloatType``: The leftmost numeric result cast to
          ``float32`` or an exception if there is none.

        - "double" or ``pyspark.sql.types.DoubleType``: The leftmost numeric result cast to
          ``double`` or an exception if there is none.

        - ``ArrayType(FloatType|DoubleType)``: All numeric columns cast to the requested type or
          an exception if there are no numeric columns.

        - "string" or ``pyspark.sql.types.StringType``: The leftmost column converted to ``string``.

        - ``ArrayType(StringType)``: All columns converted to ``string``.

    :return: Spark UDF that applies the model's ``predict`` method to the data and returns a
             type specified by ``result_type``, which by default is a double.
    """

    # Scope Spark import to this method so users don't need pyspark to use non-Spark-related
    # functionality.
    from mlflow.pyfunc.spark_model_cache import SparkModelCache
    from pyspark.sql.functions import pandas_udf
    from pyspark.sql.types import _parse_datatype_string
    from pyspark.sql.types import ArrayType, DataType
    from pyspark.sql.types import DoubleType, IntegerType, FloatType, LongType, StringType

    if not isinstance(result_type, DataType):
        result_type = _parse_datatype_string(result_type)

    elem_type = result_type
    if isinstance(elem_type, ArrayType):
        elem_type = elem_type.elementType

    supported_types = [IntegerType, LongType, FloatType, DoubleType, StringType]

    if not any([isinstance(elem_type, x) for x in supported_types]):
        raise MlflowException(
            message="Invalid result_type '{}'. Result type can only be one of or an array of one "
                    "of the following types types: {}".format(str(elem_type), str(supported_types)),
            error_code=INVALID_PARAMETER_VALUE)

    local_model_path = _download_artifact_from_uri(artifact_uri=model_uri)
    archive_path = SparkModelCache.add_local_model(spark, local_model_path)

    def predict(*args):
        model = SparkModelCache.get_or_load(archive_path)
        schema = {str(i): arg for i, arg in enumerate(args)}
        # Explicitly pass order of columns to avoid lexicographic ordering (i.e., 10 < 2)
        columns = [str(i) for i, _ in enumerate(args)]
        pdf = pandas.DataFrame(schema, columns=columns)
        result = model.predict(pdf)
        if not isinstance(result, pandas.DataFrame):
            result = pandas.DataFrame(data=result)

        elif type(elem_type) == IntegerType:
            result = result.select_dtypes([np.byte, np.ubyte, np.short, np.ushort,
                                           np.int32]).astype(np.int32)

        elif type(elem_type) == LongType:
            result = result.select_dtypes([np.byte, np.ubyte, np.short, np.ushort, np.int, np.long])

        elif type(elem_type) == FloatType:
            result = result.select_dtypes(include=(np.number,)).astype(np.float32)

        elif type(elem_type) == DoubleType:
            result = result.select_dtypes(include=(np.number,)).astype(np.float64)

        if len(result.columns) == 0:
            raise MlflowException(
                message="The the model did not produce any values compatible with the requested "
                        "type '{}'. Consider requesting udf with StringType or "
                        "Arraytype(StringType).".format(str(elem_type)),
                error_code=INVALID_PARAMETER_VALUE)

        if type(elem_type) == StringType:
            result = result.applymap(str)

        if type(result_type) == ArrayType:
            return pandas.Series([row[1].values for row in result.iterrows()])
        else:
            return result[result.columns[0]]

    return pandas_udf(predict, result_type)
Beispiel #9
0
def spark_udf(spark, model_uri, features, result_type="double"):
    """
    Create spark pandas udf given the model uri
    :param spark: spark context
    :param model_uri: path to model
    :param features: list containing the feature names
    :param result_type: result type of the model
    :return:
    """
    # Scope Spark import to this method so users don't need pyspark to use non-Spark-related
    # functionality.
    from mlflow.pyfunc.spark_model_cache import SparkModelCache
    from pyspark.sql.functions import pandas_udf
    from pyspark.sql.types import _parse_datatype_string
    from pyspark.sql.types import ArrayType, DataType
    from pyspark.sql.types import DoubleType, IntegerType, FloatType, LongType, StringType

    if not isinstance(result_type, DataType):
        result_type = _parse_datatype_string(result_type)

    elem_type = result_type
    if isinstance(elem_type, ArrayType):
        elem_type = elem_type.elementType

    supported_types = [
        IntegerType, LongType, FloatType, DoubleType, StringType
    ]

    if not any([isinstance(elem_type, x) for x in supported_types]):
        raise ValueError(
            "Invalid result_type '{}'. Result type can only be one of or an array of one "
            "of the following types types: {}".format(str(elem_type),
                                                      str(supported_types)))

    with TempDir() as local_tmpdir:
        local_model_path = _download_artifact_from_uri(
            artifact_uri=model_uri, output_path=local_tmpdir.path())
        archive_path = SparkModelCache.add_local_model(spark, local_model_path)

    def predict(*args):
        model = SparkModelCache.get_or_load(archive_path)
        schema = {features[i]: arg for i, arg in enumerate(args)}
        pdf = None
        for x in args:
            if type(x) == pandas.DataFrame:
                if len(args) != 1:
                    raise Exception(
                        "If passing a StructType column, there should be only one "
                        "input column, but got %d" % len(args))
                pdf = x
        if pdf is None:
            pdf = pandas.DataFrame(schema)
        result = model.predict(pdf)
        if not isinstance(result, pandas.DataFrame):
            result = pandas.DataFrame(data=result)

        elif type(elem_type) == IntegerType:
            result = result.select_dtypes(
                [np.byte, np.ubyte, np.short, np.ushort,
                 np.int32]).astype(np.int32)

        elif type(elem_type) == LongType:
            result = result.select_dtypes(
                [np.byte, np.ubyte, np.short, np.ushort, np.int, np.long])

        elif type(elem_type) == FloatType:
            result = result.select_dtypes(include=(np.number, )).astype(
                np.float32)

        elif type(elem_type) == DoubleType:
            result = result.select_dtypes(include=(np.number, )).astype(
                np.float64)

        if len(result.columns) == 0:
            raise ValueError(
                "The the model did not produce any values compatible with the requested "
                "type '{}'. Consider requesting udf with StringType or "
                "Arraytype(StringType).".format(str(elem_type)))

        if type(elem_type) == StringType:
            result = result.applymap(str)

        if type(result_type) == ArrayType:
            return pandas.Series([row[1].values for row in result.iterrows()])
        else:
            return result[result.columns[0]]

    return pandas_udf(predict, result_type)