コード例 #1
0
    def execute(self, context):
        hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)

        logging.info("Dumping MySQL query results to local file")
        conn = mysql.get_conn()
        cursor = conn.cursor(MySQLdb.cursors.SSCursor)
        cursor.execute(self.sql)
        with NamedTemporaryFile("wb") as f:
            csv_writer = csv.writer(f,
                                    delimiter=self.delimiter,
                                    encoding="utf-8")
            field_dict = OrderedDict()
            for field in cursor.description:
                field_dict[field[0]] = self.type_map(field[1])
            # csv_writer.writerows(cursor)
            while True:
                row = cursor.fetchone()
                if not row:
                    break
                csv_writer.writerow(row)
            f.flush()
            cursor.close()
            conn.close()
            logging.info("Loading file into Hive")
            hive.load_file(f.name,
                           self.hive_table,
                           field_dict=field_dict,
                           create=self.create,
                           partition=self.partition,
                           delimiter=self.delimiter,
                           recreate=self.recreate)
コード例 #2
0
    def execute(self, context):
        mssql = MsSqlHook(mssql_conn_id=self.mssql_conn_id)
        self.log.info(
            "Dumping Microsoft SQL Server query results to local file")
        with mssql.get_conn() as conn:
            with conn.cursor() as cursor:
                cursor.execute(self.sql)
                with NamedTemporaryFile("w") as tmp_file:
                    csv_writer = csv.writer(tmp_file,
                                            delimiter=self.delimiter,
                                            encoding='utf-8')
                    field_dict = OrderedDict()
                    col_count = 0
                    for field in cursor.description:
                        col_count += 1
                        col_position = "Column{position}".format(
                            position=col_count)
                        field_dict[col_position if field[0] ==
                                   '' else field[0]] = self.type_map(field[1])
                    csv_writer.writerows(cursor)
                    tmp_file.flush()

            hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
            self.log.info("Loading file into Hive")
            hive.load_file(tmp_file.name,
                           self.hive_table,
                           field_dict=field_dict,
                           create=self.create,
                           partition=self.partition,
                           delimiter=self.delimiter,
                           recreate=self.recreate,
                           tblproperties=self.tblproperties)
コード例 #3
0
    def execute(self, context):
        hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        mssql = MsSqlHook(mssql_conn_id=self.mssql_conn_id)

        logging.info("Dumping Microsoft SQL Server query results to local file")
        conn = mssql.get_conn()
        cursor = conn.cursor()
        cursor.execute(self.sql)
        with NamedTemporaryFile("w") as f:
            csv_writer = csv.writer(f, delimiter=self.delimiter, encoding='utf-8')
            field_dict = OrderedDict()
            col_count = 0
            for field in cursor.description:
                col_count += 1
                col_position = "Column{position}".format(position=col_count)
                field_dict[col_position if field[0] == '' else field[0]] = self.type_map(field[1])
            csv_writer.writerows(cursor)
            f.flush()
            cursor.close()
            conn.close()
            logging.info("Loading file into Hive")
            hive.load_file(
                f.name,
                self.hive_table,
                field_dict=field_dict,
                create=self.create,
                partition=self.partition,
                delimiter=self.delimiter,
                recreate=self.recreate)
コード例 #4
0
ファイル: vertica_to_hive.py プロジェクト: zyh1690/airflow
    def execute(self, context):
        hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id)

        self.log.info("Dumping Vertica query results to local file")
        conn = vertica.get_conn()
        cursor = conn.cursor()
        cursor.execute(self.sql)
        with NamedTemporaryFile("w") as f:
            csv_writer = csv.writer(f,
                                    delimiter=self.delimiter,
                                    encoding='utf-8')
            field_dict = OrderedDict()
            col_count = 0
            for field in cursor.description:
                col_count += 1
                col_position = "Column{position}".format(position=col_count)
                field_dict[col_position if field[0] == '' else field[0]] = \
                    self.type_map(field[1])
            csv_writer.writerows(cursor.iterate())
            f.flush()
            cursor.close()
            conn.close()
            self.log.info("Loading file into Hive")
            hive.load_file(f.name,
                           self.hive_table,
                           field_dict=field_dict,
                           create=self.create,
                           partition=self.partition,
                           delimiter=self.delimiter,
                           recreate=self.recreate)
コード例 #5
0
    def execute(self, context):
        hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)

        self.log.info("Dumping MySQL query results to local file")
        conn = mysql.get_conn()
        cursor = conn.cursor()
        cursor.execute(self.sql)
        with NamedTemporaryFile("wb") as f:
            csv_writer = csv.writer(f, delimiter=self.delimiter, encoding="utf-8")
            field_dict = OrderedDict()
            for field in cursor.description:
                field_dict[field[0]] = self.type_map(field[1])
            csv_writer.writerows(cursor)
            f.flush()
            cursor.close()
            conn.close()
            self.log.info("Loading file into Hive")
            hive.load_file(
                f.name,
                self.hive_table,
                field_dict=field_dict,
                create=self.create,
                partition=self.partition,
                delimiter=self.delimiter,
                recreate=self.recreate,
                tblproperties=self.tblproperties)
コード例 #6
0
    def test_load_file_create_table(self, mock_run_cli):
        filepath = "/path/to/input/file"
        table = "output_table"
        field_dict = OrderedDict([("name", "string"), ("gender", "string")])
        fields = ",\n    ".join([k + ' ' + v for k, v in field_dict.items()])

        hook = HiveCliHook()
        hook.load_file(filepath=filepath,
                       table=table,
                       field_dict=field_dict,
                       create=True,
                       recreate=True)

        create_table = ("DROP TABLE IF EXISTS {table};\n"
                        "CREATE TABLE IF NOT EXISTS {table} (\n{fields})\n"
                        "ROW FORMAT DELIMITED\n"
                        "FIELDS TERMINATED BY ','\n"
                        "STORED AS textfile\n;".format(table=table,
                                                       fields=fields))

        load_data = ("LOAD DATA LOCAL INPATH '{filepath}' "
                     "OVERWRITE INTO TABLE {table} ;\n".format(
                         filepath=filepath, table=table))
        calls = [mock.call(create_table), mock.call(load_data)]
        mock_run_cli.assert_has_calls(calls, any_order=True)
