Esempio n. 1
0
    def _read_parquet_schema(self, engine) -> Tuple[Dict[str, Any], ...]:
        if engine == "pyarrow":
            from pyarrow import parquet

            from pyathena.arrow.util import to_column_info

            if not self._unload_location:
                raise ProgrammingError("UnloadLocation is none or empty.")
            bucket, key = parse_output_location(self._unload_location)
            try:
                dataset = parquet.ParquetDataset(f"{bucket}/{key}",
                                                 filesystem=self._fs,
                                                 use_legacy_dataset=False)
                return to_column_info(dataset.schema)
            except Exception as e:
                _logger.exception(f"Failed to read schema {bucket}/{key}.")
                raise OperationalError(*e.args) from e
        elif engine == "fastparquet":
            from fastparquet import ParquetFile

            # TODO: https://github.com/python/mypy/issues/1153
            from pyathena.fastparquet.util import to_column_info  # type: ignore

            if not self._data_manifest:
                self._data_manifest = self._read_data_manifest()
            bucket, key = parse_output_location(self._data_manifest[0])
            try:
                file = ParquetFile(f"{bucket}/{key}", open_with=self._fs.open)
                return to_column_info(file.schema)
            except Exception as e:
                _logger.exception(f"Failed to read schema {bucket}/{key}.")
                raise OperationalError(*e.args) from e
        else:
            raise ProgrammingError(
                "Engine must be one of `pyarrow`, `fastparquet`.")
Esempio n. 2
0
    def _read_csv(self) -> "Table":
        import pyarrow as pa
        from pyarrow import csv

        if not self.output_location:
            raise ProgrammingError("OutputLocation is none or empty.")
        if not self.output_location.endswith((".csv", ".txt")):
            return pa.Table.from_pydict(dict())
        length = self._get_content_length()
        if length and self.output_location.endswith(".txt"):
            description = self.description if self.description else []
            column_names = [d[0] for d in description]
            read_opts = csv.ReadOptions(
                skip_rows=0,
                column_names=column_names,
                block_size=self._block_size,
                use_threads=True,
            )
            parse_opts = csv.ParseOptions(
                delimiter="\t",
                quote_char=False,
                double_quote=False,
                escape_char=False,
            )
        elif length and self.output_location.endswith(".csv"):
            read_opts = csv.ReadOptions(skip_rows=0,
                                        block_size=self._block_size,
                                        use_threads=True)
            parse_opts = csv.ParseOptions(
                delimiter=",",
                quote_char='"',
                double_quote=True,
                escape_char=False,
            )
        else:
            return pa.Table.from_pydict(dict())

        bucket, key = parse_output_location(self.output_location)
        try:
            return csv.read_csv(
                self._fs.open_input_stream(f"{bucket}/{key}"),
                read_options=read_opts,
                parse_options=parse_opts,
                convert_options=csv.ConvertOptions(
                    quoted_strings_can_be_null=False,
                    timestamp_parsers=self.timestamp_parsers,
                    column_types=self.column_types,
                ),
            )
        except Exception as e:
            _logger.exception(f"Failed to read {bucket}/{key}.")
            raise OperationalError(*e.args) from e
Esempio n. 3
0
    def _read_csv(self) -> "DataFrame":
        import pandas as pd

        if not self.output_location:
            raise ProgrammingError("OutputLocation is none or empty.")
        if not self.output_location.endswith((".csv", ".txt")):
            return pd.DataFrame()
        length = self._get_content_length()
        if length and self.output_location.endswith(".txt"):
            sep = "\t"
            header = None
            description = self.description if self.description else []
            names = [d[0] for d in description]
        elif length and self.output_location.endswith(".csv"):
            sep = ","
            header = 0
            names = None
        else:
            return pd.DataFrame()
        try:
            # TODO chunksize
            df = pd.read_csv(
                self.output_location,
                sep=sep,
                header=header,
                names=names,
                dtype=self.dtypes,
                converters=self.converters,
                parse_dates=self.parse_dates,
                infer_datetime_format=True,
                skip_blank_lines=False,
                keep_default_na=self._keep_default_na,
                na_values=self._na_values,
                quoting=self._quoting,
                storage_options={
                    "profile": self.connection.profile_name,
                    "client_kwargs": {
                        "region_name": self.connection.region_name,
                        **self.connection._client_kwargs,
                    },
                },
                **self._kwargs,
            )
            return self._trunc_date(df)
        except Exception as e:
            _logger.exception(f"Failed to read {self.output_location}.")
            raise OperationalError(*e.args) from e
Esempio n. 4
0
    def _read_parquet(self) -> "Table":
        import pyarrow as pa
        from pyarrow import parquet

        manifests = self._read_data_manifest()
        if not manifests:
            return pa.Table.from_pydict(dict())
        if not self._unload_location:
            self._unload_location = "/".join(
                manifests[0].split("/")[:-1]) + "/"

        bucket, key = parse_output_location(self._unload_location)
        try:
            dataset = parquet.ParquetDataset(f"{bucket}/{key}",
                                             filesystem=self._fs,
                                             use_legacy_dataset=False)
            return dataset.read(use_threads=True)
        except Exception as e:
            _logger.exception(f"Failed to read {bucket}/{key}.")
            raise OperationalError(*e.args) from e
Esempio n. 5
0
    def _read_parquet(self, engine) -> "DataFrame":
        import pandas as pd

        self._data_manifest = self._read_data_manifest()
        if not self._data_manifest:
            return pd.DataFrame()
        if not self._unload_location:
            self._unload_location = (
                "/".join(self._data_manifest[0].split("/")[:-1]) + "/")

        if engine == "pyarrow":
            unload_location = self._unload_location
            kwargs = {
                "use_threads": True,
                "use_legacy_dataset": False,
            }
        elif engine == "fastparquet":
            unload_location = f"{self._unload_location}*"
            kwargs = {}
        else:
            raise ProgrammingError(
                "Engine must be one of `pyarrow`, `fastparquet`.")
        kwargs.update(self._kwargs)

        try:
            return pd.read_parquet(
                unload_location,
                engine=self._engine,
                storage_options={
                    "profile": self.connection.profile_name,
                    "client_kwargs": {
                        "region_name": self.connection.region_name,
                        **self.connection._client_kwargs,
                    },
                },
                use_nullable_dtypes=False,
                **kwargs,
            )
        except Exception as e:
            _logger.exception(f"Failed to read {self.output_location}.")
            raise OperationalError(*e.args) from e
Esempio n. 6
0
def to_sql(df,
           name,
           conn,
           location,
           schema='default',
           index=False,
           index_label=None,
           chunksize=None,
           if_exists='fail',
           compression=None,
           flavor='spark',
           type_mappings=to_sql_type_mappings):
    # TODO Supports orc, avro, json, csv or tsv format
    # TODO Supports partitioning
    if if_exists not in ('fail', 'replace', 'append'):
        raise ValueError('`{0}` is not valid for if_exists'.format(if_exists))
    if compression is not None and not AthenaCompression.is_valid(compression):
        raise ValueError(
            '`{0}` is not valid for compression'.format(compression))

    import pyarrow as pa
    import pyarrow.parquet as pq
    bucket_name, key_prefix = parse_output_location(location)
    bucket = conn.session.resource('s3',
                                   region_name=conn.region_name,
                                   **conn._client_kwargs).Bucket(bucket_name)
    cursor = conn.cursor()
    retry_config = conn.retry_config

    table = cursor.execute("""
    SELECT table_name
    FROM information_schema.tables
    WHERE table_schema = '{schema}'
    AND table_name = '{table}'
    """.format(schema=schema, table=name)).fetchall()
    if if_exists == 'fail':
        if table:
            raise OperationalError('Table `{0}.{1}` already exists.'.format(
                schema, name))
    elif if_exists == 'replace':
        if table:
            cursor.execute("""
            DROP TABLE {schema}.{table}
            """.format(schema=schema, table=name))
            objects = bucket.objects.filter(Prefix=key_prefix)
            if list(objects.limit(1)):
                objects.delete()

    if index:
        reset_index(df, index_label)
    for chunk in get_chunks(df, chunksize):
        table = pa.Table.from_pandas(chunk)
        buf = pa.BufferOutputStream()
        pq.write_table(table, buf, compression=compression, flavor=flavor)
        retry_api_call(bucket.put_object,
                       config=retry_config,
                       Body=buf.getvalue().to_pybytes(),
                       Key=key_prefix + str(uuid.uuid4()))

    ddl = generate_ddl(df=df,
                       name=name,
                       location=location,
                       schema=schema,
                       compression=compression,
                       type_mappings=type_mappings)
    cursor.execute(ddl)
