def execute(self, context): hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) logging.info("Dumping MySQL query results to local file") conn = mysql.get_conn() cursor = conn.cursor(MySQLdb.cursors.SSCursor) cursor.execute(self.sql) with NamedTemporaryFile("wb") as f: csv_writer = csv.writer(f, delimiter=self.delimiter, encoding="utf-8") field_dict = OrderedDict() for field in cursor.description: field_dict[field[0]] = self.type_map(field[1]) # csv_writer.writerows(cursor) while True: row = cursor.fetchone() if not row: break csv_writer.writerow(row) f.flush() cursor.close() conn.close() logging.info("Loading file into Hive") hive.load_file(f.name, self.hive_table, field_dict=field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate)
def execute(self, context): mssql = MsSqlHook(mssql_conn_id=self.mssql_conn_id) self.log.info( "Dumping Microsoft SQL Server query results to local file") with mssql.get_conn() as conn: with conn.cursor() as cursor: cursor.execute(self.sql) with NamedTemporaryFile("w") as tmp_file: csv_writer = csv.writer(tmp_file, delimiter=self.delimiter, encoding='utf-8') field_dict = OrderedDict() col_count = 0 for field in cursor.description: col_count += 1 col_position = "Column{position}".format( position=col_count) field_dict[col_position if field[0] == '' else field[0]] = self.type_map(field[1]) csv_writer.writerows(cursor) tmp_file.flush() hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) self.log.info("Loading file into Hive") hive.load_file(tmp_file.name, self.hive_table, field_dict=field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties)
def execute(self, context): hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) mssql = MsSqlHook(mssql_conn_id=self.mssql_conn_id) logging.info("Dumping Microsoft SQL Server query results to local file") conn = mssql.get_conn() cursor = conn.cursor() cursor.execute(self.sql) with NamedTemporaryFile("w") as f: csv_writer = csv.writer(f, delimiter=self.delimiter, encoding='utf-8') field_dict = OrderedDict() col_count = 0 for field in cursor.description: col_count += 1 col_position = "Column{position}".format(position=col_count) field_dict[col_position if field[0] == '' else field[0]] = self.type_map(field[1]) csv_writer.writerows(cursor) f.flush() cursor.close() conn.close() logging.info("Loading file into Hive") hive.load_file( f.name, self.hive_table, field_dict=field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate)
def execute(self, context): hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id) self.log.info("Dumping Vertica query results to local file") conn = vertica.get_conn() cursor = conn.cursor() cursor.execute(self.sql) with NamedTemporaryFile("w") as f: csv_writer = csv.writer(f, delimiter=self.delimiter, encoding='utf-8') field_dict = OrderedDict() col_count = 0 for field in cursor.description: col_count += 1 col_position = "Column{position}".format(position=col_count) field_dict[col_position if field[0] == '' else field[0]] = \ self.type_map(field[1]) csv_writer.writerows(cursor.iterate()) f.flush() cursor.close() conn.close() self.log.info("Loading file into Hive") hive.load_file(f.name, self.hive_table, field_dict=field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate)
def execute(self, context): hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) self.log.info("Dumping MySQL query results to local file") conn = mysql.get_conn() cursor = conn.cursor() cursor.execute(self.sql) with NamedTemporaryFile("wb") as f: csv_writer = csv.writer(f, delimiter=self.delimiter, encoding="utf-8") field_dict = OrderedDict() for field in cursor.description: field_dict[field[0]] = self.type_map(field[1]) csv_writer.writerows(cursor) f.flush() cursor.close() conn.close() self.log.info("Loading file into Hive") hive.load_file( f.name, self.hive_table, field_dict=field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties)
def test_load_file_create_table(self, mock_run_cli): filepath = "/path/to/input/file" table = "output_table" field_dict = OrderedDict([("name", "string"), ("gender", "string")]) fields = ",\n ".join([k + ' ' + v for k, v in field_dict.items()]) hook = HiveCliHook() hook.load_file(filepath=filepath, table=table, field_dict=field_dict, create=True, recreate=True) create_table = ("DROP TABLE IF EXISTS {table};\n" "CREATE TABLE IF NOT EXISTS {table} (\n{fields})\n" "ROW FORMAT DELIMITED\n" "FIELDS TERMINATED BY ','\n" "STORED AS textfile\n;".format(table=table, fields=fields)) load_data = ("LOAD DATA LOCAL INPATH '{filepath}' " "OVERWRITE INTO TABLE {table} ;\n".format( filepath=filepath, table=table)) calls = [mock.call(create_table), mock.call(load_data)] mock_run_cli.assert_has_calls(calls, any_order=True)
def execute(self, context): hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) self.log.info("Dumping MySQL query results to local file") conn = mysql.get_conn() cursor = conn.cursor() cursor.execute(self.sql) with NamedTemporaryFile("wb") as f: csv_writer = csv.writer(f, delimiter=self.delimiter, quoting=self.quoting, quotechar=self.quotechar, escapechar=self.escapechar, encoding="utf-8") field_dict = OrderedDict() for field in cursor.description: field_dict[field[0]] = self.type_map(field[1]) csv_writer.writerows(cursor) f.flush() cursor.close() conn.close() self.log.info("Loading file into Hive") hive.load_file(f.name, self.hive_table, field_dict=field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties)
def test_load_file(self, mock_run_cli): filepath = "/path/to/input/file" table = "output_table" hook = HiveCliHook() hook.load_file(filepath=filepath, table=table, create=False) query = ("LOAD DATA LOCAL INPATH '{filepath}' " "OVERWRITE INTO TABLE {table} ;\n".format(filepath=filepath, table=table)) mock_run_cli.assert_called_with(query)
def test_load_file(self, mock_run_cli): filepath = "/path/to/input/file" table = "output_table" hook = HiveCliHook() hook.load_file(filepath=filepath, table=table, create=False) query = ("LOAD DATA LOCAL INPATH '{filepath}' " "OVERWRITE INTO TABLE {table} ;\n".format(filepath=filepath, table=table)) calls = [mock.call(';'), mock.call(query)] mock_run_cli.assert_has_calls(calls, any_order=True)
def test_load_file(self, mock_run_cli): filepath = "/path/to/input/file" table = "output_table" hook = HiveCliHook() hook.load_file(filepath=filepath, table=table, create=False) query = ( "LOAD DATA LOCAL INPATH '{filepath}' " "OVERWRITE INTO TABLE {table} \n" .format(filepath=filepath, table=table) ) mock_run_cli.assert_called_with(query)
def run(self, *args, **kwargs): aws_access_key_id = Variable.get("aws_access_key_id") aws_secret_access_key = Variable.get("aws_secret_access_key") aws_bucket_name = Variable.get("aws_bucket_name") hive_cli_conn_id = Variable.get("hive_cli_conn_id") field_dict = Variable.get("field_dict") self.storage = S3StorageDriver(aws_access_key_id, aws_secret_access_key, aws_bucket_name) url = self.get_input_filename("file") parse_url = urlparse(url) cli_hook = HiveCliHook(hive_cli_conn_id=hive_cli_conn_id) path = '{}{}{}{}{}{}{}{}'.format(parse_url.scheme, "n://", aws_access_key_id, ":", aws_secret_access_key, "@", str(parse_url.netloc), parse_url.path) bucketname = '{}{}'.format("s3_", str(parse_url.path).split("/").pop()) cli_hook.load_file(field_dict=ast.literal_eval(field_dict), table=bucketname, filepath=path)
def execute(self, context): # Downloading file from S3 s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) hive_hook = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) self.log.info("Downloading S3 file") if self.wildcard_match: if not s3_hook.check_for_wildcard_key(self.s3_key): raise AirflowException(f"No key matches {self.s3_key}") s3_key_object = s3_hook.get_wildcard_key(self.s3_key) else: if not s3_hook.check_for_key(self.s3_key): raise AirflowException(f"The key {self.s3_key} does not exists") s3_key_object = s3_hook.get_key(self.s3_key) _, file_ext = os.path.splitext(s3_key_object.key) if (self.select_expression and self.input_compressed and file_ext.lower() != '.gz'): raise AirflowException("GZIP is the only compression " + "format Amazon S3 Select supports") with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\ NamedTemporaryFile(mode="wb", dir=tmp_dir, suffix=file_ext) as f: self.log.info( "Dumping S3 key %s contents to local file %s", s3_key_object.key, f.name ) if self.select_expression: option = {} if self.headers: option['FileHeaderInfo'] = 'USE' if self.delimiter: option['FieldDelimiter'] = self.delimiter input_serialization = {'CSV': option} if self.input_compressed: input_serialization['CompressionType'] = 'GZIP' content = s3_hook.select_key( bucket_name=s3_key_object.bucket_name, key=s3_key_object.key, expression=self.select_expression, input_serialization=input_serialization ) f.write(content.encode("utf-8")) else: s3_key_object.download_fileobj(f) f.flush() if self.select_expression or not self.headers: self.log.info("Loading file %s into Hive", f.name) hive_hook.load_file( f.name, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties) else: # Decompressing file if self.input_compressed: self.log.info("Uncompressing file %s", f.name) fn_uncompressed = uncompress_file(f.name, file_ext, tmp_dir) self.log.info("Uncompressed to %s", fn_uncompressed) # uncompressed file available now so deleting # compressed file to save disk space f.close() else: fn_uncompressed = f.name # Testing if header matches field_dict if self.check_headers: self.log.info("Matching file header against field_dict") header_list = self._get_top_row_as_list(fn_uncompressed) if not self._match_headers(header_list): raise AirflowException("Header check failed") # Deleting top header row self.log.info("Removing header from file %s", fn_uncompressed) headless_file = ( self._delete_top_row_and_compress(fn_uncompressed, file_ext, tmp_dir)) self.log.info("Headless file %s", headless_file) self.log.info("Loading file %s into Hive", headless_file) hive_hook.load_file(headless_file, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties)
class S3ToHiveTransfer(BaseOperator): """ Moves data from S3 to Hive. The operator downloads a file from S3, stores the file locally before loading it into a Hive table. If the ``create`` or ``recreate`` arguments are set to ``True``, a ``CREATE TABLE`` and ``DROP TABLE`` statements are generated. Hive data types are inferred from the cursor's metadata from. Note that the table generated in Hive uses ``STORED AS textfile`` which isn't the most efficient serialization format. If a large amount of data is loaded and/or if the tables gets queried considerably, you may want to use this operator only to stage the data into a temporary table before loading it into its final destination using a ``HiveOperator``. :param s3_key: The key to be retrieved from S3 :type s3_key: str :param field_dict: A dictionary of the fields name in the file as keys and their Hive types as values :type field_dict: dict :param hive_table: target Hive table, use dot notation to target a specific database :type hive_table: str :param create: whether to create the table if it doesn't exist :type create: bool :param recreate: whether to drop and recreate the table at every execution :type recreate: bool :param partition: target partition as a dict of partition columns and values :type partition: dict :param headers: whether the file contains column names on the first line :type headers: bool :param check_headers: whether the column names on the first line should be checked against the keys of field_dict :type check_headers: bool :param wildcard_match: whether the s3_key should be interpreted as a Unix wildcard pattern :type wildcard_match: bool :param delimiter: field delimiter in the file :type delimiter: str :param s3_conn_id: source s3 connection :type s3_conn_id: str :param hive_conn_id: destination hive connection :type hive_conn_id: str """ template_fields = ('s3_key', 'partition', 'hive_table') template_ext = () ui_color = '#a0e08c' @apply_defaults def __init__( self, s3_key, field_dict, hive_table, delimiter=',', create=True, recreate=False, partition=None, headers=False, check_headers=False, wildcard_match=False, s3_conn_id='s3_default', hive_cli_conn_id='hive_cli_default', *args, **kwargs): super(S3ToHiveTransfer, self).__init__(*args, **kwargs) self.s3_key = s3_key self.field_dict = field_dict self.hive_table = hive_table self.delimiter = delimiter self.create = create self.recreate = recreate self.partition = partition self.headers = headers self.check_headers = check_headers self.wildcard_match = wildcard_match self.hive_cli_conn_id = hive_cli_conn_id self.s3_conn_id = s3_conn_id def execute(self, context): self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) self.s3 = S3Hook(s3_conn_id=self.s3_conn_id) logging.info("Downloading S3 file") if self.wildcard_match: if not self.s3.check_for_wildcard_key(self.s3_key): raise AirflowException("No key matches {0}".format(self.s3_key)) s3_key_object = self.s3.get_wildcard_key(self.s3_key) else: if not self.s3.check_for_key(self.s3_key): raise AirflowException( "The key {0} does not exists".format(self.s3_key)) s3_key_object = self.s3.get_key(self.s3_key) with NamedTemporaryFile("w") as f: logging.info("Dumping S3 key {0} contents to local" " file {1}".format(s3_key_object.key, f.name)) s3_key_object.get_contents_to_file(f) f.flush() self.s3.connection.close() if not self.headers: logging.info("Loading file into Hive") self.hive.load_file( f.name, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate) else: with open(f.name, 'r') as tmpf: if self.check_headers: header_l = tmpf.readline() header_line = header_l.rstrip() header_list = header_line.split(self.delimiter) field_names = list(self.field_dict.keys()) test_field_match = [h1.lower() == h2.lower() for h1, h2 in zip(header_list, field_names)] if not all(test_field_match): logging.warning("Headers do not match field names" "File headers:\n {header_list}\n" "Field names: \n {field_names}\n" "".format(**locals())) raise AirflowException("Headers do not match the " "field_dict keys") with NamedTemporaryFile("w") as f_no_headers: tmpf.seek(0) next(tmpf) for line in tmpf: f_no_headers.write(line) f_no_headers.flush() logging.info("Loading file without headers into Hive") self.hive.load_file( f_no_headers.name, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate)
class S3ToHiveTransfer(BaseOperator): """ Moves data from S3 to Hive. The operator downloads a file from S3, stores the file locally before loading it into a Hive table. If the ``create`` or ``recreate`` arguments are set to ``True``, a ``CREATE TABLE`` and ``DROP TABLE`` statements are generated. Hive data types are inferred from the cursor's metadata from. Note that the table generated in Hive uses ``STORED AS textfile`` which isn't the most efficient serialization format. If a large amount of data is loaded and/or if the tables gets queried considerably, you may want to use this operator only to stage the data into a temporary table before loading it into its final destination using a ``HiveOperator``. :param s3_key: The key to be retrieved from S3 :type s3_key: str :param field_dict: A dictionary of the fields name in the file as keys and their Hive types as values :type field_dict: dict :param hive_table: target Hive table, use dot notation to target a specific database :type hive_table: str :param create: whether to create the table if it doesn't exist :type create: bool :param recreate: whether to drop and recreate the table at every execution :type recreate: bool :param partition: target partition as a dict of partition columns and values :type partition: dict :param headers: whether the file contains column names on the first line :type headers: bool :param check_headers: whether the column names on the first line should be checked against the keys of field_dict :type check_headers: bool :param wildcard_match: whether the s3_key should be interpreted as a Unix wildcard pattern :type wildcard_match: bool :param delimiter: field delimiter in the file :type delimiter: str :param s3_conn_id: source s3 connection :type s3_conn_id: str :param hive_cli_conn_id: destination hive connection :type hive_cli_conn_id: str :param input_compressed: Boolean to determine if file decompression is required to process headers :type input_compressed: bool :param tblproperties: TBLPROPERTIES of the hive table being created :type tblproperties: dict """ template_fields = ('s3_key', 'partition', 'hive_table') template_ext = () ui_color = '#a0e08c' @apply_defaults def __init__(self, s3_key, field_dict, hive_table, delimiter=',', create=True, recreate=False, partition=None, headers=False, check_headers=False, wildcard_match=False, s3_conn_id='s3_default', hive_cli_conn_id='hive_cli_default', input_compressed=False, tblproperties=None, *args, **kwargs): super(S3ToHiveTransfer, self).__init__(*args, **kwargs) self.s3_key = s3_key self.field_dict = field_dict self.hive_table = hive_table self.delimiter = delimiter self.create = create self.recreate = recreate self.partition = partition self.headers = headers self.check_headers = check_headers self.wildcard_match = wildcard_match self.hive_cli_conn_id = hive_cli_conn_id self.s3_conn_id = s3_conn_id self.input_compressed = input_compressed self.tblproperties = tblproperties if (self.check_headers and not (self.field_dict is not None and self.headers)): raise AirflowException("To check_headers provide " + "field_dict and headers") def execute(self, context): # Downloading file from S3 self.s3 = S3Hook(s3_conn_id=self.s3_conn_id) self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) logging.info("Downloading S3 file") if self.wildcard_match: if not self.s3.check_for_wildcard_key(self.s3_key): raise AirflowException("No key matches {0}".format( self.s3_key)) s3_key_object = self.s3.get_wildcard_key(self.s3_key) else: if not self.s3.check_for_key(self.s3_key): raise AirflowException("The key {0} does not exists".format( self.s3_key)) s3_key_object = self.s3.get_key(self.s3_key) root, file_ext = os.path.splitext(s3_key_object.key) with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\ NamedTemporaryFile(mode="w", dir=tmp_dir, suffix=file_ext) as f: logging.info("Dumping S3 key {0} contents to local" " file {1}".format(s3_key_object.key, f.name)) s3_key_object.get_contents_to_file(f) f.flush() self.s3.connection.close() if not self.headers: logging.info("Loading file {0} into Hive".format(f.name)) self.hive.load_file(f.name, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties) else: # Decompressing file if self.input_compressed: logging.info("Uncompressing file {0}".format(f.name)) fn_uncompressed = uncompress_file(f.name, file_ext, tmp_dir) logging.info("Uncompressed to {0}".format(fn_uncompressed)) # uncompressed file available now so deleting # compressed file to save disk space f.close() else: fn_uncompressed = f.name # Testing if header matches field_dict if self.check_headers: logging.info("Matching file header against field_dict") header_list = self._get_top_row_as_list(fn_uncompressed) if not self._match_headers(header_list): raise AirflowException("Header check failed") # Deleting top header row logging.info( "Removing header from file {0}".format(fn_uncompressed)) headless_file = (self._delete_top_row_and_compress( fn_uncompressed, file_ext, tmp_dir)) logging.info("Headless file {0}".format(headless_file)) logging.info( "Loading file {0} into Hive".format(headless_file)) self.hive.load_file(headless_file, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties) def _get_top_row_as_list(self, file_name): with open(file_name, 'rt') as f: header_line = f.readline().strip() header_list = header_line.split(self.delimiter) return header_list def _match_headers(self, header_list): if not header_list: raise AirflowException("Unable to retrieve header row from file") field_names = self.field_dict.keys() if len(field_names) != len(header_list): logging.warning("Headers count mismatch" "File headers:\n {header_list}\n" "Field names: \n {field_names}\n" "".format(**locals())) return False test_field_match = [ h1.lower() == h2.lower() for h1, h2 in zip(header_list, field_names) ] if not all(test_field_match): logging.warning("Headers do not match field names" "File headers:\n {header_list}\n" "Field names: \n {field_names}\n" "".format(**locals())) return False else: return True def _delete_top_row_and_compress(self, input_file_name, output_file_ext, dest_dir): # When output_file_ext is not defined, file is not compressed open_fn = open if output_file_ext.lower() == '.gz': open_fn = gzip.GzipFile elif output_file_ext.lower() == '.bz2': open_fn = bz2.BZ2File os_fh_output, fn_output = \ tempfile.mkstemp(suffix=output_file_ext, dir=dest_dir) with open(input_file_name, 'rb') as f_in,\ open_fn(fn_output, 'wb') as f_out: f_in.seek(0) next(f_in) for line in f_in: f_out.write(line) return fn_output
class S3ToHiveTransfer(BaseOperator): """ Moves data from S3 to Hive. The operator downloads a file from S3, stores the file locally before loading it into a Hive table. If the ``create`` or ``recreate`` arguments are set to ``True``, a ``CREATE TABLE`` and ``DROP TABLE`` statements are generated. Hive data types are inferred from the cursor's metadata from. Note that the table generated in Hive uses ``STORED AS textfile`` which isn't the most efficient serialization format. If a large amount of data is loaded and/or if the tables gets queried considerably, you may want to use this operator only to stage the data into a temporary table before loading it into its final destination using a ``HiveOperator``. :param s3_key: The key to be retrieved from S3 :type s3_key: str :param field_dict: A dictionary of the fields name in the file as keys and their Hive types as values :type field_dict: dict :param hive_table: target Hive table, use dot notation to target a specific database :type hive_table: str :param create: whether to create the table if it doesn't exist :type create: bool :param recreate: whether to drop and recreate the table at every execution :type recreate: bool :param partition: target partition as a dict of partition columns and values :type partition: dict :param headers: whether the file contains column names on the first line :type headers: bool :param check_headers: whether the column names on the first line should be checked against the keys of field_dict :type check_headers: bool :param wildcard_match: whether the s3_key should be interpreted as a Unix wildcard pattern :type wildcard_match: bool :param delimiter: field delimiter in the file :type delimiter: str :param aws_conn_id: source s3 connection :type aws_conn_id: str :param hive_cli_conn_id: destination hive connection :type hive_cli_conn_id: str :param input_compressed: Boolean to determine if file decompression is required to process headers :type input_compressed: bool :param tblproperties: TBLPROPERTIES of the hive table being created :type tblproperties: dict """ template_fields = ('s3_key', 'partition', 'hive_table') template_ext = () ui_color = '#a0e08c' @apply_defaults def __init__( self, s3_key, field_dict, hive_table, delimiter=',', create=True, recreate=False, partition=None, headers=False, check_headers=False, wildcard_match=False, aws_conn_id='aws_default', hive_cli_conn_id='hive_cli_default', input_compressed=False, tblproperties=None, *args, **kwargs): super(S3ToHiveTransfer, self).__init__(*args, **kwargs) self.s3_key = s3_key self.field_dict = field_dict self.hive_table = hive_table self.delimiter = delimiter self.create = create self.recreate = recreate self.partition = partition self.headers = headers self.check_headers = check_headers self.wildcard_match = wildcard_match self.hive_cli_conn_id = hive_cli_conn_id self.aws_conn_id = aws_conn_id self.input_compressed = input_compressed self.tblproperties = tblproperties if (self.check_headers and not (self.field_dict is not None and self.headers)): raise AirflowException("To check_headers provide " + "field_dict and headers") def execute(self, context): # Downloading file from S3 self.s3 = S3Hook(aws_conn_id=self.aws_conn_id) self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) self.log.info("Downloading S3 file") if self.wildcard_match: if not self.s3.check_for_wildcard_key(self.s3_key): raise AirflowException("No key matches {0}" .format(self.s3_key)) s3_key_object = self.s3.get_wildcard_key(self.s3_key) else: if not self.s3.check_for_key(self.s3_key): raise AirflowException( "The key {0} does not exists".format(self.s3_key)) s3_key_object = self.s3.get_key(self.s3_key) root, file_ext = os.path.splitext(s3_key_object.key) with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\ NamedTemporaryFile(mode="wb", dir=tmp_dir, suffix=file_ext) as f: self.log.info("Dumping S3 key {0} contents to local file {1}" .format(s3_key_object.key, f.name)) s3_key_object.download_fileobj(f) f.flush() if not self.headers: self.log.info("Loading file %s into Hive", f.name) self.hive.load_file( f.name, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties) else: # Decompressing file if self.input_compressed: self.log.info("Uncompressing file %s", f.name) fn_uncompressed = uncompress_file(f.name, file_ext, tmp_dir) self.log.info("Uncompressed to %s", fn_uncompressed) # uncompressed file available now so deleting # compressed file to save disk space f.close() else: fn_uncompressed = f.name # Testing if header matches field_dict if self.check_headers: self.log.info("Matching file header against field_dict") header_list = self._get_top_row_as_list(fn_uncompressed) if not self._match_headers(header_list): raise AirflowException("Header check failed") # Deleting top header row self.log.info("Removing header from file %s", fn_uncompressed) headless_file = ( self._delete_top_row_and_compress(fn_uncompressed, file_ext, tmp_dir)) self.log.info("Headless file %s", headless_file) self.log.info("Loading file %s into Hive", headless_file) self.hive.load_file(headless_file, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties) def _get_top_row_as_list(self, file_name): with open(file_name, 'rt') as f: header_line = f.readline().strip() header_list = header_line.split(self.delimiter) return header_list def _match_headers(self, header_list): if not header_list: raise AirflowException("Unable to retrieve header row from file") field_names = self.field_dict.keys() if len(field_names) != len(header_list): self.log.warning("Headers count mismatch" "File headers:\n {header_list}\n" "Field names: \n {field_names}\n" "".format(**locals())) return False test_field_match = [h1.lower() == h2.lower() for h1, h2 in zip(header_list, field_names)] if not all(test_field_match): self.log.warning("Headers do not match field names" "File headers:\n {header_list}\n" "Field names: \n {field_names}\n" "".format(**locals())) return False else: return True def _delete_top_row_and_compress( self, input_file_name, output_file_ext, dest_dir): # When output_file_ext is not defined, file is not compressed open_fn = open if output_file_ext.lower() == '.gz': open_fn = gzip.GzipFile elif output_file_ext.lower() == '.bz2': open_fn = bz2.BZ2File os_fh_output, fn_output = \ tempfile.mkstemp(suffix=output_file_ext, dir=dest_dir) with open(input_file_name, 'rb') as f_in,\ open_fn(fn_output, 'wb') as f_out: f_in.seek(0) next(f_in) for line in f_in: f_out.write(line) return fn_output
class S3ToHiveTransfer(BaseOperator): """ Moves data from S3 to Hive. The operator downloads a file from S3, stores the file locally before loading it into a Hive table. If the ``create`` or ``recreate`` arguments are set to ``True``, a ``CREATE TABLE`` and ``DROP TABLE`` statements are generated. Hive data types are inferred from the cursor's metadata from. Note that the table generated in Hive uses ``STORED AS textfile`` which isn't the most efficient serialization format. If a large amount of data is loaded and/or if the tables gets queried considerably, you may want to use this operator only to stage the data into a temporary table before loading it into its final destination using a ``HiveOperator``. :param s3_key: The key to be retrieved from S3. (templated) :type s3_key: str :param field_dict: A dictionary of the fields name in the file as keys and their Hive types as values :type field_dict: dict :param hive_table: target Hive table, use dot notation to target a specific database. (templated) :type hive_table: str :param delimiter: field delimiter in the file :type delimiter: str :param create: whether to create the table if it doesn't exist :type create: bool :param recreate: whether to drop and recreate the table at every execution :type recreate: bool :param partition: target partition as a dict of partition columns and values. (templated) :type partition: dict :param headers: whether the file contains column names on the first line :type headers: bool :param check_headers: whether the column names on the first line should be checked against the keys of field_dict :type check_headers: bool :param wildcard_match: whether the s3_key should be interpreted as a Unix wildcard pattern :type wildcard_match: bool :param aws_conn_id: source s3 connection :type aws_conn_id: str :param verify: Whether or not to verify SSL certificates for S3 connection. By default SSL certificates are verified. You can provide the following values: - ``False``: do not validate SSL certificates. SSL will still be used (unless use_ssl is False), but SSL certificates will not be verified. - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. You can specify this argument if you want to use a different CA cert bundle than the one used by botocore. :type verify: bool or str :param hive_cli_conn_id: destination hive connection :type hive_cli_conn_id: str :param input_compressed: Boolean to determine if file decompression is required to process headers :type input_compressed: bool :param tblproperties: TBLPROPERTIES of the hive table being created :type tblproperties: dict :param select_expression: S3 Select expression :type select_expression: str """ template_fields = ('s3_key', 'partition', 'hive_table') template_ext = () ui_color = '#a0e08c' @apply_defaults def __init__( self, s3_key: str, field_dict: Dict, hive_table: str, delimiter: str = ',', create: bool = True, recreate: bool = False, partition: Optional[Dict] = None, headers: bool = False, check_headers: bool = False, wildcard_match: bool = False, aws_conn_id: str = 'aws_default', verify: Optional[Union[bool, str]] = None, hive_cli_conn_id: str = 'hive_cli_default', input_compressed: bool = False, tblproperties: Optional[Dict] = None, select_expression: Optional[str] = None, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.s3_key = s3_key self.field_dict = field_dict self.hive_table = hive_table self.delimiter = delimiter self.create = create self.recreate = recreate self.partition = partition self.headers = headers self.check_headers = check_headers self.wildcard_match = wildcard_match self.hive_cli_conn_id = hive_cli_conn_id self.aws_conn_id = aws_conn_id self.verify = verify self.input_compressed = input_compressed self.tblproperties = tblproperties self.select_expression = select_expression if (self.check_headers and not (self.field_dict is not None and self.headers)): raise AirflowException("To check_headers provide " + "field_dict and headers") def execute(self, context): # Downloading file from S3 self.s3 = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) self.log.info("Downloading S3 file") if self.wildcard_match: if not self.s3.check_for_wildcard_key(self.s3_key): raise AirflowException("No key matches {0}" .format(self.s3_key)) s3_key_object = self.s3.get_wildcard_key(self.s3_key) else: if not self.s3.check_for_key(self.s3_key): raise AirflowException( "The key {0} does not exists".format(self.s3_key)) s3_key_object = self.s3.get_key(self.s3_key) _, file_ext = os.path.splitext(s3_key_object.key) if (self.select_expression and self.input_compressed and file_ext.lower() != '.gz'): raise AirflowException("GZIP is the only compression " + "format Amazon S3 Select supports") with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\ NamedTemporaryFile(mode="wb", dir=tmp_dir, suffix=file_ext) as f: self.log.info( "Dumping S3 key %s contents to local file %s", s3_key_object.key, f.name ) if self.select_expression: option = {} if self.headers: option['FileHeaderInfo'] = 'USE' if self.delimiter: option['FieldDelimiter'] = self.delimiter input_serialization = {'CSV': option} if self.input_compressed: input_serialization['CompressionType'] = 'GZIP' content = self.s3.select_key( bucket_name=s3_key_object.bucket_name, key=s3_key_object.key, expression=self.select_expression, input_serialization=input_serialization ) f.write(content.encode("utf-8")) else: s3_key_object.download_fileobj(f) f.flush() if self.select_expression or not self.headers: self.log.info("Loading file %s into Hive", f.name) self.hive.load_file( f.name, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties) else: # Decompressing file if self.input_compressed: self.log.info("Uncompressing file %s", f.name) fn_uncompressed = uncompress_file(f.name, file_ext, tmp_dir) self.log.info("Uncompressed to %s", fn_uncompressed) # uncompressed file available now so deleting # compressed file to save disk space f.close() else: fn_uncompressed = f.name # Testing if header matches field_dict if self.check_headers: self.log.info("Matching file header against field_dict") header_list = self._get_top_row_as_list(fn_uncompressed) if not self._match_headers(header_list): raise AirflowException("Header check failed") # Deleting top header row self.log.info("Removing header from file %s", fn_uncompressed) headless_file = ( self._delete_top_row_and_compress(fn_uncompressed, file_ext, tmp_dir)) self.log.info("Headless file %s", headless_file) self.log.info("Loading file %s into Hive", headless_file) self.hive.load_file(headless_file, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties) def _get_top_row_as_list(self, file_name): with open(file_name, 'rt') as file: header_line = file.readline().strip() header_list = header_line.split(self.delimiter) return header_list def _match_headers(self, header_list): if not header_list: raise AirflowException("Unable to retrieve header row from file") field_names = self.field_dict.keys() if len(field_names) != len(header_list): self.log.warning( "Headers count mismatch File headers:\n %s\nField names: \n %s\n", header_list, field_names ) return False test_field_match = [h1.lower() == h2.lower() for h1, h2 in zip(header_list, field_names)] if not all(test_field_match): self.log.warning( "Headers do not match field names File headers:\n %s\nField names: \n %s\n", header_list, field_names ) return False else: return True @staticmethod def _delete_top_row_and_compress( input_file_name, output_file_ext, dest_dir): # When output_file_ext is not defined, file is not compressed open_fn = open if output_file_ext.lower() == '.gz': open_fn = gzip.GzipFile elif output_file_ext.lower() == '.bz2': open_fn = bz2.BZ2File _, fn_output = tempfile.mkstemp(suffix=output_file_ext, dir=dest_dir) with open(input_file_name, 'rb') as f_in, open_fn(fn_output, 'wb') as f_out: f_in.seek(0) next(f_in) for line in f_in: f_out.write(line) return fn_output
class S3ToHiveTransfer(BaseOperator): """ Moves data from S3 to Hive. The operator downloads a file from S3, stores the file locally before loading it into a Hive table. If the ``create`` or ``recreate`` arguments are set to ``True``, a ``CREATE TABLE`` and ``DROP TABLE`` statements are generated. Hive data types are inferred from the cursor's metadata from. Note that the table generated in Hive uses ``STORED AS textfile`` which isn't the most efficient serialization format. If a large amount of data is loaded and/or if the tables gets queried considerably, you may want to use this operator only to stage the data into a temporary table before loading it into its final destination using a ``HiveOperator``. :param s3_key: The key to be retrieved from S3 :type s3_key: str :param field_dict: A dictionary of the fields name in the file as keys and their Hive types as values :type field_dict: dict :param hive_table: target Hive table, use dot notation to target a specific database :type hive_table: str :param create: whether to create the table if it doesn't exist :type create: bool :param recreate: whether to drop and recreate the table at every execution :type recreate: bool :param partition: target partition as a dict of partition columns and values :type partition: dict :param headers: whether the file contains column names on the first line :type headers: bool :param check_headers: whether the column names on the first line should be checked against the keys of field_dict :type check_headers: bool :param wildcard_match: whether the s3_key should be interpreted as a Unix wildcard pattern :type wildcard_match: bool :param delimiter: field delimiter in the file :type delimiter: str :param s3_conn_id: source s3 connection :type s3_conn_id: str :param hive_conn_id: destination hive connection :type hive_conn_id: str """ template_fields = ('s3_key', 'partition', 'hive_table') template_ext = () ui_color = '#a0e08c' @apply_defaults def __init__(self, s3_key, field_dict, hive_table, delimiter=',', create=True, recreate=False, partition=None, headers=False, check_headers=False, wildcard_match=False, s3_conn_id='s3_default', hive_cli_conn_id='hive_cli_default', *args, **kwargs): super(S3ToHiveTransfer, self).__init__(*args, **kwargs) self.s3_key = s3_key self.field_dict = field_dict self.hive_table = hive_table self.delimiter = delimiter self.create = create self.recreate = recreate self.partition = partition self.headers = headers self.check_headers = check_headers self.wildcard_match = wildcard_match self.hive_cli_conn_id = hive_cli_conn_id self.s3_conn_id = s3_conn_id def execute(self, context): self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) self.s3 = S3Hook(s3_conn_id=self.s3_conn_id) _log.info("Downloading S3 file") if self.wildcard_match: if not self.s3.check_for_wildcard_key(self.s3_key): raise AirflowException("No key matches {0}".format( self.s3_key)) s3_key_object = self.s3.get_wildcard_key(self.s3_key) else: if not self.s3.check_for_key(self.s3_key): raise AirflowException("The key {0} does not exists".format( self.s3_key)) s3_key_object = self.s3.get_key(self.s3_key) with NamedTemporaryFile("w") as f: _log.info("Dumping S3 key {0} contents to local" " file {1}".format(s3_key_object.key, f.name)) s3_key_object.get_contents_to_file(f) f.flush() self.s3.connection.close() if not self.headers: _log.info("Loading file into Hive") self.hive.load_file(f.name, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate) else: with open(f.name, 'r') as tmpf: if self.check_headers: header_l = tmpf.readline() header_line = header_l.rstrip() header_list = header_line.split(self.delimiter) field_names = list(self.field_dict.keys()) test_field_match = [ h1.lower() == h2.lower() for h1, h2 in zip(header_list, field_names) ] if not all(test_field_match): _log.warning("Headers do not match field names" "File headers:\n {header_list}\n" "Field names: \n {field_names}\n" "".format(**locals())) raise AirflowException("Headers do not match the " "field_dict keys") with NamedTemporaryFile("w") as f_no_headers: tmpf.seek(0) next(tmpf) for line in tmpf: f_no_headers.write(line) f_no_headers.flush() _log.info("Loading file without headers into Hive") self.hive.load_file(f_no_headers.name, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate)