Esempio n. 1
0
    def _read_parquet_schema(self, engine) -> Tuple[Dict[str, Any], ...]:
        if engine == "pyarrow":
            from pyarrow import parquet

            from pyathena.arrow.util import to_column_info

            if not self._unload_location:
                raise ProgrammingError("UnloadLocation is none or empty.")
            bucket, key = parse_output_location(self._unload_location)
            try:
                dataset = parquet.ParquetDataset(f"{bucket}/{key}",
                                                 filesystem=self._fs,
                                                 use_legacy_dataset=False)
                return to_column_info(dataset.schema)
            except Exception as e:
                _logger.exception(f"Failed to read schema {bucket}/{key}.")
                raise OperationalError(*e.args) from e
        elif engine == "fastparquet":
            from fastparquet import ParquetFile

            # TODO: https://github.com/python/mypy/issues/1153
            from pyathena.fastparquet.util import to_column_info  # type: ignore

            if not self._data_manifest:
                self._data_manifest = self._read_data_manifest()
            bucket, key = parse_output_location(self._data_manifest[0])
            try:
                file = ParquetFile(f"{bucket}/{key}", open_with=self._fs.open)
                return to_column_info(file.schema)
            except Exception as e:
                _logger.exception(f"Failed to read schema {bucket}/{key}.")
                raise OperationalError(*e.args) from e
        else:
            raise ProgrammingError(
                "Engine must be one of `pyarrow`, `fastparquet`.")
Esempio n. 2
0
    def test_parse_output_location(self):
        # valid
        actual = parse_output_location("s3://bucket/path/to")
        assert actual[0] == "bucket"
        assert actual[1] == "path/to"

        # invalid
        with pytest.raises(DataError):
            parse_output_location("http://foobar")
Esempio n. 3
0
    def test_parse_output_location(self):
        # valid
        actual = parse_output_location("s3://bucket/path/to")
        self.assertEqual(actual[0], "bucket")
        self.assertEqual(actual[1], "path/to")

        # invalid
        with self.assertRaises(DataError):
            parse_output_location("http://foobar")
Esempio n. 4
0
    def test_parse_output_location(self):
        # valid
        actual = parse_output_location('s3://bucket/path/to')
        self.assertEqual(actual[0], 'bucket')
        self.assertEqual(actual[1], 'path/to')

        # invalid
        with self.assertRaises(DataError):
            parse_output_location('http://foobar')
Esempio n. 5
0
 def _as_pandas(self):
     import pandas as pd
     if not self.output_location:
         raise ProgrammingError('OutputLocation is none or empty.')
     bucket, key = parse_output_location(self.output_location)
     try:
         response = retry_api_call(self._client.get_object,
                                   config=self._retry_config,
                                   logger=_logger,
                                   Bucket=bucket,
                                   Key=key)
     except Exception as e:
         _logger.exception('Failed to download csv.')
         raise_from(OperationalError(*e.args), e)
     else:
         length = response['ContentLength']
         if length:
             df = pd.read_csv(io.BytesIO(response['Body'].read()),
                              dtype=self.dtypes,
                              converters=self.converters,
                              parse_dates=self.parse_dates,
                              infer_datetime_format=True)
             df = self._trunc_date(df)
         else:  # Allow empty response
             df = pd.DataFrame()
         return df
