def execute(self, context):
        hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)

        self.log.info("Dumping MySQL query results to local file")
        conn = mysql.get_conn()
        cursor = conn.cursor()
        cursor.execute(self.sql)
        with NamedTemporaryFile("wb") as f:
            csv_writer = csv.writer(f, delimiter=self.delimiter, encoding="utf-8")
            field_dict = OrderedDict()
            for field in cursor.description:
                field_dict[field[0]] = self.type_map(field[1])
            csv_writer.writerows(cursor)
            f.flush()
            cursor.close()
            conn.close()
            self.log.info("Loading file into Hive")
            hive.load_file(
                f.name,
                self.hive_table,
                field_dict=field_dict,
                create=self.create,
                partition=self.partition,
                delimiter=self.delimiter,
                recreate=self.recreate,
                tblproperties=self.tblproperties)
 def _query_mysql(self):
     """
     Queries mysql and returns a cursor to the results.
     """
     mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
     conn = mysql.get_conn()
     cursor = conn.cursor()
     cursor.execute(self.sql)
     return cursor
Example #3
0
    def test_mysql_to_hive_type_conversion(self, mock_load_file):
        mysql_conn_id = 'airflow_ci'
        mysql_table = 'test_mysql_to_hive'

        from airflow.hooks.mysql_hook import MySqlHook
        m = MySqlHook(mysql_conn_id)

        try:
            with m.get_conn() as c:
                c.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
                c.execute("""
                    CREATE TABLE {} (
                        c0 TINYINT,
                        c1 SMALLINT,
                        c2 MEDIUMINT,
                        c3 INT,
                        c4 BIGINT
                    )
                """.format(mysql_table))

            from airflow.operators.mysql_to_hive import MySqlToHiveTransfer
            t = MySqlToHiveTransfer(
                task_id='test_m2h',
                mysql_conn_id=mysql_conn_id,
                hive_cli_conn_id='beeline_default',
                sql="SELECT * FROM {}".format(mysql_table),
                hive_table='test_mysql_to_hive',
                dag=self.dag)
            t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

            mock_load_file.assert_called_once()
            d = OrderedDict()
            d["c0"] = "SMALLINT"
            d["c1"] = "INT"
            d["c2"] = "INT"
            d["c3"] = "BIGINT"
            d["c4"] = "DECIMAL(38,0)"
            self.assertEqual(mock_load_file.call_args[1]["field_dict"], d)
        finally:
            with m.get_conn() as c:
                c.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
Example #4
0
    def test_mysql_hook_test_bulk_load(self):
        records = ("foo", "bar", "baz")

        import tempfile
        with tempfile.NamedTemporaryFile() as t:
            t.write("\n".join(records).encode('utf8'))
            t.flush()

            from airflow.hooks.mysql_hook import MySqlHook
            h = MySqlHook('airflow_ci')
            with h.get_conn() as c:
                c.execute("""
                    CREATE TABLE IF NOT EXISTS test_airflow (
                        dummy VARCHAR(50)
                    )
                """)
                c.execute("TRUNCATE TABLE test_airflow")
                h.bulk_load("test_airflow", t.name)
                c.execute("SELECT dummy FROM test_airflow")
                results = tuple(result[0] for result in c.fetchall())
                self.assertEqual(sorted(results), sorted(records))
Example #5
0
def run_check_table(db_name, table_name, conn_id, hive_table_name, server_name,
                    **kwargs):
    # SHOW TABLES in oride_db LIKE 'data_aa'
    check_sql = 'SHOW TABLES in %s LIKE \'%s\'' % (HIVE_DB, hive_table_name)
    hive2_conn = HiveServer2Hook().get_conn()
    cursor = hive2_conn.cursor()
    cursor.execute(check_sql)
    if len(cursor.fetchall()) == 0:
        logging.info('Create Hive Table: %s.%s', HIVE_DB, hive_table_name)
        # get table column
        column_sql = '''
                SELECT
                    COLUMN_NAME,
                    DATA_TYPE,
                    NUMERIC_PRECISION,
                    NUMERIC_SCALE,
                    COLUMN_COMMENT
                FROM
                    information_schema.columns
                WHERE
                    table_schema='{db_name}' and table_name='{table_name}'
            '''.format(db_name=db_name, table_name=table_name)
        mysql_hook = MySqlHook(conn_id)
        mysql_conn = mysql_hook.get_conn()
        mysql_cursor = mysql_conn.cursor()
        mysql_cursor.execute(column_sql)
        results = mysql_cursor.fetchall()
        rows = []
        for result in results:
            if result[0] == 'dt':
                col_name = '_dt'
            else:
                col_name = result[0]
            if result[1] == 'timestamp' or result[1] == 'varchar' or result[1] == 'char' or result[1] == 'text' or \
                    result[1] == 'longtext' or \
                    result[1] == 'datetime' or result[1] == 'mediumtext':
                data_type = 'string'
            elif result[1] == 'decimal':
                data_type = result[1] + "(" + str(result[2]) + "," + str(
                    result[3]) + ")"
            else:
                data_type = result[1]
            rows.append("`%s` %s comment '%s'" %
                        (col_name, data_type, str(result[4]).replace(
                            '\n', '').replace('\r', '').replace('\'', '\\\'')))
        mysql_conn.close()

        # hive create table
        hive_hook = HiveCliHook()
        sql = ODS_CREATE_TABLE_SQL.format(
            db_name=HIVE_DB,
            table_name=hive_table_name,
            columns=",\n".join(rows),
            oss_path=OSS_PATH % ("{server_name}.{db_name}.{table_name}".format(
                server_name=server_name,
                db_name=db_name,
                table_name=table_name)))
        logging.info('Executing: %s', sql)
        hive_hook.run_cli(sql)

    else:
        sqoopSchema = SqoopSchemaUpdate()
        response = sqoopSchema.append_hive_schema(
            hive_db=HIVE_DB,
            hive_table=hive_table_name,
            mysql_db=db_name,
            mysql_table=table_name,
            mysql_conn=conn_id,
            oss_path=OSS_PATH % ("{server_name}.{db_name}.{table_name}".format(
                server_name=server_name,
                db_name=db_name,
                table_name=table_name)))
        if response:
            return True
    return
