def test_compare_column_schemas(): expected = StructType([ StructField("a_float", FloatType()), StructField("b_long", LongType()), StructField("c_str", StringType()), ]) missing_col = StructType([ StructField("a_float", FloatType()), StructField("b_long", LongType()) ]) assert spark_util.compare_column_schemas(expected, missing_col) == False incorrect_type = StructType([ StructField("b_long", LongType()), StructField("a_float", FloatType()), StructField("c_str", LongType()), ]) assert spark_util.compare_column_schemas(expected, incorrect_type) == False actual = StructType([ StructField("b_long", LongType()), StructField("a_float", FloatType()), StructField("c_str", StringType()), ]) assert spark_util.compare_column_schemas(expected, actual) == True
def test_get_expected_schema_from_context_parquet(ctx_obj, get_context): ctx_obj["environment"] = { "data": { "type": "parquet", "schema": [ { "column_name": "a_str", "feature_name": "a_str" }, { "column_name": "b_float", "feature_name": "b_float" }, { "column_name": "c_long", "feature_name": "c_long" }, ], } } ctx_obj["raw_features"] = { "b_float": { "name": "b_float", "type": "FLOAT_FEATURE", "required": True, "id": "-" }, "c_long": { "name": "c_long", "type": "INT_FEATURE", "required": False, "id": "-" }, "a_str": { "name": "a_str", "type": "STRING_FEATURE", "required": True, "id": "-" }, } ctx = get_context(ctx_obj) expected_output = StructType([ StructField("c_long", LongType(), True), StructField("b_float", FloatType(), False), StructField("a_str", StringType(), False), ]) actual = spark_util.expected_schema_from_context(ctx) assert spark_util.compare_column_schemas(actual, expected_output) == True
def test_get_expected_schema_from_context_csv(ctx_obj, get_context): ctx_obj["environment"] = { "data": { "type": "csv", "schema": ["income", "years_employed", "prior_default"] } } ctx_obj["raw_features"] = { "income": { "name": "income", "type": "FLOAT_FEATURE", "required": True, "id": "-" }, "years_employed": { "name": "years_employed", "type": "INT_FEATURE", "required": False, "id": "-", }, "prior_default": { "name": "prior_default", "type": "STRING_FEATURE", "required": True, "id": "-", }, } ctx = get_context(ctx_obj) expected_output = StructType([ StructField("years_employed", LongType(), True), StructField("income", FloatType(), False), StructField("prior_default", StringType(), False), ]) actual = spark_util.expected_schema_from_context(ctx) assert spark_util.compare_column_schemas(actual, expected_output) == True