コード例 #1
0
ファイル: test_utils.py プロジェクト: smellslikeml/rikai
def test_df_to_rikai(spark: SparkSession, tmp_path: Path):
    df = spark.createDataFrame(
        [Row(Box2d(1, 2, 3, 4)), Row(Box2d(23, 33, 44, 88))], ["bbox"]
    )
    df_to_rikai(df, str(tmp_path))
    actual_df = spark.read.format("rikai").load(str(tmp_path))
    assert_count_equal(df.collect(), actual_df.collect())
コード例 #2
0
ファイル: test_parquet_udt.py プロジェクト: da-tubi/rikai
def test_bbox(spark: SparkSession, tmp_path: Path):
    test_dir = str(tmp_path)
    df = spark.createDataFrame([Row(b=Box2d(1, 2, 3, 4))])
    df.write.mode("overwrite").format("rikai").save(test_dir)

    records = _read_parquets(test_dir)

    assert_count_equal([{"b": Box2d(1, 2, 3, 4)}], records)
コード例 #3
0
ファイル: test_parquet_udt.py プロジェクト: da-tubi/rikai
def test_images(spark: SparkSession, tmp_path):
    expected = [
        {
            "id": 1,
            "image": Image(uri="s3://123"),
        },
        {
            "id": 2,
            "image": Image(uri="s3://abc"),
        },
    ]
    df = spark.createDataFrame(expected)
    df.write.mode("overwrite").parquet(str(tmp_path))

    records = sorted(_read_parquets(str(tmp_path)), key=lambda x: x["id"])
    assert_count_equal(expected, records)
コード例 #4
0
ファイル: test_dataset.py プロジェクト: smellslikeml/rikai
def test_select_columns(spark: SparkSession, tmp_path: Path):
    """Test reading rikai dataset with selected columns."""
    df = spark.createDataFrame([
        Row(id=1, col1="value", col2=123),
        Row(id=2, col1="more", col2=456),
    ])
    df.write.format("rikai").save(str(tmp_path))

    dataset = Dataset(str(tmp_path), columns=["id", "col1"])
    actual = sorted(list(dataset), key=lambda x: x["id"])

    assert_count_equal([{
        "id": 1,
        "col1": "value"
    }, {
        "id": 2,
        "col1": "more"
    }], actual)
コード例 #5
0
ファイル: test_types.py プロジェクト: da-tubi/rikai
def _check_roundtrip(spark: SparkSession, df: DataFrame, tmp_path: Path):
    df.show()
    df.write.mode("overwrite").format("rikai").save(str(tmp_path))
    actual_df = spark.read.format("rikai").load(str(tmp_path))
    assert_count_equal(df.collect(), actual_df.collect())