コード例 #1
0
ファイル: _fs.py プロジェクト: telegit/aws-data-wrangler
 def close(self) -> None:
     """Clean up the cache."""
     if self.closed:  # pylint: disable=using-constant-test
         return None
     if self.writable():
         _logger.debug("Closing: %s parts", self._parts_count)
         if self._parts_count > 0:
             self.flush(force=True)
             parts: List[Dict[str,
                              Union[str,
                                    int]]] = self._upload_proxy.close()
             part_info: Dict[str, List[Dict[str, Any]]] = {"Parts": parts}
             _logger.debug("Running complete_multipart_upload...")
             _utils.try_it(
                 f=self._client.complete_multipart_upload,
                 ex=_S3_RETRYABLE_ERRORS,
                 base=0.5,
                 max_num_tries=6,
                 Bucket=self._bucket,
                 Key=self._key,
                 UploadId=self._mpu["UploadId"],
                 MultipartUpload=part_info,
                 **get_botocore_valid_kwargs(
                     function_name="complete_multipart_upload",
                     s3_additional_kwargs=self._s3_additional_kwargs),
             )
             _logger.debug("complete_multipart_upload done!")
         elif self._buffer.tell() > 0:
             _logger.debug("put_object")
             _utils.try_it(
                 f=self._client.put_object,
                 ex=_S3_RETRYABLE_ERRORS,
                 base=0.5,
                 max_num_tries=6,
                 Bucket=self._bucket,
                 Key=self._key,
                 Body=self._buffer.getvalue(),
                 **get_botocore_valid_kwargs(
                     function_name="put_object",
                     s3_additional_kwargs=self._s3_additional_kwargs),
             )
         self._parts_count = 0
         self._upload_proxy.close()
         self._buffer.seek(0)
         self._buffer.truncate(0)
         self._buffer.close()
     elif self.readable():
         self._cache = b""
     else:
         raise RuntimeError(f"Invalid mode: {self._mode}")
     super().close()
     return None
コード例 #2
0
def _describe_object(
    path: str,
    boto3_session: boto3.Session,
    s3_additional_kwargs: Optional[Dict[str, Any]],
    version_id: Optional[str] = None,
) -> Tuple[str, Dict[str, Any]]:
    client_s3: boto3.client = _utils.client(service_name="s3",
                                            session=boto3_session)
    bucket: str
    key: str
    bucket, key = _utils.parse_path(path=path)
    if s3_additional_kwargs:
        extra_kwargs: Dict[str, Any] = _fs.get_botocore_valid_kwargs(
            function_name="head_object",
            s3_additional_kwargs=s3_additional_kwargs)
    else:
        extra_kwargs = {}
    desc: Dict[str, Any]
    if version_id:
        extra_kwargs["VersionId"] = version_id
    desc = _utils.try_it(f=client_s3.head_object,
                         ex=client_s3.exceptions.NoSuchKey,
                         Bucket=bucket,
                         Key=key,
                         **extra_kwargs)
    return path, desc
コード例 #3
0
def get_work_group(
        workgroup: str,
        boto3_session: Optional[boto3.Session] = None) -> Dict[str, Any]:
    """Return information about the workgroup with the specified name.

    Parameters
    ----------
    workgroup : str
        Work Group name.
    boto3_session : boto3.Session(), optional
        Boto3 Session. The default boto3 session will be used if boto3_session receive None.

    Returns
    -------
    Dict[str, Any]
        https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/athena.html#Athena.Client.get_work_group

    Examples
    --------
    >>> import awswrangler as wr
    >>> res = wr.athena.get_work_group(workgroup='workgroup_name')

    """
    client_athena: boto3.client = _utils.client(service_name="athena",
                                                session=boto3_session)
    return cast(
        Dict[str, Any],
        _utils.try_it(
            f=client_athena.get_work_group,
            ex=botocore.exceptions.ClientError,
            ex_code="ThrottlingException",
            max_num_tries=5,
            WorkGroup=workgroup,
        ),
    )
