Esempio n. 1
0
 def test_get_max_partition_from_valid_part_specs(self):
     max_partition = HiveMetastoreHook._get_max_partition_from_part_specs(
         [{
             'key1': 'value1',
             'key2': 'value2'
         }, {
             'key1': 'value3',
             'key2': 'value4'
         }],
         'key1',
         self.VALID_FILTER_MAP,
     )
     self.assertEqual(max_partition, 'value1')
Esempio n. 2
0
 def test_get_max_partition_from_valid_part_specs_return_type(self):
     max_partition = HiveMetastoreHook._get_max_partition_from_part_specs(
         [{
             'key1': 'value1',
             'key2': 'value2'
         }, {
             'key1': 'value3',
             'key2': 'value4'
         }],
         'key1',
         self.VALID_FILTER_MAP,
     )
     assert isinstance(max_partition, str)
Esempio n. 3
0
def closest_ds_partition(table,
                         ds,
                         before=True,
                         schema="default",
                         metastore_conn_id='metastore_default'):
    """
    This function finds the date in a list closest to the target date.
    An optional parameter can be given to get the closest before or after.

    :param table: A hive table name
    :param ds: A datestamp ``%Y-%m-%d`` e.g. ``yyyy-mm-dd``
    :param before: closest before (True), after (False) or either side of ds
    :param schema: table schema
    :param metastore_conn_id: which metastore connection to use
    :returns: The closest date
    :rtype: str or None

    >>> tbl = 'airflow.static_babynames_partitioned'
    >>> closest_ds_partition(tbl, '2015-01-02')
    '2015-01-01'
    """
    from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook

    if '.' in table:
        schema, table = table.split('.')
    hive_hook = HiveMetastoreHook(metastore_conn_id=metastore_conn_id)
    partitions = hive_hook.get_partitions(schema=schema, table_name=table)
    if not partitions:
        return None
    part_vals = [list(p.values())[0] for p in partitions]
    if ds in part_vals:
        return ds
    else:
        parts = [
            datetime.datetime.strptime(pv, '%Y-%m-%d') for pv in part_vals
        ]
        target_dt = datetime.datetime.strptime(ds, '%Y-%m-%d')
        closest_ds = _closest_date(target_dt, parts, before_target=before)
        return closest_ds.isoformat()
Esempio n. 4
0
def max_partition(table,
                  schema="default",
                  field=None,
                  filter_map=None,
                  metastore_conn_id='metastore_default'):
    """
    Gets the max partition for a table.

    :param schema: The hive schema the table lives in
    :type schema: str
    :param table: The hive table you are interested in, supports the dot
        notation as in "my_database.my_table", if a dot is found,
        the schema param is disregarded
    :type table: str
    :param metastore_conn_id: The hive connection you are interested in.
        If your default is set you don't need to use this parameter.
    :type metastore_conn_id: str
    :param filter_map: partition_key:partition_value map used for partition filtering,
                       e.g. {'key1': 'value1', 'key2': 'value2'}.
                       Only partitions matching all partition_key:partition_value
                       pairs will be considered as candidates of max partition.
    :type filter_map: dict
    :param field: the field to get the max value from. If there's only
        one partition field, this will be inferred
    :type field: str

    >>> max_partition('airflow.static_babynames_partitioned')
    '2015-01-01'
    """
    from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook

    if '.' in table:
        schema, table = table.split('.')
    hive_hook = HiveMetastoreHook(metastore_conn_id=metastore_conn_id)
    return hive_hook.max_partition(schema=schema,
                                   table_name=table,
                                   field=field,
                                   filter_map=filter_map)