class MongoMetadataOperator(BaseOperator):
    @apply_defaults
    def __init__(self, mongo_conn_id, *args, **kwargs):
        self.mongo_conn_id = mongo_conn_id
        super(MongoMetadataOperator, self).__init__(*args, **kwargs)

    def execute(self, context):

        self.mongo_hook = MongoHook(self.mongo_conn_id)
        self.mysql_metadata_hook = MySqlHook("airflow_connection")
        self.insert_tables()
        return True

    def insert_tables(self):

        mongo_conn = self.mongo_hook.get_conn()
        db = mongo_conn['sw_client']
        collections = db.collection_names(include_system_collections=False)
        mongo_conn.close()

        tables = []
        db = self.mysql_metadata_hook.get_conn()
        generator = self.mysql_statement_generator()

        mysql_main_table = generator.next()
        self.mysql_statement_executor(db, mysql_main_table)

        mysql_main_staging_table = generator.next()
        self.mysql_statement_executor(db, mysql_main_staging_table)

        for table in collections:
            tables.append(('sw_client', table, 1))

        delete_metadata = generator.next()
        self.mysql_statement_executor(db, delete_metadata)

        insert_metadata = generator.next()
        insert_metadata = insert_metadata + ",".join("(%s,%s ,%s)"
                                                     for _ in tables)
        flattened_values = [item for sublist in tables for item in sublist]
        self.mysql_statement_executor(db, insert_metadata, flattened_values)

        delete_data_main_table = generator.next()
        self.mysql_statement_executor(db, delete_data_main_table)

        insert_data_main_table = generator.next()
        self.mysql_statement_executor(db, insert_data_main_table)

        db.commit()
        db.close()

        return 0

    def mysql_statement_generator(self):

        yield '''CREATE TABLE IF NOT EXISTS mongo_metadata (id INT(10) AUTO_INCREMENT, db_name VARCHAR(50), 
            table_name VARCHAR(128) , is_active TINYINT(4) , created_at TIMESTAMP NOT NULL  DEFAULT CURRENT_TIMESTAMP,  
             updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP , PRIMARY KEY(id) ) '''

        yield ''' CREATE TABLE IF NOT EXISTS mongo_metadata_staging (id INT(10) AUTO_INCREMENT, db_name VARCHAR(50), 
            table_name VARCHAR(128) , is_active TINYINT(4) , created_at TIMESTAMP NOT NULL  DEFAULT CURRENT_TIMESTAMP,  
            updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ,PRIMARY KEY(id)) '''

        yield "DELETE FROM mongo_metadata_staging"

        yield "INSERT INTO mongo_metadata_staging (db_name, table_name, is_active) VALUES "

        yield "DELETE FROM mongo_metadata"

        yield "INSERT INTO mongo_metadata SELECT * FROM mysql_metadata_staging"

        return

    def mysql_statement_executor(self, db, mysql, values=None):
        cursor = db.cursor()
        if values is None:
            cursor.execute(mysql)
        else:
            cursor.execute(mysql, values)
        result = cursor.fetchall()
        cursor.close()
        return result
Example #7
0
    def execute(self, context):
        vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id)
        mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)

        tmpfile = None
        result = None

        selected_columns = []

        count = 0
        with closing(vertica.get_conn()) as conn:
            with closing(conn.cursor()) as cursor:
                cursor.execute(self.sql)
                selected_columns = [d.name for d in cursor.description]

                if self.bulk_load:
                    tmpfile = NamedTemporaryFile("w")

                    logging.info("Selecting rows from Vertica to local file " +
                                 str(tmpfile.name) + "...")
                    logging.info(self.sql)

                    csv_writer = csv.writer(tmpfile,
                                            delimiter='\t',
                                            encoding='utf-8')
                    for row in cursor.iterate():
                        csv_writer.writerow(row)
                        count += 1

                    tmpfile.flush()
                else:
                    logging.info("Selecting rows from Vertica...")
                    logging.info(self.sql)

                    result = cursor.fetchall()
                    count = len(result)

                logging.info("Selected rows from Vertica " + str(count))

        if self.mysql_preoperator:
            logging.info("Running MySQL preoperator...")
            mysql.run(self.mysql_preoperator)

        try:
            if self.bulk_load:
                logging.info("Bulk inserting rows into MySQL...")
                with closing(mysql.get_conn()) as conn:
                    with closing(conn.cursor()) as cursor:
                        cursor.execute(
                            "LOAD DATA LOCAL INFILE '%s' INTO TABLE %s LINES TERMINATED BY '\r\n' (%s)"
                            % (tmpfile.name, self.mysql_table,
                               ", ".join(selected_columns)))
                        conn.commit()
                tmpfile.close()
            else:
                logging.info("Inserting rows into MySQL...")
                mysql.insert_rows(table=self.mysql_table,
                                  rows=result,
                                  target_fields=selected_columns)
            logging.info("Inserted rows into MySQL " + str(count))
        except:
            logging.error("Inserted rows into MySQL 0")
            raise

        if self.mysql_postoperator:
            logging.info("Running MySQL postoperator...")
            mysql.run(self.mysql_postoperator)

        logging.info("Done")
Example #8
0
class TestMySqlHookConn(unittest.TestCase):
    def setUp(self):
        super().setUp()

        self.connection = Connection(
            login='******',
            password='******',
            host='host',
            schema='schema',
        )

        self.db_hook = MySqlHook()
        self.db_hook.get_connection = mock.Mock()
        self.db_hook.get_connection.return_value = self.connection

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn(self, mock_connect):
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['user'], 'login')
        self.assertEqual(kwargs['passwd'], 'password')
        self.assertEqual(kwargs['host'], 'host')
        self.assertEqual(kwargs['db'], 'schema')

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_from_connection(self, mock_connect):
        conn = Connection(login='******',
                          password='******',
                          host='host',
                          schema='schema')
        hook = MySqlHook(connection=conn)
        hook.get_conn()
        mock_connect.assert_called_once_with(user='******',
                                             passwd='password-conn',
                                             host='host',
                                             db='schema',
                                             port=3306)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_from_connection_with_schema(self, mock_connect):
        conn = Connection(login='******',
                          password='******',
                          host='host',
                          schema='schema')
        hook = MySqlHook(connection=conn, schema='schema-override')
        hook.get_conn()
        mock_connect.assert_called_once_with(user='******',
                                             passwd='password-conn',
                                             host='host',
                                             db='schema-override',
                                             port=3306)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_port(self, mock_connect):
        self.connection.port = 3307
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['port'], 3307)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_charset(self, mock_connect):
        self.connection.extra = json.dumps({'charset': 'utf-8'})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['charset'], 'utf-8')
        self.assertEqual(kwargs['use_unicode'], True)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_cursor(self, mock_connect):
        self.connection.extra = json.dumps({'cursor': 'sscursor'})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['cursorclass'], MySQLdb.cursors.SSCursor)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_local_infile(self, mock_connect):
        self.connection.extra = json.dumps({'local_infile': True})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['local_infile'], 1)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_con_unix_socket(self, mock_connect):
        self.connection.extra = json.dumps({'unix_socket': "/tmp/socket"})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['unix_socket'], '/tmp/socket')

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_ssl_as_dictionary(self, mock_connect):
        self.connection.extra = json.dumps({'ssl': SSL_DICT})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['ssl'], SSL_DICT)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_ssl_as_string(self, mock_connect):
        self.connection.extra = json.dumps({'ssl': json.dumps(SSL_DICT)})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['ssl'], SSL_DICT)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    @mock.patch('airflow.contrib.hooks.aws_hook.AwsHook.get_client_type')
    def test_get_conn_rds_iam(self, mock_client, mock_connect):
        self.connection.extra = '{"iam":true}'
        mock_client.return_value.generate_db_auth_token.return_value = 'aws_token'
        self.db_hook.get_conn()
        mock_connect.assert_called_once_with(
            user='******',
            passwd='aws_token',
            host='host',
            db='schema',
            port=3306,
            read_default_group='enable-cleartext-plugin')
Example #9
0
    'owner': 'linan',
    'start_date': datetime(2020, 3, 19),
    'depends_on_past': False,
    'retries': 1,
    'retry_delay': timedelta(minutes=5),
    # 'email': ['*****@*****.**'],
    # 'email_on_failure': True,
    # 'email_on_retry': False,
}

dag = airflow.DAG('bussiness_monitor',
                  schedule_interval="*/5 * * * *",
                  default_args=args)

mysql_hook = MySqlHook("bussiness_mysql")
mysql_conn = mysql_hook.get_conn()
mysql_cursor = mysql_conn.cursor()