Esempio n. 7
0
def to_sql(
    df,
    name,
    conn,
    location,
    schema="default",
    index=False,
    index_label=None,
    partitions=None,
    chunksize=None,
    if_exists="fail",
    compression=None,
    flavor="spark",
    type_mappings=to_sql_type_mappings,
    executor_class=ThreadPoolExecutor,
    max_workers=(cpu_count() or 1) * 5,
):
    # TODO Supports orc, avro, json, csv or tsv format
    if if_exists not in ("fail", "replace", "append"):
        raise ValueError("`{0}` is not valid for if_exists".format(if_exists))
    if compression is not None and not AthenaCompression.is_valid(compression):
        raise ValueError("`{0}` is not valid for compression".format(compression))
    if partitions is None:
        partitions = []

    bucket_name, key_prefix = parse_output_location(location)
    bucket = conn.session.resource(
        "s3", region_name=conn.region_name, **conn._client_kwargs
    ).Bucket(bucket_name)
    cursor = conn.cursor()

    table = cursor.execute(
        """
    SELECT table_name
    FROM information_schema.tables
    WHERE table_schema = '{schema}'
    AND table_name = '{table}'
    """.format(
            schema=schema, table=name
        )
    ).fetchall()
    if if_exists == "fail":
        if table:
            raise OperationalError(
                "Table `{0}.{1}` already exists.".format(schema, name)
            )
    elif if_exists == "replace":
        if table:
            cursor.execute(
                """
            DROP TABLE {schema}.{table}
            """.format(
                    schema=schema, table=name
                )
            )
            objects = bucket.objects.filter(Prefix=key_prefix)
            if list(objects.limit(1)):
                objects.delete()

    if index:
        reset_index(df, index_label)
    with executor_class(max_workers=max_workers) as e:
        futures = []
        session_kwargs = deepcopy(conn._session_kwargs)
        session_kwargs.update({"profile_name": conn.profile_name})
        client_kwargs = deepcopy(conn._client_kwargs)
        client_kwargs.update({"region_name": conn.region_name})
        if partitions:
            for keys, group in df.groupby(by=partitions, observed=True):
                keys = keys if isinstance(keys, tuple) else (keys,)
                group = group.drop(partitions, axis=1)
                partition_prefix = "/".join(
                    ["{0}={1}".format(key, val) for key, val in zip(partitions, keys)]
                )
                for chunk in get_chunks(group, chunksize):
                    futures.append(
                        e.submit(
                            to_parquet,
                            chunk,
                            bucket_name,
                            "{0}{1}/".format(key_prefix, partition_prefix),
                            conn._retry_config,
                            session_kwargs,
                            client_kwargs,
                            compression,
                            flavor,
                        )
                    )
        else:
            for chunk in get_chunks(df, chunksize):
                futures.append(
                    e.submit(
                        to_parquet,
                        chunk,
                        bucket_name,
                        key_prefix,
                        conn._retry_config,
                        session_kwargs,
                        client_kwargs,
                        compression,
                        flavor,
                    )
                )
        for future in concurrent.futures.as_completed(futures):
            result = future.result()
            _logger.info("to_parquet: {0}".format(result))

    ddl = generate_ddl(
        df=df,
        name=name,
        location=location,
        schema=schema,
        partitions=partitions,
        compression=compression,
        type_mappings=type_mappings,
    )
    _logger.info(ddl)
    cursor.execute(ddl)
    if partitions:
        repair = "MSCK REPAIR TABLE {0}.{1}".format(schema, name)
        _logger.info(repair)
        cursor.execute(repair)
