Beispiel #1
0
def make_geography_db(
    data: Mapping[str, Any],
    con: sa.engine.Engine,
) -> None:
    metadata = sa.MetaData(bind=con)

    with con.begin() as bind:
        for table_name, schema in SCHEMAS.items():
            table = sa.Table(
                table_name,
                metadata,
                *(sa.Column(col_name, col_type)
                  for col_name, col_type in schema),
            )
            table_columns = table.c.keys()
            post_parse = POST_PARSE_FUNCTIONS.get(table_name, toolz.identity)

            table.drop(bind=bind, checkfirst=True)
            table.create(bind=bind)
            bind.execute(
                table.insert().values(),
                [
                    post_parse(dict(zip(table_columns, row)))
                    for row in data[table_name]
                ],
            )
Beispiel #2
0
def auto_migrate(engine: sqlalchemy.engine.Engine):
    """Compares the current database with all defined models and applies the diff"""
    ddl = get_migration_ddl(engine)

    with engine.begin() as connection:
        for statement in ddl:
            sys.stdout.write('\033[1;32m' + statement + '\033[0;0m')
            connection.execute(statement)
Beispiel #3
0
def drop_table(tbl_name: str, eng: sa.engine.Engine, dispose_eng=False):
    try:
        if eng.has_table(tbl_name):
            with eng.begin() as con:
                con.execute(f"DROP TABLE {tbl_name}")
    finally:
        if dispose_eng:
            eng.dispose()
Beispiel #4
0
def insert_into_table(eng: sa.engine.Engine,
                      df: pd.DataFrame,
                      table_name: str,
                      dtypes: dict = None,
                      unique_columns=None,
                      index_columns=None,
                      hash_index_columns=None,
                      dispose_eng=False):
    """
    Adds df to a new table called $table_name
    Args:
        eng: An engine object connecting the db
        df: The dataframe we want to insert to the DB
        table_name: The new table's name, assuming it is not in the DB
        dtypes: The data-types for each column in the DB
        unique_columns: Optional param for adding a unique key index for several columns,
                            needed for using merge_to_db in postgresql. If set, $dtypes also needs to be set
        dispose_eng: Whether to dispose of the engine after the read

    Returns:
        None
    """
    table_name = table_name.lower()
    if unique_columns is not None:
        assert dtypes is not None, "if unique_columns is set, dtypes cannot be none, to handle gis columns correctly"

    if dtypes is None:
        dtypes = {}
    with eng.begin() as con:
        df.to_sql(table_name,
                  con,
                  if_exists="append",
                  index=False,
                  dtype=dtypes)

        # for some reason oracle does problems with this, it is only needed in postgres so whatever
        if unique_columns is not None and eng.dialect.name == "postgresql":
            from coord2vec.common.db.postgres import get_index_str_for_unique
            con.execute(
                f"CREATE UNIQUE INDEX {table_name}_uind "
                f"ON {table_name} ({get_index_str_for_unique(unique_columns, dtypes)});"
            )

        if index_columns is not None:
            for col in index_columns:
                con.execute(
                    f"CREATE INDEX {table_name}_{col}_ind ON {table_name} (col);"
                )
        if hash_index_columns is not None:
            for col in hash_index_columns:
                con.execute(
                    f"CREATE INDEX {table_name}_{col}_ind ON {table_name} using hash(col);"
                )

    if dispose_eng:
        eng.dispose()
Beispiel #5
0
def append_to_history(*, engine: sa.engine.Engine, job_name: str) -> None:
    with engine.begin() as con:
        result = con.execute(db.status.select().where(
            db.status.c.job_name == job_name)).first()
        if result:
            con.execute(db.job_history.insert().values(
                job_name=job_name,
                status=result.status,
                started=result.started,
                ended=result.ended,
                skipped_reason=result.skipped_reason,
                error_message=result.error_message,
            ))