コード例 #7
0
    def execute(self, context):
        hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)

        self.log.info("Dumping MySQL query results to local file")
        conn = mysql.get_conn()
        cursor = conn.cursor()
        cursor.execute(self.sql)
        with NamedTemporaryFile("wb") as f:
            csv_writer = csv.writer(f,
                                    delimiter=self.delimiter,
                                    quoting=self.quoting,
                                    quotechar=self.quotechar,
                                    escapechar=self.escapechar,
                                    encoding="utf-8")
            field_dict = OrderedDict()
            for field in cursor.description:
                field_dict[field[0]] = self.type_map(field[1])
            csv_writer.writerows(cursor)
            f.flush()
            cursor.close()
            conn.close()
            self.log.info("Loading file into Hive")
            hive.load_file(f.name,
                           self.hive_table,
                           field_dict=field_dict,
                           create=self.create,
                           partition=self.partition,
                           delimiter=self.delimiter,
                           recreate=self.recreate,
                           tblproperties=self.tblproperties)
コード例 #8
0
    def test_load_file(self, mock_run_cli):
        filepath = "/path/to/input/file"
        table = "output_table"

        hook = HiveCliHook()
        hook.load_file(filepath=filepath, table=table, create=False)

        query = ("LOAD DATA LOCAL INPATH '{filepath}' "
                 "OVERWRITE INTO TABLE {table} ;\n".format(filepath=filepath,
                                                           table=table))
        mock_run_cli.assert_called_with(query)
コード例 #9
0
    def test_load_file(self, mock_run_cli):
        filepath = "/path/to/input/file"
        table = "output_table"

        hook = HiveCliHook()
        hook.load_file(filepath=filepath, table=table, create=False)

        query = ("LOAD DATA LOCAL INPATH '{filepath}' "
                 "OVERWRITE INTO TABLE {table} ;\n".format(filepath=filepath,
                                                           table=table))
        calls = [mock.call(';'), mock.call(query)]
        mock_run_cli.assert_has_calls(calls, any_order=True)
コード例 #10
0
    def test_load_file(self, mock_run_cli):
        filepath = "/path/to/input/file"
        table = "output_table"

        hook = HiveCliHook()
        hook.load_file(filepath=filepath, table=table, create=False)

        query = (
            "LOAD DATA LOCAL INPATH '{filepath}' "
            "OVERWRITE INTO TABLE {table} \n"
            .format(filepath=filepath, table=table)
        )
        mock_run_cli.assert_called_with(query)
 def run(self, *args, **kwargs):
     aws_access_key_id = Variable.get("aws_access_key_id")
     aws_secret_access_key = Variable.get("aws_secret_access_key")
     aws_bucket_name = Variable.get("aws_bucket_name")
     hive_cli_conn_id = Variable.get("hive_cli_conn_id")
     field_dict = Variable.get("field_dict")
     self.storage = S3StorageDriver(aws_access_key_id,
                                    aws_secret_access_key, aws_bucket_name)
     url = self.get_input_filename("file")
     parse_url = urlparse(url)
     cli_hook = HiveCliHook(hive_cli_conn_id=hive_cli_conn_id)
     path = '{}{}{}{}{}{}{}{}'.format(parse_url.scheme, "n://",
                                      aws_access_key_id, ":",
                                      aws_secret_access_key, "@",
                                      str(parse_url.netloc), parse_url.path)
     bucketname = '{}{}'.format("s3_", str(parse_url.path).split("/").pop())
     cli_hook.load_file(field_dict=ast.literal_eval(field_dict),
                        table=bucketname,
                        filepath=path)
コード例 #12
0
    def execute(self, context):
        # Downloading file from S3
        s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
        hive_hook = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        self.log.info("Downloading S3 file")

        if self.wildcard_match:
            if not s3_hook.check_for_wildcard_key(self.s3_key):
                raise AirflowException(f"No key matches {self.s3_key}")
            s3_key_object = s3_hook.get_wildcard_key(self.s3_key)
        else:
            if not s3_hook.check_for_key(self.s3_key):
                raise AirflowException(f"The key {self.s3_key} does not exists")
            s3_key_object = s3_hook.get_key(self.s3_key)

        _, file_ext = os.path.splitext(s3_key_object.key)
        if (self.select_expression and self.input_compressed and
                file_ext.lower() != '.gz'):
            raise AirflowException("GZIP is the only compression " +
                                   "format Amazon S3 Select supports")

        with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\
                NamedTemporaryFile(mode="wb",
                                   dir=tmp_dir,
                                   suffix=file_ext) as f:
            self.log.info(
                "Dumping S3 key %s contents to local file %s", s3_key_object.key, f.name
            )
            if self.select_expression:
                option = {}
                if self.headers:
                    option['FileHeaderInfo'] = 'USE'
                if self.delimiter:
                    option['FieldDelimiter'] = self.delimiter

                input_serialization = {'CSV': option}
                if self.input_compressed:
                    input_serialization['CompressionType'] = 'GZIP'

                content = s3_hook.select_key(
                    bucket_name=s3_key_object.bucket_name,
                    key=s3_key_object.key,
                    expression=self.select_expression,
                    input_serialization=input_serialization
                )
                f.write(content.encode("utf-8"))
            else:
                s3_key_object.download_fileobj(f)
            f.flush()

            if self.select_expression or not self.headers:
                self.log.info("Loading file %s into Hive", f.name)
                hive_hook.load_file(
                    f.name,
                    self.hive_table,
                    field_dict=self.field_dict,
                    create=self.create,
                    partition=self.partition,
                    delimiter=self.delimiter,
                    recreate=self.recreate,
                    tblproperties=self.tblproperties)
            else:
                # Decompressing file
                if self.input_compressed:
                    self.log.info("Uncompressing file %s", f.name)
                    fn_uncompressed = uncompress_file(f.name,
                                                      file_ext,
                                                      tmp_dir)
                    self.log.info("Uncompressed to %s", fn_uncompressed)
                    # uncompressed file available now so deleting
                    # compressed file to save disk space
                    f.close()
                else:
                    fn_uncompressed = f.name

                # Testing if header matches field_dict
                if self.check_headers:
                    self.log.info("Matching file header against field_dict")
                    header_list = self._get_top_row_as_list(fn_uncompressed)
                    if not self._match_headers(header_list):
                        raise AirflowException("Header check failed")

                # Deleting top header row
                self.log.info("Removing header from file %s", fn_uncompressed)
                headless_file = (
                    self._delete_top_row_and_compress(fn_uncompressed,
                                                      file_ext,
                                                      tmp_dir))
                self.log.info("Headless file %s", headless_file)
                self.log.info("Loading file %s into Hive", headless_file)
                hive_hook.load_file(headless_file,
                                    self.hive_table,
                                    field_dict=self.field_dict,
                                    create=self.create,
                                    partition=self.partition,
                                    delimiter=self.delimiter,
                                    recreate=self.recreate,
                                    tblproperties=self.tblproperties)