Esempio n. 8
0
def to_sql(
    df: "DataFrame",
    name: str,
    conn: "Connection",
    location: str,
    schema: str = "default",
    index: bool = False,
    index_label: Optional[str] = None,
    partitions: List[str] = None,
    chunksize: Optional[int] = None,
    if_exists: str = "fail",
    compression: str = None,
    flavor: str = "spark",
    type_mappings: Callable[["Series"], str] = to_sql_type_mappings,
    executor_class: Type[Union[ThreadPoolExecutor,
                               ProcessPoolExecutor]] = ThreadPoolExecutor,
    max_workers: int = (cpu_count() or 1) * 5,
    repair_table=True,
) -> None:
    # TODO Supports orc, avro, json, csv or tsv format
    if if_exists not in ("fail", "replace", "append"):
        raise ValueError(f"`{if_exists}` is not valid for if_exists")
    if compression is not None and not AthenaCompression.is_valid(compression):
        raise ValueError(f"`{compression}` is not valid for compression")
    if partitions is None:
        partitions = []
    if not location.endswith("/"):
        location += "/"

    bucket_name, key_prefix = parse_output_location(location)
    bucket = conn.session.resource("s3",
                                   region_name=conn.region_name,
                                   **conn._client_kwargs).Bucket(bucket_name)
    cursor = conn.cursor()

    table = cursor.execute(
        textwrap.dedent(f"""
            SELECT table_name
            FROM information_schema.tables
            WHERE table_schema = '{schema}'
            AND table_name = '{name}'
            """)).fetchall()
    if if_exists == "fail":
        if table:
            raise OperationalError(f"Table `{schema}.{name}` already exists.")
    elif if_exists == "replace":
        if table:
            cursor.execute(
                textwrap.dedent(f"""
                    DROP TABLE {schema}.{name}
                    """))
            objects = bucket.objects.filter(Prefix=key_prefix)
            if list(objects.limit(1)):
                objects.delete()

    if index:
        reset_index(df, index_label)
    with executor_class(max_workers=max_workers) as e:
        futures = []
        session_kwargs = deepcopy(conn._session_kwargs)
        session_kwargs.update({"profile_name": conn.profile_name})
        client_kwargs = deepcopy(conn._client_kwargs)
        client_kwargs.update({"region_name": conn.region_name})
        partition_prefixes = []
        if partitions:
            for keys, group in df.groupby(by=partitions, observed=True):
                keys = keys if isinstance(keys, tuple) else (keys, )
                group = group.drop(partitions, axis=1)
                partition_prefix = "/".join(
                    [f"{key}={val}" for key, val in zip(partitions, keys)])
                partition_prefixes.append((
                    ", ".join([
                        f"`{key}` = '{val}'"
                        for key, val in zip(partitions, keys)
                    ]),
                    f"{location}{partition_prefix}/",
                ))
                for chunk in get_chunks(group, chunksize):
                    futures.append(
                        e.submit(
                            to_parquet,
                            chunk,
                            bucket_name,
                            f"{key_prefix}{partition_prefix}/",
                            conn._retry_config,
                            session_kwargs,
                            client_kwargs,
                            compression,
                            flavor,
                        ))
        else:
            for chunk in get_chunks(df, chunksize):
                futures.append(
                    e.submit(
                        to_parquet,
                        chunk,
                        bucket_name,
                        key_prefix,
                        conn._retry_config,
                        session_kwargs,
                        client_kwargs,
                        compression,
                        flavor,
                    ))
        for future in concurrent.futures.as_completed(futures):
            result = future.result()
            _logger.info(f"to_parquet: {result}")

    ddl = generate_ddl(
        df=df,
        name=name,
        location=location,
        schema=schema,
        partitions=partitions,
        compression=compression,
        type_mappings=type_mappings,
    )
    _logger.info(ddl)
    cursor.execute(ddl)
    if partitions and repair_table:
        for partition in partition_prefixes:
            add_partition = textwrap.dedent(f"""
                ALTER TABLE `{schema}`.`{name}`
                ADD IF NOT EXISTS PARTITION ({partition[0]}) LOCATION '{partition[1]}'
                """)
            _logger.info(add_partition)
            cursor.execute(add_partition)