コード例 #4
0
 def _caller(
     bucket: str,
     key: str,
     part: int,
     upload_id: str,
     data: bytes,
     boto3_primitives: _utils.Boto3PrimitivesType,
     boto3_kwargs: Dict[str, Any],
 ) -> Dict[str, Union[str, int]]:
     _logger.debug("Upload part %s started.", part)
     boto3_session: boto3.Session = _utils.boto3_from_primitives(primitives=boto3_primitives)
     client: boto3.client = _utils.client(service_name="s3", session=boto3_session)
     resp: Dict[str, Any] = _utils.try_it(
         f=client.upload_part,
         ex=_S3_RETRYABLE_ERRORS,
         base=0.5,
         max_num_tries=6,
         Bucket=bucket,
         Key=key,
         Body=data,
         PartNumber=part,
         UploadId=upload_id,
         **boto3_kwargs,
     )
     _logger.debug("Upload part %s done.", part)
     return {"PartNumber": part, "ETag": resp["ETag"]}
コード例 #5
0
ファイル: _fs.py プロジェクト: Westerley/aws-data-wrangler
def _fetch_range(
    range_values: Tuple[int, int],
    bucket: str,
    key: str,
    boto3_primitives: _utils.Boto3PrimitivesType,
    boto3_kwargs: Dict[str, Any],
) -> Tuple[int, bytes]:
    start, end = range_values
    _logger.debug("Fetching: s3://%s/%s - Range: %s-%s", bucket, key, start,
                  end)
    boto3_session: boto3.Session = _utils.boto3_from_primitives(
        primitives=boto3_primitives)
    client: boto3.client = _utils.client(service_name="s3",
                                         session=boto3_session)
    resp: Dict[str, Any] = _utils.try_it(
        f=client.get_object,
        ex=_S3_RETRYABLE_ERRORS,
        base=0.5,
        max_num_tries=6,
        Bucket=bucket,
        Key=key,
        Range=f"bytes={start}-{end - 1}",
        **boto3_kwargs,
    )
    return start, cast(bytes, resp["Body"].read())
コード例 #6
0
def get_query_execution(
        query_execution_id: str,
        boto3_session: Optional[boto3.Session] = None) -> Dict[str, Any]:
    """Fetch query execution details.

    https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/athena.html#Athena.Client.get_query_execution

    Parameters
    ----------
    query_execution_id : str
        Athena query execution ID.
    boto3_session : boto3.Session(), optional
        Boto3 Session. The default boto3 session will be used if boto3_session receive None.

    Returns
    -------
    Dict[str, Any]
        Dictionary with the get_query_execution response.

    Examples
    --------
    >>> import awswrangler as wr
    >>> res = wr.athena.get_query_execution(query_execution_id='query-execution-id')

    """
    client_athena: boto3.client = _utils.client(service_name="athena",
                                                session=boto3_session)
    response: Dict[str, Any] = _utils.try_it(
        f=client_athena.get_query_execution,
        ex=botocore.exceptions.ClientError,
        ex_code="ThrottlingException",
        max_num_tries=5,
        QueryExecutionId=query_execution_id,
    )
    return cast(Dict[str, Any], response["QueryExecution"])
