コード例 #1
0
 def test_uncompress_file(self):
     # Testing txt file type
     self.assertRaisesRegexp(NotImplementedError,
                             "^Received .txt format. Only gz and bz2.*",
                             compression.uncompress_file,
                             **{'input_file_name': None,
                                'file_extension': '.txt',
                                'dest_dir': None
                                })
     # Testing gz file type
     fn_txt = self._get_fn('.txt')
     fn_gz = self._get_fn('.gz')
     txt_gz = compression.uncompress_file(fn_gz, '.gz', self.tmp_dir)
     self.assertTrue(filecmp.cmp(txt_gz, fn_txt, shallow=False),
                     msg="Uncompressed file doest match original")
     # Testing bz2 file type
     fn_bz2 = self._get_fn('.bz2')
     txt_bz2 = compression.uncompress_file(fn_bz2, '.bz2', self.tmp_dir)
     self.assertTrue(filecmp.cmp(txt_bz2, fn_txt, shallow=False),
                     msg="Uncompressed file doest match original")
コード例 #2
0
ファイル: test_compression.py プロジェクト: yqian1991/airflow
 def test_uncompress_file(self):
     # Testing txt file type
     self.assertRaisesRegex(
         NotImplementedError, "^Received .txt format. Only gz and bz2.*",
         compression.uncompress_file, **{
             'input_file_name': None,
             'file_extension': '.txt',
             'dest_dir': None
         })
     # Testing gz file type
     fn_txt = self._get_fn('.txt')
     fn_gz = self._get_fn('.gz')
     txt_gz = compression.uncompress_file(fn_gz, '.gz', self.tmp_dir)
     self.assertTrue(filecmp.cmp(txt_gz, fn_txt, shallow=False),
                     msg="Uncompressed file doest match original")
     # Testing bz2 file type
     fn_bz2 = self._get_fn('.bz2')
     txt_bz2 = compression.uncompress_file(fn_bz2, '.bz2', self.tmp_dir)
     self.assertTrue(filecmp.cmp(txt_bz2, fn_txt, shallow=False),
                     msg="Uncompressed file doest match original")
コード例 #3
0
    def execute(self, context):
        # Downloading file from S3
        self.s3 = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
        self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        self.log.info("Downloading S3 file")

        if self.wildcard_match:
            if not self.s3.check_for_wildcard_key(self.s3_key):
                raise AirflowException("No key matches {0}"
                                       .format(self.s3_key))
            s3_key_object = self.s3.get_wildcard_key(self.s3_key)
        else:
            if not self.s3.check_for_key(self.s3_key):
                raise AirflowException(
                    "The key {0} does not exists".format(self.s3_key))
            s3_key_object = self.s3.get_key(self.s3_key)

        _, file_ext = os.path.splitext(s3_key_object.key)
        if (self.select_expression and self.input_compressed and
                file_ext.lower() != '.gz'):
            raise AirflowException("GZIP is the only compression " +
                                   "format Amazon S3 Select supports")

        with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\
                NamedTemporaryFile(mode="wb",
                                   dir=tmp_dir,
                                   suffix=file_ext) as f:
            self.log.info(
                "Dumping S3 key %s contents to local file %s", s3_key_object.key, f.name
            )
            if self.select_expression:
                option = {}
                if self.headers:
                    option['FileHeaderInfo'] = 'USE'
                if self.delimiter:
                    option['FieldDelimiter'] = self.delimiter

                input_serialization = {'CSV': option}
                if self.input_compressed:
                    input_serialization['CompressionType'] = 'GZIP'

                content = self.s3.select_key(
                    bucket_name=s3_key_object.bucket_name,
                    key=s3_key_object.key,
                    expression=self.select_expression,
                    input_serialization=input_serialization
                )
                f.write(content.encode("utf-8"))
            else:
                s3_key_object.download_fileobj(f)
            f.flush()

            if self.select_expression or not self.headers:
                self.log.info("Loading file %s into Hive", f.name)
                self.hive.load_file(
                    f.name,
                    self.hive_table,
                    field_dict=self.field_dict,
                    create=self.create,
                    partition=self.partition,
                    delimiter=self.delimiter,
                    recreate=self.recreate,
                    tblproperties=self.tblproperties)
            else:
                # Decompressing file
                if self.input_compressed:
                    self.log.info("Uncompressing file %s", f.name)
                    fn_uncompressed = uncompress_file(f.name,
                                                      file_ext,
                                                      tmp_dir)
                    self.log.info("Uncompressed to %s", fn_uncompressed)
                    # uncompressed file available now so deleting
                    # compressed file to save disk space
                    f.close()
                else:
                    fn_uncompressed = f.name

                # Testing if header matches field_dict
                if self.check_headers:
                    self.log.info("Matching file header against field_dict")
                    header_list = self._get_top_row_as_list(fn_uncompressed)
                    if not self._match_headers(header_list):
                        raise AirflowException("Header check failed")

                # Deleting top header row
                self.log.info("Removing header from file %s", fn_uncompressed)
                headless_file = (
                    self._delete_top_row_and_compress(fn_uncompressed,
                                                      file_ext,
                                                      tmp_dir))
                self.log.info("Headless file %s", headless_file)
                self.log.info("Loading file %s into Hive", headless_file)
                self.hive.load_file(headless_file,
                                    self.hive_table,
                                    field_dict=self.field_dict,
                                    create=self.create,
                                    partition=self.partition,
                                    delimiter=self.delimiter,
                                    recreate=self.recreate,
                                    tblproperties=self.tblproperties)
