def test_real_get_and_set(self): hook = RedisHook(redis_conn_id='redis_default') redis = hook.get_conn() self.assertTrue(redis.set('test_key', 'test_value'), 'Connection to Redis with SET works.') self.assertEqual(redis.get('test_key'), b'test_value', 'Connection to Redis with GET works.') self.assertEqual(redis.delete('test_key'), 1, 'Connection to Redis with DELETE works.')
def __init__(self, channels, redis_conn_id, *args, **kwargs): super().__init__(*args, **kwargs) self.channels = channels self.redis_conn_id = redis_conn_id self.pubsub = RedisHook( redis_conn_id=self.redis_conn_id).get_conn().pubsub() self.pubsub.subscribe(self.channels)
def clear_redis(self, dag_name: str) -> None: redis_conn = RedisHook(self.redis_conn_id).get_conn() ns_keys = dag_name + '*' for key in redis_conn.scan_iter(ns_keys): redis_conn.delete(key) self.log.warning(f'Delete key: {key} in Redis')
def test_poke_true(self): sensor = RedisPubSubSensor(task_id='test_task', dag=self.dag, channels='test', redis_conn_id='redis_default') hook = RedisHook(redis_conn_id='redis_default') redis = hook.get_conn() redis.publish('test', 'message') result = sensor.poke(self.mock_context) self.assertFalse(result) result = sensor.poke(self.mock_context) self.assertTrue(result) context_calls = [ call.xcom_push(key='message', value={ 'type': 'message', 'pattern': None, 'channel': b'test', 'data': b'message' }) ] self.assertTrue(self.mock_context['ti'].method_calls == context_calls, "context calls should be same") result = sensor.poke(self.mock_context) self.assertFalse(result)
def test_poke(self): hook = RedisHook(redis_conn_id='redis_default') redis = hook.get_conn() redis.set('test_key', 'test_value') self.assertTrue(self.sensor.poke(None), "Key exists on first call.") redis.delete('test_key') self.assertFalse(self.sensor.poke(None), "Key does NOT exists on second call.")
def test_execute_hello(self): operator = RedisPublishOperator( task_id='test_task', dag=self.dag, message='hello', channel=self.channel, redis_conn_id='redis_default', ) hook = RedisHook(redis_conn_id='redis_default') pubsub = hook.get_conn().pubsub() pubsub.subscribe(self.channel) operator.execute(self.mock_context) context_calls = [] assert self.mock_context[ 'ti'].method_calls == context_calls, "context calls should be same" message = pubsub.get_message() assert message['type'] == 'subscribe' message = pubsub.get_message() assert message['type'] == 'message' assert message['data'] == b'hello' pubsub.unsubscribe(self.channel)
def serialize_value(value: Any): hook = RedisHook(redis_conn_id=XComRedisBackend.CONN_ID) key = str(uuid4()) # We use the default serializer, which pickles or JSONs hook.get_conn().set(key, pickle.dumps(value)) # Add prefix to make it clear where the value is stored. value = XComRedisBackend.PREFIX + key return BaseXCom.serialize_value(value)
def __init__(self, *, channels: Union[List[str], str], redis_conn_id: str, **kwargs) -> None: super().__init__(**kwargs) self.channels = channels self.redis_conn_id = redis_conn_id self.pubsub = RedisHook( redis_conn_id=self.redis_conn_id).get_conn().pubsub() self.pubsub.subscribe(self.channels)
def test_get_conn(self): hook = RedisHook(redis_conn_id='redis_default') assert hook.redis is None assert hook.host is None, 'host initialised as None.' assert hook.port is None, 'port initialised as None.' assert hook.password is None, 'password initialised as None.' assert hook.db is None, 'db initialised as None.' assert hook.get_conn() is hook.get_conn(), 'Connection initialized only if None.'
def test_get_conn(self): hook = RedisHook(redis_conn_id='redis_default') self.assertEqual(hook.redis, None) self.assertEqual(hook.host, None, 'host initialised as None.') self.assertEqual(hook.port, None, 'port initialised as None.') self.assertEqual(hook.password, None, 'password initialised as None.') self.assertEqual(hook.db, None, 'db initialised as None.') self.assertIs(hook.get_conn(), hook.get_conn(), 'Connection initialized only if None.')
def deserialize_value(result) -> Any: result = BaseXCom.deserialize_value(result) prefix = XComRedisBackend.PREFIX if isinstance(result, str) and result.startswith(prefix): key = result.replace(prefix, "") hook = RedisHook(redis_conn_id=XComRedisBackend.CONN_ID) result = hook.get_conn().get(key) result = pickle.loads(result) return result
def serialize_value(value: Any): """ Docstring goes here """ hook = RedisHook() hook.get_conn() redis = hook.redis key = f"data_{uuid.uuid4()}" xcom = {key: json.dumps(value)} redis.mset(xcom) return BaseXCom.serialize_value(key)
def deserialize_value(result) -> Any: """ Docstring goes here """ result = BaseXCom.deserialize_value(result) if isinstance(result, str): hook = RedisHook() hook.get_conn() redis = hook.redis xcom = redis.mget(result) result = eval(xcom[0]) return result
def execute(self, context: Dict) -> None: """ Publish the message to Redis channel :param context: the context object :type context: dict """ redis_hook = RedisHook(redis_conn_id=self.redis_conn_id) self.log.info('Sending messsage %s to Redis on channel %s', self.message, self.channel) result = redis_hook.get_conn().publish(channel=self.channel, message=self.message) self.log.info('Result of publishing %s', result)
def execute(self, context): oracle = OracleHelper(self.oracle_conn_id) redis = RedisHook(self.redis_conn_id) self.log.info(f"Executing SQL:{self.sql}") self.log.info("Extracting data from Oracle") conn_redis = redis.get_conn() records = oracle.get_rows_with_bind(sql=self.sql, bind=self.dict_bind) self.log.info("Inserting rows into Redis") pipe = conn_redis.pipeline() [pipe.lpush(self.name_redis_key, str(row)) for row in records] pipe.execute() self.log.info(f"Inserted {len(records)} rows.")
class RedisPubSubSensor(BaseSensorOperator): """ Redis sensor for reading a message from pub sub channels :param channels: The channels to be subscribed to (templated) :type channels: str or list of str :param redis_conn_id: the redis connection id :type redis_conn_id: str """ template_fields = ('channels', ) ui_color = '#f0eee4' @apply_defaults def __init__(self, *, channels: Union[List[str], str], redis_conn_id: str, **kwargs) -> None: super().__init__(**kwargs) self.channels = channels self.redis_conn_id = redis_conn_id self.pubsub = RedisHook( redis_conn_id=self.redis_conn_id).get_conn().pubsub() self.pubsub.subscribe(self.channels) def poke(self, context: Dict) -> bool: """ Check for message on subscribed channels and write to xcom the message with key ``message`` An example of message ``{'type': 'message', 'pattern': None, 'channel': b'test', 'data': b'hello'}`` :param context: the context object :type context: dict :return: ``True`` if message (with type 'message') is available or ``False`` if not """ self.log.info('RedisPubSubSensor checking for message on channels: %s', self.channels) message = self.pubsub.get_message() self.log.info('Message %s from channel %s', message, self.channels) # Process only message types if message and message['type'] == 'message': context['ti'].xcom_push(key='message', value=message) self.pubsub.unsubscribe(self.channels) return True return False
def generate_pagination(self, **context) -> None: """:return file_name_000000000000000_total_pg_01-01-1900 = 50 """ self.log.info(f'Getting pagination total ...') ns_keys = context['current_dag_name'] + '_' + '*' name_var_total_pg = context['current_dag_name'] + '_' + 'total_pg' list_dates_redis_bin = RedisHook(self.redis_conn_id).get_conn().scan_iter(ns_keys) list_keys_redis = sorted((x.decode("utf-8") for x in list_dates_redis_bin)) total_pg = 0 for date, key in zip(context['list_current_dates'], list_keys_redis): list_records = self.get_list_redis(redis_key=key) # pg by date total_pg_by_date = math.ceil( len(list_records) / int(context['items_by_query'])) # e.g: 3.5 to 4 name_key_pg = name_var_total_pg + '_' + date Variable.set(key=name_key_pg, value=total_pg_by_date) self.log.info(f'{name_key_pg} = {total_pg_by_date}') total_pg += total_pg_by_date # total pg Variable.set(key=name_var_total_pg, value=total_pg)
def test_get_conn_with_extra_config(self, mock_get_connection, mock_redis): connection = mock_get_connection.return_value hook = RedisHook() hook.get_conn() mock_redis.assert_called_once_with( host=connection.host, password=connection.password, port=connection.port, db=connection.extra_dejson["db"], ssl=connection.extra_dejson["ssl"], ssl_cert_reqs=connection.extra_dejson["ssl_cert_reqs"], ssl_ca_certs=connection.extra_dejson["ssl_ca_certs"], ssl_keyfile=connection.extra_dejson["ssl_keyfile"], ssl_cert_file=connection.extra_dejson["ssl_cert_file"], ssl_check_hostname=connection.extra_dejson["ssl_check_hostname"])
def fill_data_gaps(self, **context) -> None: """Fill data without dt_ref Example: input: ('ef57132e-0d25-4e75-bee6-158c01c5b360', 50000, None) output: ('e85e10b5-43af-4721-b832-9b9b4bb366fe', 10708, '01-01-1900') """ list_records = self.get_list_redis(context['redis_key']) pipe = RedisHook(self.redis_conn_id).get_conn().pipeline() df = pd.DataFrame(data=list_records) col_date_ref = list(df.columns)[-1] df[col_date_ref].replace(to_replace=[None], value='01-01-1900', inplace=True) records = [tuple(x) for x in df.to_numpy()] [pipe.lpush(context['current_dag_name'], str(row)) for row in records] pipe.execute() self.log.info(f"\nSample rows:\n{df.head(5)}")
def split_id_by_date(self, **context) -> None: """Create redis key by date example (redis keys): input: file_name_000000000000000 output: file_name_000000000000000_01-01-1900, file_name_000000000000000_24-12-2020 """ logging.info(f'Spliting IDs by date ...') pipe = RedisHook(self.redis_conn_id).get_conn().pipeline() list_records = self.get_list_redis(context['redis_key']) for date in context['list_current_dates']: list_item = [item for item in list_records if date in item[-1]] name_redis_key = context['current_dag_name'] + '_' + date [pipe.lpush(name_redis_key, str(row)) for row in list_item] pipe.execute() self.log.info(f'{date} - Storaged at redis {name_redis_key} = {len(list_item)}')
def get_list_redis(self, redis_key: str) -> list: """Get all objects from a Redis list of type key Args: ---------- redis_conn_id : str connection name set in: airflow > admin > connections redis_key : str key name in redis. Returns: ---------- A list containing string objects without repetition. Example: ['01-01-2021', '02-01-2021', '03-01-2021'] """ set_records = set(RedisHook(self.redis_conn_id) .get_conn() \ .lrange(redis_key, 0, -1)) return [eval(x.decode("utf-8")) for x in set_records]
def poke(self, context): self.log.info('Sensor checks for existence of key: %s', self.key) return RedisHook(self.redis_conn_id).get_conn().exists(self.key)
def test_get_conn_password_stays_none(self): hook = RedisHook(redis_conn_id='redis_default') hook.get_conn() assert hook.password is None
def test_get_conn_password_stays_none(self): hook = RedisHook(redis_conn_id='redis_default') hook.get_conn() self.assertEqual(hook.password, None)
def test_real_ping(self): hook = RedisHook(redis_conn_id='redis_default') redis = hook.get_conn() self.assertTrue(redis.ping(), 'Connection to Redis with PING works.')
def get_hook(self): if self.conn_type == 'mysql': from airflow.providers.mysql.hooks.mysql import MySqlHook return MySqlHook(mysql_conn_id=self.conn_id) elif self.conn_type == 'google_cloud_platform': from airflow.gcp.hooks.bigquery import BigQueryHook return BigQueryHook(bigquery_conn_id=self.conn_id) elif self.conn_type == 'postgres': from airflow.providers.postgres.hooks.postgres import PostgresHook return PostgresHook(postgres_conn_id=self.conn_id) elif self.conn_type == 'pig_cli': from airflow.providers.apache.pig.hooks.pig import PigCliHook return PigCliHook(pig_cli_conn_id=self.conn_id) elif self.conn_type == 'hive_cli': from airflow.providers.apache.hive.hooks.hive import HiveCliHook return HiveCliHook(hive_cli_conn_id=self.conn_id) elif self.conn_type == 'presto': from airflow.providers.presto.hooks.presto import PrestoHook return PrestoHook(presto_conn_id=self.conn_id) elif self.conn_type == 'hiveserver2': from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook return HiveServer2Hook(hiveserver2_conn_id=self.conn_id) elif self.conn_type == 'sqlite': from airflow.providers.sqlite.hooks.sqlite import SqliteHook return SqliteHook(sqlite_conn_id=self.conn_id) elif self.conn_type == 'jdbc': from airflow.providers.jdbc.hooks.jdbc import JdbcHook return JdbcHook(jdbc_conn_id=self.conn_id) elif self.conn_type == 'mssql': from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook return MsSqlHook(mssql_conn_id=self.conn_id) elif self.conn_type == 'odbc': from airflow.providers.odbc.hooks.odbc import OdbcHook return OdbcHook(odbc_conn_id=self.conn_id) elif self.conn_type == 'oracle': from airflow.providers.oracle.hooks.oracle import OracleHook return OracleHook(oracle_conn_id=self.conn_id) elif self.conn_type == 'vertica': from airflow.providers.vertica.hooks.vertica import VerticaHook return VerticaHook(vertica_conn_id=self.conn_id) elif self.conn_type == 'cloudant': from airflow.providers.cloudant.hooks.cloudant import CloudantHook return CloudantHook(cloudant_conn_id=self.conn_id) elif self.conn_type == 'jira': from airflow.providers.jira.hooks.jira import JiraHook return JiraHook(jira_conn_id=self.conn_id) elif self.conn_type == 'redis': from airflow.providers.redis.hooks.redis import RedisHook return RedisHook(redis_conn_id=self.conn_id) elif self.conn_type == 'wasb': from airflow.providers.microsoft.azure.hooks.wasb import WasbHook return WasbHook(wasb_conn_id=self.conn_id) elif self.conn_type == 'docker': from airflow.providers.docker.hooks.docker import DockerHook return DockerHook(docker_conn_id=self.conn_id) elif self.conn_type == 'azure_data_lake': from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id) elif self.conn_type == 'azure_cosmos': from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id) elif self.conn_type == 'cassandra': from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook return CassandraHook(cassandra_conn_id=self.conn_id) elif self.conn_type == 'mongo': from airflow.providers.mongo.hooks.mongo import MongoHook return MongoHook(conn_id=self.conn_id) elif self.conn_type == 'gcpcloudsql': from airflow.gcp.hooks.cloud_sql import CloudSQLDatabaseHook return CloudSQLDatabaseHook(gcp_cloudsql_conn_id=self.conn_id) elif self.conn_type == 'grpc': from airflow.providers.grpc.hooks.grpc import GrpcHook return GrpcHook(grpc_conn_id=self.conn_id) raise AirflowException("Unknown hook type {}".format(self.conn_type))