コード例 #7
0
def _write_batch(
    database: str,
    table: str,
    cols_names: List[str],
    measure_type: str,
    batch: List[Any],
    boto3_primitives: _utils.Boto3PrimitivesType,
) -> List[Dict[str, str]]:
    boto3_session: boto3.Session = _utils.boto3_from_primitives(
        primitives=boto3_primitives)
    client: boto3.client = _utils.client(
        service_name="timestream-write",
        session=boto3_session,
        botocore_config=Config(read_timeout=20,
                               max_pool_connections=5000,
                               retries={"max_attempts": 10}),
    )
    try:
        _utils.try_it(
            f=client.write_records,
            ex=(client.exceptions.ThrottlingException,
                client.exceptions.InternalServerException),
            max_num_tries=5,
            DatabaseName=database,
            TableName=table,
            Records=[{
                "Dimensions": [{
                    "Name": name,
                    "DimensionValueType": "VARCHAR",
                    "Value": str(value)
                } for name, value in zip(cols_names[2:], rec[2:])],
                "MeasureName":
                cols_names[1],
                "MeasureValueType":
                measure_type,
                "MeasureValue":
                str(rec[1]),
                "Time":
                str(round(rec[0].timestamp() * 1_000)),
                "TimeUnit":
                "MILLISECONDS",
            } for rec in batch],
        )
    except client.exceptions.RejectedRecordsException as ex:
        return cast(List[Dict[str, str]], ex.response["RejectedRecords"])
    return []
コード例 #8
0
def _describe_object(path: str, boto3_session: boto3.Session) -> Tuple[str, Dict[str, Any]]:
    client_s3: boto3.client = _utils.client(service_name="s3", session=boto3_session)
    bucket: str
    key: str
    bucket, key = _utils.parse_path(path=path)
    desc: Dict[str, Any] = _utils.try_it(
        f=client_s3.head_object, ex=client_s3.exceptions.NoSuchKey, Bucket=bucket, Key=key
    )
    return path, desc
コード例 #9
0
ファイル: _utils.py プロジェクト: Westerley/aws-data-wrangler
def extract_cloudformation_outputs():
    outputs = {}
    client = boto3.client("cloudformation")
    response = try_it(client.describe_stacks, botocore.exceptions.ClientError, max_num_tries=5)
    for stack in response.get("Stacks"):
        if (stack["StackName"] in ["aws-data-wrangler-base", "aws-data-wrangler-databases"]) and (
            stack["StackStatus"] in CFN_VALID_STATUS
        ):
            for output in stack.get("Outputs"):
                outputs[output.get("OutputKey")] = output.get("OutputValue")
    return outputs
コード例 #10
0
def validate_workgroup_key(workgroup):
    if "ResultConfiguration" in workgroup["Configuration"]:
        if "EncryptionConfiguration" in workgroup["Configuration"]["ResultConfiguration"]:
            if "KmsKey" in workgroup["Configuration"]["ResultConfiguration"]["EncryptionConfiguration"]:
                kms_client = boto3.client("kms")
                key = try_it(
                    kms_client.describe_key,
                    kms_client.exceptions.NotFoundException,
                    KeyId=workgroup["Configuration"]["ResultConfiguration"]["EncryptionConfiguration"]["KmsKey"],
                )["KeyMetadata"]
                if key["KeyState"] != "Enabled":
                    return False
    return True
コード例 #11
0
ファイル: _utils.py プロジェクト: trsilva32/aws-data-wrangler
def _start_query_execution(
    sql: str,
    wg_config: _WorkGroupConfig,
    database: Optional[str] = None,
    data_source: Optional[str] = None,
    s3_output: Optional[str] = None,
    workgroup: Optional[str] = None,
    encryption: Optional[str] = None,
    kms_key: Optional[str] = None,
    boto3_session: Optional[boto3.Session] = None,
) -> str:
    args: Dict[str, Any] = {"QueryString": sql}
    session: boto3.Session = _utils.ensure_session(session=boto3_session)

    # s3_output
    args["ResultConfiguration"] = {
        "OutputLocation": _get_s3_output(s3_output=s3_output, wg_config=wg_config, boto3_session=session)
    }

    # encryption
    if wg_config.enforced is True:
        if wg_config.encryption is not None:
            args["ResultConfiguration"]["EncryptionConfiguration"] = {"EncryptionOption": wg_config.encryption}
            if wg_config.kms_key is not None:
                args["ResultConfiguration"]["EncryptionConfiguration"]["KmsKey"] = wg_config.kms_key
    else:
        if encryption is not None:
            args["ResultConfiguration"]["EncryptionConfiguration"] = {"EncryptionOption": encryption}
            if kms_key is not None:
                args["ResultConfiguration"]["EncryptionConfiguration"]["KmsKey"] = kms_key

    # database
    if database is not None:
        args["QueryExecutionContext"] = {"Database": database}
        if data_source is not None:
            args["QueryExecutionContext"]["Catalog"] = data_source

    # workgroup
    if workgroup is not None:
        args["WorkGroup"] = workgroup

    client_athena: boto3.client = _utils.client(service_name="athena", session=session)
    _logger.debug("args: \n%s", pprint.pformat(args))
    response: Dict[str, Any] = _utils.try_it(
        f=client_athena.start_query_execution,
        ex=botocore.exceptions.ClientError,
        ex_code="ThrottlingException",
        max_num_tries=5,
        **args,
    )
    return cast(str, response["QueryExecutionId"])
