コード例 #1
0
 def query(self):
     """Queries mysql and returns a cursor to the results."""
     mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
     conn = mysql.get_conn()
     cursor = conn.cursor()
     if self.ensure_utc:
         # Ensure TIMESTAMP results are in UTC
         tz_query = "SET time_zone = '+00:00'"
         self.log.info('Executing: %s', tz_query)
         cursor.execute(tz_query)
     self.log.info('Executing: %s', self.sql)
     cursor.execute(self.sql)
     return cursor
コード例 #2
0
ファイル: presto_to_mysql.py プロジェクト: yqian1991/airflow
    def execute(self, context):
        presto = PrestoHook(presto_conn_id=self.presto_conn_id)
        self.log.info("Extracting data from Presto: %s", self.sql)
        results = presto.get_records(self.sql)

        mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
        if self.mysql_preoperator:
            self.log.info("Running MySQL preoperator")
            self.log.info(self.mysql_preoperator)
            mysql.run(self.mysql_preoperator)

        self.log.info("Inserting rows into MySQL")
        mysql.insert_rows(table=self.mysql_table, rows=results)
コード例 #3
0
ファイル: trino_to_mysql.py プロジェクト: zarrarrana/airflow
    def execute(self, context: Dict) -> None:
        trino = TrinoHook(trino_conn_id=self.trino_conn_id)
        self.log.info("Extracting data from Trino: %s", self.sql)
        results = trino.get_records(self.sql)

        mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
        if self.mysql_preoperator:
            self.log.info("Running MySQL preoperator")
            self.log.info(self.mysql_preoperator)
            mysql.run(self.mysql_preoperator)

        self.log.info("Inserting rows into MySQL")
        mysql.insert_rows(table=self.mysql_table, rows=results)
コード例 #4
0
def load_review_data(data):
    mysql_hook = MySqlHook(mysql_conn_id='sky')
    engine = mysql_hook.get_sqlalchemy_engine(
        engine_kwargs={'connect_args': {
            'charset': 'utf8mb4'
        }})
    connection = engine.connect()

    for review_values in data.itertuples(index=False, name=None):
        replace_into(connection, review_values)

    connection.close()
    engine.dispose()
コード例 #5
0
 def get_database_hook(
         self, connection: Connection) -> Union[PostgresHook, MySqlHook]:
     """
     Retrieve database hook. This is the actual Postgres or MySQL database hook
     that uses proxy or connects directly to the Google Cloud SQL database.
     """
     if self.database_type == 'postgres':
         self.db_hook = PostgresHook(connection=connection,
                                     schema=self.database)
     else:
         self.db_hook = MySqlHook(connection=connection,
                                  schema=self.database)
     return self.db_hook
コード例 #6
0
ファイル: vertica_to_mysql.py プロジェクト: vipadm/airflow
    def execute(self, context: 'Context'):
        vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id)
        mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)

        if self.bulk_load:
            self._bulk_load_transfer(mysql, vertica)
        else:
            self._non_bulk_load_transfer(mysql, vertica)

        if self.mysql_postoperator:
            self.log.info("Running MySQL postoperator...")
            mysql.run(self.mysql_postoperator)

        self.log.info("Done")
コード例 #7
0
ファイル: mysql_to_s3.py プロジェクト: werbolis/airflow
    def execute(self, context) -> None:
        mysql_hook = MySqlHook(mysql_conn_id=self.mysql_conn_id)
        s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
        data_df = mysql_hook.get_pandas_df(self.query)
        self.log.info("Data from MySQL obtained")

        self._fix_int_dtypes(data_df)
        with NamedTemporaryFile(mode='r+', suffix='.csv') as tmp_csv:
            data_df.to_csv(tmp_csv.name, **self.pd_csv_kwargs)
            s3_conn.load_file(filename=tmp_csv.name, key=self.s3_key, bucket_name=self.s3_bucket)

        if s3_conn.check_for_key(self.s3_key, bucket_name=self.s3_bucket):
            file_location = os.path.join(self.s3_bucket, self.s3_key)
            self.log.info("File saved correctly in %s", file_location)
