Example #1
0
    def test_real_get_and_set(self):
        hook = RedisHook(redis_conn_id='redis_default')
        redis = hook.get_conn()

        self.assertTrue(redis.set('test_key', 'test_value'), 'Connection to Redis with SET works.')
        self.assertEqual(redis.get('test_key'), b'test_value', 'Connection to Redis with GET works.')
        self.assertEqual(redis.delete('test_key'), 1, 'Connection to Redis with DELETE works.')
Example #2
0
 def __init__(self, channels, redis_conn_id, *args, **kwargs):
     super().__init__(*args, **kwargs)
     self.channels = channels
     self.redis_conn_id = redis_conn_id
     self.pubsub = RedisHook(
         redis_conn_id=self.redis_conn_id).get_conn().pubsub()
     self.pubsub.subscribe(self.channels)
Example #3
0
    def clear_redis(self, dag_name: str) -> None:
        redis_conn = RedisHook(self.redis_conn_id).get_conn()
        ns_keys = dag_name + '*'

        for key in redis_conn.scan_iter(ns_keys):
            redis_conn.delete(key)
            self.log.warning(f'Delete key: {key} in Redis')
Example #4
0
    def test_poke_true(self):
        sensor = RedisPubSubSensor(task_id='test_task',
                                   dag=self.dag,
                                   channels='test',
                                   redis_conn_id='redis_default')

        hook = RedisHook(redis_conn_id='redis_default')
        redis = hook.get_conn()
        redis.publish('test', 'message')

        result = sensor.poke(self.mock_context)
        self.assertFalse(result)
        result = sensor.poke(self.mock_context)
        self.assertTrue(result)
        context_calls = [
            call.xcom_push(key='message',
                           value={
                               'type': 'message',
                               'pattern': None,
                               'channel': b'test',
                               'data': b'message'
                           })
        ]
        self.assertTrue(self.mock_context['ti'].method_calls == context_calls,
                        "context calls should be same")
        result = sensor.poke(self.mock_context)
        self.assertFalse(result)
 def test_poke(self):
     hook = RedisHook(redis_conn_id='redis_default')
     redis = hook.get_conn()
     redis.set('test_key', 'test_value')
     self.assertTrue(self.sensor.poke(None), "Key exists on first call.")
     redis.delete('test_key')
     self.assertFalse(self.sensor.poke(None), "Key does NOT exists on second call.")
Example #6
0
    def test_execute_hello(self):
        operator = RedisPublishOperator(
            task_id='test_task',
            dag=self.dag,
            message='hello',
            channel=self.channel,
            redis_conn_id='redis_default',
        )

        hook = RedisHook(redis_conn_id='redis_default')
        pubsub = hook.get_conn().pubsub()
        pubsub.subscribe(self.channel)

        operator.execute(self.mock_context)
        context_calls = []
        assert self.mock_context[
            'ti'].method_calls == context_calls, "context calls should be same"

        message = pubsub.get_message()
        assert message['type'] == 'subscribe'

        message = pubsub.get_message()
        assert message['type'] == 'message'
        assert message['data'] == b'hello'

        pubsub.unsubscribe(self.channel)
Example #7
0
 def serialize_value(value: Any):
     hook = RedisHook(redis_conn_id=XComRedisBackend.CONN_ID)
     key = str(uuid4())
     # We use the default serializer, which pickles or JSONs
     hook.get_conn().set(key, pickle.dumps(value))
     # Add prefix to make it clear where the value is stored.
     value = XComRedisBackend.PREFIX + key
     return BaseXCom.serialize_value(value)
Example #8
0
 def __init__(self, *, channels: Union[List[str], str], redis_conn_id: str,
              **kwargs) -> None:
     super().__init__(**kwargs)
     self.channels = channels
     self.redis_conn_id = redis_conn_id
     self.pubsub = RedisHook(
         redis_conn_id=self.redis_conn_id).get_conn().pubsub()
     self.pubsub.subscribe(self.channels)
Example #9
0
    def test_get_conn(self):
        hook = RedisHook(redis_conn_id='redis_default')
        assert hook.redis is None

        assert hook.host is None, 'host initialised as None.'
        assert hook.port is None, 'port initialised as None.'
        assert hook.password is None, 'password initialised as None.'
        assert hook.db is None, 'db initialised as None.'
        assert hook.get_conn() is hook.get_conn(), 'Connection initialized only if None.'
