예제 #1
0
 def test_to_sql_invalid_args(self, cursor):
     df = pd.DataFrame({"col_int": np.int32([1])})
     table_name = "to_sql_{0}".format(str(uuid.uuid4()).replace("-", ""))
     location = "{0}{1}/{2}/".format(ENV.s3_staging_dir, S3_PREFIX,
                                     table_name)
     # invalid if_exists
     with self.assertRaises(ValueError):
         to_sql(
             df,
             table_name,
             cursor._connection,
             location,
             schema=SCHEMA,
             if_exists="foobar",
             compression="snappy",
         )
     # invalid compression
     with self.assertRaises(ValueError):
         to_sql(
             df,
             table_name,
             cursor._connection,
             location,
             schema=SCHEMA,
             if_exists="fail",
             compression="foobar",
         )
예제 #2
0
def load_dataframe_into_test_athena_database_as_table(
    df: pd.DataFrame,
    table_name: str,
    connection,
    data_location_bucket: Optional[str] = None,
    data_location: Optional[str] = None,
) -> None:
    """

    Args:
        df: dataframe containing data.
        table_name: name of table to write.
        connection: connection to database.
        data_location_bucket: name of bucket where data is located.
        data_location: path to data from bucket without leading / e.g.
            "data/stuff/" in path "s3://my-bucket/data/stuff/"

    Returns:
        None
    """

    from pyathena.pandas.util import to_sql

    if not data_location_bucket:
        data_location_bucket = os.getenv("ATHENA_DATA_BUCKET")
    if not data_location:
        data_location = "data/ten_trips_from_each_month/"
    location: str = f"s3://{data_location_bucket}/{data_location}"
    to_sql(
        df=df,
        name=table_name,
        conn=connection,
        location=location,
        if_exists="replace",
    )
예제 #3
0
 def test_to_sql_with_multiple_partitions(self, cursor):
     df = pd.DataFrame({
         "col_int":
         np.int32([i for i in range(10)]),
         "col_bigint":
         np.int64([12345 for _ in range(10)]),
         "col_string": ["a" for _ in range(5)] + ["b" for _ in range(5)],
     })
     table_name = "to_sql_{0}".format(str(uuid.uuid4()).replace("-", ""))
     location = "{0}{1}/{2}/".format(ENV.s3_staging_dir, S3_PREFIX,
                                     table_name)
     to_sql(
         df,
         table_name,
         cursor._connection,
         location,
         schema=SCHEMA,
         partitions=["col_int", "col_string"],
         if_exists="fail",
         compression="snappy",
     )
     cursor.execute("SHOW PARTITIONS {0}".format(table_name))
     self.assertEqual(
         sorted(cursor.fetchall()),
         [("col_int={0}/col_string=a".format(i), )
          for i in range(5)] + [("col_int={0}/col_string=b".format(i), )
                                for i in range(5, 10)],
     )
     cursor.execute("SELECT COUNT(*) FROM {0}".format(table_name))
     self.assertEqual(cursor.fetchall(), [(10, )])
예제 #4
0
def load_data():
    '''Loads data into AWS Athena table'''
    with open('/home/tom/AWS_CREDS.txt', 'r') as f:
        data = f.read()  #reads in AWS credentials

    aws_access_key_id = data.split(' ')[0].strip()
    aws_secret_access_key = data.split(' ')[1].strip()
    con = connect(s3_staging_dir="s3://testingbucket1003/",
                  region_name="eu-west-2",
                  aws_secret_access_key=aws_secret_access_key,
                  aws_access_key_id=aws_access_key_id)
    df = pd.read_parquet('/home/tom/Documents/csv_files/house_parquet.parquet')
    to_sql(df, 'house_data', con,
           "s3://testingbucket1003/athenadata/")  #loads data into athena table
예제 #5
0
def test_to_sql_with_index(cursor):
    df = pd.DataFrame({"col_int": np.int32([1])})
    table_name = "to_sql_{0}".format(str(uuid.uuid4()).replace("-", ""))
    location = "{0}{1}/{2}/".format(ENV.s3_staging_dir, ENV.schema, table_name)
    to_sql(
        df,
        table_name,
        cursor._connection,
        location,
        schema=ENV.schema,
        if_exists="fail",
        compression="snappy",
        index=True,
        index_label="col_index",
    )
    cursor.execute("SELECT * FROM {0}".format(table_name))
    assert cursor.fetchall() == [(0, 1)]
    assert [(d[0], d[1]) for d in cursor.description] == [
        ("col_index", "bigint"),
        ("col_int", "integer"),
    ]