Esempio n. 5
0
    def execute(self, context: Optional[Dict[str, Any]] = None) -> None:
        metastore = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id)
        table = metastore.get_table(table_name=self.table)
        field_types = {col.name: col.type for col in table.sd.cols}

        exprs: Any = {('', 'count'): 'COUNT(*)'}
        for col, col_type in list(field_types.items()):
            if self.assignment_func:
                assign_exprs = self.assignment_func(col, col_type)
                if assign_exprs is None:
                    assign_exprs = self.get_default_exprs(col, col_type)
            else:
                assign_exprs = self.get_default_exprs(col, col_type)
            exprs.update(assign_exprs)
        exprs.update(self.extra_exprs)
        exprs = OrderedDict(exprs)
        exprs_str = ",\n        ".join([v + " AS " + k[0] + '__' + k[1] for k, v in exprs.items()])

        where_clause_ = ["{} = '{}'".format(k, v) for k, v in self.partition.items()]
        where_clause = " AND\n        ".join(where_clause_)
        sql = "SELECT {exprs_str} FROM {table} WHERE {where_clause};".format(
            exprs_str=exprs_str, table=self.table, where_clause=where_clause
        )

        presto = PrestoHook(presto_conn_id=self.presto_conn_id)
        self.log.info('Executing SQL check: %s', sql)
        row = presto.get_first(hql=sql)
        self.log.info("Record: %s", row)
        if not row:
            raise AirflowException("The query returned None")

        part_json = json.dumps(self.partition, sort_keys=True)

        self.log.info("Deleting rows from previous runs if they exist")
        mysql = MySqlHook(self.mysql_conn_id)
        sql = """
        SELECT 1 FROM hive_stats
        WHERE
            table_name='{table}' AND
            partition_repr='{part_json}' AND
            dttm='{dttm}'
        LIMIT 1;
        """.format(
            table=self.table, part_json=part_json, dttm=self.dttm
        )
        if mysql.get_records(sql):
            sql = """
            DELETE FROM hive_stats
            WHERE
                table_name='{table}' AND
                partition_repr='{part_json}' AND
                dttm='{dttm}';
            """.format(
                table=self.table, part_json=part_json, dttm=self.dttm
            )
            mysql.run(sql)

        self.log.info("Pivoting and loading cells into the Airflow db")
        rows = [
            (self.ds, self.dttm, self.table, part_json) + (r[0][0], r[0][1], r[1]) for r in zip(exprs, row)
        ]
        mysql.insert_rows(
            table='hive_stats',
            rows=rows,
            target_fields=[
                'ds',
                'dttm',
                'table_name',
                'partition_repr',
                'col',
                'metric',
                'value',
            ],
        )
 def test_get_max_partition_from_empty_part_specs(self):
     max_partition = HiveMetastoreHook._get_max_partition_from_part_specs(
         [], 'key1', self.VALID_FILTER_MAP
     )
     self.assertIsNone(max_partition)
 def tearDown(self):
     hook = HiveMetastoreHook()
     with hook.get_conn() as metastore:
         metastore.drop_table(self.database, self.table, deleteData=True)
Esempio n. 8
0
class NamedHivePartitionSensor(BaseSensorOperator):
    """
    Waits for a set of partitions to show up in Hive.

    :param partition_names: List of fully qualified names of the
        partitions to wait for. A fully qualified name is of the
        form ``schema.table/pk1=pv1/pk2=pv2``, for example,
        default.users/ds=2016-01-01. This is passed as is to the metastore
        Thrift client ``get_partitions_by_name`` method. Note that
        you cannot use logical or comparison operators as in
        HivePartitionSensor.
    :type partition_names: list[str]
    :param metastore_conn_id: reference to the metastore thrift service
        connection id
    :type metastore_conn_id: str
    """

    template_fields = ('partition_names', )
    ui_color = '#8d99ae'

    @apply_defaults
    def __init__(self,
                 partition_names,
                 metastore_conn_id='metastore_default',
                 poke_interval=60 * 3,
                 hook=None,
                 *args,
                 **kwargs):
        super().__init__(poke_interval=poke_interval, *args, **kwargs)

        self.next_index_to_poke = 0
        if isinstance(partition_names, str):
            raise TypeError('partition_names must be an array of strings')

        self.metastore_conn_id = metastore_conn_id
        self.partition_names = partition_names
        self.hook = hook
        if self.hook and metastore_conn_id != 'metastore_default':
            self.log.warning(
                'A hook was passed but a non defaul metastore_conn_id=%s was used',
                metastore_conn_id)

    @staticmethod
    def parse_partition_name(partition):
        """Get schema, table, and partition info."""
        first_split = partition.split('.', 1)
        if len(first_split) == 1:
            schema = 'default'
            table_partition = max(first_split)  # poor man first
        else:
            schema, table_partition = first_split
        second_split = table_partition.split('/', 1)
        if len(second_split) == 1:
            raise ValueError('Could not parse ' + partition +
                             'into table, partition')
        else:
            table, partition = second_split
        return schema, table, partition

    def poke_partition(self, partition):
        """Check for a named partition."""
        if not self.hook:
            from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook
            self.hook = HiveMetastoreHook(
                metastore_conn_id=self.metastore_conn_id)

        schema, table, partition = self.parse_partition_name(partition)

        self.log.info('Poking for %s.%s/%s', schema, table, partition)
        return self.hook.check_for_named_partition(schema, table, partition)

    def poke(self, context):

        number_of_partitions = len(self.partition_names)
        poke_index_start = self.next_index_to_poke
        for i in range(number_of_partitions):
            self.next_index_to_poke = (poke_index_start +
                                       i) % number_of_partitions
            if not self.poke_partition(
                    self.partition_names[self.next_index_to_poke]):
                return False

        self.next_index_to_poke = 0
        return True