def test_create_schema_output_path(sample_df): location = 'hdfs:///test/folder' location_schema = schema.create_schema(sample_df, "test_db", "test_table", output_path=location) default_schema = schema.create_schema(sample_df, "test_db", "test_table") assert f"LOCATION '{location}'" in location_schema assert "LOCATION" not in default_schema assert location not in default_schema
def test_create_schema_partition_col(sample_df): non_present_partition = "dt" present_partition = "created" non_present_schema = schema.create_schema(sample_df, "test_db", "test_table", partition_col=non_present_partition) present_schema = schema.create_schema(sample_df, "test_db", "test_table", partition_col=present_partition) non_partitioned_schema = schema.create_schema(sample_df, "test_db", "test_table") assert f"PARTITIONED BY ({non_present_partition} STRING)" in non_present_schema assert f"PARTITIONED BY ({present_partition} STRING)" in present_schema assert f"{present_partition} string," not in present_schema assert f"{present_partition} string," in non_partitioned_schema
def test_read_different_order(spark): """ Here we add a new column before the old one, to check for """ df1 = (spark .range(100) .withColumn('p', sf.lit(1)) ) df2 = (spark .range(100) .withColumn('new_id', sf.col('id')) .drop('id') .withColumn('id', sf.col('new_id')) .withColumn('p', sf.lit(2)) ) path = '/tmp/test' spark.sql('create database test') df1.write.partitionBy('p').saveAsTable('test.df', format='parquet', path=path, mode='overwrite') _schema = schema.create_schema(df2, 'test', 'df', external=True, output_path=path, partition_col='p') spark.sql('drop table test.df') spark.sql(_schema) df2.write.format('parquet').mode('append').partitionBy('p').save(path) spark.sql('msck repair table test.df') assert set(spark.sql("SELECT DISTINCT p FROM test.df") .rdd .map(lambda row: int(row['p'])).collect()) == {1, 2} # if the order doesn't matter, then new_id for p = 1 should always be NULL, so the count should # be zero. If the order count (i.e. the new_id column will be filled with data from id in p = 1) # then the count won't be 0 assert spark.sql("""select distinct new_id from test.df where p = 1 and new_id is not null""").count() == 0
def test_create_schema_happy_flow(sample_df): result = """CREATE EXTERNAL TABLE IF NOT EXISTS test_db.test_table ( author struct<active:boolean,avatarUrls:struct<16x16:string,48x48:string>,displayName:string,emailAddress:string,name:string,self:string>, created string, id string, items array<struct<field:string,fieldtype:string,from:string,fromString:string,to:string,toString:string>> ) PARTITIONED BY (dt STRING) STORED AS PARQUET LOCATION 'hdfs:///test/folder'""" assert result == schema.create_schema(sample_df, "test_db", "test_table", partition_col="dt", format_output='parquet', output_path='hdfs:///test/folder', external=True)
def test_create_schema_faulty_output(sample_df): with pytest.raises(KeyError): schema.create_schema(sample_df, '', '', format_output="wrong")
def main(input, format_output, database='default', table_name='', output_path=None, mode_output='append', partition_col='dt', partition_with=None, spark=None, **kwargs): r""" :param input: Either the location location for the data to load, which will be passed to `.load` in Spark. Or the dataframe that contains the data. :param format_output str: One of `parquet` and `com.databricks.spark.csv` at the moment :param database str: The Hive database where to write the output :param table str: The Hive table where to write the output :param output_path str: The table location :param mode_output str: Anything accepted by Spark's `.write.mode()`. :param partition_col str: The partition column :param partition_function: A Spark Column expression for the `partition_col`. If not present, `partition_col` should already be in the input data :Keyword Arguments: * *spark_config* (``dict``) -- This dictionaries contains options to be passed when building a `SparkSession` (for example `{'master': 'yarn'}`) * *format* (``str``) -- The format the data will be in. All options supported by Spark. Parquet is the default. * *header* (``bool``) -- If reading a csv file, this will tell if the header is present (and use the schema) * *schema* (``pyspark.sql.types.StructType``) The input schema * *master* (``str``) -- Specify which `master` should be used * *repartition* (``bool``) -- Whether to partition the data by partition column before writing. This reduces the number of small files written by Spark * *key* (``str``) In principle all `key` if accepted by `spark.read.options`, by `findspark.init()`, or by `SparkSession.builder.config` :Example: >>> import pyspark.sql.functions as sf >>> column = 'a_column_with_unixtime' >>> partition_function = lambda column: sf.from_unixtime(sf.col(column), fmt='yyyy-MM-dd') >>> from spark_partitionr import main >>> main('hdfs:///data/some_data', 'parquet', 'my_db', 'my_tbl', mode_output='overwrite', ... partition_col='dt', partition_with=partition_function('a_col'), ... master='yarn', format='com.databricks.spark.csv', ... header=True, to_unnest=['deeply_nested_column']) """ sanitized_table = sanitize_table_name(table_name) if not spark: spark = create_spark_session(database, sanitized_table, **kwargs) if isinstance(input, str): df = load_data(spark, input, **kwargs) else: df = input try: old_df = spark.read.table("{}.{}".format(database, sanitized_table)) new = False except Exception as e: # spark exception new = True old_table_name = sanitized_table if not new: try: par_col = partition_col if partition_with else None schema_equal, schema_compatible = check_compatibility(df, old_df, format_output, par_col) old_table_name, sanitized_table = check_external(spark, database, sanitized_table, schema_equal, schema_compatible) except schema.SchemaError as e: # we set this as no table movement should take place schema_compatible = False _, schema_backward_compatible = check_compatibility(old_df, df, format_output) if schema_backward_compatible: df = try_impose_schema(spark, input, old_df.schema, **kwargs) else: raise schema.SchemaError('Schemas are not compatible in both direction') df_schema = schema.create_schema(df, database, sanitized_table, partition_col, format_output, output_path, **kwargs) spark.sql(df_schema) partitioned_df = add_partition_column(df, partition_col, partition_with) if not output_path: output_path = get_output_path(spark, database, sanitized_table) write_data(partitioned_df, format_output, mode_output, partition_col, output_path, **kwargs) repair_partitions(spark, database, sanitized_table) if not new and old_table_name != sanitized_table: move_table(spark, from_database=database, from_table=sanitized_table, to_database=database, to_table=old_table_name)