예제 #6
0
def test_to_sql_with_partitions(cursor):
    df = pd.DataFrame(
        {
            "col_int": np.int32([i for i in range(10)]),
            "col_bigint": np.int64([12345 for _ in range(10)]),
            "col_string": ["a" for _ in range(10)],
        }
    )
    table_name = "to_sql_{0}".format(str(uuid.uuid4()).replace("-", ""))
    location = "{0}{1}/{2}/".format(ENV.s3_staging_dir, ENV.schema, table_name)
    to_sql(
        df,
        table_name,
        cursor._connection,
        location,
        schema=ENV.schema,
        partitions=["col_int"],
        if_exists="fail",
        compression="snappy",
    )
    cursor.execute("SHOW PARTITIONS {0}".format(table_name))
    assert sorted(cursor.fetchall()) == [("col_int={0}".format(i),) for i in range(10)]
    cursor.execute("SELECT COUNT(*) FROM {0}".format(table_name))
    assert cursor.fetchall() == [(10,)]
예제 #7
0
    def test_to_sql(self, cursor):
        df = pd.DataFrame({
            "col_int": np.int32([1]),
            "col_bigint": np.int64([12345]),
            "col_float": np.float32([1.0]),
            "col_double": np.float64([1.2345]),
            "col_string": ["a"],
            "col_boolean": np.bool_([True]),
            "col_timestamp": [datetime(2020, 1, 1, 0, 0, 0)],
            "col_date": [date(2020, 12, 31)],
            "col_binary": "foobar".encode(),
        })
        # Explicitly specify column order
        df = df[[
            "col_int",
            "col_bigint",
            "col_float",
            "col_double",
            "col_string",
            "col_boolean",
            "col_timestamp",
            "col_date",
            "col_binary",
        ]]
        table_name = "to_sql_{0}".format(str(uuid.uuid4()).replace("-", ""))
        location = "{0}{1}/{2}/".format(ENV.s3_staging_dir, S3_PREFIX,
                                        table_name)
        to_sql(
            df,
            table_name,
            cursor._connection,
            location,
            schema=SCHEMA,
            if_exists="fail",
            compression="snappy",
        )
        # table already exists
        with self.assertRaises(OperationalError):
            to_sql(
                df,
                table_name,
                cursor._connection,
                location,
                schema=SCHEMA,
                if_exists="fail",
                compression="snappy",
            )
        # replace
        to_sql(
            df,
            table_name,
            cursor._connection,
            location,
            schema=SCHEMA,
            if_exists="replace",
            compression="snappy",
        )

        cursor.execute("SELECT * FROM {0}".format(table_name))
        self.assertEqual(
            cursor.fetchall(),
            [(
                1,
                12345,
                1.0,
                1.2345,
                "a",
                True,
                datetime(2020, 1, 1, 0, 0, 0),
                date(2020, 12, 31),
                "foobar".encode(),
            )],
        )
        self.assertEqual(
            [(d[0], d[1]) for d in cursor.description],
            [
                ("col_int", "integer"),
                ("col_bigint", "bigint"),
                ("col_float", "float"),
                ("col_double", "double"),
                ("col_string", "varchar"),
                ("col_boolean", "boolean"),
                ("col_timestamp", "timestamp"),
                ("col_date", "date"),
                ("col_binary", "varbinary"),
            ],
        )

        # append
        to_sql(
            df,
            table_name,
            cursor._connection,
            location,
            schema=SCHEMA,
            if_exists="append",
            compression="snappy",
        )
        cursor.execute("SELECT * FROM {0}".format(table_name))
        self.assertEqual(
            cursor.fetchall(),
            [
                (
                    1,
                    12345,
                    1.0,
                    1.2345,
                    "a",
                    True,
                    datetime(2020, 1, 1, 0, 0, 0),
                    date(2020, 12, 31),
                    "foobar".encode(),
                ),
                (
                    1,
                    12345,
                    1.0,
                    1.2345,
                    "a",
                    True,
                    datetime(2020, 1, 1, 0, 0, 0),
                    date(2020, 12, 31),
                    "foobar".encode(),
                ),
            ],
        )