Example #10
0
    def test_get_conn(self):
        hook = RedisHook(redis_conn_id='redis_default')
        self.assertEqual(hook.redis, None)

        self.assertEqual(hook.host, None, 'host initialised as None.')
        self.assertEqual(hook.port, None, 'port initialised as None.')
        self.assertEqual(hook.password, None, 'password initialised as None.')
        self.assertEqual(hook.db, None, 'db initialised as None.')
        self.assertIs(hook.get_conn(), hook.get_conn(), 'Connection initialized only if None.')
Example #11
0
 def deserialize_value(result) -> Any:
     result = BaseXCom.deserialize_value(result)
     prefix = XComRedisBackend.PREFIX
     if isinstance(result, str) and result.startswith(prefix):
         key = result.replace(prefix, "")
         hook = RedisHook(redis_conn_id=XComRedisBackend.CONN_ID)
         result = hook.get_conn().get(key)
         result = pickle.loads(result)
     return result
 def serialize_value(value: Any):
     """
     Docstring goes here
     """
     hook = RedisHook()
     hook.get_conn()
     redis = hook.redis
     key = f"data_{uuid.uuid4()}"
     xcom = {key: json.dumps(value)}
     redis.mset(xcom)
     return BaseXCom.serialize_value(key)
 def deserialize_value(result) -> Any:
     """
     Docstring goes here
     """
     result = BaseXCom.deserialize_value(result)
     if isinstance(result, str):
         hook = RedisHook()
         hook.get_conn()
         redis = hook.redis
         xcom = redis.mget(result)
         result = eval(xcom[0])
     return result
Example #14
0
    def execute(self, context: Dict) -> None:
        """
        Publish the message to Redis channel

        :param context: the context object
        :type context: dict
        """
        redis_hook = RedisHook(redis_conn_id=self.redis_conn_id)

        self.log.info('Sending messsage %s to Redis on channel %s', self.message, self.channel)

        result = redis_hook.get_conn().publish(channel=self.channel, message=self.message)

        self.log.info('Result of publishing %s', result)
    def execute(self, context):
        oracle = OracleHelper(self.oracle_conn_id)
        redis = RedisHook(self.redis_conn_id)
        self.log.info(f"Executing SQL:{self.sql}")

        self.log.info("Extracting data from Oracle")
        conn_redis = redis.get_conn()
        records = oracle.get_rows_with_bind(sql=self.sql,
                                            bind=self.dict_bind)

        self.log.info("Inserting rows into Redis")
        pipe = conn_redis.pipeline()
        [pipe.lpush(self.name_redis_key, str(row)) for row in records]
        pipe.execute()
        self.log.info(f"Inserted {len(records)} rows.")
Example #16
0
class RedisPubSubSensor(BaseSensorOperator):
    """
    Redis sensor for reading a message from pub sub channels

    :param channels: The channels to be subscribed to (templated)
    :type channels: str or list of str
    :param redis_conn_id: the redis connection id
    :type redis_conn_id: str
    """

    template_fields = ('channels', )
    ui_color = '#f0eee4'

    @apply_defaults
    def __init__(self, *, channels: Union[List[str], str], redis_conn_id: str,
                 **kwargs) -> None:
        super().__init__(**kwargs)
        self.channels = channels
        self.redis_conn_id = redis_conn_id
        self.pubsub = RedisHook(
            redis_conn_id=self.redis_conn_id).get_conn().pubsub()
        self.pubsub.subscribe(self.channels)

    def poke(self, context: Dict) -> bool:
        """
        Check for message on subscribed channels and write to xcom the message with key ``message``

        An example of message ``{'type': 'message', 'pattern': None, 'channel': b'test', 'data': b'hello'}``

        :param context: the context object
        :type context: dict
        :return: ``True`` if message (with type 'message') is available or ``False`` if not
        """
        self.log.info('RedisPubSubSensor checking for message on channels: %s',
                      self.channels)

        message = self.pubsub.get_message()
        self.log.info('Message %s from channel %s', message, self.channels)

        # Process only message types
        if message and message['type'] == 'message':

            context['ti'].xcom_push(key='message', value=message)
            self.pubsub.unsubscribe(self.channels)

            return True

        return False