コード例 #8
0
ファイル: test_mysql.py プロジェクト: kangyunseok/airflow
    def setUp(self):
        super().setUp()

        self.connection = Connection(
            login='******',
            password='******',
            host='host',
            schema='schema',
            extra='{"client": "mysql-connector-python"}',
        )

        self.db_hook = MySqlHook()
        self.db_hook.get_connection = mock.Mock()
        self.db_hook.get_connection.return_value = self.connection
コード例 #9
0
def mysql_to_pq(source_transform,
                name_of_dataset='project_four_airflow',
                by_row_batch=1000):
    '''
    extract mysql database and save into local pq ``tmp/sales-date.pq``. this function take the last rows of bq dataset and compared againts current
    mysql database to avoid duplication, only extract load new data from mysql to bq. if dataset not exist it will create dataset using name given
    
    Args:

        1. source_transform = 'path/local/file.pq'

        2. by_row_batch = number of row you want to extract ``int``

    return: 
        ``str`` of local pq file path
    '''
    client = BigQueryHook(gcp_conn_id='google_cloud_default').get_client()
    row_id = client.query(
        'select id from project_four_airflow.sales order by id desc limit 1')
    try:
        for i in row_id:
            last_row_id = i[0]
            print(i[0])
    except GoogleAPIError:
        row_id.error_result['reason'] == 'notFound'
        last_row_id = 0
        print('no dataset.table')
        client.create_dataset(name_of_dataset)
        print('new dataset, {} created'.format(name_of_dataset))
    conn = MySqlHook(mysql_conn_id='mysql_localhost').get_conn()
    cur = conn.cursor()
    cur.execute('use sales_records_airflow')
    cur.execute('select * from sales where id>={} and id<={}'.format(
        last_row_id + 1, last_row_id + by_row_batch))

    list_row = cur.fetchall()
    rows_of_extracted_mysql = []
    for i in list_row:
        rows_of_extracted_mysql.append(list(i))
    print('extracting from mysql')
    df = pd.DataFrame(rows_of_extracted_mysql,
                      columns=[
                          'id', 'region', 'country', 'item_type',
                          'sales_channel', 'Order Priority', 'order_date',
                          'order_id', 'ship_date', 'units_sold', 'unit_price',
                          'unit_cost', 'total_revenue', 'total_cost',
                          'total_profit'
                      ])
    df.to_parquet(source_transform)
    print('task complete check,', source_transform)
コード例 #10
0
ファイル: test_mysql.py プロジェクト: kangyunseok/airflow
    def setUp(self):
        super().setUp()

        self.connection = Connection(
            conn_type='mysql',
            login='******',
            password='******',
            host='host',
            schema='schema',
        )

        self.db_hook = MySqlHook()
        self.db_hook.get_connection = mock.Mock()
        self.db_hook.get_connection.return_value = self.connection
コード例 #11
0
 def test_mysql_hook_test_bulk_dump(self, client):
     with MySqlContext(client):
         hook = MySqlHook('airflow_db')
         priv = hook.get_first("SELECT @@global.secure_file_priv")
         # Use random names to allow re-running
         if priv and priv[0]:
             # Confirm that no error occurs
             hook.bulk_dump(
                 "INFORMATION_SCHEMA.TABLES",
                 os.path.join(priv[0], "TABLES_{}-{}".format(client, uuid.uuid1())),
             )
         elif priv == ("",):
             hook.bulk_dump("INFORMATION_SCHEMA.TABLES", "TABLES_{}_{}".format(client, uuid.uuid1()))
         else:
             self.skipTest("Skip test_mysql_hook_test_bulk_load " "since file output is not permitted")
