def test_should_combine_disjoint_fields():
        # given
        class ParentStruct(Struct):
            parent_field = String()

        class ChildStruct(ParentStruct):
            child_field = Float()

        # when, then
        assert schema(ParentStruct) == StructType([StructField("parent_field", StringType())])
        assert schema(ChildStruct) == StructType(
            [StructField("parent_field", StringType()), StructField("child_field", FloatType())]
        )
    def test_should_inherit_includes():
        # given
        class SiblingStruct(Struct):
            sibling_field = Integer()

        class ParentStruct(Struct):
            class Meta:
                includes = [SiblingStruct]

            parent_field = String()

        class ChildStruct(ParentStruct):
            child_field = Float()

        # when
        spark_schema = schema(ChildStruct)

        # then

        assert spark_schema == StructType(
            [
                StructField("parent_field", StringType()),
                StructField("sibling_field", IntegerType()),
                StructField("child_field", FloatType()),
            ]
        )
    def test_should_structise_object_containing_array_of_objects():
        # given
        class Tag(Struct):
            id = String(nullable=False)
            name = String()

        class Article(Struct):
            id = String(nullable=False)
            tags = Array(Tag(nullable=True))

        # when
        struct = schema(Article)

        # then
        assert struct == StructType(
            [
                StructField("id", StringType(), nullable=False),
                StructField(
                    "tags",
                    ArrayType(
                        containsNull=True,
                        elementType=StructType(
                            [StructField("id", StringType(), nullable=False), StructField("name", StringType())]
                        ),
                    ),
                    nullable=True,
                ),
            ]
        )
    def test_should_structise_deep_object():
        # given
        class User(Struct):
            id = String(nullable=False)
            age = Float()
            full_name = String(name="name")

        class Article(Struct):
            author = User(name="article_author", nullable=False)
            title = String(nullable=False)
            date = Timestamp()

        # when
        struct = schema(Article)

        # then
        assert struct == StructType(
            [
                StructField(
                    "article_author",
                    nullable=False,
                    dataType=StructType(
                        [
                            StructField("id", StringType(), nullable=False),
                            StructField("age", FloatType()),
                            StructField("name", StringType()),
                        ]
                    ),
                ),
                StructField("title", StringType(), nullable=False),
                StructField("date", TimestampType()),
            ]
        )
Exemple #5
0
def test_stringified_schema():
    # given

    # when
    generated_schema = pretty_schema(schema(Conference))

    # then
    assert generated_schema == prettified_schema.strip()
Exemple #6
0
def test_sparkql_stringified_schema():
    # given

    # when
    generated_schema = pretty_schema(schema(arrays.Article))

    # then
    assert generated_schema == arrays.prettified_schema.strip()
    def test_should_allow_override_with_same_type():
        # given
        class ParentStruct(Struct):
            parent_field = String()

        class ChildStruct(ParentStruct):
            parent_field = String()

        # when, then
        assert schema(ChildStruct) == StructType([StructField("parent_field", StringType())])
Exemple #8
0
    def test_should_combine_overlapping_includes():
        # given
        class AnObject(Struct):
            field_z = String()

        class AnotherObject(Struct):
            field_z = String()

        # when
        class CompositeObject(Struct):
            class Meta:
                includes = [AnObject, AnotherObject]

        composite_schema = schema(CompositeObject)

        # expect
        assert composite_schema == StructType(
            [StructField("field_z", StringType())])
    def test_should_structise_flat_object():
        # given
        class User(Struct):
            id = String(nullable=False)
            age = Float()
            full_name = String(name="name")

        # when
        struct = schema(User)

        # then
        assert struct == StructType(
            [
                StructField("id", StringType(), nullable=False),
                StructField("age", FloatType()),
                StructField("name", StringType()),
            ]
        )
Exemple #10
0
    def test_should_combine_disjoint_includes():
        # given
        class AnObject(Struct):
            field_a = String()

        class AnotherObject(Struct):
            field_b = String()

        # when
        class CompositeObject(Struct):
            class Meta:
                includes = [AnObject, AnotherObject]

            native_field = String()

        composite_schema = schema(CompositeObject)

        # expect
        assert composite_schema == StructType([
            StructField("native_field", StringType()),
            StructField("field_a", StringType()),
            StructField("field_b", StringType()),
        ])
Exemple #11
0
def test_inheritance_registration_schema():
    generated_schema = pretty_schema(schema(inheritance.RegistrationEvent))
    assert generated_schema == inheritance.prettified_registration_event_schema.strip(
    )
Exemple #12
0
def test_includes_registration_schema():
    generated_schema = pretty_schema(schema(includes.RegistrationEvent))
    assert generated_schema == includes.prettified_registration_event_schema.strip(
    )
Exemple #13
0
"""

#
# Create a data frame with some dummy data

spark = SparkSession.builder.appName(
    "conferences-comparison-demo").getOrCreate()
dframe = spark.createDataFrame([{
    str(Conference.name): "PyCon UK 2019",
    str(Conference.city): {
        str(City.name): "Cardiff",
        str(City.latitude): 51.48,
        str(City.longitude): -3.18
    }
}],
                               schema=schema(Conference))

#
# Munge some data

dframe = dframe.withColumn("city_name", path_col(Conference.city.name))

#
# Here's what the output looks like

expected_rows = [{
    'name': 'PyCon UK 2019',
    'city': {
        'name': 'Cardiff',
        'latitude': 51.47999954223633,
        'longitude': -3.180000066757202