コード例 #4
0
    def execute(self, context):
        # Downloading file from S3
        self.s3 = S3Hook(aws_conn_id=self.aws_conn_id)
        self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        self.log.info("Downloading S3 file")

        if self.wildcard_match:
            if not self.s3.check_for_wildcard_key(self.s3_key):
                raise AirflowException("No key matches {0}"
                                       .format(self.s3_key))
            s3_key_object = self.s3.get_wildcard_key(self.s3_key)
        else:
            if not self.s3.check_for_key(self.s3_key):
                raise AirflowException(
                    "The key {0} does not exists".format(self.s3_key))
            s3_key_object = self.s3.get_key(self.s3_key)
        root, file_ext = os.path.splitext(s3_key_object.key)
        with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\
                NamedTemporaryFile(mode="wb",
                                   dir=tmp_dir,
                                   suffix=file_ext) as f:
            self.log.info("Dumping S3 key {0} contents to local file {1}"
                          .format(s3_key_object.key, f.name))
            s3_key_object.download_fileobj(f)
            f.flush()
            if not self.headers:
                self.log.info("Loading file %s into Hive", f.name)
                self.hive.load_file(
                    f.name,
                    self.hive_table,
                    field_dict=self.field_dict,
                    create=self.create,
                    partition=self.partition,
                    delimiter=self.delimiter,
                    recreate=self.recreate,
                    tblproperties=self.tblproperties)
            else:
                # Decompressing file
                if self.input_compressed:
                    self.log.info("Uncompressing file %s", f.name)
                    fn_uncompressed = uncompress_file(f.name,
                                                      file_ext,
                                                      tmp_dir)
                    self.log.info("Uncompressed to %s", fn_uncompressed)
                    # uncompressed file available now so deleting
                    # compressed file to save disk space
                    f.close()
                else:
                    fn_uncompressed = f.name

                # Testing if header matches field_dict
                if self.check_headers:
                    self.log.info("Matching file header against field_dict")
                    header_list = self._get_top_row_as_list(fn_uncompressed)
                    if not self._match_headers(header_list):
                        raise AirflowException("Header check failed")

                # Deleting top header row
                self.log.info("Removing header from file %s", fn_uncompressed)
                headless_file = (
                    self._delete_top_row_and_compress(fn_uncompressed,
                                                      file_ext,
                                                      tmp_dir))
                self.log.info("Headless file %s", headless_file)
                self.log.info("Loading file %s into Hive", headless_file)
                self.hive.load_file(headless_file,
                                    self.hive_table,
                                    field_dict=self.field_dict,
                                    create=self.create,
                                    partition=self.partition,
                                    delimiter=self.delimiter,
                                    recreate=self.recreate,
                                    tblproperties=self.tblproperties)