コード例 #12
0
    def test_mysql_to_hive_verify_csv_special_char(self):
        mysql_table = 'test_mysql_to_hive'
        hive_table = 'test_mysql_to_hive'

        from airflow.providers.mysql.hooks.mysql import MySqlHook
        hook = MySqlHook()

        try:
            db_record = (
                'c0',
                '["true"]'
            )
            with hook.get_conn() as conn:
                conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
                conn.execute("""
                    CREATE TABLE {} (
                        c0 VARCHAR(25),
                        c1 VARCHAR(25)
                    )
                """.format(mysql_table))
                conn.execute("""
                    INSERT INTO {} VALUES (
                        '{}', '{}'
                    )
                """.format(mysql_table, *db_record))

            from airflow.operators.mysql_to_hive import MySqlToHiveTransfer
            import unicodecsv as csv
            op = MySqlToHiveTransfer(
                task_id='test_m2h',
                hive_cli_conn_id='hive_cli_default',
                sql="SELECT * FROM {}".format(mysql_table),
                hive_table=hive_table,
                recreate=True,
                delimiter=",",
                quoting=csv.QUOTE_NONE,
                quotechar='',
                escapechar='@',
                dag=self.dag)
            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

            from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook
            hive_hook = HiveServer2Hook()
            result = hive_hook.get_records("SELECT * FROM {}".format(hive_table))
            self.assertEqual(result[0], db_record)
        finally:
            with hook.get_conn() as conn:
                conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
コード例 #13
0
    def test_mysql_hook_test_bulk_dump_mock(self, mock_get_conn):
        mock_execute = mock.MagicMock()
        mock_get_conn.return_value.cursor.return_value.execute = mock_execute

        hook = MySqlHook('airflow_db')
        table = "INFORMATION_SCHEMA.TABLES"
        tmp_file = "/path/to/output/file"
        hook.bulk_dump(table, tmp_file)

        from tests.test_utils.asserts import assert_equal_ignore_multiple_spaces
        assert mock_execute.call_count == 1
        query = """
            SELECT * INTO OUTFILE '{tmp_file}'
            FROM {table}
        """.format(tmp_file=tmp_file, table=table)
        assert_equal_ignore_multiple_spaces(self, mock_execute.call_args[0][0], query)
コード例 #14
0
    def test_mysql_to_hive_type_conversion(self, mock_load_file):
        mysql_table = 'test_mysql_to_hive'

        hook = MySqlHook()

        try:
            with hook.get_conn() as conn:
                conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
                conn.execute(
                    """
                    CREATE TABLE {} (
                        c0 TINYINT,
                        c1 SMALLINT,
                        c2 MEDIUMINT,
                        c3 INT,
                        c4 BIGINT,
                        c5 TIMESTAMP
                    )
                """.format(
                        mysql_table
                    )
                )

            op = MySqlToHiveOperator(
                task_id='test_m2h',
                hive_cli_conn_id='hive_cli_default',
                sql="SELECT * FROM {}".format(mysql_table),
                hive_table='test_mysql_to_hive',
                dag=self.dag,
            )
            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

            assert mock_load_file.call_count == 1
            ordered_dict = OrderedDict()
            ordered_dict["c0"] = "SMALLINT"
            ordered_dict["c1"] = "INT"
            ordered_dict["c2"] = "INT"
            ordered_dict["c3"] = "BIGINT"
            ordered_dict["c4"] = "DECIMAL(38,0)"
            ordered_dict["c5"] = "TIMESTAMP"
            self.assertEqual(mock_load_file.call_args[1]["field_dict"], ordered_dict)
        finally:
            with hook.get_conn() as conn:
                conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
コード例 #15
0
    def test_mysql_hook_test_bulk_load(self):
        records = ("foo", "bar", "baz")

        import tempfile
        with tempfile.NamedTemporaryFile() as f:
            f.write("\n".join(records).encode('utf8'))
            f.flush()

            hook = MySqlHook('airflow_db')
            with hook.get_conn() as conn:
                conn.execute("""
                    CREATE TABLE IF NOT EXISTS test_airflow (
                        dummy VARCHAR(50)
                    )
                """)
                conn.execute("TRUNCATE TABLE test_airflow")
                hook.bulk_load("test_airflow", f.name)
                conn.execute("SELECT dummy FROM test_airflow")
                results = tuple(result[0] for result in conn.fetchall())
                self.assertEqual(sorted(results), sorted(records))
