Exemplo n.º 1
0
    def test_restore_to_timestamp(self) -> None:
        self.__writeDeltaTable([('a', 1), ('b', 2)])
        timestampToRestore = DeltaTable.forPath(self.spark, self.tempFile) \
            .history() \
            .head() \
            .timestamp \
            .strftime('%Y-%m-%d %H:%M:%S.%f')

        self.__overwriteDeltaTable([('a', 3), ('b', 2)],
                                   schema=["key_new", "value_new"],
                                   overwriteSchema='true')

        overwritten = DeltaTable.forPath(self.spark, self.tempFile).toDF()
        self.__checkAnswer(
            overwritten,
            [Row(key_new='a', value_new=3),
             Row(key_new='b', value_new=2)])

        DeltaTable.forPath(
            self.spark, self.tempFile).restoreToTimestamp(timestampToRestore)

        restored = DeltaTable.forPath(self.spark, self.tempFile).toDF()
        self.__checkAnswer(restored,
                           [Row(key='a', value=1),
                            Row(key='b', value=2)])
Exemplo n.º 2
0
    def test_restore_to_timestamp(self) -> None:
        self.__writeDeltaTable([('a', 1), ('b', 2)])
        timestampToRestore = DeltaTable.forPath(self.spark, self.tempFile) \
            .history() \
            .head() \
            .timestamp \
            .strftime('%Y-%m-%d %H:%M:%S.%f')

        self.__overwriteDeltaTable([('a', 3), ('b', 2)],
                                   schema=["key_new", "value_new"],
                                   overwriteSchema='true')

        overwritten = DeltaTable.forPath(self.spark, self.tempFile).toDF()
        self.__checkAnswer(
            overwritten,
            [Row(key_new='a', value_new=3),
             Row(key_new='b', value_new=2)])

        DeltaTable.forPath(
            self.spark, self.tempFile).restoreToTimestamp(timestampToRestore)

        restored = DeltaTable.forPath(self.spark, self.tempFile).toDF()
        self.__checkAnswer(restored,
                           [Row(key='a', value=1),
                            Row(key='b', value=2)])

        # we cannot test the actual working of restore to timestamp here but we can make sure
        # that the api is being called at least
        def runRestore() -> None:
            DeltaTable.forPath(self.spark,
                               self.tempFile).restoreToTimestamp('05/04/1999')

        self.__intercept(
            runRestore, "The provided timestamp ('05/04/1999') "
            "cannot be converted to a valid timestamp")
Exemplo n.º 3
0
 def get_current_data(self) -> DeltaTable:
     try:
         delta_output = DeltaTable.forPath(
             self.spark_configuration.spark_session,
             self.delta_src + self.current_data_table_name)
     except AnalysisException:
         print("Delta Table not exists -> creating")
         self._init_current_data()
         delta_output = DeltaTable.forPath(
             self.spark_configuration.spark_session,
             self.delta_src + self.current_data_table_name)
     return delta_output
Exemplo n.º 4
0
def run() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--input',
        dest='input',
        required=True,
        help='input file path',
    )
    parser.add_argument(
        '--output',
        dest='output',
        required=True,
        help='output delta table path',
    )
    parser.add_argument(
        '--action',
        dest='action',
        default='show',
        help='action to apply',
    )
    args = parser.parse_args()

    spark = SparkSession.builder.appName('taxipy') \
        .config('spark.jars.packages', 'io.delta:delta-core_2.12:0.7.0') \
        .config('spark.sql.extensions', 'io.delta.sql.DeltaSparkSessionExtension') \
        .config('spark.sql.catalog.spark_catalog', 'org.apache.spark.sql.delta.catalog.DeltaCatalog') \
        .getOrCreate()

    from delta.tables import DeltaTable
    taxi_ride_table = DeltaTable.forPath(spark, args.output)
    taxi_ride_table.update(condition=expr('vendor_id == 2'),
                           set={'total_amount': '123'})
    show_records(spark, args.output)
    spark.stop()