Beispiel #6
0
def copy_files_to_redshift(  # pylint: disable=too-many-locals,too-many-arguments
    path: Union[str, List[str]],
    manifest_directory: str,
    con: sqlalchemy.engine.Engine,
    table: str,
    schema: str,
    iam_role: str,
    parquet_infer_sampling: float = 1.0,
    mode: str = "append",
    diststyle: str = "AUTO",
    distkey: Optional[str] = None,
    sortstyle: str = "COMPOUND",
    sortkey: Optional[List[str]] = None,
    primary_keys: Optional[List[str]] = None,
    varchar_lengths_default: int = 256,
    varchar_lengths: Optional[Dict[str, int]] = None,
    use_threads: bool = True,
    boto3_session: Optional[boto3.Session] = None,
    s3_additional_kwargs: Optional[Dict[str, str]] = None,
) -> None:
    """Load Parquet files from S3 to a Table on Amazon Redshift (Through COPY command).

    https://docs.aws.amazon.com/redshift/latest/dg/r_COPY.html

    This function accepts Unix shell-style wildcards in the path argument.
    * (matches everything), ? (matches any single character),
    [seq] (matches any character in seq), [!seq] (matches any character not in seq).

    Note
    ----
    If the table does not exist yet,
    it will be automatically created for you
    using the Parquet metadata to
    infer the columns data types.

    Note
    ----
    In case of `use_threads=True` the number of threads
    that will be spawned will be gotten from os.cpu_count().

    Parameters
    ----------
    path : Union[str, List[str]]
        S3 prefix (accepts Unix shell-style wildcards)
        (e.g. s3://bucket/prefix) or list of S3 objects paths (e.g. [s3://bucket/key0, s3://bucket/key1]).
    manifest_directory : str
        S3 prefix (e.g. s3://bucket/prefix)
    con : sqlalchemy.engine.Engine
        SQLAlchemy Engine. Please use,
        wr.db.get_engine(), wr.db.get_redshift_temp_engine() or wr.catalog.get_engine()
    table : str
        Table name
    schema : str
        Schema name
    iam_role : str
        AWS IAM role with the related permissions.
    parquet_infer_sampling : float
        Random sample ratio of files that will have the metadata inspected.
        Must be `0.0 < sampling <= 1.0`.
        The higher, the more accurate.
        The lower, the faster.
    mode : str
        Append, overwrite or upsert.
    diststyle : str
        Redshift distribution styles. Must be in ["AUTO", "EVEN", "ALL", "KEY"].
        https://docs.aws.amazon.com/redshift/latest/dg/t_Distributing_data.html
    distkey : str, optional
        Specifies a column name or positional number for the distribution key.
    sortstyle : str
        Sorting can be "COMPOUND" or "INTERLEAVED".
        https://docs.aws.amazon.com/redshift/latest/dg/t_Sorting_data.html
    sortkey : List[str], optional
        List of columns to be sorted.
    primary_keys : List[str], optional
        Primary keys.
    varchar_lengths_default : int
        The size that will be set for all VARCHAR columns not specified with varchar_lengths.
    varchar_lengths : Dict[str, int], optional
        Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200}).
    use_threads : bool
        True to enable concurrent requests, False to disable multiple threads.
        If enabled os.cpu_count() will be used as the max number of threads.
    boto3_session : boto3.Session(), optional
        Boto3 Session. The default boto3 session will be used if boto3_session receive None.
    s3_additional_kwargs:
        Forward to botocore requests. Valid parameters: "ACL", "Metadata", "ServerSideEncryption", "StorageClass",
        "SSECustomerAlgorithm", "SSECustomerKey", "SSEKMSKeyId", "SSEKMSEncryptionContext", "Tagging".
        e.g. s3_additional_kwargs={'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': 'YOUR_KMY_KEY_ARN'}

    Returns
    -------
    None
        None.

    Examples
    --------
    >>> import awswrangler as wr
    >>> wr.db.copy_files_to_redshift(
    ...     path="s3://bucket/my_parquet_files/",
    ...     con=wr.catalog.get_engine(connection="my_glue_conn_name"),
    ...     table="my_table",
    ...     schema="public"
    ...     iam_role="arn:aws:iam::XXX:role/XXX"
    ... )

    """
    _varchar_lengths: Dict[
        str, int] = {} if varchar_lengths is None else varchar_lengths
    session: boto3.Session = _utils.ensure_session(session=boto3_session)
    paths: List[str] = _path2list(path=path, boto3_session=session)  # pylint: disable=protected-access
    manifest_directory = manifest_directory if manifest_directory.endswith(
        "/") else f"{manifest_directory}/"
    manifest_path: str = f"{manifest_directory}manifest.json"
    write_redshift_copy_manifest(
        manifest_path=manifest_path,
        paths=paths,
        use_threads=use_threads,
        boto3_session=session,
        s3_additional_kwargs=s3_additional_kwargs,
    )
    s3.wait_objects_exist(paths=paths + [manifest_path],
                          use_threads=False,
                          boto3_session=session)
    athena_types, _ = s3.read_parquet_metadata(path=paths,
                                               sampling=parquet_infer_sampling,
                                               dataset=False,
                                               use_threads=use_threads,
                                               boto3_session=session)
    _logger.debug("athena_types: %s", athena_types)
    redshift_types: Dict[str, str] = {}
    for col_name, col_type in athena_types.items():
        length: int = _varchar_lengths[
            col_name] if col_name in _varchar_lengths else varchar_lengths_default
        redshift_types[col_name] = _data_types.athena2redshift(
            dtype=col_type, varchar_length=length)
    with con.begin() as _con:
        created_table, created_schema = _rs_create_table(
            con=_con,
            table=table,
            schema=schema,
            redshift_types=redshift_types,
            mode=mode,
            diststyle=diststyle,
            sortstyle=sortstyle,
            distkey=distkey,
            sortkey=sortkey,
            primary_keys=primary_keys,
        )
        _rs_copy(
            con=_con,
            table=created_table,
            schema=created_schema,
            manifest_path=manifest_path,
            iam_role=iam_role,
            num_files=len(paths),
        )
        if table != created_table:  # upsert
            _rs_upsert(con=_con,
                       schema=schema,
                       table=table,
                       temp_table=created_table,
                       primary_keys=primary_keys)
    s3.delete_objects(path=[manifest_path],
                      use_threads=use_threads,
                      boto3_session=session)