コード例 #16
0
    def execute(self, context) -> None:
        mysql_hook = MySqlHook(mysql_conn_id=self.mysql_conn_id)
        s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
        data_df = mysql_hook.get_pandas_df(self.query)
        self.log.info("Data from MySQL obtained")

        self._fix_int_dtypes(data_df)
        file_options = FILE_OPTIONS_MAP[self.file_format]
        with NamedTemporaryFile(mode=file_options.mode,
                                suffix=file_options.suffix) as tmp_file:
            if self.file_format == FILE_FORMAT.CSV:
                data_df.to_csv(tmp_file.name, **self.pd_kwargs)
            else:
                data_df.to_parquet(tmp_file.name, **self.pd_kwargs)
            s3_conn.load_file(filename=tmp_file.name,
                              key=self.s3_key,
                              bucket_name=self.s3_bucket)

        if s3_conn.check_for_key(self.s3_key, bucket_name=self.s3_bucket):
            file_location = os.path.join(self.s3_bucket, self.s3_key)
            self.log.info("File saved correctly in %s", file_location)
コード例 #17
0
 def index(self):
     """Create default view"""
     sql = """
     SELECT
         a.name as db, db_location_uri as location,
         count(1) as object_count, a.desc as description
     FROM DBS a
     JOIN TBLS b ON a.DB_ID = b.DB_ID
     GROUP BY a.name, db_location_uri, a.desc
     """
     hook = MySqlHook(METASTORE_MYSQL_CONN_ID)
     df = hook.get_pandas_df(sql)
     df.db = '<a href="/metastorebrowserview/db/?db=' + df.db + '">' + df.db + '</a>'
     table = df.to_html(
         classes="table table-striped table-bordered table-hover",
         index=False,
         escape=False,
         na_rep='',
     )
     return self.render_template("metastore_browser/dbs.html",
                                 table=Markup(table))
コード例 #18
0
def check_data(task_instance, create_table_query_file):
    conn = MySqlHook(mysql_conn_id='mysql_localhost').get_conn()
    cur = conn.cursor()
    try:
        cur.execute('use sales_records_airflow')
        cur.execute('select count(*) from sales')
        total_rows = cur.fetchone()[0]
        task_instance.xcom_push(key='mysql_total_rows', value=total_rows)
        if type(total_rows) is int:
            print('appending new data')
            return 'csv_file_exist'
        elif total_rows == 50000:
            print('up to date')
            return 'check_dataset'
    except cur.OperationalError:
        print('sql_file execute')
        sql_file = open(create_table_query_file, 'r')
        sql_query = sql_file.read()
        for query in sql_query.split(';', maxsplit=2):
            cur.execute('{}'.format(query))
            conn.commit()
        return 'csv_file_not_exist'
コード例 #19
0
 def execute(self, context: 'Context') -> None:
     big_query_hook = BigQueryHook(
         gcp_conn_id=self.gcp_conn_id,
         delegate_to=self.delegate_to,
         location=self.location,
         impersonation_chain=self.impersonation_chain,
     )
     mysql_hook = MySqlHook(schema=self.database, mysql_conn_id=self.mysql_conn_id)
     for rows in bigquery_get_data(
         self.log,
         self.dataset_id,
         self.table_id,
         big_query_hook,
         self.batch_size,
         self.selected_fields,
     ):
         mysql_hook.insert_rows(
             table=self.mysql_table,
             rows=rows,
             target_fields=self.selected_fields,
             replace=self.replace,
         )
コード例 #20
0
    def execute(self, context):
        hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)

        self.log.info("Extracting data from Hive: %s", self.sql)
        hive_conf = context_to_airflow_vars(context)
        if self.hive_conf:
            hive_conf.update(self.hive_conf)
        if self.bulk_load:
            tmp_file = NamedTemporaryFile()
            hive.to_csv(
                self.sql,
                tmp_file.name,
                delimiter='\t',
                lineterminator='\n',
                output_header=False,
                hive_conf=hive_conf,
            )
        else:
            hive_results = hive.get_records(self.sql, hive_conf=hive_conf)

        mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)

        if self.mysql_preoperator:
            self.log.info("Running MySQL preoperator")
            mysql.run(self.mysql_preoperator)

        self.log.info("Inserting rows into MySQL")
        if self.bulk_load:
            mysql.bulk_load(table=self.mysql_table, tmp_file=tmp_file.name)
            tmp_file.close()
        else:
            mysql.insert_rows(table=self.mysql_table, rows=hive_results)

        if self.mysql_postoperator:
            self.log.info("Running MySQL postoperator")
            mysql.run(self.mysql_postoperator)

        self.log.info("Done.")
