示例#1
0
 def test_get_variable_first_try(self, mock_env_get, mock_meta_get):
     """
     Test if Variable is present in Environment Variable, it does not look for it in
     Metastore DB
     """
     mock_env_get.return_value = [["something"]]  # returns nonempty list
     get_variable("fake_var_key")
     mock_env_get.assert_called_once_with(key="fake_var_key")
     mock_meta_get.not_called()
示例#2
0
 def test_get_variable_second_try(self, mock_env_get, mock_meta_get):
     """
     Test if Variable is not present in Environment Variable, it then looks for it in
     Metastore DB
     """
     mock_env_get.return_value = None
     get_variable("fake_var_key")
     mock_meta_get.assert_called_once_with(key="fake_var_key")
     mock_env_get.assert_called_once_with(key="fake_var_key")
示例#3
0
 def get(
     cls,
     key: str,
     default_var: Any = __NO_DEFAULT_SENTINEL,
     deserialize_json: bool = False,
 ):
     var_val = get_variable(key=key)
     if var_val is None:
         if default_var is not cls.__NO_DEFAULT_SENTINEL:
             return default_var
         else:
             raise KeyError('Variable {} does not exist'.format(key))
     else:
         if deserialize_json:
             return json.loads(var_val)
         else:
             return var_val
示例#4
0
    def execute(self, context):
        # check if table is created
        # if not, create them
        logger.info('Setting up operator')

        with redis_session() as r:
            # TODO: find other way to handle data race (round-robin?)
            while r.keys('*'):
                logger.info('Not finished previous run, wait for 300 seconds.')
                time.sleep(300)

        start = time.perf_counter()

        if self.mode == 'local':
            _setup_local()
            db_cfg = Config.DATABASE
            conn_id = os.getenv('AIRFLOW_POSTGRES_CONN_ID')
        elif self.mode == 'redshift':
            db_cfg = Config.AWS['REDSHIFT']
            conn_id = os.getenv('AIRFLOW_REDSHIFT_CONN_ID')

        cfg = ConnectionConfig(conn_id=conn_id,
                               host=db_cfg['HOST'],
                               login=db_cfg['USERNAME'],
                               password=db_cfg['PASSWORD'],
                               schema=db_cfg['DB_NAME'],
                               port=db_cfg['PORT'])

        update_connection(cfg)

        for key, val in self.set_variable_keys.items():
            logger.info(f'Setting key="{key}" to Airflow Variable')

            variable = get_variable(key=key)
            if variable is None:
                variable = Variable(key=key)
                variable.set_val(value=val)

                with create_session() as sess:
                    sess.add(variable)

        end = time.perf_counter()
        logger.info(
            f'Process Time [{self.__class__.__name__}]: {end-start:.3f} sec.')
示例#5
0
    def load_variables(self, keys: str or List[str]):
        """Take variables from registered json formatted variables
        and store them to operator's attributes.

        If variable is empty or inappropriately formatted,
        raise exception.

        Example:
            Set value with following format on airflow:
                key:   somple_key
                value: {
                          "timeout": 20,
                          "count": 10
                        }

            and you can get value by:
                >>> operator = SomeCustomOperator(variable_keys='sample_key')
                >>> operator.timeout == 20
                True
                >>> operator.count == 10
                True
        """
        if keys is None:
            return

        if isinstance(keys, str):
            # make iterable
            keys = (keys, )

        for key in keys:
            variables = get_variable(key.replace('-', '_'))

            if variables is None:
                raise InvalidDataFormatError(f'"{key}" is not registered.')

            variables = json.loads(variables)
            if not isinstance(variables, dict):
                raise InvalidDataFormatError(
                    f'"{variables}" is not key-value pairs.')

            for k, v in variables.items():
                setattr(self, k, v)
示例#6
0
    def get(
        cls,
        key: str,
        default_var: Any = __NO_DEFAULT_SENTINEL,
        deserialize_json: bool = False,
    ) -> Any:
        """
        Sets a value for an Airflow Key

        :param key: Variable Key
        :param default_var: Default value of the Variable if the Variable doesn't exists
        :param deserialize_json: Deserialize the value to a Python dict
        """
        var_val = get_variable(key=key)
        if var_val is None:
            if default_var is not cls.__NO_DEFAULT_SENTINEL:
                return default_var
            else:
                raise KeyError('Variable {} does not exist'.format(key))
        else:
            if deserialize_json:
                return json.loads(var_val)
            else:
                return var_val