Example #1
0
    def test_reader_wo_table(self, mock_spark_sql):
        d = [{'foo': 'bar'}]
        mock_spark_sql.return_value = spark.createDataFrame(d)

        reader = HiveReader()
        db = 'default'
        with self.assertRaises(ValueError):
            df = reader.select(db=db)
            mock_spark_sql.assert_not_called()
Example #2
0
    def test_reader_wo_db(self, mock_spark_sql):
        d = [{'foo': 'bar'}]
        mock_spark_sql.return_value = spark.createDataFrame(d)

        reader = HiveReader()
        table = 'test'
        df = reader.select(table=table)
        mock_spark_sql.assert_called_with("select * from `default`.`test`")
        mock_spark_sql.assert_called_once()
        self.assertTrue(isinstance(df, DataFrame))
Example #3
0
    def test_read_df_w_invalid_partition_not_basestring(self, mock_spark_sql):
        d = [{'foo': 'bar'}]
        mock_spark_sql.return_value = spark.createDataFrame(d)

        reader = HiveReader()
        db = 'default'
        table = 'test'
        condition = ["dt > '2018-01-01'", "dt < '2018-01-02'"]
        with self.assertRaises(ValueError):
            df = reader.select(db, table, condition=condition)
            mock_spark_sql.assert_not_called()
Example #4
0
    def test_reader_is_dataframe_w_partition(self, mock_spark_sql):
        d = [{
            'foo': 'bar',
            'dt': '2018-01-01'
        }, {
            'foo': 'qoo',
            'dt': '2018-01-02'
        }]
        mock_spark_sql.return_value = spark.createDataFrame(d)

        reader = HiveReader()
        db = 'default'
        table = 'test'
        condition = "dt > '2018-01-01'"
        df = reader.select(db, table, condition=condition)
        mock_spark_sql.assert_called_with(
            "select * from `default`.`test` where dt > '2018-01-01'")
        mock_spark_sql.assert_called_once()
        self.assertTrue(isinstance(df, DataFrame))
Example #5
0
 def __attrs_post_init__(self):
     self._table = Table(self.schema_path, self.schema_mapper, self.schema_parser)
     self.schema = self._table.schema
     self.reader = HiveReader(self._table)
     self.db_name = self.reader.db_name
     self.table_name = self.reader.table_name
Example #6
0
class Driver(object):
    src = attr.ib(default='hive')
    schema_path = attr.ib(default=attr.Factory(str))
    schema_mapper = attr.ib(default=attr.Factory(dict))
    schema_parser = attr.ib(default='simple')
    _table = attr.ib(init=False)
    schema = attr.ib(init=False)
    reader = attr.ib(init=False)
    table_name = attr.ib(init=False)
    db_name = attr.ib(init=False)
    _df = attr.ib(init=False, default=None)
    _valid_df = attr.ib(init=False, default=None)
    _sum_df = attr.ib(init=False, default=None)

    def __attrs_post_init__(self):
        self._table = Table(self.schema_path, self.schema_mapper, self.schema_parser)
        self.schema = self._table.schema
        self.reader = HiveReader(self._table)
        self.db_name = self.reader.db_name
        self.table_name = self.reader.table_name

    @property
    def df(self):
        if not self._df:
            raise ValueError("df is not ready, please 'read' it first.")
        else:
            return self._df

    @property
    def valid_df(self):
        if not self._valid_df:
            raise ValueError("valid df is not ready, please 'validate' it first.")
        else:
            return self._valid_df

    @property
    def sum_df(self):
        if not self._sum_df:
            raise ValueError("Please check sum first.")
        else:
            return self._sum_df

    def read(self, condition=None, repair=False):
        self._df = self.reader.select(condition=condition, repair=repair)
        return self

    def validate(self, validate_schema=True, rule='bq'):
        self._valid_df = Validator.validate_data(self.df, self.schema)

        if validate_schema is True:
            self._valid_df = Validator.validate_schema(self._valid_df, rule=rule)

        return self

    def check_sum(self):
        if self.df and self.valid_df:
            sum_df = self.df.count()
            sum_valid_df = self.valid_df.count()
            result = sum_df == sum_valid_df

            if result is True:
                logger.info("Check Sum Passed! Result: {}".format(sum_valid_df))
                self._sum_df = sum_df
                return True
            else:
                logger.error("Check Sum Failed! df: {}, valid_df: {}".format(sum_df, sum_valid_df))
                raise ValueError("Check Sum Failed!")

    def write(self, file_format, path_prefix):
        full_path = '{}/{}/{}'.format(path_prefix, self.db_name, self.table_name)
        self._valid_df.write.format(file_format).save(full_path)
        return True
Example #7
0
 def test_reader_w_table(self):
     path = schema_path + 'default.test.json'
     table = Table(path, 'bq')
     reader = HiveReader(table)
     self.assertEqual(reader.table_name, table.name)
     self.assertEqual(reader.db_name, table.db_name)