def pyarrow2redshift( # pylint: disable=too-many-branches,too-many-return-statements dtype: pa.DataType, string_type: str ) -> str: """Pyarrow to Redshift data types conversion.""" if pa.types.is_int8(dtype): return "SMALLINT" if pa.types.is_int16(dtype) or pa.types.is_uint8(dtype): return "SMALLINT" if pa.types.is_int32(dtype) or pa.types.is_uint16(dtype): return "INTEGER" if pa.types.is_int64(dtype) or pa.types.is_uint32(dtype): return "BIGINT" if pa.types.is_uint64(dtype): raise exceptions.UnsupportedType("There is no support for uint64, please consider int64 or uint32.") if pa.types.is_float32(dtype): return "FLOAT4" if pa.types.is_float64(dtype): return "FLOAT8" if pa.types.is_boolean(dtype): return "BOOL" if pa.types.is_string(dtype): return string_type if pa.types.is_timestamp(dtype): return "TIMESTAMP" if pa.types.is_date(dtype): return "DATE" if pa.types.is_decimal(dtype): return f"DECIMAL({dtype.precision},{dtype.scale})" if pa.types.is_dictionary(dtype): return pyarrow2redshift(dtype=dtype.value_type, string_type=string_type) if pa.types.is_list(dtype) or pa.types.is_struct(dtype): return "SUPER" raise exceptions.UnsupportedType(f"Unsupported Redshift type: {dtype}")
def pyarrow2sqlserver( # pylint: disable=too-many-branches,too-many-return-statements dtype: pa.DataType, string_type: str ) -> str: """Pyarrow to Microsoft SQL Server data types conversion.""" if pa.types.is_int8(dtype): return "SMALLINT" if pa.types.is_int16(dtype) or pa.types.is_uint8(dtype): return "SMALLINT" if pa.types.is_int32(dtype) or pa.types.is_uint16(dtype): return "INT" if pa.types.is_int64(dtype) or pa.types.is_uint32(dtype): return "BIGINT" if pa.types.is_uint64(dtype): raise exceptions.UnsupportedType("There is no support for uint64, please consider int64 or uint32.") if pa.types.is_float32(dtype): return "FLOAT(24)" if pa.types.is_float64(dtype): return "FLOAT" if pa.types.is_boolean(dtype): return "BIT" if pa.types.is_string(dtype): return string_type if pa.types.is_timestamp(dtype): return "DATETIME2" if pa.types.is_date(dtype): return "DATE" if pa.types.is_decimal(dtype): return f"DECIMAL({dtype.precision},{dtype.scale})" if pa.types.is_dictionary(dtype): return pyarrow2sqlserver(dtype=dtype.value_type, string_type=string_type) if pa.types.is_binary(dtype): return "VARBINARY" raise exceptions.UnsupportedType(f"Unsupported PostgreSQL type: {dtype}")
def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-statements dtype: pa.DataType, db_type: str) -> VisitableType: """Pyarrow to Athena data types conversion.""" if pa.types.is_int8(dtype): return sqlalchemy.types.SmallInteger if pa.types.is_int16(dtype): return sqlalchemy.types.SmallInteger if pa.types.is_int32(dtype): return sqlalchemy.types.Integer if pa.types.is_int64(dtype): return sqlalchemy.types.BigInteger if pa.types.is_float32(dtype): return sqlalchemy.types.Float if pa.types.is_float64(dtype): if db_type == "mysql": return sqlalchemy.dialects.mysql.DOUBLE if db_type == "postgresql": return sqlalchemy.dialects.postgresql.DOUBLE_PRECISION if db_type == "redshift": return sqlalchemy_redshift.dialect.DOUBLE_PRECISION raise exceptions.InvalidDatabaseType( f"{db_type} is a invalid database type, please choose between postgresql, mysql and redshift." ) # pragma: no cover if pa.types.is_boolean(dtype): return sqlalchemy.types.Boolean if pa.types.is_string(dtype): if db_type == "mysql": return sqlalchemy.types.Text if db_type == "postgresql": return sqlalchemy.types.Text if db_type == "redshift": return sqlalchemy.types.VARCHAR(length=256) raise exceptions.InvalidDatabaseType( f"{db_type} is a invalid database type. " f"Please choose between postgresql, mysql and redshift." ) # pragma: no cover if pa.types.is_timestamp(dtype): return sqlalchemy.types.DateTime if pa.types.is_date(dtype): return sqlalchemy.types.Date if pa.types.is_binary(dtype): if db_type == "redshift": raise exceptions.UnsupportedType( f"Binary columns are not supported for Redshift." ) # pragma: no cover return sqlalchemy.types.Binary if pa.types.is_decimal(dtype): return sqlalchemy.types.Numeric(precision=dtype.precision, scale=dtype.scale) if pa.types.is_dictionary(dtype): return pyarrow2sqlalchemy(dtype=dtype.value_type, db_type=db_type) if dtype == pa.null(): # pragma: no cover raise exceptions.UndetectedType( "We can not infer the data type from an entire null object column") raise exceptions.UnsupportedType( f"Unsupported Pyarrow type: {dtype}") # pragma: no cover
def athena2redshift( # pylint: disable=too-many-branches,too-many-return-statements dtype: str, varchar_length: int = 256) -> str: """Athena to Redshift data types conversion.""" dtype = dtype.lower() if dtype == "smallint": return "SMALLINT" if dtype in ("int", "integer"): return "INTEGER" if dtype == "bigint": return "BIGINT" if dtype == "float": return "FLOAT4" if dtype == "double": return "FLOAT8" if dtype in ("boolean", "bool"): return "BOOL" if dtype in ("string", "char", "varchar"): return f"VARCHAR({varchar_length})" if dtype == "timestamp": return "TIMESTAMP" if dtype == "date": return "DATE" if dtype.startswith("decimal"): return dtype.upper() raise exceptions.UnsupportedType( f"Unsupported Athena type: {dtype}") # pragma: no cover
def athena2pandas(dtype: str) -> str: # pylint: disable=too-many-branches,too-many-return-statements """Athena to Pandas data types conversion.""" dtype = dtype.lower() if dtype == "tinyint": return "Int8" if dtype == "smallint": return "Int16" if dtype in ("int", "integer"): return "Int32" if dtype == "bigint": return "Int64" if dtype == "float": return "float32" if dtype == "double": return "float64" if dtype == "boolean": return "boolean" if (dtype == "string" ) or dtype.startswith("char") or dtype.startswith("varchar"): return "string" if dtype in ("timestamp", "timestamp with time zone"): return "datetime64" if dtype == "date": return "date" if dtype.startswith("decimal"): return "decimal" if dtype in ("binary", "varbinary"): return "bytes" if dtype == "array": # pragma: no cover return "list" raise exceptions.UnsupportedType( f"Unsupported Athena type: {dtype}") # pragma: no cover
def athena2pyarrow(dtype: str) -> pa.DataType: # pylint: disable=too-many-return-statements """Athena to PyArrow data types conversion.""" dtype = dtype.lower() if dtype == "tinyint": return pa.int8() if dtype == "smallint": return pa.int16() if dtype in ("int", "integer"): return pa.int32() if dtype == "bigint": return pa.int64() if dtype == "float": return pa.float32() if dtype == "double": return pa.float64() if dtype == "boolean": return pa.bool_() if (dtype == "string" ) or dtype.startswith("char") or dtype.startswith("varchar"): return pa.string() if dtype == "timestamp": return pa.timestamp(unit="ns") if dtype == "date": return pa.date32() if dtype in ("binary" or "varbinary"): return pa.binary() if dtype.startswith("decimal"): precision, scale = dtype.replace("decimal(", "").replace(")", "").split(sep=",") return pa.decimal128(precision=int(precision), scale=int(scale)) raise exceptions.UnsupportedType( f"Unsupported Athena type: {dtype}") # pragma: no cover
def athena2quicksight(dtype: str) -> str: # pylint: disable=too-many-branches,too-many-return-statements """Athena to Quicksight data types conversion.""" dtype = dtype.lower() if dtype == "tinyint": return "INTEGER" if dtype == "smallint": return "INTEGER" if dtype in ("int", "integer"): return "INTEGER" if dtype == "bigint": return "INTEGER" if dtype in ("float", "real"): return "DECIMAL" if dtype == "double": return "DECIMAL" if dtype in ("boolean", "bool"): return "BOOLEAN" if dtype in ("string", "char", "varchar"): return "STRING" if dtype == "timestamp": return "DATETIME" if dtype == "date": return "DATETIME" if dtype.startswith("decimal"): return "DECIMAL" if dtype == "binary": # pragma: no cover return "BIT" raise exceptions.UnsupportedType( f"Unsupported Athena type: {dtype}") # pragma: no cover
def athena2pyarrow(dtype: str) -> pa.DataType: # pylint: disable=too-many-return-statements """Athena to PyArrow data types conversion.""" dtype = dtype.lower().replace(" ", "") if dtype == "tinyint": return pa.int8() if dtype == "smallint": return pa.int16() if dtype in ("int", "integer"): return pa.int32() if dtype == "bigint": return pa.int64() if dtype in ("float", "real"): return pa.float32() if dtype == "double": return pa.float64() if dtype == "boolean": return pa.bool_() if (dtype == "string") or dtype.startswith("char") or dtype.startswith("varchar"): return pa.string() if dtype == "timestamp": return pa.timestamp(unit="ns") if dtype == "date": return pa.date32() if dtype in ("binary" or "varbinary"): return pa.binary() if dtype.startswith("decimal") is True: precision, scale = dtype.replace("decimal(", "").replace(")", "").split(sep=",") return pa.decimal128(precision=int(precision), scale=int(scale)) if dtype.startswith("array") is True: return pa.list_(value_type=athena2pyarrow(dtype=dtype[6:-1]), list_size=-1) if dtype.startswith("struct") is True: return pa.struct([(f.split(":", 1)[0], athena2pyarrow(f.split(":", 1)[1])) for f in dtype[7:-1].split(",")]) if dtype.startswith("map") is True: return pa.map_(athena2pyarrow(dtype[4:-1].split(",", 1)[0]), athena2pyarrow(dtype[4:-1].split(",", 1)[1])) raise exceptions.UnsupportedType(f"Unsupported Athena type: {dtype}")
def pyarrow2athena(dtype: pa.DataType) -> str: # pylint: disable=too-many-branches,too-many-return-statements """Pyarrow to Athena data types conversion.""" if pa.types.is_int8(dtype): return "tinyint" if pa.types.is_int16(dtype): return "smallint" if pa.types.is_int32(dtype): return "int" if pa.types.is_int64(dtype): return "bigint" if pa.types.is_float32(dtype): return "float" if pa.types.is_float64(dtype): return "double" if pa.types.is_boolean(dtype): return "boolean" if pa.types.is_string(dtype): return "string" if pa.types.is_timestamp(dtype): return "timestamp" if pa.types.is_date(dtype): return "date" if pa.types.is_binary(dtype): return "binary" if pa.types.is_dictionary(dtype): return pyarrow2athena(dtype=dtype.value_type) if pa.types.is_decimal(dtype): return f"decimal({dtype.precision},{dtype.scale})" if pa.types.is_list(dtype): return f"array<{pyarrow2athena(dtype=dtype.value_type)}>" if pa.types.is_struct(dtype): # pragma: no cover return f"struct<{', '.join([f'{f.name}: {pyarrow2athena(dtype=f.type)}' for f in dtype])}>" if dtype == pa.null(): raise exceptions.UndetectedType("We can not infer the data type from an entire null object column") raise exceptions.UnsupportedType(f"Unsupported Pyarrow type: {dtype}") # pragma: no cover
def athena_types_from_pandas(df: pd.DataFrame, index: bool, dtype: Optional[Dict[str, str]] = None, index_left: bool = False) -> Dict[str, str]: """Extract the related Athena data types from any Pandas DataFrame.""" casts: Dict[str, str] = dtype if dtype else {} pa_columns_types: Dict[str, Optional[pa.DataType]] = pyarrow_types_from_pandas( df=df, index=index, ignore_cols=list(casts.keys()), index_left=index_left) athena_columns_types: Dict[str, str] = {} for k, v in pa_columns_types.items(): if v is None: athena_columns_types[k] = casts[k].replace(" ", "") else: try: athena_columns_types[k] = pyarrow2athena(dtype=v) except exceptions.UndetectedType as ex: raise exceptions.UndetectedType( "Impossible to infer the equivalent Athena data type " f"for the {k} column. " "It is completely empty (only null values) " f"and has a too generic data type ({df[k].dtype}). " "Please, cast this columns with a more deterministic data type " f"(e.g. df['{k}'] = df['{k}'].astype('string')) or " "pass the column schema as argument for AWS Data Wrangler " f"(e.g. dtype={{'{k}': 'string'}}") from ex except exceptions.UnsupportedType as ex: raise exceptions.UnsupportedType( f"Unsupported Pyarrow type: {v} for column {k}") from ex _logger.debug("athena_columns_types: %s", athena_columns_types) return athena_columns_types
def _get_query_metadata( query_execution_id: str, categories: List[str] = None, boto3_session: Optional[boto3.Session] = None ) -> Tuple[Dict[str, str], List[str], List[str], Dict[str, Any], List[str]]: """Get query metadata.""" cols_types: Dict[str, str] = get_query_columns_types( query_execution_id=query_execution_id, boto3_session=boto3_session ) _logger.debug("cols_types: %s", cols_types) dtype: Dict[str, str] = {} parse_timestamps: List[str] = [] parse_dates: List[str] = [] converters: Dict[str, Any] = {} binaries: List[str] = [] col_name: str col_type: str for col_name, col_type in cols_types.items(): if col_type == "array": raise exceptions.UnsupportedType( "List data type is not support with ctas_approach=False. " "Please use ctas_approach=True for List columns." ) if col_type == "row": raise exceptions.UnsupportedType( "Struct data type is not support with ctas_approach=False. " "Please use ctas_approach=True for Struct columns." ) pandas_type: str = _data_types.athena2pandas(dtype=col_type) if (categories is not None) and (col_name in categories): dtype[col_name] = "category" elif pandas_type in ["datetime64", "date"]: parse_timestamps.append(col_name) if pandas_type == "date": parse_dates.append(col_name) elif pandas_type == "bytes": dtype[col_name] = "string" binaries.append(col_name) elif pandas_type == "decimal": converters[col_name] = lambda x: Decimal(str(x)) if str(x) not in ("", "none", " ", "<NA>") else None else: dtype[col_name] = pandas_type _logger.debug("dtype: %s", dtype) _logger.debug("parse_timestamps: %s", parse_timestamps) _logger.debug("parse_dates: %s", parse_dates) _logger.debug("converters: %s", converters) _logger.debug("binaries: %s", binaries) return dtype, parse_timestamps, parse_dates, converters, binaries
def pyarrow2athena( # pylint: disable=too-many-branches,too-many-return-statements dtype: pa.DataType, ignore_null: bool = False) -> str: """Pyarrow to Athena data types conversion.""" if pa.types.is_int8(dtype): return "tinyint" if pa.types.is_int16(dtype) or pa.types.is_uint8(dtype): return "smallint" if pa.types.is_int32(dtype) or pa.types.is_uint16(dtype): return "int" if pa.types.is_int64(dtype) or pa.types.is_uint32(dtype): return "bigint" if pa.types.is_uint64(dtype): raise exceptions.UnsupportedType( "There is no support for uint64, please consider int64 or uint32.") if pa.types.is_float32(dtype): return "float" if pa.types.is_float64(dtype): return "double" if pa.types.is_boolean(dtype): return "boolean" if pa.types.is_string(dtype): return "string" if pa.types.is_timestamp(dtype): return "timestamp" if pa.types.is_date(dtype): return "date" if pa.types.is_binary(dtype): return "binary" if pa.types.is_dictionary(dtype): return pyarrow2athena(dtype=dtype.value_type) if pa.types.is_decimal(dtype): return f"decimal({dtype.precision},{dtype.scale})" if pa.types.is_list(dtype): return f"array<{pyarrow2athena(dtype=dtype.value_type)}>" if pa.types.is_struct(dtype): return f"struct<{','.join([f'{f.name}:{pyarrow2athena(dtype=f.type)}' for f in dtype])}>" if pa.types.is_map(dtype): return f"map<{pyarrow2athena(dtype=dtype.key_type)}, {pyarrow2athena(dtype=dtype.item_type)}>" if dtype == pa.null(): if ignore_null: return "" raise exceptions.UndetectedType( "We can not infer the data type from an entire null object column") raise exceptions.UnsupportedType(f"Unsupported Pyarrow type: {dtype}")
def pyarrow2timestream(dtype: pa.DataType) -> str: # pylint: disable=too-many-branches,too-many-return-statements """Pyarrow to Amazon Timestream data types conversion.""" if pa.types.is_int8(dtype): return "BIGINT" if pa.types.is_int16(dtype) or pa.types.is_uint8(dtype): return "BIGINT" if pa.types.is_int32(dtype) or pa.types.is_uint16(dtype): return "BIGINT" if pa.types.is_int64(dtype) or pa.types.is_uint32(dtype): return "BIGINT" if pa.types.is_uint64(dtype): return "BIGINT" if pa.types.is_float32(dtype): return "DOUBLE" if pa.types.is_float64(dtype): return "DOUBLE" if pa.types.is_boolean(dtype): return "BOOLEAN" if pa.types.is_string(dtype): return "VARCHAR" raise exceptions.UnsupportedType(f"Unsupported Amazon Timestream measure type: {dtype}")
def pyarrow2mysql( # pylint: disable=too-many-branches,too-many-return-statements dtype: pa.DataType, string_type: str ) -> str: """Pyarrow to MySQL data types conversion.""" if pa.types.is_int8(dtype): return "TINYINT" if pa.types.is_uint8(dtype): return "UNSIGNED TINYINT" if pa.types.is_int16(dtype): return "SMALLINT" if pa.types.is_uint16(dtype): return "UNSIGNED SMALLINT" if pa.types.is_int32(dtype): return "INTEGER" if pa.types.is_uint32(dtype): return "UNSIGNED INTEGER" if pa.types.is_int64(dtype): return "BIGINT" if pa.types.is_uint64(dtype): return "UNSIGNED BIGINT" if pa.types.is_float32(dtype): return "FLOAT" if pa.types.is_float64(dtype): return "DOUBLE PRECISION" if pa.types.is_boolean(dtype): return "BOOLEAN" if pa.types.is_string(dtype): return string_type if pa.types.is_timestamp(dtype): return "TIMESTAMP" if pa.types.is_date(dtype): return "DATE" if pa.types.is_decimal(dtype): return f"DECIMAL({dtype.precision},{dtype.scale})" if pa.types.is_dictionary(dtype): return pyarrow2mysql(dtype=dtype.value_type, string_type=string_type) if pa.types.is_binary(dtype): return "BLOB" raise exceptions.UnsupportedType(f"Unsupported MySQL type: {dtype}")
def _get_query_metadata( # pylint: disable=too-many-statements query_execution_id: str, boto3_session: boto3.Session, categories: Optional[List[str]] = None, query_execution_payload: Optional[Dict[str, Any]] = None, ) -> _QueryMetadata: """Get query metadata.""" if (query_execution_payload is not None) and (query_execution_payload["Status"]["State"] in _QUERY_FINAL_STATES): if query_execution_payload["Status"]["State"] != "SUCCEEDED": reason: str = query_execution_payload["Status"]["StateChangeReason"] raise exceptions.QueryFailed(f"Query error: {reason}") _query_execution_payload: Dict[str, Any] = query_execution_payload else: _query_execution_payload = wait_query(query_execution_id=query_execution_id, boto3_session=boto3_session) cols_types: Dict[str, str] = get_query_columns_types( query_execution_id=query_execution_id, boto3_session=boto3_session ) _logger.debug("cols_types: %s", cols_types) dtype: Dict[str, str] = {} parse_timestamps: List[str] = [] parse_dates: List[str] = [] converters: Dict[str, Any] = {} binaries: List[str] = [] col_name: str col_type: str for col_name, col_type in cols_types.items(): if col_type == "array": raise exceptions.UnsupportedType( "List data type is not support with ctas_approach=False. " "Please use ctas_approach=True for List columns." ) if col_type == "row": raise exceptions.UnsupportedType( "Struct data type is not support with ctas_approach=False. " "Please use ctas_approach=True for Struct columns." ) pandas_type: str = _data_types.athena2pandas(dtype=col_type) if (categories is not None) and (col_name in categories): dtype[col_name] = "category" elif pandas_type in ["datetime64", "date"]: parse_timestamps.append(col_name) if pandas_type == "date": parse_dates.append(col_name) elif pandas_type == "bytes": dtype[col_name] = "string" binaries.append(col_name) elif pandas_type == "decimal": converters[col_name] = lambda x: Decimal(str(x)) if str(x) not in ("", "none", " ", "<NA>") else None else: dtype[col_name] = pandas_type output_location: Optional[str] = None if "ResultConfiguration" in _query_execution_payload: output_location = _query_execution_payload["ResultConfiguration"].get("OutputLocation") athena_statistics: Dict[str, Union[int, str]] = _query_execution_payload.get("Statistics", {}) manifest_location: Optional[str] = str(athena_statistics.get("DataManifestLocation")) query_metadata: _QueryMetadata = _QueryMetadata( execution_id=query_execution_id, dtype=dtype, parse_timestamps=parse_timestamps, parse_dates=parse_dates, converters=converters, binaries=binaries, output_location=output_location, manifest_location=manifest_location, raw_payload=_query_execution_payload, ) _logger.debug("query_metadata:\n%s", query_metadata) return query_metadata