def test_should_combine_disjoint_fields(): # given class ParentStruct(Struct): parent_field = String() class ChildStruct(ParentStruct): child_field = Float() # when, then assert schema(ParentStruct) == StructType([StructField("parent_field", StringType())]) assert schema(ChildStruct) == StructType( [StructField("parent_field", StringType()), StructField("child_field", FloatType())] )
def test_should_inherit_includes(): # given class SiblingStruct(Struct): sibling_field = Integer() class ParentStruct(Struct): class Meta: includes = [SiblingStruct] parent_field = String() class ChildStruct(ParentStruct): child_field = Float() # when spark_schema = schema(ChildStruct) # then assert spark_schema == StructType( [ StructField("parent_field", StringType()), StructField("sibling_field", IntegerType()), StructField("child_field", FloatType()), ] )
def test_should_structise_object_containing_array_of_objects(): # given class Tag(Struct): id = String(nullable=False) name = String() class Article(Struct): id = String(nullable=False) tags = Array(Tag(nullable=True)) # when struct = schema(Article) # then assert struct == StructType( [ StructField("id", StringType(), nullable=False), StructField( "tags", ArrayType( containsNull=True, elementType=StructType( [StructField("id", StringType(), nullable=False), StructField("name", StringType())] ), ), nullable=True, ), ] )
def test_should_structise_deep_object(): # given class User(Struct): id = String(nullable=False) age = Float() full_name = String(name="name") class Article(Struct): author = User(name="article_author", nullable=False) title = String(nullable=False) date = Timestamp() # when struct = schema(Article) # then assert struct == StructType( [ StructField( "article_author", nullable=False, dataType=StructType( [ StructField("id", StringType(), nullable=False), StructField("age", FloatType()), StructField("name", StringType()), ] ), ), StructField("title", StringType(), nullable=False), StructField("date", TimestampType()), ] )
def test_stringified_schema(): # given # when generated_schema = pretty_schema(schema(Conference)) # then assert generated_schema == prettified_schema.strip()
def test_sparkql_stringified_schema(): # given # when generated_schema = pretty_schema(schema(arrays.Article)) # then assert generated_schema == arrays.prettified_schema.strip()
def test_should_allow_override_with_same_type(): # given class ParentStruct(Struct): parent_field = String() class ChildStruct(ParentStruct): parent_field = String() # when, then assert schema(ChildStruct) == StructType([StructField("parent_field", StringType())])
def test_should_combine_overlapping_includes(): # given class AnObject(Struct): field_z = String() class AnotherObject(Struct): field_z = String() # when class CompositeObject(Struct): class Meta: includes = [AnObject, AnotherObject] composite_schema = schema(CompositeObject) # expect assert composite_schema == StructType( [StructField("field_z", StringType())])
def test_should_structise_flat_object(): # given class User(Struct): id = String(nullable=False) age = Float() full_name = String(name="name") # when struct = schema(User) # then assert struct == StructType( [ StructField("id", StringType(), nullable=False), StructField("age", FloatType()), StructField("name", StringType()), ] )
def test_should_combine_disjoint_includes(): # given class AnObject(Struct): field_a = String() class AnotherObject(Struct): field_b = String() # when class CompositeObject(Struct): class Meta: includes = [AnObject, AnotherObject] native_field = String() composite_schema = schema(CompositeObject) # expect assert composite_schema == StructType([ StructField("native_field", StringType()), StructField("field_a", StringType()), StructField("field_b", StringType()), ])
def test_inheritance_registration_schema(): generated_schema = pretty_schema(schema(inheritance.RegistrationEvent)) assert generated_schema == inheritance.prettified_registration_event_schema.strip( )
def test_includes_registration_schema(): generated_schema = pretty_schema(schema(includes.RegistrationEvent)) assert generated_schema == includes.prettified_registration_event_schema.strip( )
""" # # Create a data frame with some dummy data spark = SparkSession.builder.appName( "conferences-comparison-demo").getOrCreate() dframe = spark.createDataFrame([{ str(Conference.name): "PyCon UK 2019", str(Conference.city): { str(City.name): "Cardiff", str(City.latitude): 51.48, str(City.longitude): -3.18 } }], schema=schema(Conference)) # # Munge some data dframe = dframe.withColumn("city_name", path_col(Conference.city.name)) # # Here's what the output looks like expected_rows = [{ 'name': 'PyCon UK 2019', 'city': { 'name': 'Cardiff', 'latitude': 51.47999954223633, 'longitude': -3.180000066757202