def test_get_max_partition_from_valid_part_specs(self): max_partition = HiveMetastoreHook._get_max_partition_from_part_specs( [{ 'key1': 'value1', 'key2': 'value2' }, { 'key1': 'value3', 'key2': 'value4' }], 'key1', self.VALID_FILTER_MAP, ) self.assertEqual(max_partition, 'value1')
def test_get_max_partition_from_valid_part_specs_return_type(self): max_partition = HiveMetastoreHook._get_max_partition_from_part_specs( [{ 'key1': 'value1', 'key2': 'value2' }, { 'key1': 'value3', 'key2': 'value4' }], 'key1', self.VALID_FILTER_MAP, ) assert isinstance(max_partition, str)
def closest_ds_partition(table, ds, before=True, schema="default", metastore_conn_id='metastore_default'): """ This function finds the date in a list closest to the target date. An optional parameter can be given to get the closest before or after. :param table: A hive table name :param ds: A datestamp ``%Y-%m-%d`` e.g. ``yyyy-mm-dd`` :param before: closest before (True), after (False) or either side of ds :param schema: table schema :param metastore_conn_id: which metastore connection to use :returns: The closest date :rtype: str or None >>> tbl = 'airflow.static_babynames_partitioned' >>> closest_ds_partition(tbl, '2015-01-02') '2015-01-01' """ from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook if '.' in table: schema, table = table.split('.') hive_hook = HiveMetastoreHook(metastore_conn_id=metastore_conn_id) partitions = hive_hook.get_partitions(schema=schema, table_name=table) if not partitions: return None part_vals = [list(p.values())[0] for p in partitions] if ds in part_vals: return ds else: parts = [ datetime.datetime.strptime(pv, '%Y-%m-%d') for pv in part_vals ] target_dt = datetime.datetime.strptime(ds, '%Y-%m-%d') closest_ds = _closest_date(target_dt, parts, before_target=before) return closest_ds.isoformat()
def max_partition(table, schema="default", field=None, filter_map=None, metastore_conn_id='metastore_default'): """ Gets the max partition for a table. :param schema: The hive schema the table lives in :type schema: str :param table: The hive table you are interested in, supports the dot notation as in "my_database.my_table", if a dot is found, the schema param is disregarded :type table: str :param metastore_conn_id: The hive connection you are interested in. If your default is set you don't need to use this parameter. :type metastore_conn_id: str :param filter_map: partition_key:partition_value map used for partition filtering, e.g. {'key1': 'value1', 'key2': 'value2'}. Only partitions matching all partition_key:partition_value pairs will be considered as candidates of max partition. :type filter_map: dict :param field: the field to get the max value from. If there's only one partition field, this will be inferred :type field: str >>> max_partition('airflow.static_babynames_partitioned') '2015-01-01' """ from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook if '.' in table: schema, table = table.split('.') hive_hook = HiveMetastoreHook(metastore_conn_id=metastore_conn_id) return hive_hook.max_partition(schema=schema, table_name=table, field=field, filter_map=filter_map)
def execute(self, context: Optional[Dict[str, Any]] = None) -> None: metastore = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id) table = metastore.get_table(table_name=self.table) field_types = {col.name: col.type for col in table.sd.cols} exprs: Any = {('', 'count'): 'COUNT(*)'} for col, col_type in list(field_types.items()): if self.assignment_func: assign_exprs = self.assignment_func(col, col_type) if assign_exprs is None: assign_exprs = self.get_default_exprs(col, col_type) else: assign_exprs = self.get_default_exprs(col, col_type) exprs.update(assign_exprs) exprs.update(self.extra_exprs) exprs = OrderedDict(exprs) exprs_str = ",\n ".join([v + " AS " + k[0] + '__' + k[1] for k, v in exprs.items()]) where_clause_ = ["{} = '{}'".format(k, v) for k, v in self.partition.items()] where_clause = " AND\n ".join(where_clause_) sql = "SELECT {exprs_str} FROM {table} WHERE {where_clause};".format( exprs_str=exprs_str, table=self.table, where_clause=where_clause ) presto = PrestoHook(presto_conn_id=self.presto_conn_id) self.log.info('Executing SQL check: %s', sql) row = presto.get_first(hql=sql) self.log.info("Record: %s", row) if not row: raise AirflowException("The query returned None") part_json = json.dumps(self.partition, sort_keys=True) self.log.info("Deleting rows from previous runs if they exist") mysql = MySqlHook(self.mysql_conn_id) sql = """ SELECT 1 FROM hive_stats WHERE table_name='{table}' AND partition_repr='{part_json}' AND dttm='{dttm}' LIMIT 1; """.format( table=self.table, part_json=part_json, dttm=self.dttm ) if mysql.get_records(sql): sql = """ DELETE FROM hive_stats WHERE table_name='{table}' AND partition_repr='{part_json}' AND dttm='{dttm}'; """.format( table=self.table, part_json=part_json, dttm=self.dttm ) mysql.run(sql) self.log.info("Pivoting and loading cells into the Airflow db") rows = [ (self.ds, self.dttm, self.table, part_json) + (r[0][0], r[0][1], r[1]) for r in zip(exprs, row) ] mysql.insert_rows( table='hive_stats', rows=rows, target_fields=[ 'ds', 'dttm', 'table_name', 'partition_repr', 'col', 'metric', 'value', ], )
def test_get_max_partition_from_empty_part_specs(self): max_partition = HiveMetastoreHook._get_max_partition_from_part_specs( [], 'key1', self.VALID_FILTER_MAP ) self.assertIsNone(max_partition)
def tearDown(self): hook = HiveMetastoreHook() with hook.get_conn() as metastore: metastore.drop_table(self.database, self.table, deleteData=True)
class NamedHivePartitionSensor(BaseSensorOperator): """ Waits for a set of partitions to show up in Hive. :param partition_names: List of fully qualified names of the partitions to wait for. A fully qualified name is of the form ``schema.table/pk1=pv1/pk2=pv2``, for example, default.users/ds=2016-01-01. This is passed as is to the metastore Thrift client ``get_partitions_by_name`` method. Note that you cannot use logical or comparison operators as in HivePartitionSensor. :type partition_names: list[str] :param metastore_conn_id: reference to the metastore thrift service connection id :type metastore_conn_id: str """ template_fields = ('partition_names', ) ui_color = '#8d99ae' @apply_defaults def __init__(self, partition_names, metastore_conn_id='metastore_default', poke_interval=60 * 3, hook=None, *args, **kwargs): super().__init__(poke_interval=poke_interval, *args, **kwargs) self.next_index_to_poke = 0 if isinstance(partition_names, str): raise TypeError('partition_names must be an array of strings') self.metastore_conn_id = metastore_conn_id self.partition_names = partition_names self.hook = hook if self.hook and metastore_conn_id != 'metastore_default': self.log.warning( 'A hook was passed but a non defaul metastore_conn_id=%s was used', metastore_conn_id) @staticmethod def parse_partition_name(partition): """Get schema, table, and partition info.""" first_split = partition.split('.', 1) if len(first_split) == 1: schema = 'default' table_partition = max(first_split) # poor man first else: schema, table_partition = first_split second_split = table_partition.split('/', 1) if len(second_split) == 1: raise ValueError('Could not parse ' + partition + 'into table, partition') else: table, partition = second_split return schema, table, partition def poke_partition(self, partition): """Check for a named partition.""" if not self.hook: from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook self.hook = HiveMetastoreHook( metastore_conn_id=self.metastore_conn_id) schema, table, partition = self.parse_partition_name(partition) self.log.info('Poking for %s.%s/%s', schema, table, partition) return self.hook.check_for_named_partition(schema, table, partition) def poke(self, context): number_of_partitions = len(self.partition_names) poke_index_start = self.next_index_to_poke for i in range(number_of_partitions): self.next_index_to_poke = (poke_index_start + i) % number_of_partitions if not self.poke_partition( self.partition_names[self.next_index_to_poke]): return False self.next_index_to_poke = 0 return True