def test_restore_to_timestamp(self) -> None: self.__writeDeltaTable([('a', 1), ('b', 2)]) timestampToRestore = DeltaTable.forPath(self.spark, self.tempFile) \ .history() \ .head() \ .timestamp \ .strftime('%Y-%m-%d %H:%M:%S.%f') self.__overwriteDeltaTable([('a', 3), ('b', 2)], schema=["key_new", "value_new"], overwriteSchema='true') overwritten = DeltaTable.forPath(self.spark, self.tempFile).toDF() self.__checkAnswer( overwritten, [Row(key_new='a', value_new=3), Row(key_new='b', value_new=2)]) DeltaTable.forPath( self.spark, self.tempFile).restoreToTimestamp(timestampToRestore) restored = DeltaTable.forPath(self.spark, self.tempFile).toDF() self.__checkAnswer(restored, [Row(key='a', value=1), Row(key='b', value=2)])
def test_restore_to_timestamp(self) -> None: self.__writeDeltaTable([('a', 1), ('b', 2)]) timestampToRestore = DeltaTable.forPath(self.spark, self.tempFile) \ .history() \ .head() \ .timestamp \ .strftime('%Y-%m-%d %H:%M:%S.%f') self.__overwriteDeltaTable([('a', 3), ('b', 2)], schema=["key_new", "value_new"], overwriteSchema='true') overwritten = DeltaTable.forPath(self.spark, self.tempFile).toDF() self.__checkAnswer( overwritten, [Row(key_new='a', value_new=3), Row(key_new='b', value_new=2)]) DeltaTable.forPath( self.spark, self.tempFile).restoreToTimestamp(timestampToRestore) restored = DeltaTable.forPath(self.spark, self.tempFile).toDF() self.__checkAnswer(restored, [Row(key='a', value=1), Row(key='b', value=2)]) # we cannot test the actual working of restore to timestamp here but we can make sure # that the api is being called at least def runRestore() -> None: DeltaTable.forPath(self.spark, self.tempFile).restoreToTimestamp('05/04/1999') self.__intercept( runRestore, "The provided timestamp ('05/04/1999') " "cannot be converted to a valid timestamp")
def get_current_data(self) -> DeltaTable: try: delta_output = DeltaTable.forPath( self.spark_configuration.spark_session, self.delta_src + self.current_data_table_name) except AnalysisException: print("Delta Table not exists -> creating") self._init_current_data() delta_output = DeltaTable.forPath( self.spark_configuration.spark_session, self.delta_src + self.current_data_table_name) return delta_output
def run() -> None: parser = argparse.ArgumentParser() parser.add_argument( '--input', dest='input', required=True, help='input file path', ) parser.add_argument( '--output', dest='output', required=True, help='output delta table path', ) parser.add_argument( '--action', dest='action', default='show', help='action to apply', ) args = parser.parse_args() spark = SparkSession.builder.appName('taxipy') \ .config('spark.jars.packages', 'io.delta:delta-core_2.12:0.7.0') \ .config('spark.sql.extensions', 'io.delta.sql.DeltaSparkSessionExtension') \ .config('spark.sql.catalog.spark_catalog', 'org.apache.spark.sql.delta.catalog.DeltaCatalog') \ .getOrCreate() from delta.tables import DeltaTable taxi_ride_table = DeltaTable.forPath(spark, args.output) taxi_ride_table.update(condition=expr('vendor_id == 2'), set={'total_amount': '123'}) show_records(spark, args.output) spark.stop()
def test_optimize_zorder_by(self) -> None: # write an unoptimized delta table self.spark.createDataFrame([i for i in range(0, 100)], IntegerType()) \ .withColumn("col1", floor(col("value") % 7)) \ .withColumn("col2", floor(col("value") % 27)) \ .withColumn("p", floor(col("value") % 10)) \ .repartition(4).write.partitionBy("p").format("delta").save(self.tempFile) # create DeltaTable dt = DeltaTable.forPath(self.spark, self.tempFile) # execute Z-Order Optimization optimizer = dt.optimize() result = optimizer.executeZOrderBy(["col1", "col2"]) metrics = result.select("metrics.*").head() self.assertTrue(metrics.numFilesAdded == 10) self.assertTrue(metrics.numFilesRemoved == 37) self.assertTrue(metrics.totalFilesSkipped == 0) self.assertTrue(metrics.totalConsideredFiles == 37) self.assertTrue(metrics.zOrderStats.strategyName == 'all') self.assertTrue(metrics.zOrderStats.numOutputCubes == 10) # negative test: Z-Order on partition column def optimize() -> None: dt.optimize().where("p = 1").executeZOrderBy(["p"]) self.__intercept( optimize, "p is a partition column. " "Z-Ordering can only be performed on data columns")
def test_optimize(self) -> None: # write an unoptimized delta table df = self.spark.createDataFrame([("a", 1), ("a", 2)], ["key", "value"]).repartition(1) df.write.format("delta").save(self.tempFile) df = self.spark.createDataFrame([("a", 3), ("a", 4)], ["key", "value"]).repartition(1) df.write.format("delta").save(self.tempFile, mode="append") df = self.spark.createDataFrame([("b", 1), ("b", 2)], ["key", "value"]).repartition(1) df.write.format("delta").save(self.tempFile, mode="append") # create DeltaTable dt = DeltaTable.forPath(self.spark, self.tempFile) # execute bin compaction optimizer = dt.optimize() res = optimizer.executeCompaction() op_params = dt.history().first().operationParameters # assertions self.assertTrue(isinstance(optimizer, DeltaOptimizeBuilder)) self.assertTrue(isinstance(res, DataFrame)) self.assertEqual(1, res.first().metrics.numFilesAdded) self.assertEqual(3, res.first().metrics.numFilesRemoved) self.assertEqual('[]', op_params['predicate']) # test non-partition column def optimize() -> None: dt.optimize().where("key = 'a'").executeCompaction() self.__intercept( optimize, "Predicate references non-partition column 'key'. " "Only the partition columns may be referenced: []")
def test_optimize_zorder_by_w_partition_filter(self) -> None: # write an unoptimized delta table df = self.spark.createDataFrame([i for i in range(0, 100)], IntegerType()) \ .withColumn("col1", floor(col("value") % 7)) \ .withColumn("col2", floor(col("value") % 27)) \ .withColumn("p", floor(col("value") % 10)) \ .repartition(4).write.partitionBy("p") df.format("delta").save(self.tempFile) # create DeltaTable dt = DeltaTable.forPath(self.spark, self.tempFile) # execute Z-OrderBy optimizer = dt.optimize().where("p = 2") result = optimizer.executeZOrderBy(["col1", "col2"]) metrics = result.select("metrics.*").head() # assertions (partition 'p = 2' has four files) self.assertTrue(metrics.numFilesAdded == 1) self.assertTrue(metrics.numFilesRemoved == 4) self.assertTrue(metrics.totalFilesSkipped == 0) self.assertTrue(metrics.totalConsideredFiles == 4) self.assertTrue(metrics.zOrderStats.strategyName == 'all') self.assertTrue(metrics.zOrderStats.numOutputCubes == 1)
def upsert_table(spark, updatesDF, condition, output_file, partition_columns=None): ''' update/insert transformed immigration data to fact_table incase of duplicate id, overwrite the old value with new value, otherwise append it to dataframe ''' from delta.tables import DeltaTable if not os.path.exists(output_file): if partition_columns is None: updatesDF.write.format('delta').save(output_file) else: updatesDF.write.format('delta').partitionBy( *partition_columns).save(output_file) else: deltaTable = DeltaTable.forPath(spark, output_file) deltaTable.alias("source").merge( source = updatesDF.alias("update"), condition = condition) \ .whenMatchedUpdateAll() \ .whenNotMatchedInsertAll() \ .execute()
def merge_write(logger, df_dict: Dict[str, DataFrame], rules: Dict[str, str], output_path: str, spark: SparkSession): """ Write data if the dataset doesn't exist or merge it to the existing dataset Args: logger: Logger instance used to log events df_dict: Dictionary of the datasets with the structure {Name: Dataframe} rules: Matching rules use to merge output_path: Path to write the data spark: Spark instance Returns: """ try: from delta.tables import DeltaTable for df_name, df in df_dict.items(): file_path = path.join(output_path, df_name) if DeltaTable.isDeltaTable(spark, file_path): delta_table = DeltaTable.forPath(spark, file_path) delta_table.alias("old").merge( df.alias("new"), rules.get(df_name) ).whenMatchedUpdateAll().whenNotMatchedInsertAll() else: df.write.format("delta").save(file_path) except Exception as e: logger.error( "Writing sanitized data couldn't be performed: {}".format(e), traceback.format_exc()) raise e else: logger.info("Sanitized dataframes written in {} folder".format( output_path))
def _merge_into_table(self, df, destination_path, checkpoints_path, condition): """ Merges data from the given dataframe into the delta table at the specified destination_path, based on the given condition. If not delta table exists at the specified destination_path, a new delta table is created and the data from the given dataframe is inserted. eg, merge_into_table(df_lookup, np_destination_path, source_path + '/_checkpoints/delta_np', "current.id_pseudonym = updates.id_pseudonym") """ if DeltaTable.isDeltaTable(spark, destination_path): dt = DeltaTable.forPath(spark, destination_path) def upsert(batch_df, batchId): dt.alias("current").merge( batch_df.alias("updates"), condition).whenMatchedUpdateAll( ).whenNotMatchedInsertAll().execute() query = df.writeStream.format("delta").foreachBatch( upsert).outputMode("update").trigger(once=True).option( "checkpointLocation", checkpoints_path) else: logger.info( f'Delta table does not yet exist at {destination_path} - creating one now and inserting initial data.' ) query = df.writeStream.format("delta").outputMode( "append").trigger(once=True).option("checkpointLocation", checkpoints_path) query = query.start(destination_path) query.awaitTermination( ) # block until query is terminated, with stop() or with error; A StreamingQueryException will be thrown if an exception occurs. logger.info(query.lastProgress)
def test_restore_to_version(self) -> None: self.__writeDeltaTable([('a', 1), ('b', 2)]) self.__overwriteDeltaTable([('a', 3), ('b', 2)], schema=["key_new", "value_new"], overwriteSchema='true') overwritten = DeltaTable.forPath(self.spark, self.tempFile).toDF() self.__checkAnswer( overwritten, [Row(key_new='a', value_new=3), Row(key_new='b', value_new=2)]) DeltaTable.forPath(self.spark, self.tempFile).restoreToVersion(0) restored = DeltaTable.forPath(self.spark, self.tempFile).toDF() self.__checkAnswer(restored, [Row(key='a', value=1), Row(key='b', value=2)])
def get_delta_table(path): try: dt= DeltaTable.forPath(spark,path) except AnalysisException as e: if('doesn\'t exist;' in str(e).lower() or 'is not a delta table.' in str(e).lower()): print("Error Occured due to : "+str(e)) return None else: raise e return dt
def _get_delta_table(spark, delta_path, delta_table): if [delta_path, delta_table].count(None) == 2: raise ValueError("delta_path ou delta_table deve ser passado") if delta_path is not None: delta_table = DeltaTable.forPath(spark, delta_path) else: delta_table = DeltaTable.forName(spark, delta_table) return delta_table
def test_history(self) -> None: self.__writeDeltaTable([('a', 1), ('b', 2), ('c', 3)]) self.__overwriteDeltaTable([('a', 3), ('b', 2), ('c', 1)]) dt = DeltaTable.forPath(self.spark, self.tempFile) operations = dt.history().select('operation') self.__checkAnswer( operations, [Row("WRITE"), Row("WRITE")], StructType([StructField("operation", StringType(), True)])) lastMode = dt.history(1).select('operationParameters.mode') self.__checkAnswer( lastMode, [Row("Overwrite")], StructType( [StructField("operationParameters.mode", StringType(), True)]))
def _get_delta_table(self, spark, table_or_path, update_delta_table): try: deltaTable = DeltaTable.forPath(spark, table_or_path) except: try: deltaTable = DeltaTable.forName(spark, table_or_path) except AssertionError as E: raise E if update_delta_table: return deltaTable return deltaTable.toDF()
def merge(spark, update, tableName, cols, key): """ 将DataFrame和delta表进行merge操作,insert操作要求DataFrame必须包含delta表所有的列(0.5版本) 当我们使用merge操作更新/插入delta表其中几列时,指定在DataFrame中不存在的列的值为null。 注:DataFrame中要写入delta表的列要和delta表一样 :param spark,SparkSession实例 :param update,spark DataFrame :param tableName,要更新的delta表 """ # 如果没有dt列,创建当前日期的dt列 if "dt" not in cols: update = update.withColumn("dt", f.current_date()) cols.append("dt") # 1.构建merge条件 mergeExpr = f"origin.{key}=update.{key}" print(f"merge expression:{mergeExpr}") # 2.构建更新表达式 updateExpr = {} for c in cols: updateExpr[c] = f"update.{c}" print(f"update expression:{updateExpr}") origin = DeltaTable.forPath(spark, tableName) origin_cols = origin.toDF().columns # 3.构建插入表达式 insertExpr = {} for origin_col in origin_cols: if origin_col in cols: insertExpr[origin_col] = f"update.{origin_col}" else: # 不存在,插入null值(不是字符串) insertExpr[origin_col] = "null" print(f"insert expression:{insertExpr}") # for origin_col in origin_cols: # if origin_col not in cols: # update=update.withColumn(origin_col,f.lit(None)) origin.alias("origin") \ .merge(update.alias("update"), mergeExpr) \ .whenMatchedUpdate(set=updateExpr) \ .whenNotMatchedInsert(values=insertExpr) \ .execute()
def test_vacuum(self): self.__writeDeltaTable([('a', 1), ('b', 2), ('c', 3)]) dt = DeltaTable.forPath(self.spark, self.tempFile) self.__createFile('abc.txt', 'abcde') self.__createFile('bac.txt', 'abcdf') self.assertEqual(True, self.__checkFileExists('abc.txt')) dt.vacuum() # will not delete files as default retention is used. self.assertEqual(True, self.__checkFileExists('bac.txt')) retentionConf = "spark.databricks.delta.retentionDurationCheck.enabled" self.spark.conf.set(retentionConf, "false") dt.vacuum(0.0) self.spark.conf.set(retentionConf, "true") self.assertEqual(False, self.__checkFileExists('bac.txt')) self.assertEqual(False, self.__checkFileExists('abc.txt'))
def get_delta_table( spark: SparkSession, schema: StructType, delta_library_jar: str, delta_path: str): # load delta library jar, so we can use delta module spark.sparkContext.addPyFile(delta_library_jar) from delta.tables import DeltaTable # check existence of delta table if not DeltaTable.isDeltaTable(spark, delta_path): print(f">>> Delta table: {delta_path} is not initialized, performing initialization..") df = spark.createDataFrame([], schema=schema) df.write.format("delta").save(delta_path) return DeltaTable.forPath(spark, delta_path)
def update_sum_count_or_insert(self, new_data: DataFrame, table: str, id_col: str): try: delta_table = DeltaTable.forPath(self.spark_configuration.spark_session, self.delta_src + table) except AnalysisException: # If delta table not exists just create it new_data.write \ .format("delta") \ .save(self.delta_src + table) return delta_table.alias("current_data").merge( new_data.alias("updates"), "current_data.{0} = updates.{0}".format(id_col))\ .whenMatchedUpdate(set={ "count": "current_data.count + updates.count" }) \ .whenNotMatchedInsertAll() \ .execute()
def merge( self, df: DataFrame, location: str, condition: str, # Only supports SQL-like string condition match_update_dict: dict, # "target_column": "expression" insert_when_not_matched: False, # Set to True for upsert save_mode: str = 'table'): '''Merge a dataframe to target table or path. This merge operation can represent both update and upsert operation. Source and target table is defaultly alias-ed as 'SRC' and 'TGT'. This could be used in condition string and update/insert expressions. Args: df (DataFrame): The source dataframe to write. save_mode (str): 'table' or 'path' location (str): The table name or path to be merge into. condition (str): The condition in SQL-like string form. match_update_dict (dict): Contains ("target_column": "expression"). This represents the updated value if matched. NOTE: "target_column"'s come without schema ("SRC" or "TGT"). not_match_insert_dict (dict): Contains ("target_column": "expression"). This represents the inserted value if not matched. Other columns which are not specified shall be null. NOTE: "target_column"'s come without schema ("SRC" or "TGT"). ''' super(DeltaDataSource, self).merge(df, condition, match_update_dict, insert_when_not_matched=insert_when_not_matched) save_mode = save_mode.lower() if save_mode == "table": target_table = DeltaTable.forName(self.spark, location) elif save_mode == "path": target_table = DeltaTable.forPath(self.spark, location) else: raise ValueError("save_mode should be 'path' or 'table'.") merger = target_table.alias("TGT").merge(df.alias("SRC"), condition) merger = merger.whenMatchedUpdate(set=match_update_dict) if insert_when_not_matched: merger = merger.whenNotMatchedInsert(values=match_update_dict) merger.execute()
def test_generate(self) -> None: # create a delta table numFiles = 10 self.spark.range(100).repartition(numFiles).write.format("delta").save(self.tempFile) dt = DeltaTable.forPath(self.spark, self.tempFile) # Generate the symlink format manifest dt.generate("symlink_format_manifest") # check the contents of the manifest # NOTE: this is not a correctness test, we are testing correctness in the scala suite manifestPath = os.path.join(self.tempFile, os.path.join("_symlink_format_manifest", "manifest")) files = [] with open(manifestPath) as f: files = f.readlines() # the number of files we write should equal the number of lines in the manifest assert(len(files) == numFiles)
def test_delete(self) -> None: self.__writeDeltaTable([('a', 1), ('b', 2), ('c', 3), ('d', 4)]) dt = DeltaTable.forPath(self.spark, self.tempFile) # delete with condition as str dt.delete("key = 'a'") self.__checkAnswer(dt.toDF(), [('b', 2), ('c', 3), ('d', 4)]) # delete with condition as Column dt.delete(col("key") == lit("b")) self.__checkAnswer(dt.toDF(), [('c', 3), ('d', 4)]) # delete without condition dt.delete() self.__checkAnswer(dt.toDF(), []) # bad args with self.assertRaises(TypeError): dt.delete(condition=1) # type: ignore[arg-type]
def test_restore_invalid_inputs(self) -> None: df = self.spark.createDataFrame([('a', 1), ('b', 2), ('c', 3)], ["key", "value"]) df.write.format("delta").save(self.tempFile) dt = DeltaTable.forPath(self.spark, self.tempFile) def runRestoreToTimestamp() -> None: dt.restoreToTimestamp(12342323232) # type: ignore[arg-type] self.__intercept( runRestoreToTimestamp, "timestamp needs to be a string but got '<class 'int'>'") def runRestoreToVersion() -> None: dt.restoreToVersion("0") # type: ignore[arg-type] self.__intercept(runRestoreToVersion, "version needs to be an int but got '<class 'str'>'")
def test_update(self) -> None: self.__writeDeltaTable([('a', 1), ('b', 2), ('c', 3), ('d', 4)]) dt = DeltaTable.forPath(self.spark, self.tempFile) # update with condition as str and with set exprs as str dt.update("key = 'a' or key = 'b'", {"value": "1"}) self.__checkAnswer(dt.toDF(), [('a', 1), ('b', 1), ('c', 3), ('d', 4)]) # update with condition as Column and with set exprs as Columns dt.update(expr("key = 'a' or key = 'b'"), {"value": expr("0")}) self.__checkAnswer(dt.toDF(), [('a', 0), ('b', 0), ('c', 3), ('d', 4)]) # update without condition dt.update(set={"value": "200"}) self.__checkAnswer(dt.toDF(), [('a', 200), ('b', 200), ('c', 200), ('d', 200)]) # bad args with self.assertRaisesRegex(ValueError, "cannot be None"): dt.update({"value": "200"}) # type: ignore[call-overload] with self.assertRaisesRegex(ValueError, "cannot be None"): dt.update(condition='a') # type: ignore[call-overload] with self.assertRaisesRegex(TypeError, "must be a dict"): dt.update(set=1) # type: ignore[call-overload] with self.assertRaisesRegex(TypeError, "must be a Spark SQL Column or a string"): dt.update(1, {}) # type: ignore[call-overload] with self.assertRaisesRegex(TypeError, "Values of dict in .* must contain only"): dt.update(set={"value": 1}) # type: ignore[dict-item] with self.assertRaisesRegex(TypeError, "Keys of dict in .* must contain only"): dt.update(set={1: ""}) # type: ignore[dict-item] with self.assertRaises(TypeError): dt.update(set=1) # type: ignore[call-overload]
def update_silver_table(spark: SparkSession, silverPath: str) -> bool: from delta.tables import DeltaTable silver_df = load_dataframe(spark, format="delta", path=silverPath) silverTable = DeltaTable.forPath(spark, silverPath) update_match = """ health_tracker.eventtime = updates.eventtime AND health_tracker.device_id = updates.device_id """ update = {"heartrate": "updates.heartrate"} updates_df = prepare_interpolated_updates_dataframe(spark, silver_df) (silverTable.alias("health_tracker").merge( updates_df.alias("updates"), update_match).whenMatchedUpdate(set=update).execute()) return True
def test_protocolUpgrade(self) -> None: try: self.spark.conf.set('spark.databricks.delta.minWriterVersion', '2') self.spark.conf.set('spark.databricks.delta.minReaderVersion', '1') self.__writeDeltaTable([('a', 1), ('b', 2), ('c', 3), ('d', 4)]) dt = DeltaTable.forPath(self.spark, self.tempFile) dt.upgradeTableProtocol(1, 3) finally: self.spark.conf.unset('spark.databricks.delta.minWriterVersion') self.spark.conf.unset('spark.databricks.delta.minReaderVersion') # cannot downgrade once upgraded failed = False try: dt.upgradeTableProtocol(1, 2) except BaseException: failed = True self.assertTrue( failed, "The upgrade should have failed, because downgrades aren't allowed" ) # bad args with self.assertRaisesRegex(ValueError, "readerVersion"): dt.upgradeTableProtocol("abc", 3) # type: ignore[arg-type] with self.assertRaisesRegex(ValueError, "readerVersion"): dt.upgradeTableProtocol([1], 3) # type: ignore[arg-type] with self.assertRaisesRegex(ValueError, "readerVersion"): dt.upgradeTableProtocol([], 3) # type: ignore[arg-type] with self.assertRaisesRegex(ValueError, "readerVersion"): dt.upgradeTableProtocol({}, 3) # type: ignore[arg-type] with self.assertRaisesRegex(ValueError, "writerVersion"): dt.upgradeTableProtocol(1, "abc") # type: ignore[arg-type] with self.assertRaisesRegex(ValueError, "writerVersion"): dt.upgradeTableProtocol(1, [3]) # type: ignore[arg-type] with self.assertRaisesRegex(ValueError, "writerVersion"): dt.upgradeTableProtocol(1, []) # type: ignore[arg-type] with self.assertRaisesRegex(ValueError, "writerVersion"): dt.upgradeTableProtocol(1, {}) # type: ignore[arg-type]
def test_alias_and_toDF(self) -> None: self.__writeDeltaTable([('a', 1), ('b', 2), ('c', 3)]) dt = DeltaTable.forPath(self.spark, self.tempFile).toDF() self.__checkAnswer( dt.alias("myTable").select('myTable.key', 'myTable.value'), [('a', 1), ('b', 2), ('c', 3)])
def test_forPath(self) -> None: self.__writeDeltaTable([('a', 1), ('b', 2), ('c', 3)]) dt = DeltaTable.forPath(self.spark, self.tempFile).toDF() self.__checkAnswer(dt, [('a', 1), ('b', 2), ('c', 3)])
def test_merge(self) -> None: self.__writeDeltaTable([('a', 1), ('b', 2), ('c', 3), ('d', 4)]) source = self.spark.createDataFrame([('a', -1), ('b', 0), ('e', -5), ('f', -6)], ["k", "v"]) def reset_table() -> None: self.__overwriteDeltaTable([('a', 1), ('b', 2), ('c', 3), ('d', 4)]) dt = DeltaTable.forPath(self.spark, self.tempFile) # ============== Test basic syntax ============== # String expressions in merge condition and dicts reset_table() dt.merge(source, "key = k") \ .whenMatchedUpdate(set={"value": "v + 0"}) \ .whenNotMatchedInsert(values={"key": "k", "value": "v + 0"}) \ .execute() self.__checkAnswer(dt.toDF(), ([('a', -1), ('b', 0), ('c', 3), ('d', 4), ('e', -5), ('f', -6)])) # Column expressions in merge condition and dicts reset_table() dt.merge(source, expr("key = k")) \ .whenMatchedUpdate(set={"value": col("v") + 0}) \ .whenNotMatchedInsert(values={"key": "k", "value": col("v") + 0}) \ .execute() self.__checkAnswer(dt.toDF(), ([('a', -1), ('b', 0), ('c', 3), ('d', 4), ('e', -5), ('f', -6)])) # ============== Test clause conditions ============== # String expressions in all conditions and dicts reset_table() dt.merge(source, "key = k") \ .whenMatchedUpdate(condition="k = 'a'", set={"value": "v + 0"}) \ .whenMatchedDelete(condition="k = 'b'") \ .whenNotMatchedInsert(condition="k = 'e'", values={"key": "k", "value": "v + 0"}) \ .execute() self.__checkAnswer(dt.toDF(), ([('a', -1), ('c', 3), ('d', 4), ('e', -5)])) # Column expressions in all conditions and dicts reset_table() dt.merge(source, expr("key = k")) \ .whenMatchedUpdate( condition=expr("k = 'a'"), set={"value": col("v") + 0}) \ .whenMatchedDelete(condition=expr("k = 'b'")) \ .whenNotMatchedInsert( condition=expr("k = 'e'"), values={"key": "k", "value": col("v") + 0}) \ .execute() self.__checkAnswer(dt.toDF(), ([('a', -1), ('c', 3), ('d', 4), ('e', -5)])) # Positional arguments reset_table() dt.merge(source, "key = k") \ .whenMatchedUpdate("k = 'a'", {"value": "v + 0"}) \ .whenMatchedDelete("k = 'b'") \ .whenNotMatchedInsert("k = 'e'", {"key": "k", "value": "v + 0"}) \ .execute() self.__checkAnswer(dt.toDF(), ([('a', -1), ('c', 3), ('d', 4), ('e', -5)])) # ============== Test updateAll/insertAll ============== # No clause conditions and insertAll/updateAll + aliases reset_table() dt.alias("t") \ .merge(source.toDF("key", "value").alias("s"), expr("t.key = s.key")) \ .whenMatchedUpdateAll() \ .whenNotMatchedInsertAll() \ .execute() self.__checkAnswer(dt.toDF(), ([('a', -1), ('b', 0), ('c', 3), ('d', 4), ('e', -5), ('f', -6)])) # String expressions in all clause conditions and insertAll/updateAll + aliases reset_table() dt.alias("t") \ .merge(source.toDF("key", "value").alias("s"), "s.key = t.key") \ .whenMatchedUpdateAll("s.key = 'a'") \ .whenNotMatchedInsertAll("s.key = 'e'") \ .execute() self.__checkAnswer(dt.toDF(), ([('a', -1), ('b', 2), ('c', 3), ('d', 4), ('e', -5)])) # Column expressions in all clause conditions and insertAll/updateAll + aliases reset_table() dt.alias("t") \ .merge(source.toDF("key", "value").alias("s"), expr("t.key = s.key")) \ .whenMatchedUpdateAll(expr("s.key = 'a'")) \ .whenNotMatchedInsertAll(expr("s.key = 'e'")) \ .execute() self.__checkAnswer(dt.toDF(), ([('a', -1), ('b', 2), ('c', 3), ('d', 4), ('e', -5)])) # ============== Test bad args ============== # ---- bad args in merge() with self.assertRaisesRegex(TypeError, "must be DataFrame"): dt.merge(1, "key = k") # type: ignore[arg-type] with self.assertRaisesRegex(TypeError, "must be a Spark SQL Column or a string"): dt.merge(source, 1) # type: ignore[arg-type] # ---- bad args in whenMatchedUpdate() with self.assertRaisesRegex(ValueError, "cannot be None"): (dt # type: ignore[call-overload] .merge(source, "key = k") .whenMatchedUpdate({"value": "v"})) with self.assertRaisesRegex(ValueError, "cannot be None"): (dt # type: ignore[call-overload] .merge(source, "key = k") .whenMatchedUpdate(1)) with self.assertRaisesRegex(ValueError, "cannot be None"): (dt # type: ignore[call-overload] .merge(source, "key = k") .whenMatchedUpdate(condition="key = 'a'")) with self.assertRaisesRegex(TypeError, "must be a Spark SQL Column or a string"): (dt # type: ignore[call-overload] .merge(source, "key = k") .whenMatchedUpdate(1, {"value": "v"})) with self.assertRaisesRegex(TypeError, "must be a dict"): (dt # type: ignore[call-overload] .merge(source, "key = k") .whenMatchedUpdate("k = 'a'", 1)) with self.assertRaisesRegex(TypeError, "Values of dict in .* must contain only"): (dt .merge(source, "key = k") .whenMatchedUpdate(set={"value": 1})) # type: ignore[dict-item] with self.assertRaisesRegex(TypeError, "Keys of dict in .* must contain only"): (dt .merge(source, "key = k") .whenMatchedUpdate(set={1: ""})) # type: ignore[dict-item] with self.assertRaises(TypeError): (dt # type: ignore[call-overload] .merge(source, "key = k") .whenMatchedUpdate(set="k = 'a'", condition={"value": 1})) # bad args in whenMatchedDelete() with self.assertRaisesRegex(TypeError, "must be a Spark SQL Column or a string"): dt.merge(source, "key = k").whenMatchedDelete(1) # type: ignore[arg-type] # ---- bad args in whenNotMatchedInsert() with self.assertRaisesRegex(ValueError, "cannot be None"): (dt # type: ignore[call-overload] .merge(source, "key = k") .whenNotMatchedInsert({"value": "v"})) with self.assertRaisesRegex(ValueError, "cannot be None"): dt.merge(source, "key = k").whenNotMatchedInsert(1) # type: ignore[call-overload] with self.assertRaisesRegex(ValueError, "cannot be None"): (dt # type: ignore[call-overload] .merge(source, "key = k") .whenNotMatchedInsert(condition="key = 'a'")) with self.assertRaisesRegex(TypeError, "must be a Spark SQL Column or a string"): (dt # type: ignore[call-overload] .merge(source, "key = k") .whenNotMatchedInsert(1, {"value": "v"})) with self.assertRaisesRegex(TypeError, "must be a dict"): (dt # type: ignore[call-overload] .merge(source, "key = k") .whenNotMatchedInsert("k = 'a'", 1)) with self.assertRaisesRegex(TypeError, "Values of dict in .* must contain only"): (dt .merge(source, "key = k") .whenNotMatchedInsert(values={"value": 1})) # type: ignore[dict-item] with self.assertRaisesRegex(TypeError, "Keys of dict in .* must contain only"): (dt .merge(source, "key = k") .whenNotMatchedInsert(values={1: "value"})) # type: ignore[dict-item] with self.assertRaises(TypeError): (dt # type: ignore[call-overload] .merge(source, "key = k") .whenNotMatchedInsert(values="k = 'a'", condition={"value": 1}))
def _load(self) -> DeltaTable: load_path = self._fs_prefix + str(self._filepath) return DeltaTable.forPath(self._get_spark(), load_path)