Exemplo n.º 5
0
    def test_optimize_zorder_by(self) -> None:
        # write an unoptimized delta table
        self.spark.createDataFrame([i for i in range(0, 100)], IntegerType()) \
            .withColumn("col1", floor(col("value") % 7)) \
            .withColumn("col2", floor(col("value") % 27)) \
            .withColumn("p", floor(col("value") % 10)) \
            .repartition(4).write.partitionBy("p").format("delta").save(self.tempFile)

        # create DeltaTable
        dt = DeltaTable.forPath(self.spark, self.tempFile)

        # execute Z-Order Optimization
        optimizer = dt.optimize()
        result = optimizer.executeZOrderBy(["col1", "col2"])
        metrics = result.select("metrics.*").head()

        self.assertTrue(metrics.numFilesAdded == 10)
        self.assertTrue(metrics.numFilesRemoved == 37)
        self.assertTrue(metrics.totalFilesSkipped == 0)
        self.assertTrue(metrics.totalConsideredFiles == 37)
        self.assertTrue(metrics.zOrderStats.strategyName == 'all')
        self.assertTrue(metrics.zOrderStats.numOutputCubes == 10)

        # negative test: Z-Order on partition column
        def optimize() -> None:
            dt.optimize().where("p = 1").executeZOrderBy(["p"])

        self.__intercept(
            optimize, "p is a partition column. "
            "Z-Ordering can only be performed on data columns")
Exemplo n.º 6
0
    def test_optimize(self) -> None:
        # write an unoptimized delta table
        df = self.spark.createDataFrame([("a", 1), ("a", 2)],
                                        ["key", "value"]).repartition(1)
        df.write.format("delta").save(self.tempFile)
        df = self.spark.createDataFrame([("a", 3), ("a", 4)],
                                        ["key", "value"]).repartition(1)
        df.write.format("delta").save(self.tempFile, mode="append")
        df = self.spark.createDataFrame([("b", 1), ("b", 2)],
                                        ["key", "value"]).repartition(1)
        df.write.format("delta").save(self.tempFile, mode="append")

        # create DeltaTable
        dt = DeltaTable.forPath(self.spark, self.tempFile)

        # execute bin compaction
        optimizer = dt.optimize()
        res = optimizer.executeCompaction()
        op_params = dt.history().first().operationParameters

        # assertions
        self.assertTrue(isinstance(optimizer, DeltaOptimizeBuilder))
        self.assertTrue(isinstance(res, DataFrame))
        self.assertEqual(1, res.first().metrics.numFilesAdded)
        self.assertEqual(3, res.first().metrics.numFilesRemoved)
        self.assertEqual('[]', op_params['predicate'])

        # test non-partition column
        def optimize() -> None:
            dt.optimize().where("key = 'a'").executeCompaction()

        self.__intercept(
            optimize, "Predicate references non-partition column 'key'. "
            "Only the partition columns may be referenced: []")
Exemplo n.º 7
0
    def test_optimize_zorder_by_w_partition_filter(self) -> None:
        # write an unoptimized delta table
        df = self.spark.createDataFrame([i for i in range(0, 100)], IntegerType()) \
            .withColumn("col1", floor(col("value") % 7)) \
            .withColumn("col2", floor(col("value") % 27)) \
            .withColumn("p", floor(col("value") % 10)) \
            .repartition(4).write.partitionBy("p")

        df.format("delta").save(self.tempFile)

        # create DeltaTable
        dt = DeltaTable.forPath(self.spark, self.tempFile)

        # execute Z-OrderBy
        optimizer = dt.optimize().where("p = 2")
        result = optimizer.executeZOrderBy(["col1", "col2"])
        metrics = result.select("metrics.*").head()

        # assertions (partition 'p = 2' has four files)
        self.assertTrue(metrics.numFilesAdded == 1)
        self.assertTrue(metrics.numFilesRemoved == 4)
        self.assertTrue(metrics.totalFilesSkipped == 0)
        self.assertTrue(metrics.totalConsideredFiles == 4)
        self.assertTrue(metrics.zOrderStats.strategyName == 'all')
        self.assertTrue(metrics.zOrderStats.numOutputCubes == 1)
Exemplo n.º 8
0
def upsert_table(spark,
                 updatesDF,
                 condition,
                 output_file,
                 partition_columns=None):
    '''
    update/insert transformed immigration data to fact_table
    incase of duplicate id, overwrite the old value with new value, otherwise append it to dataframe
    '''

    from delta.tables import DeltaTable

    if not os.path.exists(output_file):

        if partition_columns is None:
            updatesDF.write.format('delta').save(output_file)
        else:
            updatesDF.write.format('delta').partitionBy(
                *partition_columns).save(output_file)

    else:

        deltaTable = DeltaTable.forPath(spark, output_file)

        deltaTable.alias("source").merge(
            source = updatesDF.alias("update"),
            condition = condition) \
          .whenMatchedUpdateAll() \
          .whenNotMatchedInsertAll() \
          .execute()