Example #17
0
    def generate_pagination(self, **context) -> None:
        """:return
            file_name_000000000000000_total_pg_01-01-1900 = 50
        """
        self.log.info(f'Getting pagination total ...')
        ns_keys = context['current_dag_name'] + '_' + '*'
        name_var_total_pg = context['current_dag_name'] + '_' + 'total_pg'
        list_dates_redis_bin = RedisHook(self.redis_conn_id).get_conn().scan_iter(ns_keys)
        list_keys_redis = sorted((x.decode("utf-8") for x in list_dates_redis_bin))
        total_pg = 0

        for date, key in zip(context['list_current_dates'], list_keys_redis):
            list_records = self.get_list_redis(redis_key=key)

            # pg by date
            total_pg_by_date = math.ceil(
                len(list_records) / int(context['items_by_query']))  # e.g: 3.5 to 4
            name_key_pg = name_var_total_pg + '_' + date
            Variable.set(key=name_key_pg, value=total_pg_by_date)
            self.log.info(f'{name_key_pg} = {total_pg_by_date}')

            total_pg += total_pg_by_date

        # total pg
        Variable.set(key=name_var_total_pg, value=total_pg)
Example #18
0
    def test_get_conn_with_extra_config(self, mock_get_connection, mock_redis):
        connection = mock_get_connection.return_value
        hook = RedisHook()

        hook.get_conn()
        mock_redis.assert_called_once_with(
            host=connection.host,
            password=connection.password,
            port=connection.port,
            db=connection.extra_dejson["db"],
            ssl=connection.extra_dejson["ssl"],
            ssl_cert_reqs=connection.extra_dejson["ssl_cert_reqs"],
            ssl_ca_certs=connection.extra_dejson["ssl_ca_certs"],
            ssl_keyfile=connection.extra_dejson["ssl_keyfile"],
            ssl_cert_file=connection.extra_dejson["ssl_cert_file"],
            ssl_check_hostname=connection.extra_dejson["ssl_check_hostname"])
Example #19
0
    def fill_data_gaps(self, **context) -> None:
        """Fill data without dt_ref
        Example:
            input:  ('ef57132e-0d25-4e75-bee6-158c01c5b360', 50000, None)
            output: ('e85e10b5-43af-4721-b832-9b9b4bb366fe', 10708, '01-01-1900')
        """
        list_records = self.get_list_redis(context['redis_key'])
        pipe = RedisHook(self.redis_conn_id).get_conn().pipeline()
        df = pd.DataFrame(data=list_records)
        col_date_ref = list(df.columns)[-1]
        df[col_date_ref].replace(to_replace=[None], value='01-01-1900', inplace=True)
        records = [tuple(x) for x in df.to_numpy()]

        [pipe.lpush(context['current_dag_name'], str(row)) for row in records]
        pipe.execute()
        self.log.info(f"\nSample rows:\n{df.head(5)}")
Example #20
0
    def split_id_by_date(self, **context) -> None:
        """Create redis key by date

        example (redis keys):
            input: file_name_000000000000000
            output: file_name_000000000000000_01-01-1900, file_name_000000000000000_24-12-2020
        """
        logging.info(f'Spliting IDs by date ...')
        pipe = RedisHook(self.redis_conn_id).get_conn().pipeline()
        list_records = self.get_list_redis(context['redis_key'])

        for date in context['list_current_dates']:
            list_item = [item for item in list_records if date in item[-1]]

            name_redis_key = context['current_dag_name'] + '_' + date
            [pipe.lpush(name_redis_key, str(row)) for row in list_item]
            pipe.execute()
            self.log.info(f'{date} - Storaged at redis {name_redis_key} = {len(list_item)}')
Example #21
0
    def get_list_redis(self, redis_key: str) -> list:
        """Get all objects from a Redis list of type key

        Args:
        ----------
        redis_conn_id : str
            connection name set in: airflow > admin > connections
        redis_key : str
            key name in redis.

        Returns:
        ----------
            A list containing string objects without repetition.
            Example: ['01-01-2021', '02-01-2021', '03-01-2021']
        """
        set_records = set(RedisHook(self.redis_conn_id)
                          .get_conn() \
                          .lrange(redis_key, 0, -1))

        return [eval(x.decode("utf-8")) for x in set_records]
Example #22
0
 def poke(self, context):
     self.log.info('Sensor checks for existence of key: %s', self.key)
     return RedisHook(self.redis_conn_id).get_conn().exists(self.key)
Example #23
0
 def test_get_conn_password_stays_none(self):
     hook = RedisHook(redis_conn_id='redis_default')
     hook.get_conn()
     assert hook.password is None