Beispiel #7
0
def create_tables(*, engine: sa.engine.Engine, recreate: bool = False) -> None:
    with engine.begin() as con:
        if recreate:
            metadata.drop_all(con)
        metadata.create_all(con)
Beispiel #8
0
def add_postgis_index(eng: sa.engine.Engine, table_name: str, geom_col: str):
    with eng.begin() as con:
        con.execute(
            f"create index {table_name}_{geom_col}_idx on {table_name} using gist ({geom_col});"
        )
Beispiel #9
0
def copy_files_to_redshift(  # pylint: disable=too-many-locals,too-many-arguments
    path: Union[str, List[str]],
    manifest_directory: str,
    con: sqlalchemy.engine.Engine,
    table: str,
    schema: str,
    iam_role: str,
    mode: str = "append",
    diststyle: str = "AUTO",
    distkey: Optional[str] = None,
    sortstyle: str = "COMPOUND",
    sortkey: Optional[str] = None,
    primary_keys: Optional[List[str]] = None,
    varchar_lengths_default: int = 256,
    varchar_lengths: Optional[Dict[str, int]] = None,
    use_threads: bool = True,
    boto3_session: Optional[boto3.Session] = None,
) -> None:
    """Load Parquet files from S3 to a Table on Amazon Redshift (Through COPY command).

    https://docs.aws.amazon.com/redshift/latest/dg/r_COPY.html

    Note
    ----
    If the table does not exist yet,
    it will be automatically created for you
    using the Parquet metadata to
    infer the columns data types.

    Note
    ----
    In case of `use_threads=True` the number of process that will be spawned will be get from os.cpu_count().

    Parameters
    ----------
    path : Union[str, List[str]]
        S3 prefix (e.g. s3://bucket/prefix) or list of S3 objects paths (e.g. [s3://bucket/key0, s3://bucket/key1]).
    manifest_directory : str
        S3 prefix (e.g. s3://bucket/prefix)
    con : sqlalchemy.engine.Engine
        SQLAlchemy Engine. Please use,
        wr.db.get_engine(), wr.db.get_redshift_temp_engine() or wr.catalog.get_engine()
    table : str
        Table name
    schema : str
        Schema name
    iam_role : str
        AWS IAM role with the related permissions.
    mode : str
        Append, overwrite or upsert.
    diststyle : str
        Redshift distribution styles. Must be in ["AUTO", "EVEN", "ALL", "KEY"].
        https://docs.aws.amazon.com/redshift/latest/dg/t_Distributing_data.html
    distkey : str, optional
        Specifies a column name or positional number for the distribution key.
    sortstyle : str
        Sorting can be "COMPOUND" or "INTERLEAVED".
        https://docs.aws.amazon.com/redshift/latest/dg/t_Sorting_data.html
    sortkey : str, optional
        List of columns to be sorted.
    primary_keys : List[str], optional
        Primary keys.
    varchar_lengths_default : int
        The size that will be set for all VARCHAR columns not specified with varchar_lengths.
    varchar_lengths : Dict[str, int], optional
        Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200}).
    use_threads : bool
        True to enable concurrent requests, False to disable multiple threads.
        If enabled os.cpu_count() will be used as the max number of threads.
    boto3_session : boto3.Session(), optional
        Boto3 Session. The default boto3 session will be used if boto3_session receive None.

    Returns
    -------
    None
        None.

    Examples
    --------
    >>> import awswrangler as wr
    >>> wr.db.copy_files_to_redshift(
    ...     path="s3://bucket/my_parquet_files/",
    ...     con=wr.catalog.get_engine(connection="my_glue_conn_name"),
    ...     table="my_table",
    ...     schema="public"
    ...     iam_role="arn:aws:iam::XXX:role/XXX"
    ... )

    """
    _varchar_lengths: Dict[str, int] = {} if varchar_lengths is None else varchar_lengths
    session: boto3.Session = _utils.ensure_session(session=boto3_session)
    paths: List[str] = s3._path2list(path=path, boto3_session=session)  # pylint: disable=protected-access
    manifest_directory = manifest_directory if manifest_directory.endswith("/") else f"{manifest_directory}/"
    manifest_path: str = f"{manifest_directory}manifest.json"
    write_redshift_copy_manifest(
        manifest_path=manifest_path, paths=paths, use_threads=use_threads, boto3_session=session
    )
    s3.wait_objects_exist(paths=paths + [manifest_path], use_threads=False, boto3_session=session)
    athena_types, _ = s3.read_parquet_metadata(
        path=paths, dataset=False, use_threads=use_threads, boto3_session=session
    )
    _logger.debug(f"athena_types: {athena_types}")
    redshift_types: Dict[str, str] = {}
    for col_name, col_type in athena_types.items():
        length: int = _varchar_lengths[col_name] if col_name in _varchar_lengths else varchar_lengths_default
        redshift_types[col_name] = _data_types.athena2redshift(dtype=col_type, varchar_length=length)
    with con.begin() as _con:
        created_table, created_schema = _rs_create_table(
            con=_con,
            table=table,
            schema=schema,
            redshift_types=redshift_types,
            mode=mode,
            diststyle=diststyle,
            sortstyle=sortstyle,
            distkey=distkey,
            sortkey=sortkey,
            primary_keys=primary_keys,
        )
        _rs_copy(
            con=_con,
            table=created_table,
            schema=created_schema,
            manifest_path=manifest_path,
            iam_role=iam_role,
            num_files=len(paths),
        )
        if table != created_table:  # upsert
            _rs_upsert(con=_con, schema=schema, table=table, temp_table=created_table, primary_keys=primary_keys)
    s3.delete_objects(path=[manifest_path], use_threads=use_threads, boto3_session=session)