Exemplo n.º 9
0
    def merge_write(logger, df_dict: Dict[str, DataFrame], rules: Dict[str,
                                                                       str],
                    output_path: str, spark: SparkSession):
        """
        Write data if the dataset doesn't exist or merge it to the existing dataset
        Args:
            logger: Logger instance used to log events
            df_dict: Dictionary of the datasets with the structure {Name: Dataframe}
            rules: Matching rules use to merge
            output_path: Path to write the data
            spark: Spark instance

        Returns:

        """
        try:
            from delta.tables import DeltaTable
            for df_name, df in df_dict.items():
                file_path = path.join(output_path, df_name)
                if DeltaTable.isDeltaTable(spark, file_path):
                    delta_table = DeltaTable.forPath(spark, file_path)
                    delta_table.alias("old").merge(
                        df.alias("new"), rules.get(df_name)
                    ).whenMatchedUpdateAll().whenNotMatchedInsertAll()
                else:
                    df.write.format("delta").save(file_path)

        except Exception as e:
            logger.error(
                "Writing sanitized data couldn't be performed: {}".format(e),
                traceback.format_exc())
            raise e
        else:
            logger.info("Sanitized dataframes written in {} folder".format(
                output_path))
Exemplo n.º 10
0
    def _merge_into_table(self, df, destination_path, checkpoints_path,
                          condition):
        """ Merges data from the given dataframe into the delta table at the specified destination_path, based on the given condition.
            If not delta table exists at the specified destination_path, a new delta table is created and the data from the given dataframe is inserted.
            eg, merge_into_table(df_lookup, np_destination_path, source_path + '/_checkpoints/delta_np', "current.id_pseudonym = updates.id_pseudonym")
        """
        if DeltaTable.isDeltaTable(spark, destination_path):
            dt = DeltaTable.forPath(spark, destination_path)

            def upsert(batch_df, batchId):
                dt.alias("current").merge(
                    batch_df.alias("updates"), condition).whenMatchedUpdateAll(
                    ).whenNotMatchedInsertAll().execute()

            query = df.writeStream.format("delta").foreachBatch(
                upsert).outputMode("update").trigger(once=True).option(
                    "checkpointLocation", checkpoints_path)
        else:
            logger.info(
                f'Delta table does not yet exist at {destination_path} - creating one now and inserting initial data.'
            )
            query = df.writeStream.format("delta").outputMode(
                "append").trigger(once=True).option("checkpointLocation",
                                                    checkpoints_path)
        query = query.start(destination_path)
        query.awaitTermination(
        )  # block until query is terminated, with stop() or with error; A StreamingQueryException will be thrown if an exception occurs.
        logger.info(query.lastProgress)
Exemplo n.º 11
0
    def test_restore_to_version(self) -> None:
        self.__writeDeltaTable([('a', 1), ('b', 2)])
        self.__overwriteDeltaTable([('a', 3), ('b', 2)],
                                   schema=["key_new", "value_new"],
                                   overwriteSchema='true')

        overwritten = DeltaTable.forPath(self.spark, self.tempFile).toDF()
        self.__checkAnswer(
            overwritten,
            [Row(key_new='a', value_new=3),
             Row(key_new='b', value_new=2)])

        DeltaTable.forPath(self.spark, self.tempFile).restoreToVersion(0)
        restored = DeltaTable.forPath(self.spark, self.tempFile).toDF()

        self.__checkAnswer(restored,
                           [Row(key='a', value=1),
                            Row(key='b', value=2)])
Exemplo n.º 12
0
def get_delta_table(path):
    try:
        dt= DeltaTable.forPath(spark,path)
    except AnalysisException as e:
        if('doesn\'t exist;' in str(e).lower() or 'is not a delta table.' in str(e).lower()):
            print("Error Occured due to : "+str(e))
            return None
        else:
            raise e
    return dt
Exemplo n.º 13
0
    def _get_delta_table(spark, delta_path, delta_table):

        if [delta_path, delta_table].count(None) == 2:
            raise ValueError("delta_path ou delta_table deve ser passado")

        if delta_path is not None:
            delta_table = DeltaTable.forPath(spark, delta_path)

        else:
            delta_table = DeltaTable.forName(spark, delta_table)

        return delta_table
Exemplo n.º 14
0
    def test_history(self) -> None:
        self.__writeDeltaTable([('a', 1), ('b', 2), ('c', 3)])
        self.__overwriteDeltaTable([('a', 3), ('b', 2), ('c', 1)])
        dt = DeltaTable.forPath(self.spark, self.tempFile)
        operations = dt.history().select('operation')
        self.__checkAnswer(
            operations, [Row("WRITE"), Row("WRITE")],
            StructType([StructField("operation", StringType(), True)]))

        lastMode = dt.history(1).select('operationParameters.mode')
        self.__checkAnswer(
            lastMode, [Row("Overwrite")],
            StructType(
                [StructField("operationParameters.mode", StringType(), True)]))