コード例 #13
0
class S3ToHiveTransfer(BaseOperator):
    """
    Moves data from S3 to Hive. The operator downloads a file from S3,
    stores the file locally before loading it into a Hive table.
    If the ``create`` or ``recreate`` arguments are set to ``True``,
    a ``CREATE TABLE`` and ``DROP TABLE`` statements are generated.
    Hive data types are inferred from the cursor's metadata from.

    Note that the table generated in Hive uses ``STORED AS textfile``
    which isn't the most efficient serialization format. If a
    large amount of data is loaded and/or if the tables gets
    queried considerably, you may want to use this operator only to
    stage the data into a temporary table before loading it into its
    final destination using a ``HiveOperator``.

    :param s3_key: The key to be retrieved from S3
    :type s3_key: str
    :param field_dict: A dictionary of the fields name in the file
        as keys and their Hive types as values
    :type field_dict: dict
    :param hive_table: target Hive table, use dot notation to target a
        specific database
    :type hive_table: str
    :param create: whether to create the table if it doesn't exist
    :type create: bool
    :param recreate: whether to drop and recreate the table at every
        execution
    :type recreate: bool
    :param partition: target partition as a dict of partition columns
        and values
    :type partition: dict
    :param headers: whether the file contains column names on the first
        line
    :type headers: bool
    :param check_headers: whether the column names on the first line should be
        checked against the keys of field_dict
    :type check_headers: bool
    :param wildcard_match: whether the s3_key should be interpreted as a Unix
        wildcard pattern
    :type wildcard_match: bool
    :param delimiter: field delimiter in the file
    :type delimiter: str
    :param s3_conn_id: source s3 connection
    :type s3_conn_id: str
    :param hive_conn_id: destination hive connection
    :type hive_conn_id: str
    """

    template_fields = ('s3_key', 'partition', 'hive_table')
    template_ext = ()
    ui_color = '#a0e08c'

    @apply_defaults
    def __init__(
            self,
            s3_key,
            field_dict,
            hive_table,
            delimiter=',',
            create=True,
            recreate=False,
            partition=None,
            headers=False,
            check_headers=False,
            wildcard_match=False,
            s3_conn_id='s3_default',
            hive_cli_conn_id='hive_cli_default',
            *args, **kwargs):
        super(S3ToHiveTransfer, self).__init__(*args, **kwargs)
        self.s3_key = s3_key
        self.field_dict = field_dict
        self.hive_table = hive_table
        self.delimiter = delimiter
        self.create = create
        self.recreate = recreate
        self.partition = partition
        self.headers = headers
        self.check_headers = check_headers
        self.wildcard_match = wildcard_match
        self.hive_cli_conn_id = hive_cli_conn_id
        self.s3_conn_id = s3_conn_id

    def execute(self, context):
        self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        self.s3 = S3Hook(s3_conn_id=self.s3_conn_id)
        logging.info("Downloading S3 file")
        if self.wildcard_match:
            if not self.s3.check_for_wildcard_key(self.s3_key):
                raise AirflowException("No key matches {0}".format(self.s3_key))
            s3_key_object = self.s3.get_wildcard_key(self.s3_key)
        else:
            if not self.s3.check_for_key(self.s3_key):
                raise AirflowException(
                    "The key {0} does not exists".format(self.s3_key))
            s3_key_object = self.s3.get_key(self.s3_key)
        with NamedTemporaryFile("w") as f:
            logging.info("Dumping S3 key {0} contents to local"
                         " file {1}".format(s3_key_object.key, f.name))
            s3_key_object.get_contents_to_file(f)
            f.flush()
            self.s3.connection.close()
            if not self.headers:
                logging.info("Loading file into Hive")
                self.hive.load_file(
                    f.name,
                    self.hive_table,
                    field_dict=self.field_dict,
                    create=self.create,
                    partition=self.partition,
                    delimiter=self.delimiter,
                    recreate=self.recreate)
            else:
                with open(f.name, 'r') as tmpf:
                    if self.check_headers:
                        header_l = tmpf.readline()
                        header_line = header_l.rstrip()
                        header_list = header_line.split(self.delimiter)
                        field_names = list(self.field_dict.keys())
                        test_field_match = [h1.lower() == h2.lower() for h1, h2
                                            in zip(header_list, field_names)]
                        if not all(test_field_match):
                            logging.warning("Headers do not match field names"
                                            "File headers:\n {header_list}\n"
                                            "Field names: \n {field_names}\n"
                                            "".format(**locals()))
                            raise AirflowException("Headers do not match the "
                                            "field_dict keys")
                    with NamedTemporaryFile("w") as f_no_headers:
                        tmpf.seek(0)
                        next(tmpf)
                        for line in tmpf:
                            f_no_headers.write(line)
                        f_no_headers.flush()
                        logging.info("Loading file without headers into Hive")
                        self.hive.load_file(
                            f_no_headers.name,
                            self.hive_table,
                            field_dict=self.field_dict,
                            create=self.create,
                            partition=self.partition,
                            delimiter=self.delimiter,
                            recreate=self.recreate)