Esempio n. 6
0
    def _read_csv(self) -> "Table":
        import pyarrow as pa
        from pyarrow import csv

        if not self.output_location:
            raise ProgrammingError("OutputLocation is none or empty.")
        if not self.output_location.endswith((".csv", ".txt")):
            return pa.Table.from_pydict(dict())
        length = self._get_content_length()
        if length and self.output_location.endswith(".txt"):
            description = self.description if self.description else []
            column_names = [d[0] for d in description]
            read_opts = csv.ReadOptions(
                skip_rows=0,
                column_names=column_names,
                block_size=self._block_size,
                use_threads=True,
            )
            parse_opts = csv.ParseOptions(
                delimiter="\t",
                quote_char=False,
                double_quote=False,
                escape_char=False,
            )
        elif length and self.output_location.endswith(".csv"):
            read_opts = csv.ReadOptions(skip_rows=0,
                                        block_size=self._block_size,
                                        use_threads=True)
            parse_opts = csv.ParseOptions(
                delimiter=",",
                quote_char='"',
                double_quote=True,
                escape_char=False,
            )
        else:
            return pa.Table.from_pydict(dict())

        bucket, key = parse_output_location(self.output_location)
        try:
            return csv.read_csv(
                self._fs.open_input_stream(f"{bucket}/{key}"),
                read_options=read_opts,
                parse_options=parse_opts,
                convert_options=csv.ConvertOptions(
                    quoted_strings_can_be_null=False,
                    timestamp_parsers=self.timestamp_parsers,
                    column_types=self.column_types,
                ),
            )
        except Exception as e:
            _logger.exception(f"Failed to read {bucket}/{key}.")
            raise OperationalError(*e.args) from e
Esempio n. 7
0
    def _as_pandas(self) -> "DataFrame":
        import pandas as pd

        if not self.output_location:
            raise ProgrammingError("OutputLocation is none or empty.")
        bucket, key = parse_output_location(self.output_location)
        try:
            response = retry_api_call(
                self._client.get_object,
                config=self._retry_config,
                logger=_logger,
                Bucket=bucket,
                Key=key,
            )
        except Exception as e:
            _logger.exception("Failed to download csv.")
            raise OperationalError(*e.args) from e
        else:
            length = response["ContentLength"]
            if length:
                if self.output_location.endswith(".txt"):
                    sep = "\t"
                    header = None
                    description = self.description if self.description else []
                    names: Optional[Any] = [d[0] for d in description]
                else:  # csv format
                    sep = ","
                    header = 0
                    names = None
                df = pd.read_csv(
                    response["Body"],
                    sep=sep,
                    header=header,
                    names=names,
                    dtype=self.dtypes,
                    converters=self.converters,
                    parse_dates=self.parse_dates,
                    infer_datetime_format=True,
                    skip_blank_lines=False,
                    keep_default_na=self._keep_default_na,
                    na_values=self._na_values,
                    quoting=self._quoting,
                    **self._kwargs,
                )
                df = self._trunc_date(df)
            else:  # Allow empty response
                df = pd.DataFrame()
            return df
Esempio n. 8
0
    def _read_parquet(self) -> "Table":
        import pyarrow as pa
        from pyarrow import parquet

        manifests = self._read_data_manifest()
        if not manifests:
            return pa.Table.from_pydict(dict())
        if not self._unload_location:
            self._unload_location = "/".join(
                manifests[0].split("/")[:-1]) + "/"

        bucket, key = parse_output_location(self._unload_location)
        try:
            dataset = parquet.ParquetDataset(f"{bucket}/{key}",
                                             filesystem=self._fs,
                                             use_legacy_dataset=False)
            return dataset.read(use_threads=True)
        except Exception as e:
            _logger.exception(f"Failed to read {bucket}/{key}.")
            raise OperationalError(*e.args) from e
Esempio n. 9
0
 def _as_pandas(self):
     import pandas as pd
     if not self.output_location:
         raise ProgrammingError('OutputLocation is none or empty.')
     bucket, key = parse_output_location(self.output_location)
     try:
         response = retry_api_call(self._client.get_object,
                                   config=self._retry_config,
                                   logger=_logger,
                                   Bucket=bucket,
                                   Key=key)
     except Exception as e:
         _logger.exception('Failed to download csv.')
         raise_from(OperationalError(*e.args), e)
     else:
         length = response['ContentLength']
         if length:
             if self.output_location.endswith('.txt'):
                 sep = '\t'
                 header = None
                 names = [d[0] for d in self.description]
             else:  # csv format
                 sep = ','
                 header = 0
                 names = None
             df = pd.read_csv(io.BytesIO(response['Body'].read()),
                              sep=sep,
                              header=header,
                              names=names,
                              dtype=self.dtypes,
                              converters=self.converters,
                              parse_dates=self.parse_dates,
                              infer_datetime_format=True,
                              skip_blank_lines=False)
             df = self._trunc_date(df)
         else:  # Allow empty response
             df = pd.DataFrame()
         return df