Exemplo n.º 15
0
    def _get_delta_table(self, spark, table_or_path, update_delta_table):

        try:
            deltaTable = DeltaTable.forPath(spark, table_or_path)
        except:
            try:
                deltaTable = DeltaTable.forName(spark, table_or_path)
            except AssertionError as E:
                raise E

        if update_delta_table:
            return deltaTable

        return deltaTable.toDF()
Exemplo n.º 16
0
def merge(spark, update, tableName, cols, key):
    """
    将DataFrame和delta表进行merge操作,insert操作要求DataFrame必须包含delta表所有的列(0.5版本)
    当我们使用merge操作更新/插入delta表其中几列时,指定在DataFrame中不存在的列的值为null。

    注:DataFrame中要写入delta表的列要和delta表一样
    :param spark,SparkSession实例
    :param update,spark DataFrame
    :param tableName,要更新的delta表
    """
    # 如果没有dt列,创建当前日期的dt列
    if "dt" not in cols:
        update = update.withColumn("dt", f.current_date())
        cols.append("dt")

    # 1.构建merge条件
    mergeExpr = f"origin.{key}=update.{key}"
    print(f"merge expression:{mergeExpr}")

    # 2.构建更新表达式
    updateExpr = {}
    for c in cols:
        updateExpr[c] = f"update.{c}"

    print(f"update expression:{updateExpr}")

    origin = DeltaTable.forPath(spark, tableName)
    origin_cols = origin.toDF().columns

    # 3.构建插入表达式
    insertExpr = {}
    for origin_col in origin_cols:
        if origin_col in cols:
            insertExpr[origin_col] = f"update.{origin_col}"
        else:
            # 不存在,插入null值(不是字符串)
            insertExpr[origin_col] = "null"

    print(f"insert expression:{insertExpr}")

    # for origin_col in origin_cols:
    #     if origin_col not in cols:
    #         update=update.withColumn(origin_col,f.lit(None))

    origin.alias("origin") \
        .merge(update.alias("update"), mergeExpr) \
        .whenMatchedUpdate(set=updateExpr) \
        .whenNotMatchedInsert(values=insertExpr) \
        .execute()
Exemplo n.º 17
0
    def test_vacuum(self):
        self.__writeDeltaTable([('a', 1), ('b', 2), ('c', 3)])
        dt = DeltaTable.forPath(self.spark, self.tempFile)
        self.__createFile('abc.txt', 'abcde')
        self.__createFile('bac.txt', 'abcdf')
        self.assertEqual(True, self.__checkFileExists('abc.txt'))
        dt.vacuum()  # will not delete files as default retention is used.

        self.assertEqual(True, self.__checkFileExists('bac.txt'))
        retentionConf = "spark.databricks.delta.retentionDurationCheck.enabled"
        self.spark.conf.set(retentionConf, "false")
        dt.vacuum(0.0)
        self.spark.conf.set(retentionConf, "true")
        self.assertEqual(False, self.__checkFileExists('bac.txt'))
        self.assertEqual(False, self.__checkFileExists('abc.txt'))
Exemplo n.º 18
0
def get_delta_table(
        spark: SparkSession,
        schema: StructType,
        delta_library_jar: str,
        delta_path: str):
    # load delta library jar, so we can use delta module
    spark.sparkContext.addPyFile(delta_library_jar)
    from delta.tables import DeltaTable

    # check existence of delta table
    if not DeltaTable.isDeltaTable(spark, delta_path):
        print(f">>> Delta table: {delta_path} is not initialized, performing initialization..")
        df = spark.createDataFrame([], schema=schema)
        df.write.format("delta").save(delta_path)

    return DeltaTable.forPath(spark, delta_path)
Exemplo n.º 19
0
 def update_sum_count_or_insert(self, new_data: DataFrame, table: str, id_col: str):
     try:
         delta_table = DeltaTable.forPath(self.spark_configuration.spark_session, self.delta_src + table)
     except AnalysisException:
         # If delta table not exists just create it
         new_data.write \
             .format("delta") \
             .save(self.delta_src + table)
         return
     delta_table.alias("current_data").merge(
         new_data.alias("updates"),
         "current_data.{0} = updates.{0}".format(id_col))\
         .whenMatchedUpdate(set={
             "count": "current_data.count + updates.count"
         }) \
         .whenNotMatchedInsertAll() \
         .execute()