コード例 #21
0
    def execute(self, context: dict) -> None:
        """
        Executes the transfer operation from S3 to MySQL.

        :param context: The context that is being provided when executing.
        :type context: dict
        """
        self.log.info('Loading %s to MySql table %s...', self.s3_source_key,
                      self.mysql_table)

        s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
        file = s3_hook.download_file(key=self.s3_source_key)

        try:
            mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
            mysql.bulk_load_custom(
                table=self.mysql_table,
                tmp_file=file,
                duplicate_key_handling=self.mysql_duplicate_key_handling,
                extra_options=self.mysql_extra_options)
        finally:
            # Remove file downloaded from s3 to be idempotent.
            os.remove(file)
コード例 #22
0
ファイル: test_mysql.py プロジェクト: kangyunseok/airflow
    def test_mysql_hook_test_bulk_load(self, client):
        with MySqlContext(client):
            records = ("foo", "bar", "baz")

            import tempfile

            with tempfile.NamedTemporaryFile() as f:
                f.write("\n".join(records).encode('utf8'))
                f.flush()

                hook = MySqlHook('airflow_db')
                with closing(hook.get_conn()) as conn:
                    with closing(conn.cursor()) as cursor:
                        cursor.execute("""
                            CREATE TABLE IF NOT EXISTS test_airflow (
                                dummy VARCHAR(50)
                            )
                        """)
                        cursor.execute("TRUNCATE TABLE test_airflow")
                        hook.bulk_load("test_airflow", f.name)
                        cursor.execute("SELECT dummy FROM test_airflow")
                        results = tuple(result[0]
                                        for result in cursor.fetchall())
                        assert sorted(results) == sorted(records)
コード例 #23
0
 def execute(self, context):
     """Establish connections to both MySQL & PostgreSQL databases, open
     cursor and begin processing query, loading chunks of rows into
     PostgreSQL. Repeat loading chunks until all rows processed for query.
     """
     source = MySqlHook(mysql_conn_id=self.mysql_conn_id)
     target = PostgresHook(postgres_conn_id=self.postgres_conn_id)
     with closing(source.get_conn()) as conn:
         with closing(conn.cursor()) as cursor:
             cursor.execute(self.sql, self.params)
             target_fields = [x[0] for x in cursor.description]
             row_count = 0
             rows = cursor.fetchmany(self.rows_chunk)
             while len(rows) > 0:
                 row_count += len(rows)
                 target.insert_rows(
                     self.postgres_table,
                     rows,
                     target_fields=target_fields,
                     commit_every=self.rows_chunk,
                 )
                 rows = cursor.fetchmany(self.rows_chunk)
             self.log.info(
                 f"{row_count} row(s) inserted into {self.postgres_table}.")