Esempio n. 10
0
def to_sql(
    df: "DataFrame",
    name: str,
    conn: "Connection",
    location: str,
    schema: str = "default",
    index: bool = False,
    index_label: Optional[str] = None,
    partitions: List[str] = None,
    chunksize: Optional[int] = None,
    if_exists: str = "fail",
    compression: str = None,
    flavor: str = "spark",
    type_mappings: Callable[["Series"], str] = to_sql_type_mappings,
    executor_class: Type[Union[ThreadPoolExecutor,
                               ProcessPoolExecutor]] = ThreadPoolExecutor,
    max_workers: int = (cpu_count() or 1) * 5,
) -> None:
    # TODO Supports orc, avro, json, csv or tsv format
    if if_exists not in ("fail", "replace", "append"):
        raise ValueError("`{0}` is not valid for if_exists".format(if_exists))
    if compression is not None and not AthenaCompression.is_valid(compression):
        raise ValueError(
            "`{0}` is not valid for compression".format(compression))
    if partitions is None:
        partitions = []

    bucket_name, key_prefix = parse_output_location(location)
    bucket = conn.session.resource("s3",
                                   region_name=conn.region_name,
                                   **conn._client_kwargs).Bucket(bucket_name)
    cursor = conn.cursor()

    table = cursor.execute("""
        SELECT table_name
        FROM information_schema.tables
        WHERE table_schema = '{schema}'
        AND table_name = '{table}'
        """.format(schema=schema, table=name)).fetchall()
    if if_exists == "fail":
        if table:
            raise OperationalError("Table `{0}.{1}` already exists.".format(
                schema, name))
    elif if_exists == "replace":
        if table:
            cursor.execute("""
                DROP TABLE {schema}.{table}
                """.format(schema=schema, table=name))
            objects = bucket.objects.filter(Prefix=key_prefix)
            if list(objects.limit(1)):
                objects.delete()

    if index:
        reset_index(df, index_label)
    with executor_class(max_workers=max_workers) as e:
        futures = []
        session_kwargs = deepcopy(conn._session_kwargs)
        session_kwargs.update({"profile_name": conn.profile_name})
        client_kwargs = deepcopy(conn._client_kwargs)
        client_kwargs.update({"region_name": conn.region_name})
        if partitions:
            for keys, group in df.groupby(by=partitions, observed=True):
                keys = keys if isinstance(keys, tuple) else (keys, )
                group = group.drop(partitions, axis=1)
                partition_prefix = "/".join([
                    "{0}={1}".format(key, val)
                    for key, val in zip(partitions, keys)
                ])
                for chunk in get_chunks(group, chunksize):
                    futures.append(
                        e.submit(
                            to_parquet,
                            chunk,
                            bucket_name,
                            "{0}{1}/".format(key_prefix, partition_prefix),
                            conn._retry_config,
                            session_kwargs,
                            client_kwargs,
                            compression,
                            flavor,
                        ))
        else:
            for chunk in get_chunks(df, chunksize):
                futures.append(
                    e.submit(
                        to_parquet,
                        chunk,
                        bucket_name,
                        key_prefix,
                        conn._retry_config,
                        session_kwargs,
                        client_kwargs,
                        compression,
                        flavor,
                    ))
        for future in concurrent.futures.as_completed(futures):
            result = future.result()
            _logger.info("to_parquet: {0}".format(result))

    ddl = generate_ddl(
        df=df,
        name=name,
        location=location,
        schema=schema,
        partitions=partitions,
        compression=compression,
        type_mappings=type_mappings,
    )
    _logger.info(ddl)
    cursor.execute(ddl)
    if partitions:
        repair = "MSCK REPAIR TABLE {0}.{1}".format(schema, name)
        _logger.info(repair)
        cursor.execute(repair)