Exemplo n.º 20
0
    def merge(
            self,
            df: DataFrame,
            location: str,
            condition: str,  # Only supports SQL-like string condition
            match_update_dict: dict,  # "target_column": "expression"
            insert_when_not_matched: False,  # Set to True for upsert
            save_mode: str = 'table'):
        '''Merge a dataframe to target table or path.

        This merge operation can represent both update and upsert operation.
        Source and target table is defaultly alias-ed as 'SRC' and 'TGT'. This could be used in condition string and update/insert expressions.
        Args:
            df (DataFrame): The source dataframe to write.
            save_mode (str): 'table' or 'path'
            location (str): The table name or path to be merge into.
            condition (str): The condition in SQL-like string form.
            match_update_dict (dict): Contains ("target_column": "expression"). 
                This represents the updated value if matched.
                NOTE: "target_column"'s come without schema ("SRC" or "TGT").
            not_match_insert_dict (dict): Contains ("target_column": "expression"). 
                This represents the inserted value if not matched. 
                Other columns which are not specified shall be null.
                NOTE: "target_column"'s come without schema ("SRC" or "TGT").
        '''
        super(DeltaDataSource,
              self).merge(df,
                          condition,
                          match_update_dict,
                          insert_when_not_matched=insert_when_not_matched)
        save_mode = save_mode.lower()
        if save_mode == "table":
            target_table = DeltaTable.forName(self.spark, location)
        elif save_mode == "path":
            target_table = DeltaTable.forPath(self.spark, location)
        else:
            raise ValueError("save_mode should be 'path' or 'table'.")

        merger = target_table.alias("TGT").merge(df.alias("SRC"), condition)
        merger = merger.whenMatchedUpdate(set=match_update_dict)

        if insert_when_not_matched:
            merger = merger.whenNotMatchedInsert(values=match_update_dict)

        merger.execute()
Exemplo n.º 21
0
    def test_generate(self) -> None:
        # create a delta table
        numFiles = 10
        self.spark.range(100).repartition(numFiles).write.format("delta").save(self.tempFile)
        dt = DeltaTable.forPath(self.spark, self.tempFile)

        # Generate the symlink format manifest
        dt.generate("symlink_format_manifest")

        # check the contents of the manifest
        # NOTE: this is not a correctness test, we are testing correctness in the scala suite
        manifestPath = os.path.join(self.tempFile,
                                    os.path.join("_symlink_format_manifest", "manifest"))
        files = []
        with open(manifestPath) as f:
            files = f.readlines()

        # the number of files we write should equal the number of lines in the manifest
        assert(len(files) == numFiles)
Exemplo n.º 22
0
    def test_delete(self) -> None:
        self.__writeDeltaTable([('a', 1), ('b', 2), ('c', 3), ('d', 4)])
        dt = DeltaTable.forPath(self.spark, self.tempFile)

        # delete with condition as str
        dt.delete("key = 'a'")
        self.__checkAnswer(dt.toDF(), [('b', 2), ('c', 3), ('d', 4)])

        # delete with condition as Column
        dt.delete(col("key") == lit("b"))
        self.__checkAnswer(dt.toDF(), [('c', 3), ('d', 4)])

        # delete without condition
        dt.delete()
        self.__checkAnswer(dt.toDF(), [])

        # bad args
        with self.assertRaises(TypeError):
            dt.delete(condition=1)  # type: ignore[arg-type]
Exemplo n.º 23
0
    def test_restore_invalid_inputs(self) -> None:
        df = self.spark.createDataFrame([('a', 1), ('b', 2), ('c', 3)],
                                        ["key", "value"])
        df.write.format("delta").save(self.tempFile)

        dt = DeltaTable.forPath(self.spark, self.tempFile)

        def runRestoreToTimestamp() -> None:
            dt.restoreToTimestamp(12342323232)  # type: ignore[arg-type]

        self.__intercept(
            runRestoreToTimestamp,
            "timestamp needs to be a string but got '<class 'int'>'")

        def runRestoreToVersion() -> None:
            dt.restoreToVersion("0")  # type: ignore[arg-type]

        self.__intercept(runRestoreToVersion,
                         "version needs to be an int but got '<class 'str'>'")