コード例 #24
0
ファイル: hive_stats.py プロジェクト: folly3/airflow-1
    def execute(self, context: Optional[Dict[str, Any]] = None) -> None:
        metastore = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id)
        table = metastore.get_table(table_name=self.table)
        field_types = {col.name: col.type for col in table.sd.cols}

        exprs: Any = {('', 'count'): 'COUNT(*)'}
        for col, col_type in list(field_types.items()):
            if self.assignment_func:
                assign_exprs = self.assignment_func(col, col_type)
                if assign_exprs is None:
                    assign_exprs = self.get_default_exprs(col, col_type)
            else:
                assign_exprs = self.get_default_exprs(col, col_type)
            exprs.update(assign_exprs)
        exprs.update(self.extra_exprs)
        exprs = OrderedDict(exprs)
        exprs_str = ",\n        ".join([v + " AS " + k[0] + '__' + k[1] for k, v in exprs.items()])

        where_clause_ = ["{} = '{}'".format(k, v) for k, v in self.partition.items()]
        where_clause = " AND\n        ".join(where_clause_)
        sql = "SELECT {exprs_str} FROM {table} WHERE {where_clause};".format(
            exprs_str=exprs_str, table=self.table, where_clause=where_clause
        )

        presto = PrestoHook(presto_conn_id=self.presto_conn_id)
        self.log.info('Executing SQL check: %s', sql)
        row = presto.get_first(hql=sql)
        self.log.info("Record: %s", row)
        if not row:
            raise AirflowException("The query returned None")

        part_json = json.dumps(self.partition, sort_keys=True)

        self.log.info("Deleting rows from previous runs if they exist")
        mysql = MySqlHook(self.mysql_conn_id)
        sql = """
        SELECT 1 FROM hive_stats
        WHERE
            table_name='{table}' AND
            partition_repr='{part_json}' AND
            dttm='{dttm}'
        LIMIT 1;
        """.format(
            table=self.table, part_json=part_json, dttm=self.dttm
        )
        if mysql.get_records(sql):
            sql = """
            DELETE FROM hive_stats
            WHERE
                table_name='{table}' AND
                partition_repr='{part_json}' AND
                dttm='{dttm}';
            """.format(
                table=self.table, part_json=part_json, dttm=self.dttm
            )
            mysql.run(sql)

        self.log.info("Pivoting and loading cells into the Airflow db")
        rows = [
            (self.ds, self.dttm, self.table, part_json) + (r[0][0], r[0][1], r[1]) for r in zip(exprs, row)
        ]
        mysql.insert_rows(
            table='hive_stats',
            rows=rows,
            target_fields=[
                'ds',
                'dttm',
                'table_name',
                'partition_repr',
                'col',
                'metric',
                'value',
            ],
        )
コード例 #25
0
 def tearDown(self):
     drop_tables = {'test_mysql_to_mysql', 'test_airflow'}
     with MySqlHook().get_conn() as conn:
         for table in drop_tables:
             conn.execute(f"DROP TABLE IF EXISTS {table}")
コード例 #26
0
ファイル: test_mysql.py プロジェクト: kangyunseok/airflow
 def tearDown(self):
     drop_tables = {'test_mysql_to_mysql', 'test_airflow'}
     with closing(MySqlHook().get_conn()) as conn:
         with closing(conn.cursor()) as cursor:
             for table in drop_tables:
                 cursor.execute(f"DROP TABLE IF EXISTS {table}")