Beispiel #10
0
def merge_to_table(eng: sa.engine.Engine,
                   df: pd.DataFrame,
                   table_name: str,
                   compare_columns: List[str],
                   update_columns: List[str],
                   dtypes: dict,
                   temp_table_name: str = None,
                   dispose_eng=False):
    """
    Merges the dataframe into an existing table by creating a temp table with for the df and the merging it into the existing one.
    For rows with matching $compare columns we UPDATE the other values in $UPDATE_COLUMNS
    Args:
        eng: An engine object connecting the db
        df: The dataframe we want to insert to the DB
        table_name: The existing table's name
        compare_columns: The columns we want to compare existing rows with
        update_columns: The columns we want to update in case a matching row is found
        temp_table_name: optional, a name for the temp table for the DB
        dtypes: The data-types for each column in the DB
        dispose_eng: Whether to dispose of the engine after the read

    Returns:
        None
    """
    table_name = table_name.lower()  # fixes stuff for postgres

    if df.empty:
        return

    if dtypes is None:
        dtypes = {}
    if temp_table_name is None:
        temp_table_name = get_temp_table_name()
    if eng.dialect.name.lower() == "oracle" and (len(temp_table_name) > MAX_TABLE_NAME_ORACLE or \
                                                 len(table_name) > MAX_TABLE_NAME_ORACLE):
        raise Exception('table name is too long')

    if len(df) > 200_000:
        chunk_size = 100_000
        for i in tqdm(range(0, len(df), chunk_size),
                      desc=f"Merging into {table_name}",
                      unit="100_000 chunk"):
            df_chunk = df.iloc[i:min(len(df), i + chunk_size)]
            merge_to_table(eng,
                           df_chunk,
                           table_name,
                           compare_columns,
                           update_columns,
                           dtypes=dtypes)

    else:
        try:
            # logger = logging.getLogger()
            # logger.info(f"Writing {len(df)} rows to {table_name} table")
            if not eng.has_table(table_name):
                insert_into_table(eng, df, table_name, dtypes, compare_columns)
            else:
                if eng.dialect.name.lower() not in ("oracle", "postgresql"):
                    raise RuntimeError(
                        f"merge into does not work for {eng.dialect.name}")

                insert_into_table(eng, df, temp_table_name, dtypes,
                                  compare_columns)

                if eng.dialect.name.lower() == "oracle":
                    on_statment = "\nAND ".join(
                        [f"curr.{col} = tmp.{col}" for col in compare_columns])
                    set_statment = "\n,".join(
                        [f"curr.{col} = tmp.{col}" for col in update_columns])
                    all_columns = compare_columns + update_columns
                    all_columns_names = ",".join(all_columns)
                    all_columns_values = ",".join(
                        [f"tmp.{col}" for col in all_columns])
                    sql = f"""
                    merge into {table_name} curr
                    using (select {all_columns_names} from {temp_table_name}) tmp
                    on ({on_statment})
                    when matched then
                    update set {set_statment}
                    when not matched then
                    insert ({all_columns_names})
                    values ({all_columns_values})
                    """
                else:  # postgresql
                    set_statment = ",".join([
                        f"{col} = EXCLUDED.{col}" for col in update_columns
                    ])  # postgres syntax
                    all_columns = compare_columns + update_columns
                    all_columns_names = ",".join(all_columns)
                    from coord2vec.common.db.postgres import get_index_str_for_unique
                    on_statment = get_index_str_for_unique(
                        compare_columns, dtypes)

                    sql = f"""
                    INSERT INTO {table_name} ({all_columns_names})
                    SELECT {all_columns_names}
                    FROM {temp_table_name} tmp
                    
                    ON CONFLICT ({on_statment})
                    DO UPDATE SET {set_statment};
                    """  # can fail if no key is saved on the on_statement columns

                with eng.begin() as con:
                    con.execute(sql)
                    con.execute(f"drop table {temp_table_name}")
        finally:
            if eng.has_table(temp_table_name):
                with eng.begin() as con:
                    con.execute(f"drop table {temp_table_name}")

    if dispose_eng:
        eng.dispose()