def __init__(self) -> None: # Covers all Presto data types from single word (eg: integer) to composed (eg: decimal(22, 2)) self.known_issues = { r"line (?P<line>\d+):(?P<column>\d+): Cannot cast timestamp to (?P<integer_type>tinyint|smallint|integer|int|bigint) \(\d+\)": self._cast_timestamp_to_epoch, r"line (?P<line>\d+):(?P<column>\d+): Cannot cast (?P<source_type>{d_t}) to (?P<target_type>{d_t}) \(\d+\)".format(d_t=utils.d_t): self._cannot_cast_to_type, r"line (?P<line>\d+):(?P<column>\d+): '(>|<|[><!]?=)' cannot be applied to (?P<b_type_0>{d_t}), (?P<f_type_0>{d_t}) \(\d+\)".format(d_t=utils.d_t): self._cast_both_sides, r"Mismatch at column (?P<column>\d+): \'(?P<name>\w+)\' is of type (?P<ddl_type>{d_t}) but expression is of type (?P<expr_type>{d_t}) \(\d+\)".format(d_t=utils.d_t): self._column_type_mismatch, r"line (?P<line>\d+):(?P<column>\d+): Unexpected parameters \((?P<parameters>.*)\) for function (?P<function_name>\w+)": self._unexpected_parameters, r"Table '(?P<table>\w+.\w+)' not found \(\d+\)": self._table_not_found, r"line (?P<line>\d+):(?P<column>\d+): Cannot check if (?P<b_type_0>{d_t}) is BETWEEN (?P<f_type_0>{d_t}) and (?P<f_type_1>{d_t}) \(\d+\)".format(d_t=utils.d_t): self._between, r"line (?P<line>\d+):(?P<column>\d+): IN value and list items must be the same type: (?P<type>{d_t}) \(\d+\)".format(d_t=utils.d_t): self._cast_in, r"line (?P<line>\d+):(?P<column>\d+): value and result of subquery must be of the same type for IN expression: (?P<b_type_0>{d_t}) vs (?P<f_type_0>{d_t}) \(\d+\)".format(d_t=utils.d_t): self._cast_in_subquery, r"line (?P<line>\d+):(?P<column>\d+): All CASE results must be the same type: (?P<type>{d_t}) \(\d+\)".format(d_t=utils.d_t): self._case_statements, r"line (?P<line>\d+):(?P<column>\d+): All COALESCE operands must be the same type: (?P<type>{d_t}) \(\d+\)".format(d_t=utils.d_t): self._coalesce_statements } self.ColumnCaster = utils.ColumnCaster()
def _cast_divisions_to_double(self, query: str) -> str: """By default, Presto does an integer division when encountering two integers around a / sign. For instance, 3/2 = 1. Therefore, to properly translate it at least one side needs to be cast to double (both sides done here) Args: query (str): Input SQL Returns: str: Transformed SQL """ ColumnCaster = utils.ColumnCaster() logging.debug("Flattening SQL...") start = time.perf_counter() flattened_tokens = list( sqlparse.parse(query)[0].flatten()) # Very intensive! logging.debug(f"SQL was flattened in {time.perf_counter() - start} s!") division_operators = sum([ True for token in flattened_tokens if token.ttype == Operator and token.value == "/" ]) # Count how many operators there are logging.debug(f"Found {division_operators} division operator(s)") # Multi stage query copy/paste for division_operator in range(division_operators): logging.debug( f"Fixing division operation {division_operator}/{division_operators}" ) counter = 0 idx = 0 for token in sqlparse.parse(query)[0].flatten(): if token.ttype == Operator and token.value == "/": if counter == division_operator: query = ColumnCaster.cast_non_trivial_tokens( query, token, idx, "double", { "b_type_0": "", "f_type_0": "" }) # Cast both sides break else: counter += 1 idx += len(token.value) return query
def test_cast_non_trivial_tokens_ValueError(sql: str, loc: List[int], cast_to: str, groupdict: Dict) -> None: ColumnCaster = utils.ColumnCaster() token, idx = ColumnCaster.get_problematic_token(sql, *loc) with pytest.raises(ValueError): ColumnCaster.cast_non_trivial_tokens(sql, token, idx, cast_to, groupdict, count_backward_tokens=10)
def test_cast_non_trivial_tokens(sql: str, loc: List[int], cast_to: str, bck: int, fwd: int, groupdict: Dict, expected: Dict) -> None: ColumnCaster = utils.ColumnCaster() token, idx = ColumnCaster.get_problematic_token(sql, *loc) assert ColumnCaster.cast_non_trivial_tokens(sql, token, idx, cast_to, groupdict, count_backward_tokens=bck, count_forward_tokens=fwd) == expected
def test_light_cast(token, cast_to: str, data_type: str, expected: str) -> None: ColumnCaster = utils.ColumnCaster() assert ColumnCaster._light_cast(sqlparse.parse(token)[0].tokens[0], cast_to, data_type) == expected
def test_get_problematic_token_ValueError(sql: str, line: int, column: int) -> None: ColumnCaster = utils.ColumnCaster() with pytest.raises(ValueError): ColumnCaster.get_problematic_token(sql, line, column)
def test_get_problematic_token(sql: str, line: int, column: int, expected: Dict) -> None: ColumnCaster = utils.ColumnCaster() token, idx = ColumnCaster.get_problematic_token(sql, line, column) assert token.value == expected["value"] and idx == expected["idx"]