コード例 #27
0
    def setUp(self):
        args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
        dag = DAG(TEST_DAG_ID, default_args=args)
        self.dag = dag

        rows = [
            (1880, "John", 0.081541, "boy"),
            (1880, "William", 0.080511, "boy"),
            (1880, "James", 0.050057, "boy"),
            (1880, "Charles", 0.045167, "boy"),
            (1880, "George", 0.043292, "boy"),
            (1880, "Frank", 0.02738, "boy"),
            (1880, "Joseph", 0.022229, "boy"),
            (1880, "Thomas", 0.021401, "boy"),
            (1880, "Henry", 0.020641, "boy"),
            (1880, "Robert", 0.020404, "boy"),
            (1880, "Edward", 0.019965, "boy"),
            (1880, "Harry", 0.018175, "boy"),
            (1880, "Walter", 0.014822, "boy"),
            (1880, "Arthur", 0.013504, "boy"),
            (1880, "Fred", 0.013251, "boy"),
            (1880, "Albert", 0.012609, "boy"),
            (1880, "Samuel", 0.008648, "boy"),
            (1880, "David", 0.007339, "boy"),
            (1880, "Louis", 0.006993, "boy"),
            (1880, "Joe", 0.006174, "boy"),
            (1880, "Charlie", 0.006165, "boy"),
            (1880, "Clarence", 0.006165, "boy"),
            (1880, "Richard", 0.006148, "boy"),
            (1880, "Andrew", 0.005439, "boy"),
            (1880, "Daniel", 0.00543, "boy"),
            (1880, "Ernest", 0.005194, "boy"),
            (1880, "Will", 0.004966, "boy"),
            (1880, "Jesse", 0.004805, "boy"),
            (1880, "Oscar", 0.004594, "boy"),
            (1880, "Lewis", 0.004366, "boy"),
            (1880, "Peter", 0.004189, "boy"),
            (1880, "Benjamin", 0.004138, "boy"),
            (1880, "Frederick", 0.004079, "boy"),
            (1880, "Willie", 0.00402, "boy"),
            (1880, "Alfred", 0.003961, "boy"),
            (1880, "Sam", 0.00386, "boy"),
            (1880, "Roy", 0.003716, "boy"),
            (1880, "Herbert", 0.003581, "boy"),
            (1880, "Jacob", 0.003412, "boy"),
            (1880, "Tom", 0.00337, "boy"),
            (1880, "Elmer", 0.00315, "boy"),
            (1880, "Carl", 0.003142, "boy"),
            (1880, "Lee", 0.003049, "boy"),
            (1880, "Howard", 0.003015, "boy"),
            (1880, "Martin", 0.003015, "boy"),
            (1880, "Michael", 0.00299, "boy"),
            (1880, "Bert", 0.002939, "boy"),
            (1880, "Herman", 0.002931, "boy"),
            (1880, "Jim", 0.002914, "boy"),
            (1880, "Francis", 0.002905, "boy"),
            (1880, "Harvey", 0.002905, "boy"),
            (1880, "Earl", 0.002829, "boy"),
            (1880, "Eugene", 0.00277, "boy"),
        ]

        self.env_vars = {
            'AIRFLOW_CTX_DAG_ID': 'test_dag_id',
            'AIRFLOW_CTX_TASK_ID': 'test_task_id',
            'AIRFLOW_CTX_EXECUTION_DATE': '2015-01-01T00:00:00+00:00',
            'AIRFLOW_CTX_DAG_RUN_ID': '55',
            'AIRFLOW_CTX_DAG_OWNER': 'airflow',
            'AIRFLOW_CTX_DAG_EMAIL': '*****@*****.**',
        }

        with MySqlHook().get_conn() as cur:
            cur.execute('''
            CREATE TABLE IF NOT EXISTS baby_names (
              org_year integer(4),
              baby_name VARCHAR(25),
              rate FLOAT(7,6),
              sex VARCHAR(4)
            )
            ''')

        for row in rows:
            cur.execute("INSERT INTO baby_names VALUES(%s, %s, %s, %s);", row)
コード例 #28
0
    def execute(self, context):
        vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id)
        mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)

        tmpfile = None
        result = None

        selected_columns = []

        count = 0
        with closing(vertica.get_conn()) as conn:
            with closing(conn.cursor()) as cursor:
                cursor.execute(self.sql)
                selected_columns = [d.name for d in cursor.description]

                if self.bulk_load:
                    tmpfile = NamedTemporaryFile("w")

                    self.log.info(
                        "Selecting rows from Vertica to local file %s...",
                        tmpfile.name)
                    self.log.info(self.sql)

                    csv_writer = csv.writer(tmpfile,
                                            delimiter='\t',
                                            encoding='utf-8')
                    for row in cursor.iterate():
                        csv_writer.writerow(row)
                        count += 1

                    tmpfile.flush()
                else:
                    self.log.info("Selecting rows from Vertica...")
                    self.log.info(self.sql)

                    result = cursor.fetchall()
                    count = len(result)

                self.log.info("Selected rows from Vertica %s", count)

        if self.mysql_preoperator:
            self.log.info("Running MySQL preoperator...")
            mysql.run(self.mysql_preoperator)

        try:
            if self.bulk_load:
                self.log.info("Bulk inserting rows into MySQL...")
                with closing(mysql.get_conn()) as conn:
                    with closing(conn.cursor()) as cursor:
                        cursor.execute(
                            "LOAD DATA LOCAL INFILE '%s' INTO "
                            "TABLE %s LINES TERMINATED BY '\r\n' (%s)" %
                            (tmpfile.name, self.mysql_table,
                             ", ".join(selected_columns)))
                        conn.commit()
                tmpfile.close()
            else:
                self.log.info("Inserting rows into MySQL...")
                mysql.insert_rows(table=self.mysql_table,
                                  rows=result,
                                  target_fields=selected_columns)
            self.log.info("Inserted rows into MySQL %s", count)
        except (MySQLdb.Error, MySQLdb.Warning):  # pylint: disable=no-member
            self.log.info("Inserted rows into MySQL 0")
            raise

        if self.mysql_postoperator:
            self.log.info("Running MySQL postoperator...")
            mysql.run(self.mysql_postoperator)

        self.log.info("Done")