exec_command = """
    influx -database 'serverDB' -execute '{sql}'  -format='csv' > {metrics_name}.txt && echo 1 || echo 0
"""

ssh = paramiko.SSHClient()
key = paramiko.AutoAddPolicy()
ssh.set_missing_host_key_policy(key)
ssh.connect('10.52.5.233', 22, 'airflow', '', timeout=5)

scp = SCPClient(ssh.get_transport())

cat_command = """
    cat /home/airflow/{metrics_name}.txt && echo 1 || echo 0
Example #10
0
def run_sqoop_check_table(mysql_db_name, mysql_table_name, conn_id,
                          hive_table_name, **kwargs):
    sqoopSchema = SqoopSchemaUpdate()
    response = sqoopSchema.update_hive_schema(hive_db=HIVE_SQOOP_TEMP_DB,
                                              hive_table=hive_table_name,
                                              mysql_db=mysql_db_name,
                                              mysql_table=mysql_table_name,
                                              mysql_conn=conn_id)
    if response:
        return True

    # SHOW TABLES in oride_db LIKE 'data_aa'
    check_sql = 'SHOW TABLES in %s LIKE \'%s\'' % (HIVE_DB, hive_table_name)
    hive2_conn = HiveServer2Hook().get_conn()
    cursor = hive2_conn.cursor()
    cursor.execute(check_sql)
    if len(cursor.fetchall()) == 0:
        logging.info('Create Hive Table: %s.%s', HIVE_DB, hive_table_name)
        # get table column
        column_sql = '''
            SELECT
                COLUMN_NAME,
                DATA_TYPE,
                NUMERIC_PRECISION,
                NUMERIC_SCALE,COLUMN_COMMENT
            FROM
                information_schema.columns
            WHERE
                table_schema='{db_name}' and table_name='{table_name}'
        '''.format(db_name=mysql_db_name, table_name=mysql_table_name)
        mysql_hook = MySqlHook(conn_id)
        mysql_conn = mysql_hook.get_conn()
        mysql_cursor = mysql_conn.cursor()
        mysql_cursor.execute(column_sql)
        results = mysql_cursor.fetchall()
        rows = []
        for result in results:
            if result[0] == 'dt':
                col_name = '_dt'
            else:
                col_name = result[0]
            if result[1] == 'timestamp' or result[1] == 'varchar' or result[1] == 'char' or result[1] == 'text' or \
                    result[1] == 'datetime' or result[1] == 'mediumtext' or result[1] == 'enum' or result[1] == 'json':
                data_type = 'string'
                # elif result[1] == 'decimal':
                #     data_type = result[1] + "(" + str(result[2]) + "," + str(result[3]) + ")"
                # 有json表读取insert 部分,此处切换为double
            elif result[1] == 'decimal':
                data_type = 'double'
            elif result[1] == 'mediumint':
                data_type = 'int'
            else:
                data_type = result[1]
            rows.append("`%s` %s comment '%s'" %
                        (col_name, data_type, result[4]))
        mysql_conn.close()

        # hive create table
        hive_hook = HiveCliHook()
        sql = ODS_SQOOP_CREATE_TABLE_SQL.format(
            db_name=HIVE_SQOOP_TEMP_DB,
            table_name=hive_table_name,
            columns=",\n".join(rows),
            ufile_path=UFILE_PATH % (mysql_db_name, mysql_table_name))
        logging.info('Executing: %s', sql)
        hive_hook.run_cli(sql)
    return
    def test_mysql_to_hive_verify_loaded_values(self, mock_popen,
                                                mock_temp_dir):
        mock_subprocess = MockSubProcess()
        mock_popen.return_value = mock_subprocess
        mock_temp_dir.return_value = "test_mysql_to_hive"

        mysql_table = 'test_mysql_to_hive'
        hive_table = 'test_mysql_to_hive'

        hook = MySqlHook()

        try:
            minmax = (255, 65535, 16777215, 4294967295, 18446744073709551615,
                      -128, -32768, -8388608, -2147483648,
                      -9223372036854775808)

            with hook.get_conn() as conn:
                conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
                conn.execute("""
                    CREATE TABLE {} (
                        c0 TINYINT   UNSIGNED,
                        c1 SMALLINT  UNSIGNED,
                        c2 MEDIUMINT UNSIGNED,
                        c3 INT       UNSIGNED,
                        c4 BIGINT    UNSIGNED,
                        c5 TINYINT,
                        c6 SMALLINT,
                        c7 MEDIUMINT,
                        c8 INT,
                        c9 BIGINT
                    )
                """.format(mysql_table))
                conn.execute("""
                    INSERT INTO {} VALUES (
                        {}, {}, {}, {}, {}, {}, {}, {}, {}, {}
                    )
                """.format(mysql_table, *minmax))

            with mock.patch.dict('os.environ', self.env_vars):
                op = MySqlToHiveTransfer(
                    task_id='test_m2h',
                    hive_cli_conn_id='hive_cli_default',
                    sql="SELECT * FROM {}".format(mysql_table),
                    hive_table=hive_table,
                    recreate=True,
                    delimiter=",",
                    dag=self.dag)
                op.run(start_date=DEFAULT_DATE,
                       end_date=DEFAULT_DATE,
                       ignore_ti_state=True)

                mock_cursor = MockConnectionCursor()
                mock_cursor.iterable = [minmax]
                hive_hook = MockHiveServer2Hook(connection_cursor=mock_cursor)

                result = hive_hook.get_records(
                    "SELECT * FROM {}".format(hive_table))
                self.assertEqual(result[0], minmax)

                hive_cmd = [
                    u'hive', u'-hiveconf',
                    u'[email protected]', u'-hiveconf',
                    u'airflow.ctx.dag_id=test_dag_id', u'-hiveconf',
                    u'airflow.ctx.dag_owner=airflow', u'-hiveconf',
                    u'airflow.ctx.dag_run_id=55', u'-hiveconf',
                    u'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00',
                    u'-hiveconf', u'airflow.ctx.task_id=test_task_id',
                    u'-hiveconf', u'mapreduce.job.queuename=airflow',
                    u'-hiveconf', u'mapred.job.queue.name=airflow',
                    u'-hiveconf', u'tez.queue.name=airflow', u'-f',
                    u'/tmp/airflow_hiveop_test_mysql_to_hive/tmptest_mysql_to_hive'
                ]

                mock_popen.assert_called_with(
                    hive_cmd,
                    stdout=mock_subprocess.PIPE,
                    stderr=mock_subprocess.STDOUT,
                    cwd="/tmp/airflow_hiveop_test_mysql_to_hive",
                    close_fds=True)

        finally:
            with hook.get_conn() as conn:
                conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
Example #12
0
class TestMySqlHookConn(unittest.TestCase):
    def setUp(self):
        super().setUp()

        self.connection = Connection(
            login='******',
            password='******',
            host='host',
            schema='schema',
        )

        self.db_hook = MySqlHook()
        self.db_hook.get_connection = mock.Mock()
        self.db_hook.get_connection.return_value = self.connection

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn(self, mock_connect):
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['user'], 'login')
        self.assertEqual(kwargs['passwd'], 'password')
        self.assertEqual(kwargs['host'], 'host')
        self.assertEqual(kwargs['db'], 'schema')

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_port(self, mock_connect):
        self.connection.port = 3307
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['port'], 3307)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_charset(self, mock_connect):
        self.connection.extra = json.dumps({'charset': 'utf-8'})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['charset'], 'utf-8')
        self.assertEqual(kwargs['use_unicode'], True)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_cursor(self, mock_connect):
        self.connection.extra = json.dumps({'cursor': 'sscursor'})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['cursorclass'], MySQLdb.cursors.SSCursor)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_local_infile(self, mock_connect):
        self.connection.extra = json.dumps({'local_infile': True})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['local_infile'], 1)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_con_unix_socket(self, mock_connect):
        self.connection.extra = json.dumps({'unix_socket': "/tmp/socket"})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['unix_socket'], '/tmp/socket')

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_ssl_as_dictionary(self, mock_connect):
        self.connection.extra = json.dumps({'ssl': SSL_DICT})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['ssl'], SSL_DICT)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_ssl_as_string(self, mock_connect):
        self.connection.extra = json.dumps({'ssl': json.dumps(SSL_DICT)})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['ssl'], SSL_DICT)
Example #13
0
    def test_mysql_to_hive_verify_loaded_values(self):
        mysql_conn_id = 'airflow_ci'
        mysql_table = 'test_mysql_to_hive'
        hive_table = 'test_mysql_to_hive'

        from airflow.hooks.mysql_hook import MySqlHook
        m = MySqlHook(mysql_conn_id)

        try:
            minmax = (
                255,
                65535,
                16777215,
                4294967295,
                18446744073709551615,
                -128,
                -32768,
                -8388608,
                -2147483648,
                -9223372036854775808
            )

            with m.get_conn() as c:
                c.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
                c.execute("""
                    CREATE TABLE {} (
                        c0 TINYINT   UNSIGNED,
                        c1 SMALLINT  UNSIGNED,
                        c2 MEDIUMINT UNSIGNED,
                        c3 INT       UNSIGNED,
                        c4 BIGINT    UNSIGNED,
                        c5 TINYINT,
                        c6 SMALLINT,
                        c7 MEDIUMINT,
                        c8 INT,
                        c9 BIGINT
                    )
                """.format(mysql_table))
                c.execute("""
                    INSERT INTO {} VALUES (
                        {}, {}, {}, {}, {}, {}, {}, {}, {}, {}
                    )
                """.format(mysql_table, *minmax))

            from airflow.operators.mysql_to_hive import MySqlToHiveTransfer
            t = MySqlToHiveTransfer(
                task_id='test_m2h',
                mysql_conn_id=mysql_conn_id,
                hive_cli_conn_id='beeline_default',
                sql="SELECT * FROM {}".format(mysql_table),
                hive_table=hive_table,
                recreate=True,
                delimiter=",",
                dag=self.dag)
            t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

            from airflow.hooks.hive_hooks import HiveServer2Hook
            h = HiveServer2Hook()
            r = h.get_records("SELECT * FROM {}".format(hive_table))
            self.assertEqual(r[0], minmax)
        finally:
            with m.get_conn() as c:
                c.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
Example #14
0
class Sync:
    def __init__(self, sync_task, task_id=None, dag_id=None):
        self.sync_task = sync_task
        self.sync_task_id = task_id
        self.sync_task_dag_id = dag_id
        self.adb_table = sync_task['adb_table']
        self.dml_operator = sync_task['dml_operator']
        self.mysql = MySqlHook(mysql_conn_id='adb_default')
        self.log = logger

    def unfinished(self):
        dag_id = self.sync_task["dag_id"]
        task_id = self.sync_task["task_id"]
        # filter the latest dependent task
        dep_sql = f"select end_date from task_instance where dag_id='{dag_id}' and task_id='{task_id}' and state='success' order by end_date desc limit 1"
        dep_res = get_mysql_dataset(mysql_conn_id="airflow_emr",
                                    schema="airflow",
                                    sql=dep_sql)
        self.log.info(f'dependent task:{dep_sql} {dep_res}')
        # filter the latest succeeded sync task
        sql = f"select end_date from task_instance where dag_id='{self.sync_task_dag_id}' and task_id='{self.sync_task_id}' and state='success' order by end_date desc limit 1"
        res = get_mysql_dataset(mysql_conn_id="airflow_db",
                                schema="airflow",
                                sql=sql)
        self.log.info(f'recent task run: {sql} {res}')

        need_running = bool(
            dep_res
            and (not res or res[0]["end_date"] < dep_res[0]["end_date"]))
        if not need_running:
            return True

        # task execution log after upstream task success
        upstream_task_latest_success_time = res[0]["end_date"].strftime(
            "%Y-%m-%d %H:%M:%S") if res else '1970-01-01 00:00:00'
        task_state_sql = f"select state from task_instance where dag_id='{self.sync_task_dag_id}' and task_id='{self.sync_task_id}' and execution_date > '{upstream_task_latest_success_time}'"
        target_tasks = get_mysql_dataset(mysql_conn_id="airflow_db",
                                         schema="airflow",
                                         sql=task_state_sql)
        unfinished_tasks = [
            _ for _ in target_tasks if _["state"] in State.unfinished()
        ]

        # if record unfinished > 1, means task is running
        return len(unfinished_tasks) > 1

    def mysql_run(self, table, values, columns, columns_num):
        """insert data to mysql table"""
        self.log.info(f'loading data to {table}')
        placeholders = ",".join(["%s"] * columns_num)
        sql = f'{self.dml_operator} {table} ({columns}) values({placeholders})'
        with closing(self.mysql.get_conn()) as mysql_conn:
            mysql_conn.cursor().executemany(sql, values)
            mysql_conn.commit()
            self.log.info("load data done")

    def rename_adb_table(self, table_name, new_table_name):
        """ rename mysql table """
        with closing(self.mysql.get_conn()) as mysql_conn:
            self.log.info(f"Renaming {table_name} to {new_table_name}")
            mysql_conn.cursor().execute(
                f"rename table {table_name} to {new_table_name}")
            self.log.info(f"Done Renaming {table_name}")

    def create_adb_table(self, table_name, create_sql):
        """ create mysql table """
        try:
            with closing(self.mysql.get_conn()) as mysql_conn:
                self.log.info(
                    f"Creating table {table_name}, sql: {create_sql}")
                mysql_conn.cursor().execute(create_sql)
                self.log.info(f"Done creating table {table_name}")
        except _mysql_exceptions.OperationalError as e:
            self.log.debug(e)
            # the other thread has created table
            if e.args[0] != 1050:
                raise e

    def drop_adb_table(self, table_name):
        """ delete mysql table """
        with closing(self.mysql.get_conn()) as mysql_conn:
            self.log.info(f"Dropping table {table_name}")
            drop_sql = f"drop table if exists `{table_name}`;"
            mysql_conn.cursor().execute(drop_sql)
            self.log.info(f"Done dropping table {table_name}")

    def is_adb_table_exists(self, table_name):
        with closing(self.mysql.get_conn()) as mysql_conn:
            check_sql = f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'dm' AND table_name = '{table_name}'"
            result = mysql_conn.cursor().execute(check_sql)
            return bool(result)

    def generate_create_sql(self, hive_table, mysql_table):
        """ generate create table sql """
        hive_table_columns = self.get_mysql_dataset(
            mysql_conn_id="hivemeta_db",
            schema="hivemeta",
            sql=self.get_hive_sql(*hive_table.split(".")),
        )
        partition_columns = self.get_mysql_dataset(
            mysql_conn_id="hivemeta_db",
            schema="hivemeta",
            sql=self.get_hive_partition(*hive_table.split(".")),
        )
        create_sql = [f"create table `{mysql_table}` ("]
        for column in hive_table_columns + partition_columns:
            column_name, column_type, column_comment = (
                column["column_name"],
                self.convert_column_type(column["column_type"]),
                column["column_comment"],
            )
            if column_comment:
                create_sql.append(
                    f"`{column_name}` {column_type} COMMENT '{column_comment}',"
                )
            else:
                create_sql.append(f"`{column_name}` {column_type},")
        extra_sql = self.sync_task["extra_sql"]
        if extra_sql:
            if not extra_sql.startswith('primary key'):
                create_sql[-1] = create_sql[-1].rstrip(",")
            create_sql.append(extra_sql)
        else:
            create_sql[-1] = create_sql[-1].rstrip(",")
            create_sql.append(") INDEX_ALL='Y';")
        return "".join(create_sql)

    def get_target_fields(self, hive_table):
        hive_table_columns = self.get_mysql_dataset(
            mysql_conn_id="hivemeta_db",
            schema="hivemeta",
            sql=self.get_hive_sql(*hive_table.split(".")),
        )
        partition_columns = self.get_mysql_dataset(
            mysql_conn_id="hivemeta_db",
            schema="hivemeta",
            sql=self.get_hive_partition(*hive_table.split(".")),
        )
        return ", ".join([
            f'`{column["column_name"]}`'
            for column in hive_table_columns + partition_columns
        ])

    def _is_empty_table(self, adb_table_name):
        with closing(self.mysql.get_conn()) as mysql_conn:
            result = mysql_conn.cursor().execute(
                f"select count(*) from {adb_table_name}")
            self.log.info(f"{adb_table_name} rows: {result}")
            return bool(result == 0)

    @staticmethod
    def get_mysql_dataset(**kwargs):
        if ("mysql_conn_id" not in kwargs or "schema" not in kwargs
                or "sql" not in kwargs):
            raise Exception("Miss parameter mysql_conn_id or metadata or sql.")

        maxrows = 0 if "maxrows" not in kwargs else kwargs["maxrows"]
        how = 1 if "how" not in kwargs else kwargs["how"]

        mysql = MySqlHook(mysql_conn_id=kwargs["mysql_conn_id"],
                          schema=kwargs["schema"])
        conn = mysql.get_conn()
        if not conn.open:
            raise Exception("Could not open connection.")
        conn.query(kwargs["sql"])
        result = conn.store_result()
        dataset = result.fetch_row(maxrows=maxrows, how=how)
        conn.close()

        return dataset

    @staticmethod
    def get_hive_sql(catalog, table_name):
        return f"""
        select distinct tb.TBL_NAME as table_name, c.COLUMN_NAME as column_name, c.TYPE_NAME as column_type, c.INTEGER_IDX as seq, c.COMMENT as column_comment
        from hivemeta.DBS db
            inner join hivemeta.TBLS tb on db.DB_ID = tb.DB_ID
            inner join hivemeta.SDS sd on tb.SD_ID = sd.SD_ID
            inner join hivemeta.COLUMNS_V2 c on sd.CD_ID = c.CD_ID
        where db.NAME = '{catalog}'
          and tb.TBL_NAME = '{table_name}'
        order by c.INTEGER_IDX;
        """

    @staticmethod
    def get_hive_partition(catalog, table_name):
        return f"""
        select distinct tb.TBL_NAME as table_name, c.PKEY_NAME as column_name, c.PKEY_TYPE as column_type, c.INTEGER_IDX as seq, c.PKEY_COMMENT as column_comment
        from hivemeta.DBS db
            inner join hivemeta.TBLS tb on db.DB_ID = tb.DB_ID
            inner join hivemeta.PARTITION_KEYS c on tb.TBL_ID = c.TBL_ID
        where db.NAME = '{catalog}'
          and tb.TBL_NAME = '{table_name}'
        order by c.INTEGER_IDX;
        """

    @staticmethod
    def convert_column_type(column_type):
        if column_type == "string":
            return "varchar"
        elif column_type.startswith("array") or column_type.startswith("<<"):
            return "json"
        else:
            return column_type
Example #15
0
    def execute(self, context):

        started_at = datetime.utcnow()
        _keep_going = True
        while _keep_going:

            _force_run_data = self.get_force_run_data()
            _logger.info("Force run data: {}".format(_force_run_data))

            if not _force_run_data:
                if (datetime.utcnow() -
                        started_at).total_seconds() > self.timeout:
                    raise AirflowSkipException('Snap. Time is OUT.')
                sleep(self.poke_interval)
                continue

            for row in _force_run_data:
                _keep_going = False
                biowardrobe_uid = row['uid']
                #  TODO: Check if dag is running in airflow

                #  TODO: If not running!
                data = self.get_record_data(biowardrobe_uid)
                if not data:
                    _logger.error(
                        'No biowardrobe data {}'.format(biowardrobe_uid))
                    continue
                #
                #  Actual Force RUN
                basedir = data['output_folder']
                try:
                    os.chdir(basedir)

                    for root, dirs, files in os.walk(".", topdown=False):
                        for name in files:
                            if "fastq" in name:
                                continue
                            os.remove(os.path.join(root, name))
                    rmtree(os.path.join(basedir, 'tophat'), True)
                except:
                    pass

                if int(data['deleted']) == 0:
                    cmd = 'bunzip2 {}*.fastq.bz2'.format(biowardrobe_uid)
                    try:
                        check_output(cmd, shell=True)
                    except Exception as e:
                        _logger.error("Can't uncompress: {} {}".format(
                            cmd, str(e)))

                    if not os.path.isfile(biowardrobe_uid + '.fastq'):
                        _logger.error(
                            "File does not exist: {}".format(biowardrobe_uid))
                        continue
                    if not os.path.isfile(biowardrobe_uid +
                                          '_2.fastq') and data['pair']:
                        _logger.error("File 2 does not exist: {}".format(
                            biowardrobe_uid))
                        continue
                else:
                    rmtree(basedir, True)

                mysql = MySqlHook(mysql_conn_id=biowardrobe_connection_id)
                with closing(mysql.get_conn()) as conn:
                    with closing(conn.cursor()) as cursor:
                        self.drop_sql(cursor, data)
                        if int(data['deleted']) == 0:
                            cursor.execute(
                                "update labdata set libstatustxt=%s, libstatus=10, forcerun=0, tagstotal=0,"
                                "tagsmapped=0,tagsribo=0,tagsused=0,tagssuppressed=0 where uid=%s",
                                ("Ready to be reanalyzed", biowardrobe_uid))
                            conn.commit()
                        else:
                            cursor.execute(
                                "update labdata set libstatustxt=%s,deleted=2,datedel=CURDATE() where uid=%s",
                                ("Deleted", biowardrobe_uid))
                            conn.commit()
                            _logger.info("Deleted: {}".format(biowardrobe_uid))
                            continue

                _dag_id = os.path.basename(
                    os.path.splitext(data['workflow'])[0])
                _run_id = 'forcerun__{}__{}'.format(biowardrobe_uid,
                                                    uuid.uuid4())
                session = settings.Session()
                dr = DagRun(dag_id=_dag_id,
                            run_id=_run_id,
                            conf={
                                'biowardrobe_uid': biowardrobe_uid,
                                'run_id': _run_id
                            },
                            execution_date=datetime.now(),
                            start_date=datetime.now(),
                            external_trigger=True)
                logging.info("Creating DagRun {}".format(dr))
                session.add(dr)
                session.commit()
                session.close()
Example #16
0
 def get_record_data(self, biowardrobe_uid):
     mysql = MySqlHook(mysql_conn_id=biowardrobe_connection_id)
     with closing(mysql.get_conn()) as conn:
         with closing(conn.cursor()) as cursor:
             return get_biowardrobe_data(cursor=cursor,
                                         biowardrobe_uid=biowardrobe_uid)
Example #17
0
from airflow.operators.bash_operator import BashOperator
from airflow.sensors.named_hive_partition_sensor import NamedHivePartitionSensor
from airflow.sensors.hive_partition_sensor import HivePartitionSensor
from airflow.sensors import UFileSensor
from plugins.TaskTimeoutMonitor import TaskTimeoutMonitor
from plugins.TaskTouchzSuccess import TaskTouchzSuccess
from airflow.sensors import OssSensor
from airflow.hooks.mysql_hook import MySqlHook
import json
import logging
from airflow.models import Variable
import requests
import os

opos_mysql_hook = MySqlHook("mysql_dw")
opos_mysql_conn = opos_mysql_hook.get_conn()
opos_mysql_cursor = opos_mysql_conn.cursor()

args = {
    'owner': 'yuanfeng',
    'start_date': datetime(2019, 11, 24),
    'depends_on_past': False,
    'retries': 3,
    'retry_delay': timedelta(minutes=2),
    #'email': ['*****@*****.**'],
    'email_on_failure': True,
    'email_on_retry': False,
}

dag = airflow.DAG(
    'app_opos_shop_target_week_import_d',
def _update_firewall_name_mappings(*, aws_conn_id: str, mysql_conn_id: str,
                                   state_table_name: str,
                                   mappings_bucket_name: str,
                                   mappings_prefix: str,
                                   mappings_timestamp_pattern: str,
                                   firewall_table: str, msource_id: int,
                                   task: BaseOperator, **_) -> str:
    """Python callable for the `UpdateFirewallNameMappingsOperator`."""
    # pylint: disable=too-many-locals

    log = task.log

    pipeline_state = PipelineStateHook(state_table_name,
                                       aws_conn_id=aws_conn_id)
    mappings_bucket = S3Hook(aws_conn_id).get_bucket(mappings_bucket_name)

    # get latest processed mappings file timestamp
    state_key = f'name_mappings.{firewall_table}.{msource_id}.latest_processed'
    state_value = pipeline_state.get_state(state_key)
    if state_value is not None:
        latest_processed_ts = _get_latest_processed_ts(state_value)
        latest_processed_etags = _get_latest_processed_etags(state_value)
    else:
        latest_processed_ts = '0'
        latest_processed_etags = set()

    # list files in the bucket and find the newest one
    mappings_timestamp_re = re.compile(mappings_timestamp_pattern)
    latest_available_ts = '0'
    latest_available_files = dict()
    for mappings_file in mappings_bucket.objects.filter(
            Prefix=mappings_prefix):
        match = mappings_timestamp_re.search(mappings_file.key)
        if match is not None:
            available_ts = match.group(1)
            if available_ts > latest_available_ts:
                latest_available_ts = available_ts
                latest_available_files = {mappings_file.e_tag: mappings_file}
            elif latest_available_ts == available_ts:
                latest_available_files.update(
                    {mappings_file.e_tag: mappings_file})

    # If the files for the same date, we must skip processed files
    if latest_available_ts == latest_processed_ts:
        for latest_processed_etag in latest_processed_etags:
            log.info(
                "skipping the file=%s, etag=%s",
                _get_filename(
                    latest_available_files[latest_processed_etag].key),
                latest_available_files[latest_processed_etag].e_tag)
            del latest_available_files[latest_processed_etag]

    # check if no newer file
    if (not latest_available_files or latest_available_ts < latest_processed_ts
            or (latest_available_ts == latest_processed_ts
                and not latest_available_files)):
        raise AirflowSkipException("no newer mappings file found")

    # connect to the firewall database and process the mappings
    log.info("connecting to firewall database")
    mysql = MySqlHook(mysql_conn_id=mysql_conn_id)
    with closing(mysql.get_conn()) as conn:
        mysql.set_autocommit(conn, False)
        for latest_available_file in latest_available_files.values():
            log.info("processing filename=%s, etag=%s, ts=%s",
                     _get_filename(latest_available_file.key),
                     latest_available_file.e_tag, latest_available_ts)
            with closing(conn.cursor()) as cur:
                _load_and_process_mappings(cur=cur,
                                           mappings_file=latest_available_file,
                                           firewall_table=firewall_table,
                                           msource_id=msource_id,
                                           log=log)
            log.info("committing transaction")
            conn.commit()

    # update state table
    new_state_value = _create_state(state_value, latest_available_ts,
                                    latest_available_files.values())
    log.info('saving new state: key=%s, value=%s', state_key, new_state_value)
    pipeline_state.save_state(state_key, json.dumps(new_state_value))

    # done
    return "new mappings have been processed"
Example #19
0
class TestMySqlHookConn(unittest.TestCase):
    def setUp(self):
        super(TestMySqlHookConn, self).setUp()

        self.connection = models.Connection(
            login='******',
            password='******',
            host='host',
            schema='schema',
        )

        self.db_hook = MySqlHook()
        self.db_hook.get_connection = mock.Mock()
        self.db_hook.get_connection.return_value = self.connection

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn(self, mock_connect):
        self.db_hook.get_conn()
        mock_connect.assert_called_once()
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['user'], 'login')
        self.assertEqual(kwargs['passwd'], 'password')
        self.assertEqual(kwargs['host'], 'host')
        self.assertEqual(kwargs['db'], 'schema')

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_port(self, mock_connect):
        self.connection.port = 3307
        self.db_hook.get_conn()
        mock_connect.assert_called_once()
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['port'], 3307)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_charset(self, mock_connect):
        self.connection.extra = json.dumps({'charset': 'utf-8'})
        self.db_hook.get_conn()
        mock_connect.assert_called_once()
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['charset'], 'utf-8')
        self.assertEqual(kwargs['use_unicode'], True)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_cursor(self, mock_connect):
        self.connection.extra = json.dumps({'cursor': 'sscursor'})
        self.db_hook.get_conn()
        mock_connect.assert_called_once()
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['cursorclass'], MySQLdb.cursors.SSCursor)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_local_infile(self, mock_connect):
        self.connection.extra = json.dumps({'local_infile': True})
        self.db_hook.get_conn()
        mock_connect.assert_called_once()
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['local_infile'], 1)
Example #20
0
class TestMySqlHookConn(unittest.TestCase):

    def setUp(self):
        super(TestMySqlHookConn, self).setUp()

        self.connection = Connection(
            login='******',
            password='******',
            host='host',
            schema='schema',
        )

        self.db_hook = MySqlHook()
        self.db_hook.get_connection = mock.Mock()
        self.db_hook.get_connection.return_value = self.connection

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn(self, mock_connect):
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['user'], 'login')
        self.assertEqual(kwargs['passwd'], 'password')
        self.assertEqual(kwargs['host'], 'host')
        self.assertEqual(kwargs['db'], 'schema')

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_port(self, mock_connect):
        self.connection.port = 3307
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['port'], 3307)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_charset(self, mock_connect):
        self.connection.extra = json.dumps({'charset': 'utf-8'})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['charset'], 'utf-8')
        self.assertEqual(kwargs['use_unicode'], True)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_cursor(self, mock_connect):
        self.connection.extra = json.dumps({'cursor': 'sscursor'})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['cursorclass'], MySQLdb.cursors.SSCursor)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_local_infile(self, mock_connect):
        self.connection.extra = json.dumps({'local_infile': True})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['local_infile'], 1)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_con_unix_socket(self, mock_connect):
        self.connection.extra = json.dumps({'unix_socket': "/tmp/socket"})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['unix_socket'], '/tmp/socket')

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_ssl_as_dictionary(self, mock_connect):
        self.connection.extra = json.dumps({'ssl': SSL_DICT})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['ssl'], SSL_DICT)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_ssl_as_string(self, mock_connect):
        self.connection.extra = json.dumps({'ssl': json.dumps(SSL_DICT)})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['ssl'], SSL_DICT)
Example #21
0
    def test_mysql_to_hive_verify_loaded_values(self):
        mysql_table = 'test_mysql_to_hive'
        hive_table = 'test_mysql_to_hive'

        from airflow.hooks.mysql_hook import MySqlHook
        m = MySqlHook()

        try:
            minmax = (
                255,
                65535,
                16777215,
                4294967295,
                18446744073709551615,
                -128,
                -32768,
                -8388608,
                -2147483648,
                -9223372036854775808
            )

            with m.get_conn() as c:
                c.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
                c.execute("""
                    CREATE TABLE {} (
                        c0 TINYINT   UNSIGNED,
                        c1 SMALLINT  UNSIGNED,
                        c2 MEDIUMINT UNSIGNED,
                        c3 INT       UNSIGNED,
                        c4 BIGINT    UNSIGNED,
                        c5 TINYINT,
                        c6 SMALLINT,
                        c7 MEDIUMINT,
                        c8 INT,
                        c9 BIGINT
                    )
                """.format(mysql_table))
                c.execute("""
                    INSERT INTO {} VALUES (
                        {}, {}, {}, {}, {}, {}, {}, {}, {}, {}
                    )
                """.format(mysql_table, *minmax))

            from airflow.operators.mysql_to_hive import MySqlToHiveTransfer
            t = MySqlToHiveTransfer(
                task_id='test_m2h',
                hive_cli_conn_id='beeline_default',
                sql="SELECT * FROM {}".format(mysql_table),
                hive_table=hive_table,
                recreate=True,
                delimiter=",",
                dag=self.dag)
            t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

            from airflow.hooks.hive_hooks import HiveServer2Hook
            h = HiveServer2Hook()
            r = h.get_records("SELECT * FROM {}".format(hive_table))
            self.assertEqual(r[0], minmax)
        finally:
            with m.get_conn() as c:
                c.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
    def execute(self, context):
        vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id)
        mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)

        tmpfile = None
        result = None

        selected_columns = []

        count = 0
        with closing(vertica.get_conn()) as conn:
            with closing(conn.cursor()) as cursor:
                cursor.execute(self.sql)
                selected_columns = [d.name for d in cursor.description]

                if self.bulk_load:
                    tmpfile = NamedTemporaryFile("w")

                    self.log.info(
                        "Selecting rows from Vertica to local file %s...",
                        tmpfile.name)
                    self.log.info(self.sql)

                    csv_writer = csv.writer(tmpfile, delimiter='\t', encoding='utf-8')
                    for row in cursor.iterate():
                        csv_writer.writerow(row)
                        count += 1

                    tmpfile.flush()
                else:
                    self.log.info("Selecting rows from Vertica...")
                    self.log.info(self.sql)

                    result = cursor.fetchall()
                    count = len(result)

                self.log.info("Selected rows from Vertica %s", count)

        if self.mysql_preoperator:
            self.log.info("Running MySQL preoperator...")
            mysql.run(self.mysql_preoperator)

        try:
            if self.bulk_load:
                self.log.info("Bulk inserting rows into MySQL...")
                with closing(mysql.get_conn()) as conn:
                    with closing(conn.cursor()) as cursor:
                        cursor.execute("LOAD DATA LOCAL INFILE '%s' INTO "
                                       "TABLE %s LINES TERMINATED BY '\r\n' (%s)" %
                                       (tmpfile.name,
                                        self.mysql_table,
                                        ", ".join(selected_columns)))
                        conn.commit()
                tmpfile.close()
            else:
                self.log.info("Inserting rows into MySQL...")
                mysql.insert_rows(table=self.mysql_table,
                                  rows=result,
                                  target_fields=selected_columns)
            self.log.info("Inserted rows into MySQL %s", count)
        except (MySQLdb.Error, MySQLdb.Warning):
            self.log.info("Inserted rows into MySQL 0")
            raise

        if self.mysql_postoperator:
            self.log.info("Running MySQL postoperator...")
            mysql.run(self.mysql_postoperator)

        self.log.info("Done")
Example #23
0
class MySQLToCSVOperator(BaseOperator):

    @apply_defaults
    def __init__(self,
                 mysql_conn_id,
                 redshift_connection_id,
                 aws_conn_id ,
                 s3_bucket,
                 approx_max_file_size_bytes = 19000000,
                 *args, **kwargs):

        self.mysql_conn_id = mysql_conn_id;
        self.approx_max_file_size_bytes = approx_max_file_size_bytes
        self.aws_conn_id = aws_conn_id;
        self.s3_bucket_name = s3_bucket;
        self.redshift_connection_id = redshift_connection_id;

        super(MySQLToCSVOperator, self).__init__(*args, **kwargs)

    def execute(self, context):


        self.mysql_hook = MySqlHook(self.mysql_conn_id);  #DATA SOURCE =>

        self.mysql_metadata_hook = MySqlHook("airflow_connection");


        self.table_name = self.fetch_table_name();

        if self.table_name is None:
            return True;

        self.query_process(self.table_name);
        return True;

    def query_process(self , data):

        table_data = data[0];
        table_name = table_data[2];
        cursor = self._query_mysql(hook = self.mysql_hook, mysql_query="select * from " + table_name + " limit 1" );
        files_to_upload = self._write_local_data_files(cursor , table_name);

        fetch_columns_redshift , count_field_redshift = self._fetch_field_names_table(table_name);

        # Creating table in Staging
        redshift_statement_staging, redshift_statement_main = self._create_table_statement(cursor , table_name);
        self._redshift_staging_table(redshift_sql=redshift_statement_main + ";" + redshift_statement_staging);

        #Alter table Statement
        if count_field_redshift > 0 and  count_field_redshift - 1 != len(cursor.description):
            redshift_field = frozenset([field[0] for field in fetch_columns_redshift]);
            mysql_fields = frozenset(self._write_local_file_schema(cursor));
            difference = (mysql_fields.difference(redshift_field));
            alter_fields = [];
            for elem in cursor.description:
                if elem[0] in difference:
                    alter_fields.append(elem);
            redshift_alter_table_commands = self._alter_table(table_name, alter_fields);
            self._redshift_staging_table(";".join(redshift_alter_table_commands));   #CREATE REDSHIFT STAGING TABLE

        self._upload_files(files_to_upload);
        cursor.close();

        return True;


    def _alter_table(self, table_name , alter_fields):

        redshift_sql_main = "alter table " +  "" + table_name + " add column ";
        redshift_sql_staging = "alter table " + " " + table_name + "_staging" + " add column ";
        redshift_alter_fields = [];
        for mysql_fields in alter_fields:
            redshift_alter_fields.append((redshift_sql_main + mysql_fields[0]) + " " + self.type_map(mysql_fields[1]));
            redshift_alter_fields.append((redshift_sql_staging + mysql_fields[0]) + " " + self.type_map(mysql_fields[1]));
        return redshift_alter_fields ;


    def fetch_table_name(self):

        table_name_query = "select * from mysql_metadata where is_active = 1 and db_name = '" + self.mysql_conn_id + "' limit 1";
        cursor = self._query_mysql(hook = self.mysql_metadata_hook, mysql_query=table_name_query );
        result = cursor.fetchall();
        cursor.close()

        if len(list(result)) == 0:
            return None;

        logging.info("The table name under process is {}".format(result));
        update_data = 'update mysql_metadata set is_active = 0 where id in (';
        update_data += ",".join("%s" for _ in result);
        update_data += ")";
        results = [row[0] for row in result];
        log.info("The Update statement fired {}" + update_data);
        self._query_mysql(hook = self.mysql_metadata_hook,mysql_query=update_data, value=results);
        cursor.close();
        return result;



    def _query_mysql(self, hook, mysql_query, value=None):
        """
               Queries mysql and returns a cursor to the results.
        """
        connection = hook.get_conn();
        cursor = connection.cursor();
        if value is None:
            cursor.execute(mysql_query);
        else:
            cursor.execute(mysql_query, value);
        logging.info("cursor Executed statement : {}".format(mysql_query));
        connection.commit();
        return cursor;


    def _create_csv_file(self, mysql_query, csv):

        df = pd.read_sql_query(mysql_query, self.mysql_hook.get_conn());
        df = df.replace('\n', '', regex=True)
        total_rows = len(df);
        event_row_data = ['writerows'] * total_rows;
        data = pd.Series(event_row_data);
        df.insert(loc=0, column='event', value=data);

        for column in df.select_dtypes([np.object]).columns[1:]:
            df[column] = df[column].str.replace(r"[\"\',]", '')

        df.to_csv(path_or_buf=csv.name, sep = '|', index=False);
        return csv


    def _write_local_data_files(self, cursor, table_name):
        """
            Cursor.description provides the metadata about the cursor
            Takes a cursor and iterates over it .
            Prints the cursor dat
        """

        primary_key = cursor.description[cursor.lastrowid][0]
        file_key = "m&{0}&{1}.csv".format(table_name, primary_key)
        file_handle = codecs.open(file_key, 'w' , errors='ignore', encoding='utf-8')
        file_handle = self._create_csv_file(mysql_query="select * from " + table_name , csv = file_handle)
        return {file_key : file_handle}

    def _write_local_file_schema(self, cursor):
        """
            Takes a cursor and Writes the schema to a temporary file ...
            A view towards Redshift
        """
        fields = [];
        for field in cursor.description:
            field_name = field[0];
            fields.append(field_name);
        return fields;

    def _upload_files(self, files):

        dest_s3 = S3Hook(aws_conn_id=self.aws_conn_id)

        if dest_s3 is None:
            raise AirflowException("Unable to connect to the S3 Bucket");

        if not dest_s3.check_for_bucket(self.s3_bucket_name):
            raise AirflowException("Could not find the bucket {0}".format(self.s3_bucket_name));

        for key, file_handle in files.items():
            if os.path.exists(file_handle.name):
                dest_s3.load_file(bucket_name=self.s3_bucket_name, filename=file_handle.name, key = key , replace=True);
            else:
                raise AirflowException("File Not Found");

        for key , file_handle in files.items():
            if os.path.exists(file_handle.name):
                os.remove(file_handle.name);

        return True;


    def _create_table_statement(self, cursor, table_name):

        redshift_statement = "CREATE TABLE IF NOT EXISTS " ;

        redshift_fields = ["event varchar(20)"];
        primary_key = (cursor.description[cursor.lastrowid]);

        #TODO : $ field in Mongo Tables

        diststyle_key_validate = False
        diststyle_keys = "";

        sort_key = {};
        sort_key[primary_key[0]] = 1;

        for mysql_fields in cursor.description:
            if mysql_fields[0] == 'company_id' or mysql_fields[0] == 'vendor_id':
                diststyle_key_validate = True;
                diststyle_keys = mysql_fields[0];
                sort_key[mysql_fields[0]] = 1;

            redshift_fields.append((mysql_fields[0]) + " " + self.type_map(mysql_fields[1]));

        redshift_fields.append("PRIMARY KEY({0})".format(primary_key[0]));
        redshift_fields = ",".join(redshift_fields);
        redshift_statement_staging = redshift_statement + "" + table_name + "_staging (" + redshift_fields + ")";
        redshift_statement_main = redshift_statement + "" + table_name + "(" + redshift_fields + ")"

        redshift_statement_staging += " DISTSTYLE ALL " if not diststyle_key_validate \
            else "DISTKEY({0})".format(diststyle_keys);
        redshift_statement_main += " DISTSTYLE ALL " if not diststyle_key_validate \
            else "DISTKEY({0})".format(diststyle_keys);

        redshift_statement_main += "SORTKEY ({0})".format(",".join(sort_key.keys()));
        redshift_statement_staging += "SORTKEY ({0})".format(",".join(sort_key.keys()));

        logging.info(redshift_statement_main);

        return redshift_statement_staging ,redshift_statement_main;

    def _redshift_staging_table(self, redshift_sql):

        hook = PostgresHook(self.redshift_connection_id);
        connection = hook.get_conn();
        cursor = connection.cursor();
        redshift_sql  = "begin transaction; " + redshift_sql;
        redshift_sql += ";end transaction;";
        cursor.execute(redshift_sql);
        cursor.close();
        return True;


    def _fetch_field_names_table(self, table_name):

        table_columns = "select \"column\" ,type from pg_table_def where tablename = %s"  ;
        hook = PostgresHook(self.redshift_connection_id);
        connection = hook.get_conn();
        cursor = connection.cursor();
        redshift_sql = table_columns;
        cursor.execute(redshift_sql, [table_name]);
        columns = cursor.fetchall();
        cursor.close()
        return columns , len(columns);


    @classmethod
    def type_map(cls, mysql_type):
            """
            Helper function that maps from MySQL fields to Redshift fields. Used
            when a schema_filename is set.
            """
            d = {
                FIELD_TYPE.INT24: 'DECIMAL',
                FIELD_TYPE.TINY: 'DECIMAL',
                FIELD_TYPE.BIT: 'DECIMAL',
                FIELD_TYPE.DATETIME: 'TIMESTAMP',
                FIELD_TYPE.DATE: 'DATE',
                FIELD_TYPE.DECIMAL: 'FLOAT',
                FIELD_TYPE.NEWDECIMAL: 'FLOAT',
                FIELD_TYPE.DOUBLE: 'VARCHAR(200)',
                FIELD_TYPE.FLOAT: 'FLOAT',
                FIELD_TYPE.LONG: 'DECIMAL',
                FIELD_TYPE.LONGLONG: 'DECIMAL',
                FIELD_TYPE.SHORT: 'INTEGER',
                FIELD_TYPE.TIMESTAMP: 'TIMESTAMP',
                FIELD_TYPE.YEAR: 'INTEGER',
            }
            return d[mysql_type] if mysql_type in d else 'VARCHAR(20000)';