Exemplo n.º 24
0
    def test_update(self) -> None:
        self.__writeDeltaTable([('a', 1), ('b', 2), ('c', 3), ('d', 4)])
        dt = DeltaTable.forPath(self.spark, self.tempFile)

        # update with condition as str and with set exprs as str
        dt.update("key = 'a' or key = 'b'", {"value": "1"})
        self.__checkAnswer(dt.toDF(), [('a', 1), ('b', 1), ('c', 3), ('d', 4)])

        # update with condition as Column and with set exprs as Columns
        dt.update(expr("key = 'a' or key = 'b'"), {"value": expr("0")})
        self.__checkAnswer(dt.toDF(), [('a', 0), ('b', 0), ('c', 3), ('d', 4)])

        # update without condition
        dt.update(set={"value": "200"})
        self.__checkAnswer(dt.toDF(), [('a', 200), ('b', 200), ('c', 200),
                                       ('d', 200)])

        # bad args
        with self.assertRaisesRegex(ValueError, "cannot be None"):
            dt.update({"value": "200"})  # type: ignore[call-overload]

        with self.assertRaisesRegex(ValueError, "cannot be None"):
            dt.update(condition='a')  # type: ignore[call-overload]

        with self.assertRaisesRegex(TypeError, "must be a dict"):
            dt.update(set=1)  # type: ignore[call-overload]

        with self.assertRaisesRegex(TypeError,
                                    "must be a Spark SQL Column or a string"):
            dt.update(1, {})  # type: ignore[call-overload]

        with self.assertRaisesRegex(TypeError,
                                    "Values of dict in .* must contain only"):
            dt.update(set={"value": 1})  # type: ignore[dict-item]

        with self.assertRaisesRegex(TypeError,
                                    "Keys of dict in .* must contain only"):
            dt.update(set={1: ""})  # type: ignore[dict-item]

        with self.assertRaises(TypeError):
            dt.update(set=1)  # type: ignore[call-overload]
Exemplo n.º 25
0
def update_silver_table(spark: SparkSession, silverPath: str) -> bool:
    from delta.tables import DeltaTable

    silver_df = load_dataframe(spark, format="delta", path=silverPath)
    silverTable = DeltaTable.forPath(spark, silverPath)

    update_match = """
    health_tracker.eventtime = updates.eventtime
    AND
    health_tracker.device_id = updates.device_id
    """

    update = {"heartrate": "updates.heartrate"}

    updates_df = prepare_interpolated_updates_dataframe(spark, silver_df)

    (silverTable.alias("health_tracker").merge(
        updates_df.alias("updates"),
        update_match).whenMatchedUpdate(set=update).execute())

    return True
Exemplo n.º 26
0
    def test_protocolUpgrade(self) -> None:
        try:
            self.spark.conf.set('spark.databricks.delta.minWriterVersion', '2')
            self.spark.conf.set('spark.databricks.delta.minReaderVersion', '1')
            self.__writeDeltaTable([('a', 1), ('b', 2), ('c', 3), ('d', 4)])
            dt = DeltaTable.forPath(self.spark, self.tempFile)
            dt.upgradeTableProtocol(1, 3)
        finally:
            self.spark.conf.unset('spark.databricks.delta.minWriterVersion')
            self.spark.conf.unset('spark.databricks.delta.minReaderVersion')

        # cannot downgrade once upgraded
        failed = False
        try:
            dt.upgradeTableProtocol(1, 2)
        except BaseException:
            failed = True
        self.assertTrue(
            failed,
            "The upgrade should have failed, because downgrades aren't allowed"
        )

        # bad args
        with self.assertRaisesRegex(ValueError, "readerVersion"):
            dt.upgradeTableProtocol("abc", 3)  # type: ignore[arg-type]
        with self.assertRaisesRegex(ValueError, "readerVersion"):
            dt.upgradeTableProtocol([1], 3)  # type: ignore[arg-type]
        with self.assertRaisesRegex(ValueError, "readerVersion"):
            dt.upgradeTableProtocol([], 3)  # type: ignore[arg-type]
        with self.assertRaisesRegex(ValueError, "readerVersion"):
            dt.upgradeTableProtocol({}, 3)  # type: ignore[arg-type]
        with self.assertRaisesRegex(ValueError, "writerVersion"):
            dt.upgradeTableProtocol(1, "abc")  # type: ignore[arg-type]
        with self.assertRaisesRegex(ValueError, "writerVersion"):
            dt.upgradeTableProtocol(1, [3])  # type: ignore[arg-type]
        with self.assertRaisesRegex(ValueError, "writerVersion"):
            dt.upgradeTableProtocol(1, [])  # type: ignore[arg-type]
        with self.assertRaisesRegex(ValueError, "writerVersion"):
            dt.upgradeTableProtocol(1, {})  # type: ignore[arg-type]
Exemplo n.º 27
0
 def test_alias_and_toDF(self) -> None:
     self.__writeDeltaTable([('a', 1), ('b', 2), ('c', 3)])
     dt = DeltaTable.forPath(self.spark, self.tempFile).toDF()
     self.__checkAnswer(
         dt.alias("myTable").select('myTable.key', 'myTable.value'),
         [('a', 1), ('b', 2), ('c', 3)])
