Ejemplo n.º 1
0
def test_nested_array():
    schema = StructType([
        StructField("id", IntegerType()),
        StructField("scores", ArrayType(LongType())),
    ])

    assert schema == parse_schema(schema.simpleString())
    assert schema == parse_schema("STRUCT<id:int,scores:ARRAY<bigint>>")
    assert schema == parse_schema("STRUCT<id:int,scores:ARRAY<long>>")
Ejemplo n.º 2
0
def test_modelspec(mlflow_client, resnet_model_uri):
    run_id = mlflow_client.search_model_versions("name='rikai-test'")[0].run_id
    run = mlflow_client.get_run(run_id=run_id)
    spec = MlflowModelSpec(
        "runs:/{}/model".format(run_id),
        run.data.tags,
        run.data.params,
        tracking_uri="fake",
    )
    assert spec.flavor == "pytorch"
    assert spec.schema == parse_schema(
        (
            "struct<boxes:array<array<float>>, "
            "scores:array<float>, labels:array<int>>"
        )
    )
    assert spec._spec["transforms"]["pre"] == (
        "rikai.contrib.torch.transforms.fasterrcnn_resnet50_fpn"
        ".pre_processing"
    )
    assert spec._spec["transforms"]["post"] == (
        "rikai.contrib.torch.transforms.fasterrcnn_resnet50_fpn."
        "post_processing"
    )
    assert spec.uri == "runs:/" + run_id + "/model"
Ejemplo n.º 3
0
def check_ml_predict(spark: SparkSession, model_name: str):

    # TODO: Replace uri string with Image class after GH#90 is released with
    # the upstream spark
    df = spark.createDataFrame(
        [
            # http://cocodataset.org/#explore?id=484912
            Row(
                uri="http://farm2.staticflickr.com/1129/4726871278_4dd241a03a_z.jpg"  # noqa
            ),
            # https://cocodataset.org/#explore?id=433013
            Row(
                uri="http://farm4.staticflickr.com/3726/9457732891_87c6512b62_z.jpg"  # noqa
            ),
        ],
    )
    df.createOrReplaceTempView("df")

    predictions = spark.sql(
        f"SELECT ML_PREDICT({model_name}, uri) as predictions FROM df"
    )
    predictions.show()
    assert predictions.schema == StructType(
        [
            StructField(
                "predictions",
                StructType(
                    [
                        StructField(
                            "boxes",
                            ArrayType(ArrayType(FloatType())),
                        ),
                        StructField("scores", ArrayType(FloatType())),
                        StructField("labels", ArrayType(IntegerType())),
                    ]
                ),
            ),
        ]
    )
    assert predictions.schema == StructType(
        [
            StructField(
                "predictions",
                parse_schema(
                    "STRUCT<boxes:ARRAY<ARRAY<float>>, scores:ARRAY<float>, labels:ARRAY<int>>"  # noqa
                ),
            )
        ]
    )

    assert predictions.count() == 2
Ejemplo n.º 4
0
def test_yaml_model(spark: SparkSession, resnet_spec: str):
    spark.sql("CREATE MODEL resnet_m USING 'file://{}'".format(resnet_spec))

    # TODO: Replace uri string with Image class after GH#90 is released with
    # the upstream spark
    df = spark.createDataFrame(
        [
            # http://cocodataset.org/#explore?id=484912
            Row(uri=
                "http://farm2.staticflickr.com/1129/4726871278_4dd241a03a_z.jpg"  # noqa
                ),
            # https://cocodataset.org/#explore?id=433013
            Row(uri=
                "http://farm4.staticflickr.com/3726/9457732891_87c6512b62_z.jpg"  # noqa
                ),
        ], )
    df.createOrReplaceTempView("df")

    predictions = spark.sql(
        "SELECT ML_PREDICT(resnet_m, uri) as predictions FROM df")
    predictions.show()
    assert predictions.schema == StructType([
        StructField(
            "predictions",
            StructType([
                StructField(
                    "boxes",
                    ArrayType(ArrayType(FloatType())),
                ),
                StructField("scores", ArrayType(FloatType())),
                StructField("labels", ArrayType(IntegerType())),
            ]),
        ),
    ])
    assert predictions.schema == StructType([
        StructField(
            "predictions",
            parse_schema(
                "struct<boxes:array<array<float>>, scores:array<float>, labels:array<int>>"  # noqa
            ),
        )
    ])

    assert predictions.count() == 2
Ejemplo n.º 5
0
def test_modelspec(mlflow_client: MlflowClient):
    mv = mlflow_client.search_model_versions("name='rikai-test'")[0]
    run = mlflow_client.get_run(run_id=mv.run_id)
    spec = MlflowModelSpec(
        "models:/rikai-test/{}".format(mv.version),
        run.data.tags,
        tracking_uri="fake",
    )
    assert spec.flavor == "pytorch"
    assert spec.schema == parse_schema(
        "STRUCT<boxes:ARRAY<ARRAY<float>>,"
        "scores:ARRAY<float>, labels:ARRAY<int>>")
    assert spec._spec["transforms"]["pre"] == (
        "rikai.contrib.torch.transforms.fasterrcnn_resnet50_fpn"
        ".pre_processing")
    assert spec._spec["transforms"]["post"] == (
        "rikai.contrib.torch.transforms.fasterrcnn_resnet50_fpn."
        "post_processing")
    assert spec.model_uri == "models:/rikai-test/{}".format(mv.version)
Ejemplo n.º 6
0
 def schema(self) -> str:
     """Return the output schema of the model."""
     return parse_schema(self._spec["schema"])
Ejemplo n.º 7
0
def test_bad_schema():
    with pytest.raises(SchemaError, match=r".*Invalid schema.*"):
        parse_schema("")
    with pytest.raises(SchemaError, match=r".*Can not recognize type.*"):
        parse_schema("foo,bar")
Ejemplo n.º 8
0
def test_invalid_identifier():
    with pytest.raises(SchemaError, match=r".*can not start with a digit.*"):
        parse_schema("STRUCT<0id:int>")
    with pytest.raises(SchemaError, match=r".*can not start with a digit.*"):
        parse_schema("STRUCT<id:8float>")
Ejemplo n.º 9
0
def test_primitives():
    assert BooleanType() == parse_schema("bool")
    assert BooleanType() == parse_schema("boolean")

    assert ByteType() == parse_schema("byte")
    assert ByteType() == parse_schema("tinyint")

    assert ShortType() == parse_schema("short")
    assert ShortType() == parse_schema("smallint")

    assert IntegerType() == parse_schema("int")
    assert FloatType() == parse_schema("float")
    assert DoubleType() == parse_schema("double")

    assert StringType() == parse_schema("string")
    assert BinaryType() == parse_schema("binary")
Ejemplo n.º 10
0
def test_parse_schema():
    struct = parse_schema("STRUCT<foo:int,bar:string>")
    assert struct == StructType(
        [StructField("foo", IntegerType()),
         StructField("bar", StringType())])