コード例 #14
0
class S3ToHiveTransfer(BaseOperator):
    """
    Moves data from S3 to Hive. The operator downloads a file from S3,
    stores the file locally before loading it into a Hive table.
    If the ``create`` or ``recreate`` arguments are set to ``True``,
    a ``CREATE TABLE`` and ``DROP TABLE`` statements are generated.
    Hive data types are inferred from the cursor's metadata from.

    Note that the table generated in Hive uses ``STORED AS textfile``
    which isn't the most efficient serialization format. If a
    large amount of data is loaded and/or if the tables gets
    queried considerably, you may want to use this operator only to
    stage the data into a temporary table before loading it into its
    final destination using a ``HiveOperator``.

    :param s3_key: The key to be retrieved from S3
    :type s3_key: str
    :param field_dict: A dictionary of the fields name in the file
        as keys and their Hive types as values
    :type field_dict: dict
    :param hive_table: target Hive table, use dot notation to target a
        specific database
    :type hive_table: str
    :param create: whether to create the table if it doesn't exist
    :type create: bool
    :param recreate: whether to drop and recreate the table at every
        execution
    :type recreate: bool
    :param partition: target partition as a dict of partition columns
        and values
    :type partition: dict
    :param headers: whether the file contains column names on the first
        line
    :type headers: bool
    :param check_headers: whether the column names on the first line should be
        checked against the keys of field_dict
    :type check_headers: bool
    :param wildcard_match: whether the s3_key should be interpreted as a Unix
        wildcard pattern
    :type wildcard_match: bool
    :param delimiter: field delimiter in the file
    :type delimiter: str
    :param s3_conn_id: source s3 connection
    :type s3_conn_id: str
    :param hive_cli_conn_id: destination hive connection
    :type hive_cli_conn_id: str
    :param input_compressed: Boolean to determine if file decompression is
        required to process headers
    :type input_compressed: bool
    :param tblproperties: TBLPROPERTIES of the hive table being created
    :type tblproperties: dict
    """

    template_fields = ('s3_key', 'partition', 'hive_table')
    template_ext = ()
    ui_color = '#a0e08c'

    @apply_defaults
    def __init__(self,
                 s3_key,
                 field_dict,
                 hive_table,
                 delimiter=',',
                 create=True,
                 recreate=False,
                 partition=None,
                 headers=False,
                 check_headers=False,
                 wildcard_match=False,
                 s3_conn_id='s3_default',
                 hive_cli_conn_id='hive_cli_default',
                 input_compressed=False,
                 tblproperties=None,
                 *args,
                 **kwargs):
        super(S3ToHiveTransfer, self).__init__(*args, **kwargs)
        self.s3_key = s3_key
        self.field_dict = field_dict
        self.hive_table = hive_table
        self.delimiter = delimiter
        self.create = create
        self.recreate = recreate
        self.partition = partition
        self.headers = headers
        self.check_headers = check_headers
        self.wildcard_match = wildcard_match
        self.hive_cli_conn_id = hive_cli_conn_id
        self.s3_conn_id = s3_conn_id
        self.input_compressed = input_compressed
        self.tblproperties = tblproperties

        if (self.check_headers
                and not (self.field_dict is not None and self.headers)):
            raise AirflowException("To check_headers provide " +
                                   "field_dict and headers")

    def execute(self, context):
        # Downloading file from S3
        self.s3 = S3Hook(s3_conn_id=self.s3_conn_id)
        self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        logging.info("Downloading S3 file")

        if self.wildcard_match:
            if not self.s3.check_for_wildcard_key(self.s3_key):
                raise AirflowException("No key matches {0}".format(
                    self.s3_key))
            s3_key_object = self.s3.get_wildcard_key(self.s3_key)
        else:
            if not self.s3.check_for_key(self.s3_key):
                raise AirflowException("The key {0} does not exists".format(
                    self.s3_key))
            s3_key_object = self.s3.get_key(self.s3_key)
        root, file_ext = os.path.splitext(s3_key_object.key)
        with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\
                NamedTemporaryFile(mode="w",
                                   dir=tmp_dir,
                                   suffix=file_ext) as f:
            logging.info("Dumping S3 key {0} contents to local"
                         " file {1}".format(s3_key_object.key, f.name))
            s3_key_object.get_contents_to_file(f)
            f.flush()
            self.s3.connection.close()
            if not self.headers:
                logging.info("Loading file {0} into Hive".format(f.name))
                self.hive.load_file(f.name,
                                    self.hive_table,
                                    field_dict=self.field_dict,
                                    create=self.create,
                                    partition=self.partition,
                                    delimiter=self.delimiter,
                                    recreate=self.recreate,
                                    tblproperties=self.tblproperties)
            else:
                # Decompressing file
                if self.input_compressed:
                    logging.info("Uncompressing file {0}".format(f.name))
                    fn_uncompressed = uncompress_file(f.name, file_ext,
                                                      tmp_dir)
                    logging.info("Uncompressed to {0}".format(fn_uncompressed))
                    # uncompressed file available now so deleting
                    # compressed file to save disk space
                    f.close()
                else:
                    fn_uncompressed = f.name

                # Testing if header matches field_dict
                if self.check_headers:
                    logging.info("Matching file header against field_dict")
                    header_list = self._get_top_row_as_list(fn_uncompressed)
                    if not self._match_headers(header_list):
                        raise AirflowException("Header check failed")

                # Deleting top header row
                logging.info(
                    "Removing header from file {0}".format(fn_uncompressed))
                headless_file = (self._delete_top_row_and_compress(
                    fn_uncompressed, file_ext, tmp_dir))
                logging.info("Headless file {0}".format(headless_file))
                logging.info(
                    "Loading file {0} into Hive".format(headless_file))
                self.hive.load_file(headless_file,
                                    self.hive_table,
                                    field_dict=self.field_dict,
                                    create=self.create,
                                    partition=self.partition,
                                    delimiter=self.delimiter,
                                    recreate=self.recreate,
                                    tblproperties=self.tblproperties)

    def _get_top_row_as_list(self, file_name):
        with open(file_name, 'rt') as f:
            header_line = f.readline().strip()
            header_list = header_line.split(self.delimiter)
            return header_list

    def _match_headers(self, header_list):
        if not header_list:
            raise AirflowException("Unable to retrieve header row from file")
        field_names = self.field_dict.keys()
        if len(field_names) != len(header_list):
            logging.warning("Headers count mismatch"
                            "File headers:\n {header_list}\n"
                            "Field names: \n {field_names}\n"
                            "".format(**locals()))
            return False
        test_field_match = [
            h1.lower() == h2.lower()
            for h1, h2 in zip(header_list, field_names)
        ]
        if not all(test_field_match):
            logging.warning("Headers do not match field names"
                            "File headers:\n {header_list}\n"
                            "Field names: \n {field_names}\n"
                            "".format(**locals()))
            return False
        else:
            return True

    def _delete_top_row_and_compress(self, input_file_name, output_file_ext,
                                     dest_dir):
        # When output_file_ext is not defined, file is not compressed
        open_fn = open
        if output_file_ext.lower() == '.gz':
            open_fn = gzip.GzipFile
        elif output_file_ext.lower() == '.bz2':
            open_fn = bz2.BZ2File

        os_fh_output, fn_output = \
            tempfile.mkstemp(suffix=output_file_ext, dir=dest_dir)
        with open(input_file_name, 'rb') as f_in,\
                open_fn(fn_output, 'wb') as f_out:
            f_in.seek(0)
            next(f_in)
            for line in f_in:
                f_out.write(line)
        return fn_output