Example #24
0
 def test_get_conn_password_stays_none(self):
     hook = RedisHook(redis_conn_id='redis_default')
     hook.get_conn()
     self.assertEqual(hook.password, None)
Example #25
0
    def test_real_ping(self):
        hook = RedisHook(redis_conn_id='redis_default')
        redis = hook.get_conn()

        self.assertTrue(redis.ping(), 'Connection to Redis with PING works.')
Example #26
0
 def get_hook(self):
     if self.conn_type == 'mysql':
         from airflow.providers.mysql.hooks.mysql import MySqlHook
         return MySqlHook(mysql_conn_id=self.conn_id)
     elif self.conn_type == 'google_cloud_platform':
         from airflow.gcp.hooks.bigquery import BigQueryHook
         return BigQueryHook(bigquery_conn_id=self.conn_id)
     elif self.conn_type == 'postgres':
         from airflow.providers.postgres.hooks.postgres import PostgresHook
         return PostgresHook(postgres_conn_id=self.conn_id)
     elif self.conn_type == 'pig_cli':
         from airflow.providers.apache.pig.hooks.pig import PigCliHook
         return PigCliHook(pig_cli_conn_id=self.conn_id)
     elif self.conn_type == 'hive_cli':
         from airflow.providers.apache.hive.hooks.hive import HiveCliHook
         return HiveCliHook(hive_cli_conn_id=self.conn_id)
     elif self.conn_type == 'presto':
         from airflow.providers.presto.hooks.presto import PrestoHook
         return PrestoHook(presto_conn_id=self.conn_id)
     elif self.conn_type == 'hiveserver2':
         from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook
         return HiveServer2Hook(hiveserver2_conn_id=self.conn_id)
     elif self.conn_type == 'sqlite':
         from airflow.providers.sqlite.hooks.sqlite import SqliteHook
         return SqliteHook(sqlite_conn_id=self.conn_id)
     elif self.conn_type == 'jdbc':
         from airflow.providers.jdbc.hooks.jdbc import JdbcHook
         return JdbcHook(jdbc_conn_id=self.conn_id)
     elif self.conn_type == 'mssql':
         from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
         return MsSqlHook(mssql_conn_id=self.conn_id)
     elif self.conn_type == 'odbc':
         from airflow.providers.odbc.hooks.odbc import OdbcHook
         return OdbcHook(odbc_conn_id=self.conn_id)
     elif self.conn_type == 'oracle':
         from airflow.providers.oracle.hooks.oracle import OracleHook
         return OracleHook(oracle_conn_id=self.conn_id)
     elif self.conn_type == 'vertica':
         from airflow.providers.vertica.hooks.vertica import VerticaHook
         return VerticaHook(vertica_conn_id=self.conn_id)
     elif self.conn_type == 'cloudant':
         from airflow.providers.cloudant.hooks.cloudant import CloudantHook
         return CloudantHook(cloudant_conn_id=self.conn_id)
     elif self.conn_type == 'jira':
         from airflow.providers.jira.hooks.jira import JiraHook
         return JiraHook(jira_conn_id=self.conn_id)
     elif self.conn_type == 'redis':
         from airflow.providers.redis.hooks.redis import RedisHook
         return RedisHook(redis_conn_id=self.conn_id)
     elif self.conn_type == 'wasb':
         from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
         return WasbHook(wasb_conn_id=self.conn_id)
     elif self.conn_type == 'docker':
         from airflow.providers.docker.hooks.docker import DockerHook
         return DockerHook(docker_conn_id=self.conn_id)
     elif self.conn_type == 'azure_data_lake':
         from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook
         return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id)
     elif self.conn_type == 'azure_cosmos':
         from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook
         return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id)
     elif self.conn_type == 'cassandra':
         from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook
         return CassandraHook(cassandra_conn_id=self.conn_id)
     elif self.conn_type == 'mongo':
         from airflow.providers.mongo.hooks.mongo import MongoHook
         return MongoHook(conn_id=self.conn_id)
     elif self.conn_type == 'gcpcloudsql':
         from airflow.gcp.hooks.cloud_sql import CloudSQLDatabaseHook
         return CloudSQLDatabaseHook(gcp_cloudsql_conn_id=self.conn_id)
     elif self.conn_type == 'grpc':
         from airflow.providers.grpc.hooks.grpc import GrpcHook
         return GrpcHook(grpc_conn_id=self.conn_id)
     raise AirflowException("Unknown hook type {}".format(self.conn_type))