コード例 #29
0
    def test_mysql_to_hive_verify_loaded_values(self, mock_popen,
                                                mock_temp_dir):
        mock_subprocess = MockSubProcess()
        mock_popen.return_value = mock_subprocess
        mock_temp_dir.return_value = "test_mysql_to_hive"

        mysql_table = 'test_mysql_to_hive'
        hive_table = 'test_mysql_to_hive'

        hook = MySqlHook()

        try:
            minmax = (
                255,
                65535,
                16777215,
                4294967295,
                18446744073709551615,
                -128,
                -32768,
                -8388608,
                -2147483648,
                -9223372036854775808,
            )

            with hook.get_conn() as conn:
                conn.execute(f"DROP TABLE IF EXISTS {mysql_table}")
                conn.execute("""
                    CREATE TABLE {} (
                        c0 TINYINT   UNSIGNED,
                        c1 SMALLINT  UNSIGNED,
                        c2 MEDIUMINT UNSIGNED,
                        c3 INT       UNSIGNED,
                        c4 BIGINT    UNSIGNED,
                        c5 TINYINT,
                        c6 SMALLINT,
                        c7 MEDIUMINT,
                        c8 INT,
                        c9 BIGINT
                    )
                """.format(mysql_table))
                conn.execute("""
                    INSERT INTO {} VALUES (
                        {}, {}, {}, {}, {}, {}, {}, {}, {}, {}
                    )
                """.format(mysql_table, *minmax))

            with mock.patch.dict('os.environ', self.env_vars):
                op = MySqlToHiveOperator(
                    task_id='test_m2h',
                    hive_cli_conn_id='hive_cli_default',
                    sql=f"SELECT * FROM {mysql_table}",
                    hive_table=hive_table,
                    recreate=True,
                    delimiter=",",
                    dag=self.dag,
                )
                op.run(start_date=DEFAULT_DATE,
                       end_date=DEFAULT_DATE,
                       ignore_ti_state=True)

                mock_cursor = MockConnectionCursor()
                mock_cursor.iterable = [minmax]
                hive_hook = MockHiveServer2Hook(connection_cursor=mock_cursor)

                result = hive_hook.get_records(f"SELECT * FROM {hive_table}")
                assert result[0] == minmax

                hive_cmd = [
                    'beeline',
                    '-u',
                    '"jdbc:hive2://localhost:10000/default"',
                    '-hiveconf',
                    'airflow.ctx.dag_id=unit_test_dag',
                    '-hiveconf',
                    'airflow.ctx.task_id=test_m2h',
                    '-hiveconf',
                    'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00',
                    '-hiveconf',
                    'airflow.ctx.dag_run_id=55',
                    '-hiveconf',
                    'airflow.ctx.dag_owner=airflow',
                    '-hiveconf',
                    '[email protected]',
                    '-hiveconf',
                    'mapreduce.job.queuename=airflow',
                    '-hiveconf',
                    'mapred.job.queue.name=airflow',
                    '-hiveconf',
                    'tez.queue.name=airflow',
                    '-f',
                    '/tmp/airflow_hiveop_test_mysql_to_hive/tmptest_mysql_to_hive',
                ]

                mock_popen.assert_called_with(
                    hive_cmd,
                    stdout=mock_subprocess.PIPE,
                    stderr=mock_subprocess.STDOUT,
                    cwd="/tmp/airflow_hiveop_test_mysql_to_hive",
                    close_fds=True,
                )

        finally:
            with hook.get_conn() as conn:
                conn.execute(f"DROP TABLE IF EXISTS {mysql_table}")
コード例 #30
0
 def execute(self, context: Dict) -> None:
     self.log.info('Executing: %s', self.sql)
     hook = MySqlHook(mysql_conn_id=self.mysql_conn_id, schema=self.database)
     hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)