コード例 #12
0
def _get_table_objects(
    catalog_id: Optional[str],
    database: str,
    table: str,
    transaction_id: str,
    boto3_session: Optional[boto3.Session],
    partition_cols: Optional[List[str]] = None,
    partitions_types: Optional[Dict[str, str]] = None,
    partitions_values: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
    """Get Governed Table Objects from Lake Formation Engine."""
    session: boto3.Session = _utils.ensure_session(session=boto3_session)
    client_lakeformation: boto3.client = _utils.client(
        service_name="lakeformation", session=session)

    scan_kwargs: Dict[str, Union[str, int]] = _catalog_id(
        catalog_id=catalog_id,
        **_transaction_id(transaction_id=transaction_id,
                          DatabaseName=database,
                          TableName=table,
                          MaxResults=100),
    )
    if partition_cols and partitions_types and partitions_values:
        scan_kwargs["PartitionPredicate"] = _build_partition_predicate(
            partition_cols=partition_cols,
            partitions_types=partitions_types,
            partitions_values=partitions_values)

    next_token: str = "init_token"  # Dummy token
    table_objects: List[Dict[str, Any]] = []
    while next_token:
        response = _utils.try_it(
            f=client_lakeformation.get_table_objects,
            ex=botocore.exceptions.ClientError,
            ex_code="ResourceNotReadyException",
            base=1.0,
            max_num_tries=5,
            **scan_kwargs,
        )
        for objects in response["Objects"]:
            for table_object in objects["Objects"]:
                if objects["PartitionValues"]:
                    table_object["PartitionValues"] = objects[
                        "PartitionValues"]
                table_objects.append(table_object)
        next_token = response.get("NextToken", None)
        scan_kwargs["NextToken"] = next_token
    return table_objects
コード例 #13
0
ファイル: _get.py プロジェクト: trsilva32/aws-data-wrangler
def get_connection(
        name: str,
        catalog_id: Optional[str] = None,
        boto3_session: Optional[boto3.Session] = None) -> Dict[str, Any]:
    """Get Glue connection details.

    Parameters
    ----------
    name : str
        Connection name.
    catalog_id : str, optional
        The ID of the Data Catalog from which to retrieve Databases.
        If none is provided, the AWS account ID is used by default.
    boto3_session : boto3.Session(), optional
        Boto3 Session. The default boto3 session will be used if boto3_session receive None.

    Returns
    -------
    Dict[str, Any]
        API Response for:
        https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/glue.html#Glue.Client.get_connection

    Examples
    --------
    >>> import awswrangler as wr
    >>> res = wr.catalog.get_connection(name='my_connection')

    """
    client_glue: boto3.client = _utils.client(service_name="glue",
                                              session=boto3_session)

    res = _utils.try_it(
        f=client_glue.get_connection,
        ex=botocore.exceptions.ClientError,
        ex_code="ThrottlingException",
        max_num_tries=3,
        **_catalog_id(catalog_id=catalog_id, Name=name, HidePassword=False),
    )["Connection"]

    if "ENCRYPTED_PASSWORD" in res["ConnectionProperties"]:
        client_kms = _utils.client(service_name="kms", session=boto3_session)
        pwd = client_kms.decrypt(CiphertextBlob=base64.b64decode(
            res["ConnectionProperties"]
            ["ENCRYPTED_PASSWORD"]))["Plaintext"].decode("utf-8")
        res["ConnectionProperties"]["PASSWORD"] = pwd
    return cast(Dict[str, Any], res)
コード例 #14
0
ファイル: _fs.py プロジェクト: telegit/aws-data-wrangler
 def flush(self, force: bool = False) -> None:
     """Write buffered data to S3."""
     if self.closed:  # pylint: disable=using-constant-test
         raise RuntimeError("I/O operation on closed file.")
     if self.writable() and self._buffer.closed is False:
         total_size: int = self._buffer.tell()
         if total_size < _MIN_WRITE_BLOCK and force is False:
             return None
         if total_size == 0:
             return None
         _logger.debug("Flushing: %s bytes", total_size)
         self._mpu = self._mpu or _utils.try_it(
             f=self._client.create_multipart_upload,
             ex=_S3_RETRYABLE_ERRORS,
             base=0.5,
             max_num_tries=6,
             Bucket=self._bucket,
             Key=self._key,
             **get_botocore_valid_kwargs(
                 function_name="create_multipart_upload",
                 s3_additional_kwargs=self._s3_additional_kwargs),
         )
         self._buffer.seek(0)
         for chunk_size in _utils.get_even_chunks_sizes(
                 total_size=total_size,
                 chunk_size=_MIN_WRITE_BLOCK,
                 upper_bound=False):
             _logger.debug("chunk_size: %s bytes", chunk_size)
             self._parts_count += 1
             self._upload_proxy.upload(
                 bucket=self._bucket,
                 key=self._key,
                 part=self._parts_count,
                 upload_id=self._mpu["UploadId"],
                 data=self._buffer.read(chunk_size),
                 boto3_session=self._boto3_session,
                 boto3_kwargs=get_botocore_valid_kwargs(
                     function_name="upload_part",
                     s3_additional_kwargs=self._s3_additional_kwargs),
             )
         self._buffer.seek(0)
         self._buffer.truncate(0)
         self._buffer.close()
         self._buffer = io.BytesIO()
     return None
コード例 #15
0
def create_workgroup(wkg_name, config):
    client = boto3.client("athena")
    wkgs = list_workgroups()
    wkgs = [x["Name"] for x in wkgs["WorkGroups"]]
    deleted = False
    if wkg_name in wkgs:
        wkg = try_it(client.get_work_group, botocore.exceptions.ClientError, max_num_tries=5, WorkGroup=wkg_name)[
            "WorkGroup"
        ]
        if validate_workgroup_key(workgroup=wkg) is False:
            client.delete_work_group(WorkGroup=wkg_name, RecursiveDeleteOption=True)
            deleted = True
    if wkg_name not in wkgs or deleted is True:
        client.create_work_group(
            Name=wkg_name,
            Configuration=config,
            Description=f"AWS Data Wrangler Test - {wkg_name}",
        )
    return wkg_name
コード例 #16
0
def _fetch_range(
    range_values: Tuple[int, int],
    bucket: str,
    key: str,
    s3_client: boto3.client,
    boto3_kwargs: Dict[str, Any],
) -> Tuple[int, bytes]:
    start, end = range_values
    _logger.debug("Fetching: s3://%s/%s - Range: %s-%s", bucket, key, start, end)
    resp: Dict[str, Any] = _utils.try_it(
        f=s3_client.get_object,
        ex=_S3_RETRYABLE_ERRORS,
        base=0.5,
        max_num_tries=6,
        Bucket=bucket,
        Key=key,
        Range=f"bytes={start}-{end - 1}",
        **boto3_kwargs,
    )
    return start, cast(bytes, resp["Body"].read())
コード例 #17
0
ファイル: db.py プロジェクト: vikramshitole/aws-data-wrangler
def to_sql(df: pd.DataFrame, con: sqlalchemy.engine.Engine,
           **pandas_kwargs: Any) -> None:
    """Write records stored in a DataFrame to a SQL database.

    Support for **Redshift**, **PostgreSQL** and **MySQL**.

    Support for all pandas to_sql() arguments:
    https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_sql.html

    Note
    ----
    Redshift: For large DataFrames (1MM+ rows) consider the function **wr.db.copy_to_redshift()**.

    Note
    ----
    Redshift: `index=False` will be forced.

    Parameters
    ----------
    df : pandas.DataFrame
        Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html
    con : sqlalchemy.engine.Engine
        SQLAlchemy Engine. Please use,
        wr.db.get_engine(), wr.db.get_redshift_temp_engine() or wr.catalog.get_engine()
    pandas_kwargs
        KEYWORD arguments forwarded to pandas.DataFrame.to_sql(). You can NOT pass `pandas_kwargs` explicit, just add
        valid Pandas arguments in the function call and Wrangler will accept it.
        e.g. wr.db.to_sql(df, con=con, name="table_name", schema="schema_name", if_exists="replace", index=False)
        https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_sql.html

    Returns
    -------
    None
        None.

    Examples
    --------
    Writing to Redshift with temporary credentials

    >>> import awswrangler as wr
    >>> import pandas as pd
    >>> wr.db.to_sql(
    ...     df=pd.DataFrame({'col': [1, 2, 3]}),
    ...     con=wr.db.get_redshift_temp_engine(cluster_identifier="...", user="******"),
    ...     name="table_name",
    ...     schema="schema_name"
    ... )

    Writing to Redshift with temporary credentials and using pandas_kwargs

    >>> import awswrangler as wr
    >>> import pandas as pd
    >>> wr.db.to_sql(
    ...     df=pd.DataFrame({'col': [1, 2, 3]}),
    ...     con=wr.db.get_redshift_temp_engine(cluster_identifier="...", user="******"),
    ...     name="table_name",
    ...     schema="schema_name",
    ...     if_exists="replace",
    ...     index=False,
    ... )

    Writing to Redshift from Glue Catalog Connections

    >>> import awswrangler as wr
    >>> import pandas as pd
    >>> wr.db.to_sql(
    ...     df=pd.DataFrame({'col': [1, 2, 3]}),
    ...     con=wr.catalog.get_engine(connection="..."),
    ...     name="table_name",
    ...     schema="schema_name"
    ... )

    """
    if "pandas_kwargs" in pandas_kwargs:
        raise exceptions.InvalidArgument(
            "You can NOT pass `pandas_kwargs` explicit, just add valid "
            "Pandas arguments in the function call and Wrangler will accept it."
            "e.g. wr.db.to_sql(df, con, name='...', schema='...', if_exists='replace')"
        )
    if df.empty is True:
        raise exceptions.EmptyDataFrame()
    if not isinstance(con, sqlalchemy.engine.Engine):
        raise exceptions.InvalidConnection(
            "Invalid 'con' argument, please pass a "
            "SQLAlchemy Engine. Use wr.db.get_engine(), "
            "wr.db.get_redshift_temp_engine() or wr.catalog.get_engine()")
    if "dtype" in pandas_kwargs:
        cast_columns: Dict[str, VisitableType] = pandas_kwargs["dtype"]
    else:
        cast_columns = {}
    dtypes: Dict[str,
                 VisitableType] = _data_types.sqlalchemy_types_from_pandas(
                     df=df, db_type=con.name, dtype=cast_columns)
    pandas_kwargs["dtype"] = dtypes
    pandas_kwargs["con"] = con
    if pandas_kwargs["con"].name.lower(
    ) == "redshift":  # Redshift does not accept index
        pandas_kwargs["index"] = False
    _utils.try_it(f=df.to_sql,
                  ex=sqlalchemy.exc.InternalError,
                  **pandas_kwargs)
コード例 #18
0
ファイル: _create.py プロジェクト: telegit/aws-data-wrangler
def _create_table(  # pylint: disable=too-many-branches,too-many-statements
    database: str,
    table: str,
    description: Optional[str],
    parameters: Optional[Dict[str, str]],
    mode: str,
    catalog_versioning: bool,
    boto3_session: Optional[boto3.Session],
    table_input: Dict[str, Any],
    table_exist: bool,
    projection_enabled: bool,
    partitions_types: Optional[Dict[str, str]],
    columns_comments: Optional[Dict[str, str]],
    projection_types: Optional[Dict[str, str]],
    projection_ranges: Optional[Dict[str, str]],
    projection_values: Optional[Dict[str, str]],
    projection_intervals: Optional[Dict[str, str]],
    projection_digits: Optional[Dict[str, str]],
    catalog_id: Optional[str],
) -> None:
    # Description
    mode = _update_if_necessary(dic=table_input,
                                key="Description",
                                value=description,
                                mode=mode)

    # Parameters
    parameters = parameters if parameters else {}
    for k, v in parameters.items():
        mode = _update_if_necessary(dic=table_input["Parameters"],
                                    key=k,
                                    value=v,
                                    mode=mode)

    # Projection
    if projection_enabled is True:
        table_input["Parameters"]["projection.enabled"] = "true"
        partitions_types = partitions_types if partitions_types else {}
        projection_types = projection_types if projection_types else {}
        projection_ranges = projection_ranges if projection_ranges else {}
        projection_values = projection_values if projection_values else {}
        projection_intervals = projection_intervals if projection_intervals else {}
        projection_digits = projection_digits if projection_digits else {}
        projection_types = {
            sanitize_column_name(k): v
            for k, v in projection_types.items()
        }
        projection_ranges = {
            sanitize_column_name(k): v
            for k, v in projection_ranges.items()
        }
        projection_values = {
            sanitize_column_name(k): v
            for k, v in projection_values.items()
        }
        projection_intervals = {
            sanitize_column_name(k): v
            for k, v in projection_intervals.items()
        }
        projection_digits = {
            sanitize_column_name(k): v
            for k, v in projection_digits.items()
        }
        for k, v in projection_types.items():
            dtype: Optional[str] = partitions_types.get(k)
            if dtype is None:
                raise exceptions.InvalidArgumentCombination(
                    f"Column {k} appears as projected column but not as partitioned column."
                )
            if dtype == "date":
                table_input["Parameters"][
                    f"projection.{k}.format"] = "yyyy-MM-dd"
            elif dtype == "timestamp":
                table_input["Parameters"][
                    f"projection.{k}.format"] = "yyyy-MM-dd HH:mm:ss"
                table_input["Parameters"][
                    f"projection.{k}.interval.unit"] = "SECONDS"
                table_input["Parameters"][f"projection.{k}.interval"] = "1"
        for k, v in projection_types.items():
            mode = _update_if_necessary(dic=table_input["Parameters"],
                                        key=f"projection.{k}.type",
                                        value=v,
                                        mode=mode)
        for k, v in projection_ranges.items():
            mode = _update_if_necessary(dic=table_input["Parameters"],
                                        key=f"projection.{k}.range",
                                        value=v,
                                        mode=mode)
        for k, v in projection_values.items():
            mode = _update_if_necessary(dic=table_input["Parameters"],
                                        key=f"projection.{k}.values",
                                        value=v,
                                        mode=mode)
        for k, v in projection_intervals.items():
            mode = _update_if_necessary(dic=table_input["Parameters"],
                                        key=f"projection.{k}.interval",
                                        value=str(v),
                                        mode=mode)
        for k, v in projection_digits.items():
            mode = _update_if_necessary(dic=table_input["Parameters"],
                                        key=f"projection.{k}.digits",
                                        value=str(v),
                                        mode=mode)
    else:
        table_input["Parameters"]["projection.enabled"] = "false"

    # Column comments
    columns_comments = columns_comments if columns_comments else {}
    columns_comments = {
        sanitize_column_name(k): v
        for k, v in columns_comments.items()
    }
    if columns_comments:
        for col in table_input["StorageDescriptor"]["Columns"]:
            name: str = col["Name"]
            if name in columns_comments:
                mode = _update_if_necessary(dic=col,
                                            key="Comment",
                                            value=columns_comments[name],
                                            mode=mode)
        for par in table_input["PartitionKeys"]:
            name = par["Name"]
            if name in columns_comments:
                mode = _update_if_necessary(dic=par,
                                            key="Comment",
                                            value=columns_comments[name],
                                            mode=mode)

    _logger.debug("table_input: %s", table_input)

    session: boto3.Session = _utils.ensure_session(session=boto3_session)
    client_glue: boto3.client = _utils.client(service_name="glue",
                                              session=session)
    skip_archive: bool = not catalog_versioning
    if mode not in ("overwrite", "append", "overwrite_partitions", "update"):
        raise exceptions.InvalidArgument(
            f"{mode} is not a valid mode. It must be 'overwrite', 'append' or 'overwrite_partitions'."
        )
    if table_exist is True and mode == "overwrite":
        delete_all_partitions(table=table,
                              database=database,
                              catalog_id=catalog_id,
                              boto3_session=session)
        _logger.debug("Updating table (%s)...", mode)
        client_glue.update_table(**_catalog_id(catalog_id=catalog_id,
                                               DatabaseName=database,
                                               TableInput=table_input,
                                               SkipArchive=skip_archive))
    elif (table_exist is True) and (mode in ("append", "overwrite_partitions",
                                             "update")):
        if mode == "update":
            _logger.debug("Updating table (%s)...", mode)
            client_glue.update_table(**_catalog_id(catalog_id=catalog_id,
                                                   DatabaseName=database,
                                                   TableInput=table_input,
                                                   SkipArchive=skip_archive))
    elif table_exist is False:
        try:
            _logger.debug("Creating table (%s)...", mode)
            client_glue.create_table(**_catalog_id(catalog_id=catalog_id,
                                                   DatabaseName=database,
                                                   TableInput=table_input))
        except client_glue.exceptions.AlreadyExistsException:
            if mode == "overwrite":
                _utils.try_it(
                    f=_overwrite_table,
                    ex=client_glue.exceptions.AlreadyExistsException,
                    client_glue=client_glue,
                    catalog_id=catalog_id,
                    database=database,
                    table=table,
                    table_input=table_input,
                    boto3_session=boto3_session,
                )
    _logger.debug("Leaving table as is (%s)...", mode)
コード例 #19
0
            else:
                record["MeasureName"] = measure_cols_names[0]
                record["MeasureValueType"] = "MULTI"
                record["MeasureValues"] = [{
                    "Name": measure_name,
                    "Value": str(measure_value),
                    "Type": measure_value_type
                } for measure_name, measure_value, measure_value_type in zip(
                    measure_cols_names,
                    rec[measure_cols_loc:dimensions_cols_loc], measure_types)]
            records.append(record)
        _utils.try_it(
            f=client.write_records,
            ex=(client.exceptions.ThrottlingException,
                client.exceptions.InternalServerException),
            max_num_tries=5,
            DatabaseName=database,
            TableName=table,
            Records=records,
        )
    except client.exceptions.RejectedRecordsException as ex:
        return cast(List[Dict[str, str]], ex.response["RejectedRecords"])
    return []


def _cast_value(value: str, dtype: str) -> Any:  # pylint: disable=too-many-branches,too-many-return-statements
    if dtype == "VARCHAR":
        return value
    if dtype in ("INTEGER", "BIGINT"):
        return int(value)
    if dtype == "DOUBLE":