コード例 #15
0
class S3ToHiveTransfer(BaseOperator):
    """
    Moves data from S3 to Hive. The operator downloads a file from S3,
    stores the file locally before loading it into a Hive table.
    If the ``create`` or ``recreate`` arguments are set to ``True``,
    a ``CREATE TABLE`` and ``DROP TABLE`` statements are generated.
    Hive data types are inferred from the cursor's metadata from.

    Note that the table generated in Hive uses ``STORED AS textfile``
    which isn't the most efficient serialization format. If a
    large amount of data is loaded and/or if the tables gets
    queried considerably, you may want to use this operator only to
    stage the data into a temporary table before loading it into its
    final destination using a ``HiveOperator``.

    :param s3_key: The key to be retrieved from S3
    :type s3_key: str
    :param field_dict: A dictionary of the fields name in the file
        as keys and their Hive types as values
    :type field_dict: dict
    :param hive_table: target Hive table, use dot notation to target a
        specific database
    :type hive_table: str
    :param create: whether to create the table if it doesn't exist
    :type create: bool
    :param recreate: whether to drop and recreate the table at every
        execution
    :type recreate: bool
    :param partition: target partition as a dict of partition columns
        and values
    :type partition: dict
    :param headers: whether the file contains column names on the first
        line
    :type headers: bool
    :param check_headers: whether the column names on the first line should be
        checked against the keys of field_dict
    :type check_headers: bool
    :param wildcard_match: whether the s3_key should be interpreted as a Unix
        wildcard pattern
    :type wildcard_match: bool
    :param delimiter: field delimiter in the file
    :type delimiter: str
    :param aws_conn_id: source s3 connection
    :type aws_conn_id: str
    :param hive_cli_conn_id: destination hive connection
    :type hive_cli_conn_id: str
    :param input_compressed: Boolean to determine if file decompression is
        required to process headers
    :type input_compressed: bool
    :param tblproperties: TBLPROPERTIES of the hive table being created
    :type tblproperties: dict
    """

    template_fields = ('s3_key', 'partition', 'hive_table')
    template_ext = ()
    ui_color = '#a0e08c'

    @apply_defaults
    def __init__(
            self,
            s3_key,
            field_dict,
            hive_table,
            delimiter=',',
            create=True,
            recreate=False,
            partition=None,
            headers=False,
            check_headers=False,
            wildcard_match=False,
            aws_conn_id='aws_default',
            hive_cli_conn_id='hive_cli_default',
            input_compressed=False,
            tblproperties=None,
            *args, **kwargs):
        super(S3ToHiveTransfer, self).__init__(*args, **kwargs)
        self.s3_key = s3_key
        self.field_dict = field_dict
        self.hive_table = hive_table
        self.delimiter = delimiter
        self.create = create
        self.recreate = recreate
        self.partition = partition
        self.headers = headers
        self.check_headers = check_headers
        self.wildcard_match = wildcard_match
        self.hive_cli_conn_id = hive_cli_conn_id
        self.aws_conn_id = aws_conn_id
        self.input_compressed = input_compressed
        self.tblproperties = tblproperties

        if (self.check_headers and
                not (self.field_dict is not None and self.headers)):
            raise AirflowException("To check_headers provide " +
                                   "field_dict and headers")

    def execute(self, context):
        # Downloading file from S3
        self.s3 = S3Hook(aws_conn_id=self.aws_conn_id)
        self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        self.log.info("Downloading S3 file")

        if self.wildcard_match:
            if not self.s3.check_for_wildcard_key(self.s3_key):
                raise AirflowException("No key matches {0}"
                                       .format(self.s3_key))
            s3_key_object = self.s3.get_wildcard_key(self.s3_key)
        else:
            if not self.s3.check_for_key(self.s3_key):
                raise AirflowException(
                    "The key {0} does not exists".format(self.s3_key))
            s3_key_object = self.s3.get_key(self.s3_key)
        root, file_ext = os.path.splitext(s3_key_object.key)
        with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\
                NamedTemporaryFile(mode="wb",
                                   dir=tmp_dir,
                                   suffix=file_ext) as f:
            self.log.info("Dumping S3 key {0} contents to local file {1}"
                          .format(s3_key_object.key, f.name))
            s3_key_object.download_fileobj(f)
            f.flush()
            if not self.headers:
                self.log.info("Loading file %s into Hive", f.name)
                self.hive.load_file(
                    f.name,
                    self.hive_table,
                    field_dict=self.field_dict,
                    create=self.create,
                    partition=self.partition,
                    delimiter=self.delimiter,
                    recreate=self.recreate,
                    tblproperties=self.tblproperties)
            else:
                # Decompressing file
                if self.input_compressed:
                    self.log.info("Uncompressing file %s", f.name)
                    fn_uncompressed = uncompress_file(f.name,
                                                      file_ext,
                                                      tmp_dir)
                    self.log.info("Uncompressed to %s", fn_uncompressed)
                    # uncompressed file available now so deleting
                    # compressed file to save disk space
                    f.close()
                else:
                    fn_uncompressed = f.name

                # Testing if header matches field_dict
                if self.check_headers:
                    self.log.info("Matching file header against field_dict")
                    header_list = self._get_top_row_as_list(fn_uncompressed)
                    if not self._match_headers(header_list):
                        raise AirflowException("Header check failed")

                # Deleting top header row
                self.log.info("Removing header from file %s", fn_uncompressed)
                headless_file = (
                    self._delete_top_row_and_compress(fn_uncompressed,
                                                      file_ext,
                                                      tmp_dir))
                self.log.info("Headless file %s", headless_file)
                self.log.info("Loading file %s into Hive", headless_file)
                self.hive.load_file(headless_file,
                                    self.hive_table,
                                    field_dict=self.field_dict,
                                    create=self.create,
                                    partition=self.partition,
                                    delimiter=self.delimiter,
                                    recreate=self.recreate,
                                    tblproperties=self.tblproperties)

    def _get_top_row_as_list(self, file_name):
        with open(file_name, 'rt') as f:
            header_line = f.readline().strip()
            header_list = header_line.split(self.delimiter)
            return header_list

    def _match_headers(self, header_list):
        if not header_list:
            raise AirflowException("Unable to retrieve header row from file")
        field_names = self.field_dict.keys()
        if len(field_names) != len(header_list):
            self.log.warning("Headers count mismatch"
                              "File headers:\n {header_list}\n"
                              "Field names: \n {field_names}\n"
                              "".format(**locals()))
            return False
        test_field_match = [h1.lower() == h2.lower()
                            for h1, h2 in zip(header_list, field_names)]
        if not all(test_field_match):
            self.log.warning("Headers do not match field names"
                              "File headers:\n {header_list}\n"
                              "Field names: \n {field_names}\n"
                              "".format(**locals()))
            return False
        else:
            return True

    def _delete_top_row_and_compress(
            self,
            input_file_name,
            output_file_ext,
            dest_dir):
        # When output_file_ext is not defined, file is not compressed
        open_fn = open
        if output_file_ext.lower() == '.gz':
            open_fn = gzip.GzipFile
        elif output_file_ext.lower() == '.bz2':
            open_fn = bz2.BZ2File

        os_fh_output, fn_output = \
            tempfile.mkstemp(suffix=output_file_ext, dir=dest_dir)
        with open(input_file_name, 'rb') as f_in,\
                open_fn(fn_output, 'wb') as f_out:
            f_in.seek(0)
            next(f_in)
            for line in f_in:
                f_out.write(line)
        return fn_output