コード例 #5
0
    def execute(self, context):
        # Downloading file from S3
        self.s3 = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
        self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        self.log.info("Downloading S3 file")

        if self.wildcard_match:
            if not self.s3.check_for_wildcard_key(self.s3_key):
                raise AirflowException("No key matches {0}"
                                       .format(self.s3_key))
            s3_key_object = self.s3.get_wildcard_key(self.s3_key)
        else:
            if not self.s3.check_for_key(self.s3_key):
                raise AirflowException(
                    "The key {0} does not exists".format(self.s3_key))
            s3_key_object = self.s3.get_key(self.s3_key)

        root, file_ext = os.path.splitext(s3_key_object.key)
        if (self.select_expression and self.input_compressed and
                file_ext.lower() != '.gz'):
            raise AirflowException("GZIP is the only compression " +
                                   "format Amazon S3 Select supports")

        with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\
                NamedTemporaryFile(mode="wb",
                                   dir=tmp_dir,
                                   suffix=file_ext) as f:
            self.log.info(
                "Dumping S3 key %s contents to local file %s", s3_key_object.key, f.name
            )
            if self.select_expression:
                option = {}
                if self.headers:
                    option['FileHeaderInfo'] = 'USE'
                if self.delimiter:
                    option['FieldDelimiter'] = self.delimiter

                input_serialization = {'CSV': option}
                if self.input_compressed:
                    input_serialization['CompressionType'] = 'GZIP'

                content = self.s3.select_key(
                    bucket_name=s3_key_object.bucket_name,
                    key=s3_key_object.key,
                    expression=self.select_expression,
                    input_serialization=input_serialization
                )
                f.write(content.encode("utf-8"))
            else:
                s3_key_object.download_fileobj(f)
            f.flush()

            if self.select_expression or not self.headers:
                self.log.info("Loading file %s into Hive", f.name)
                self.hive.load_file(
                    f.name,
                    self.hive_table,
                    field_dict=self.field_dict,
                    create=self.create,
                    partition=self.partition,
                    delimiter=self.delimiter,
                    recreate=self.recreate,
                    tblproperties=self.tblproperties)
            else:
                # Decompressing file
                if self.input_compressed:
                    self.log.info("Uncompressing file %s", f.name)
                    fn_uncompressed = uncompress_file(f.name,
                                                      file_ext,
                                                      tmp_dir)
                    self.log.info("Uncompressed to %s", fn_uncompressed)
                    # uncompressed file available now so deleting
                    # compressed file to save disk space
                    f.close()
                else:
                    fn_uncompressed = f.name

                # Testing if header matches field_dict
                if self.check_headers:
                    self.log.info("Matching file header against field_dict")
                    header_list = self._get_top_row_as_list(fn_uncompressed)
                    if not self._match_headers(header_list):
                        raise AirflowException("Header check failed")

                # Deleting top header row
                self.log.info("Removing header from file %s", fn_uncompressed)
                headless_file = (
                    self._delete_top_row_and_compress(fn_uncompressed,
                                                      file_ext,
                                                      tmp_dir))
                self.log.info("Headless file %s", headless_file)
                self.log.info("Loading file %s into Hive", headless_file)
                self.hive.load_file(headless_file,
                                    self.hive_table,
                                    field_dict=self.field_dict,
                                    create=self.create,
                                    partition=self.partition,
                                    delimiter=self.delimiter,
                                    recreate=self.recreate,
                                    tblproperties=self.tblproperties)
コード例 #6
0
    def execute(self, context):
        # Downloading file from S3
        self.s3 = S3Hook(s3_conn_id=self.s3_conn_id)
        self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        logging.info("Downloading S3 file")

        if self.wildcard_match:
            if not self.s3.check_for_wildcard_key(self.s3_key):
                raise AirflowException("No key matches {0}".format(
                    self.s3_key))
            s3_key_object = self.s3.get_wildcard_key(self.s3_key)
        else:
            if not self.s3.check_for_key(self.s3_key):
                raise AirflowException("The key {0} does not exists".format(
                    self.s3_key))
            s3_key_object = self.s3.get_key(self.s3_key)
        root, file_ext = os.path.splitext(s3_key_object.key)
        with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\
                NamedTemporaryFile(mode="w",
                                   dir=tmp_dir,
                                   suffix=file_ext) as f:
            logging.info("Dumping S3 key {0} contents to local"
                         " file {1}".format(s3_key_object.key, f.name))
            s3_key_object.get_contents_to_file(f)
            f.flush()
            self.s3.connection.close()
            if not self.headers:
                logging.info("Loading file {0} into Hive".format(f.name))
                self.hive.load_file(f.name,
                                    self.hive_table,
                                    field_dict=self.field_dict,
                                    create=self.create,
                                    partition=self.partition,
                                    delimiter=self.delimiter,
                                    recreate=self.recreate,
                                    tblproperties=self.tblproperties)
            else:
                # Decompressing file
                if self.input_compressed:
                    logging.info("Uncompressing file {0}".format(f.name))
                    fn_uncompressed = uncompress_file(f.name, file_ext,
                                                      tmp_dir)
                    logging.info("Uncompressed to {0}".format(fn_uncompressed))
                    # uncompressed file available now so deleting
                    # compressed file to save disk space
                    f.close()
                else:
                    fn_uncompressed = f.name

                # Testing if header matches field_dict
                if self.check_headers:
                    logging.info("Matching file header against field_dict")
                    header_list = self._get_top_row_as_list(fn_uncompressed)
                    if not self._match_headers(header_list):
                        raise AirflowException("Header check failed")

                # Deleting top header row
                logging.info(
                    "Removing header from file {0}".format(fn_uncompressed))
                headless_file = (self._delete_top_row_and_compress(
                    fn_uncompressed, file_ext, tmp_dir))
                logging.info("Headless file {0}".format(headless_file))
                logging.info(
                    "Loading file {0} into Hive".format(headless_file))
                self.hive.load_file(headless_file,
                                    self.hive_table,
                                    field_dict=self.field_dict,
                                    create=self.create,
                                    partition=self.partition,
                                    delimiter=self.delimiter,
                                    recreate=self.recreate,
                                    tblproperties=self.tblproperties)