def test_get_compute_domain_with_unmeetable_row_condition(spark_session): pd_df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [2, 3, 4, None]}) df = spark_session.createDataFrame( [ tuple( None if isinstance(x, (float, int)) and np.isnan(x) else x for x in record.tolist() ) for record in pd_df.to_records(index=False) ], pd_df.columns.tolist(), ) expected_df = df.filter(F.col("b") > 24) engine = SparkDFExecutionEngine() engine.load_batch_data(batch_data=df, batch_id="1234") data, compute_kwargs, accessor_kwargs = engine.get_compute_domain( domain_kwargs={"row_condition": "b > 24", "condition_parser": "spark"}, domain_type=MetricDomainTypes.TABLE, ) # Ensuring data has been properly queried assert data.schema == expected_df.schema assert data.collect() == expected_df.collect() # Ensuring compute kwargs have not been modified assert "row_condition" in compute_kwargs.keys() assert accessor_kwargs == {}
def test_dataframe_property_given_loaded_batch(spark_session): engine = SparkDFExecutionEngine() df = pd.DataFrame({"a": [1, 5, 22, 3, 5, 10]}) df = spark_session.createDataFrame(df) # Loading batch data engine.load_batch_data(batch_data=df, batch_id="1234") # Ensuring Data not distorted assert engine.dataframe == df
def test_dataframe_property_given_loaded_batch(): from pyspark.sql import SparkSession engine = SparkDFExecutionEngine() df = pd.DataFrame({"a": [1, 5, 22, 3, 5, 10]}) spark = SparkSession.builder.getOrCreate() df = spark.createDataFrame(df) # Loading batch data engine.load_batch_data(batch_data=df, batch_id="1234") # Ensuring Data not distorted assert engine.dataframe == df
def _build_spark_engine(df, spark_session): df = spark_session.createDataFrame( [ tuple( None if isinstance(x, (float, int)) and np.isnan(x) else x for x in record.tolist() ) for record in df.to_records(index=False) ], df.columns.tolist(), ) engine = SparkDFExecutionEngine() engine.load_batch_data("my_id", SparkDFBatchData(engine, df)) return engine
def test_get_compute_domain_with_column_domain(spark_session): pd_df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [2, 3, 4, None]}) df = spark_session.createDataFrame( [ tuple(None if isinstance(x, (float, int)) and np.isnan(x) else x for x in record.tolist()) for record in pd_df.to_records(index=False) ], pd_df.columns.tolist(), ) engine = SparkDFExecutionEngine() engine.load_batch_data(batch_data=df, batch_id="1234") data, compute_kwargs, accessor_kwargs = engine.get_compute_domain( domain_kwargs={"column": "a"}, domain_type=MetricDomainTypes.COLUMN) assert compute_kwargs is not None, "Compute domain kwargs should be existent" assert accessor_kwargs == {"column": "a"} assert data.schema == df.schema assert data.collect() == df.collect()