コード例 #16
0
class S3ToHiveTransfer(BaseOperator):
    """
    Moves data from S3 to Hive. The operator downloads a file from S3,
    stores the file locally before loading it into a Hive table.
    If the ``create`` or ``recreate`` arguments are set to ``True``,
    a ``CREATE TABLE`` and ``DROP TABLE`` statements are generated.
    Hive data types are inferred from the cursor's metadata from.

    Note that the table generated in Hive uses ``STORED AS textfile``
    which isn't the most efficient serialization format. If a
    large amount of data is loaded and/or if the tables gets
    queried considerably, you may want to use this operator only to
    stage the data into a temporary table before loading it into its
    final destination using a ``HiveOperator``.

    :param s3_key: The key to be retrieved from S3. (templated)
    :type s3_key: str
    :param field_dict: A dictionary of the fields name in the file
        as keys and their Hive types as values
    :type field_dict: dict
    :param hive_table: target Hive table, use dot notation to target a
        specific database. (templated)
    :type hive_table: str
    :param delimiter: field delimiter in the file
    :type delimiter: str
    :param create: whether to create the table if it doesn't exist
    :type create: bool
    :param recreate: whether to drop and recreate the table at every
        execution
    :type recreate: bool
    :param partition: target partition as a dict of partition columns
        and values. (templated)
    :type partition: dict
    :param headers: whether the file contains column names on the first
        line
    :type headers: bool
    :param check_headers: whether the column names on the first line should be
        checked against the keys of field_dict
    :type check_headers: bool
    :param wildcard_match: whether the s3_key should be interpreted as a Unix
        wildcard pattern
    :type wildcard_match: bool
    :param aws_conn_id: source s3 connection
    :type aws_conn_id: str
    :param verify: Whether or not to verify SSL certificates for S3 connection.
        By default SSL certificates are verified.
        You can provide the following values:

        - ``False``: do not validate SSL certificates. SSL will still be used
                 (unless use_ssl is False), but SSL certificates will not be
                 verified.
        - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
                 You can specify this argument if you want to use a different
                 CA cert bundle than the one used by botocore.
    :type verify: bool or str
    :param hive_cli_conn_id: destination hive connection
    :type hive_cli_conn_id: str
    :param input_compressed: Boolean to determine if file decompression is
        required to process headers
    :type input_compressed: bool
    :param tblproperties: TBLPROPERTIES of the hive table being created
    :type tblproperties: dict
    :param select_expression: S3 Select expression
    :type select_expression: str
    """

    template_fields = ('s3_key', 'partition', 'hive_table')
    template_ext = ()
    ui_color = '#a0e08c'

    @apply_defaults
    def __init__(
            self,
            s3_key: str,
            field_dict: Dict,
            hive_table: str,
            delimiter: str = ',',
            create: bool = True,
            recreate: bool = False,
            partition: Optional[Dict] = None,
            headers: bool = False,
            check_headers: bool = False,
            wildcard_match: bool = False,
            aws_conn_id: str = 'aws_default',
            verify: Optional[Union[bool, str]] = None,
            hive_cli_conn_id: str = 'hive_cli_default',
            input_compressed: bool = False,
            tblproperties: Optional[Dict] = None,
            select_expression: Optional[str] = None,
            *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.s3_key = s3_key
        self.field_dict = field_dict
        self.hive_table = hive_table
        self.delimiter = delimiter
        self.create = create
        self.recreate = recreate
        self.partition = partition
        self.headers = headers
        self.check_headers = check_headers
        self.wildcard_match = wildcard_match
        self.hive_cli_conn_id = hive_cli_conn_id
        self.aws_conn_id = aws_conn_id
        self.verify = verify
        self.input_compressed = input_compressed
        self.tblproperties = tblproperties
        self.select_expression = select_expression

        if (self.check_headers and
                not (self.field_dict is not None and self.headers)):
            raise AirflowException("To check_headers provide " +
                                   "field_dict and headers")

    def execute(self, context):
        # Downloading file from S3
        self.s3 = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
        self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        self.log.info("Downloading S3 file")

        if self.wildcard_match:
            if not self.s3.check_for_wildcard_key(self.s3_key):
                raise AirflowException("No key matches {0}"
                                       .format(self.s3_key))
            s3_key_object = self.s3.get_wildcard_key(self.s3_key)
        else:
            if not self.s3.check_for_key(self.s3_key):
                raise AirflowException(
                    "The key {0} does not exists".format(self.s3_key))
            s3_key_object = self.s3.get_key(self.s3_key)

        _, file_ext = os.path.splitext(s3_key_object.key)
        if (self.select_expression and self.input_compressed and
                file_ext.lower() != '.gz'):
            raise AirflowException("GZIP is the only compression " +
                                   "format Amazon S3 Select supports")

        with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\
                NamedTemporaryFile(mode="wb",
                                   dir=tmp_dir,
                                   suffix=file_ext) as f:
            self.log.info(
                "Dumping S3 key %s contents to local file %s", s3_key_object.key, f.name
            )
            if self.select_expression:
                option = {}
                if self.headers:
                    option['FileHeaderInfo'] = 'USE'
                if self.delimiter:
                    option['FieldDelimiter'] = self.delimiter

                input_serialization = {'CSV': option}
                if self.input_compressed:
                    input_serialization['CompressionType'] = 'GZIP'

                content = self.s3.select_key(
                    bucket_name=s3_key_object.bucket_name,
                    key=s3_key_object.key,
                    expression=self.select_expression,
                    input_serialization=input_serialization
                )
                f.write(content.encode("utf-8"))
            else:
                s3_key_object.download_fileobj(f)
            f.flush()

            if self.select_expression or not self.headers:
                self.log.info("Loading file %s into Hive", f.name)
                self.hive.load_file(
                    f.name,
                    self.hive_table,
                    field_dict=self.field_dict,
                    create=self.create,
                    partition=self.partition,
                    delimiter=self.delimiter,
                    recreate=self.recreate,
                    tblproperties=self.tblproperties)
            else:
                # Decompressing file
                if self.input_compressed:
                    self.log.info("Uncompressing file %s", f.name)
                    fn_uncompressed = uncompress_file(f.name,
                                                      file_ext,
                                                      tmp_dir)
                    self.log.info("Uncompressed to %s", fn_uncompressed)
                    # uncompressed file available now so deleting
                    # compressed file to save disk space
                    f.close()
                else:
                    fn_uncompressed = f.name

                # Testing if header matches field_dict
                if self.check_headers:
                    self.log.info("Matching file header against field_dict")
                    header_list = self._get_top_row_as_list(fn_uncompressed)
                    if not self._match_headers(header_list):
                        raise AirflowException("Header check failed")

                # Deleting top header row
                self.log.info("Removing header from file %s", fn_uncompressed)
                headless_file = (
                    self._delete_top_row_and_compress(fn_uncompressed,
                                                      file_ext,
                                                      tmp_dir))
                self.log.info("Headless file %s", headless_file)
                self.log.info("Loading file %s into Hive", headless_file)
                self.hive.load_file(headless_file,
                                    self.hive_table,
                                    field_dict=self.field_dict,
                                    create=self.create,
                                    partition=self.partition,
                                    delimiter=self.delimiter,
                                    recreate=self.recreate,
                                    tblproperties=self.tblproperties)

    def _get_top_row_as_list(self, file_name):
        with open(file_name, 'rt') as file:
            header_line = file.readline().strip()
            header_list = header_line.split(self.delimiter)
            return header_list

    def _match_headers(self, header_list):
        if not header_list:
            raise AirflowException("Unable to retrieve header row from file")
        field_names = self.field_dict.keys()
        if len(field_names) != len(header_list):
            self.log.warning(
                "Headers count mismatch File headers:\n %s\nField names: \n %s\n", header_list, field_names
            )
            return False
        test_field_match = [h1.lower() == h2.lower()
                            for h1, h2 in zip(header_list, field_names)]
        if not all(test_field_match):
            self.log.warning(
                "Headers do not match field names File headers:\n %s\nField names: \n %s\n",
                header_list, field_names
            )
            return False
        else:
            return True

    @staticmethod
    def _delete_top_row_and_compress(
            input_file_name,
            output_file_ext,
            dest_dir):
        # When output_file_ext is not defined, file is not compressed
        open_fn = open
        if output_file_ext.lower() == '.gz':
            open_fn = gzip.GzipFile
        elif output_file_ext.lower() == '.bz2':
            open_fn = bz2.BZ2File

        _, fn_output = tempfile.mkstemp(suffix=output_file_ext, dir=dest_dir)
        with open(input_file_name, 'rb') as f_in, open_fn(fn_output, 'wb') as f_out:
            f_in.seek(0)
            next(f_in)
            for line in f_in:
                f_out.write(line)
        return fn_output
コード例 #17
0
class S3ToHiveTransfer(BaseOperator):
    """
    Moves data from S3 to Hive. The operator downloads a file from S3,
    stores the file locally before loading it into a Hive table.
    If the ``create`` or ``recreate`` arguments are set to ``True``,
    a ``CREATE TABLE`` and ``DROP TABLE`` statements are generated.
    Hive data types are inferred from the cursor's metadata from.

    Note that the table generated in Hive uses ``STORED AS textfile``
    which isn't the most efficient serialization format. If a
    large amount of data is loaded and/or if the tables gets
    queried considerably, you may want to use this operator only to
    stage the data into a temporary table before loading it into its
    final destination using a ``HiveOperator``.

    :param s3_key: The key to be retrieved from S3
    :type s3_key: str
    :param field_dict: A dictionary of the fields name in the file
        as keys and their Hive types as values
    :type field_dict: dict
    :param hive_table: target Hive table, use dot notation to target a
        specific database
    :type hive_table: str
    :param create: whether to create the table if it doesn't exist
    :type create: bool
    :param recreate: whether to drop and recreate the table at every
        execution
    :type recreate: bool
    :param partition: target partition as a dict of partition columns
        and values
    :type partition: dict
    :param headers: whether the file contains column names on the first
        line
    :type headers: bool
    :param check_headers: whether the column names on the first line should be
        checked against the keys of field_dict
    :type check_headers: bool
    :param wildcard_match: whether the s3_key should be interpreted as a Unix
        wildcard pattern
    :type wildcard_match: bool
    :param delimiter: field delimiter in the file
    :type delimiter: str
    :param s3_conn_id: source s3 connection
    :type s3_conn_id: str
    :param hive_conn_id: destination hive connection
    :type hive_conn_id: str
    """

    template_fields = ('s3_key', 'partition', 'hive_table')
    template_ext = ()
    ui_color = '#a0e08c'

    @apply_defaults
    def __init__(self,
                 s3_key,
                 field_dict,
                 hive_table,
                 delimiter=',',
                 create=True,
                 recreate=False,
                 partition=None,
                 headers=False,
                 check_headers=False,
                 wildcard_match=False,
                 s3_conn_id='s3_default',
                 hive_cli_conn_id='hive_cli_default',
                 *args,
                 **kwargs):
        super(S3ToHiveTransfer, self).__init__(*args, **kwargs)
        self.s3_key = s3_key
        self.field_dict = field_dict
        self.hive_table = hive_table
        self.delimiter = delimiter
        self.create = create
        self.recreate = recreate
        self.partition = partition
        self.headers = headers
        self.check_headers = check_headers
        self.wildcard_match = wildcard_match
        self.hive_cli_conn_id = hive_cli_conn_id
        self.s3_conn_id = s3_conn_id

    def execute(self, context):
        self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        self.s3 = S3Hook(s3_conn_id=self.s3_conn_id)
        _log.info("Downloading S3 file")
        if self.wildcard_match:
            if not self.s3.check_for_wildcard_key(self.s3_key):
                raise AirflowException("No key matches {0}".format(
                    self.s3_key))
            s3_key_object = self.s3.get_wildcard_key(self.s3_key)
        else:
            if not self.s3.check_for_key(self.s3_key):
                raise AirflowException("The key {0} does not exists".format(
                    self.s3_key))
            s3_key_object = self.s3.get_key(self.s3_key)
        with NamedTemporaryFile("w") as f:
            _log.info("Dumping S3 key {0} contents to local"
                      " file {1}".format(s3_key_object.key, f.name))
            s3_key_object.get_contents_to_file(f)
            f.flush()
            self.s3.connection.close()
            if not self.headers:
                _log.info("Loading file into Hive")
                self.hive.load_file(f.name,
                                    self.hive_table,
                                    field_dict=self.field_dict,
                                    create=self.create,
                                    partition=self.partition,
                                    delimiter=self.delimiter,
                                    recreate=self.recreate)
            else:
                with open(f.name, 'r') as tmpf:
                    if self.check_headers:
                        header_l = tmpf.readline()
                        header_line = header_l.rstrip()
                        header_list = header_line.split(self.delimiter)
                        field_names = list(self.field_dict.keys())
                        test_field_match = [
                            h1.lower() == h2.lower()
                            for h1, h2 in zip(header_list, field_names)
                        ]
                        if not all(test_field_match):
                            _log.warning("Headers do not match field names"
                                         "File headers:\n {header_list}\n"
                                         "Field names: \n {field_names}\n"
                                         "".format(**locals()))
                            raise AirflowException("Headers do not match the "
                                                   "field_dict keys")
                    with NamedTemporaryFile("w") as f_no_headers:
                        tmpf.seek(0)
                        next(tmpf)
                        for line in tmpf:
                            f_no_headers.write(line)
                        f_no_headers.flush()
                        _log.info("Loading file without headers into Hive")
                        self.hive.load_file(f_no_headers.name,
                                            self.hive_table,
                                            field_dict=self.field_dict,
                                            create=self.create,
                                            partition=self.partition,
                                            delimiter=self.delimiter,
                                            recreate=self.recreate)