Exemplo n.º 28
0
 def test_forPath(self) -> None:
     self.__writeDeltaTable([('a', 1), ('b', 2), ('c', 3)])
     dt = DeltaTable.forPath(self.spark, self.tempFile).toDF()
     self.__checkAnswer(dt, [('a', 1), ('b', 2), ('c', 3)])
Exemplo n.º 29
0
    def test_merge(self) -> None:
        self.__writeDeltaTable([('a', 1), ('b', 2), ('c', 3), ('d', 4)])
        source = self.spark.createDataFrame([('a', -1), ('b', 0), ('e', -5), ('f', -6)], ["k", "v"])

        def reset_table() -> None:
            self.__overwriteDeltaTable([('a', 1), ('b', 2), ('c', 3), ('d', 4)])

        dt = DeltaTable.forPath(self.spark, self.tempFile)

        # ============== Test basic syntax ==============

        # String expressions in merge condition and dicts
        reset_table()
        dt.merge(source, "key = k") \
            .whenMatchedUpdate(set={"value": "v + 0"}) \
            .whenNotMatchedInsert(values={"key": "k", "value": "v + 0"}) \
            .execute()
        self.__checkAnswer(dt.toDF(),
                           ([('a', -1), ('b', 0), ('c', 3), ('d', 4), ('e', -5), ('f', -6)]))

        # Column expressions in merge condition and dicts
        reset_table()
        dt.merge(source, expr("key = k")) \
            .whenMatchedUpdate(set={"value": col("v") + 0}) \
            .whenNotMatchedInsert(values={"key": "k", "value": col("v") + 0}) \
            .execute()
        self.__checkAnswer(dt.toDF(),
                           ([('a', -1), ('b', 0), ('c', 3), ('d', 4), ('e', -5), ('f', -6)]))

        # ============== Test clause conditions ==============

        # String expressions in all conditions and dicts
        reset_table()
        dt.merge(source, "key = k") \
            .whenMatchedUpdate(condition="k = 'a'", set={"value": "v + 0"}) \
            .whenMatchedDelete(condition="k = 'b'") \
            .whenNotMatchedInsert(condition="k = 'e'", values={"key": "k", "value": "v + 0"}) \
            .execute()
        self.__checkAnswer(dt.toDF(), ([('a', -1), ('c', 3), ('d', 4), ('e', -5)]))

        # Column expressions in all conditions and dicts
        reset_table()
        dt.merge(source, expr("key = k")) \
            .whenMatchedUpdate(
                condition=expr("k = 'a'"),
                set={"value": col("v") + 0}) \
            .whenMatchedDelete(condition=expr("k = 'b'")) \
            .whenNotMatchedInsert(
                condition=expr("k = 'e'"),
                values={"key": "k", "value": col("v") + 0}) \
            .execute()
        self.__checkAnswer(dt.toDF(), ([('a', -1), ('c', 3), ('d', 4), ('e', -5)]))

        # Positional arguments
        reset_table()
        dt.merge(source, "key = k") \
            .whenMatchedUpdate("k = 'a'", {"value": "v + 0"}) \
            .whenMatchedDelete("k = 'b'") \
            .whenNotMatchedInsert("k = 'e'", {"key": "k", "value": "v + 0"}) \
            .execute()
        self.__checkAnswer(dt.toDF(), ([('a', -1), ('c', 3), ('d', 4), ('e', -5)]))

        # ============== Test updateAll/insertAll ==============

        # No clause conditions and insertAll/updateAll + aliases
        reset_table()
        dt.alias("t") \
            .merge(source.toDF("key", "value").alias("s"), expr("t.key = s.key")) \
            .whenMatchedUpdateAll() \
            .whenNotMatchedInsertAll() \
            .execute()
        self.__checkAnswer(dt.toDF(),
                           ([('a', -1), ('b', 0), ('c', 3), ('d', 4), ('e', -5), ('f', -6)]))

        # String expressions in all clause conditions and insertAll/updateAll + aliases
        reset_table()
        dt.alias("t") \
            .merge(source.toDF("key", "value").alias("s"), "s.key = t.key") \
            .whenMatchedUpdateAll("s.key = 'a'") \
            .whenNotMatchedInsertAll("s.key = 'e'") \
            .execute()
        self.__checkAnswer(dt.toDF(), ([('a', -1), ('b', 2), ('c', 3), ('d', 4), ('e', -5)]))

        # Column expressions in all clause conditions and insertAll/updateAll + aliases
        reset_table()
        dt.alias("t") \
            .merge(source.toDF("key", "value").alias("s"), expr("t.key = s.key")) \
            .whenMatchedUpdateAll(expr("s.key = 'a'")) \
            .whenNotMatchedInsertAll(expr("s.key = 'e'")) \
            .execute()
        self.__checkAnswer(dt.toDF(), ([('a', -1), ('b', 2), ('c', 3), ('d', 4), ('e', -5)]))

        # ============== Test bad args ==============
        # ---- bad args in merge()
        with self.assertRaisesRegex(TypeError, "must be DataFrame"):
            dt.merge(1, "key = k")  # type: ignore[arg-type]

        with self.assertRaisesRegex(TypeError, "must be a Spark SQL Column or a string"):
            dt.merge(source, 1)  # type: ignore[arg-type]

        # ---- bad args in whenMatchedUpdate()
        with self.assertRaisesRegex(ValueError, "cannot be None"):
            (dt  # type: ignore[call-overload]
                .merge(source, "key = k")
                .whenMatchedUpdate({"value": "v"}))

        with self.assertRaisesRegex(ValueError, "cannot be None"):
            (dt  # type: ignore[call-overload]
                .merge(source, "key = k")
                .whenMatchedUpdate(1))

        with self.assertRaisesRegex(ValueError, "cannot be None"):
            (dt  # type: ignore[call-overload]
                .merge(source, "key = k")
                .whenMatchedUpdate(condition="key = 'a'"))

        with self.assertRaisesRegex(TypeError, "must be a Spark SQL Column or a string"):
            (dt  # type: ignore[call-overload]
                .merge(source, "key = k")
                .whenMatchedUpdate(1, {"value": "v"}))

        with self.assertRaisesRegex(TypeError, "must be a dict"):
            (dt  # type: ignore[call-overload]
                .merge(source, "key = k")
                .whenMatchedUpdate("k = 'a'", 1))

        with self.assertRaisesRegex(TypeError, "Values of dict in .* must contain only"):
            (dt
                .merge(source, "key = k")
                .whenMatchedUpdate(set={"value": 1}))  # type: ignore[dict-item]

        with self.assertRaisesRegex(TypeError, "Keys of dict in .* must contain only"):
            (dt
                .merge(source, "key = k")
                .whenMatchedUpdate(set={1: ""}))  # type: ignore[dict-item]

        with self.assertRaises(TypeError):
            (dt  # type: ignore[call-overload]
                .merge(source, "key = k")
                .whenMatchedUpdate(set="k = 'a'", condition={"value": 1}))

        # bad args in whenMatchedDelete()
        with self.assertRaisesRegex(TypeError, "must be a Spark SQL Column or a string"):
            dt.merge(source, "key = k").whenMatchedDelete(1)  # type: ignore[arg-type]

        # ---- bad args in whenNotMatchedInsert()
        with self.assertRaisesRegex(ValueError, "cannot be None"):
            (dt  # type: ignore[call-overload]
                .merge(source, "key = k")
                .whenNotMatchedInsert({"value": "v"}))

        with self.assertRaisesRegex(ValueError, "cannot be None"):
            dt.merge(source, "key = k").whenNotMatchedInsert(1)  # type: ignore[call-overload]

        with self.assertRaisesRegex(ValueError, "cannot be None"):
            (dt  # type: ignore[call-overload]
                .merge(source, "key = k")
                .whenNotMatchedInsert(condition="key = 'a'"))

        with self.assertRaisesRegex(TypeError, "must be a Spark SQL Column or a string"):
            (dt  # type: ignore[call-overload]
                .merge(source, "key = k")
                .whenNotMatchedInsert(1, {"value": "v"}))

        with self.assertRaisesRegex(TypeError, "must be a dict"):
            (dt  # type: ignore[call-overload]
                .merge(source, "key = k")
                .whenNotMatchedInsert("k = 'a'", 1))

        with self.assertRaisesRegex(TypeError, "Values of dict in .* must contain only"):
            (dt
                .merge(source, "key = k")
                .whenNotMatchedInsert(values={"value": 1}))  # type: ignore[dict-item]

        with self.assertRaisesRegex(TypeError, "Keys of dict in .* must contain only"):
            (dt
                .merge(source, "key = k")
                .whenNotMatchedInsert(values={1: "value"}))  # type: ignore[dict-item]

        with self.assertRaises(TypeError):
            (dt  # type: ignore[call-overload]
                .merge(source, "key = k")
                .whenNotMatchedInsert(values="k = 'a'", condition={"value": 1}))
Exemplo n.º 30
0
 def _load(self) -> DeltaTable:
     load_path = self._fs_prefix + str(self._filepath)
     return DeltaTable.forPath(self._get_spark(), load_path)