Esempio n. 11
0
def to_sql(
    df: "DataFrame",
    name: str,
    conn: "Connection",
    location: str,
    schema: str = "default",
    index: bool = False,
    index_label: Optional[str] = None,
    partitions: List[str] = None,
    chunksize: Optional[int] = None,
    if_exists: str = "fail",
    compression: str = None,
    flavor: str = "spark",
    type_mappings: Callable[["Series"], str] = to_sql_type_mappings,
    executor_class: Type[Union[ThreadPoolExecutor,
                               ProcessPoolExecutor]] = ThreadPoolExecutor,
    max_workers: int = (cpu_count() or 1) * 5,
    repair_table=True,
) -> None:
    # TODO Supports orc, avro, json, csv or tsv format
    if if_exists not in ("fail", "replace", "append"):
        raise ValueError(f"`{if_exists}` is not valid for if_exists")
    if compression is not None and not AthenaCompression.is_valid(compression):
        raise ValueError(f"`{compression}` is not valid for compression")
    if partitions is None:
        partitions = []
    if not location.endswith("/"):
        location += "/"

    bucket_name, key_prefix = parse_output_location(location)
    bucket = conn.session.resource("s3",
                                   region_name=conn.region_name,
                                   **conn._client_kwargs).Bucket(bucket_name)
    cursor = conn.cursor()

    table = cursor.execute(
        textwrap.dedent(f"""
            SELECT table_name
            FROM information_schema.tables
            WHERE table_schema = '{schema}'
            AND table_name = '{name}'
            """)).fetchall()
    if if_exists == "fail":
        if table:
            raise OperationalError(f"Table `{schema}.{name}` already exists.")
    elif if_exists == "replace":
        if table:
            cursor.execute(
                textwrap.dedent(f"""
                    DROP TABLE {schema}.{name}
                    """))
            objects = bucket.objects.filter(Prefix=key_prefix)
            if list(objects.limit(1)):
                objects.delete()

    if index:
        reset_index(df, index_label)
    with executor_class(max_workers=max_workers) as e:
        futures = []
        session_kwargs = deepcopy(conn._session_kwargs)
        session_kwargs.update({"profile_name": conn.profile_name})
        client_kwargs = deepcopy(conn._client_kwargs)
        client_kwargs.update({"region_name": conn.region_name})
        partition_prefixes = []
        if partitions:
            for keys, group in df.groupby(by=partitions, observed=True):
                keys = keys if isinstance(keys, tuple) else (keys, )
                group = group.drop(partitions, axis=1)
                partition_prefix = "/".join(
                    [f"{key}={val}" for key, val in zip(partitions, keys)])
                partition_prefixes.append((
                    ", ".join([
                        f"`{key}` = '{val}'"
                        for key, val in zip(partitions, keys)
                    ]),
                    f"{location}{partition_prefix}/",
                ))
                for chunk in get_chunks(group, chunksize):
                    futures.append(
                        e.submit(
                            to_parquet,
                            chunk,
                            bucket_name,
                            f"{key_prefix}{partition_prefix}/",
                            conn._retry_config,
                            session_kwargs,
                            client_kwargs,
                            compression,
                            flavor,
                        ))
        else:
            for chunk in get_chunks(df, chunksize):
                futures.append(
                    e.submit(
                        to_parquet,
                        chunk,
                        bucket_name,
                        key_prefix,
                        conn._retry_config,
                        session_kwargs,
                        client_kwargs,
                        compression,
                        flavor,
                    ))
        for future in concurrent.futures.as_completed(futures):
            result = future.result()
            _logger.info(f"to_parquet: {result}")

    ddl = generate_ddl(
        df=df,
        name=name,
        location=location,
        schema=schema,
        partitions=partitions,
        compression=compression,
        type_mappings=type_mappings,
    )
    _logger.info(ddl)
    cursor.execute(ddl)
    if partitions and repair_table:
        for partition in partition_prefixes:
            add_partition = textwrap.dedent(f"""
                ALTER TABLE `{schema}`.`{name}`
                ADD IF NOT EXISTS PARTITION ({partition[0]}) LOCATION '{partition[1]}'
                """)
            _logger.info(add_partition)
            cursor.execute(add_partition)