Пример #1
0
def get_operator_extra_links() -> Set[str]:
    """Get the operator extra links.

    This includes both the built-in ones, and those come from the providers.
    """
    _OPERATOR_EXTRA_LINKS.update(ProvidersManager().extra_links_class_names)
    return _OPERATOR_EXTRA_LINKS
Пример #2
0
    def get_connection_from_secrets(cls, conn_id):
        """Save the calling hook class as provided when called."""
        conn = super().get_connection_from_secrets(conn_id)
        # conn is returned as airflow.hooks.base.BaseHook, return as cls instead
        conn = cls(
            conn_id=conn.conn_id,
            conn_type=conn.conn_type,
            description=conn.description,
            host=conn.host,
            login=conn.login,
            password=conn.password,
            schema=conn.schema,
            port=conn.port,
            extra=conn.extra,
        )

        # add hook_cls attribute
        hook_class_name, _, _, _ = ProvidersManager().hooks.get(
            conn.conn_type, (None, None, None, None)
        )
        if hook_class_name:
            conn.hook_cls = import_string(hook_class_name)
        else:
            conn.hook_cls = EWAHBaseHook
        return conn
Пример #3
0
 def __getattr__(self, name):
     if name.startswith("__"):
         raise AttributeError(f'{type(self).__name__} has no attribute {name!r}')
     decorators = ProvidersManager().taskflow_decorators
     if name not in decorators:
         raise AttributeError(f"task decorator {name!r} not found")
     return decorators[name]
Пример #4
0
 def info(self, console: AirflowConsole):
     table = SimpleTable(title="Providers info")
     table.add_column()
     table.add_column(width=150)
     for _, provider in ProvidersManager().providers.values():
         table.add_row(provider['package-name'], provider['versions'][0])
     console.print(table)
Пример #5
0
def get_providers() -> APIResponse:
    """Get providers"""
    providers = [_provider_mapper(d) for d in ProvidersManager().providers.values()]
    total_entries = len(providers)
    return provider_collection_schema.dump(
        ProviderCollection(providers=providers, total_entries=total_entries)
    )
Пример #6
0
def get_providers():
    """Get providers"""
    providers_info: List[ProviderInfo] = list(
        ProvidersManager().providers.values())
    providers = [_provider_mapper(d) for d in providers_info]
    total_entries = len(providers)
    return provider_collection_schema.dump(
        ProviderCollection(providers=providers, total_entries=total_entries))
Пример #7
0
def _get_provider_info() -> Tuple[str, str]:
    from airflow.providers_manager import ProvidersManager

    manager = ProvidersManager()
    package_name = manager.hooks[DbtCloudHook.conn_type].package_name  # type: ignore[union-attr]
    provider = manager.providers[package_name]

    return package_name, provider.version
Пример #8
0
 def __getattr__(self, name: str) -> TaskDecorator:
     """Dynamically get provider-registered task decorators, e.g. ``@task.docker``."""
     if name.startswith("__"):
         raise AttributeError(f"{type(self).__name__} has no attribute {name!r}")
     decorators = ProvidersManager().taskflow_decorators
     if name not in decorators:
         raise AttributeError(f"task decorator {name!r} not found")
     return decorators[name]
Пример #9
0
def _get_connection_types():
    """Returns connection types available."""
    _connection_types = ['fs', 'mesos_framework-id', 'email', 'generic']
    providers_manager = ProvidersManager()
    for connection_type, provider_info in providers_manager.hooks.items():
        if provider_info:
            _connection_types.append(connection_type)
    return _connection_types
Пример #10
0
    def get_hook(self):
        """Return hook based on conn_type."""
        hook_class_name, conn_id_param = ProvidersManager().hooks.get(
            self.conn_type, (None, None))

        if not hook_class_name:
            raise AirflowException(f'Unknown hook type "{self.conn_type}"')
        hook_class = import_string(hook_class_name)
        return hook_class(**{conn_id_param: self.conn_id})
Пример #11
0
def get_operator_extra_links():
    """
    Returns operator extra links - both the ones that are built in and the ones that come from
    the providers.

    :return: set of extra links
    """
    _OPERATOR_EXTRA_LINKS.update(ProvidersManager().extra_links_class_names)
    return _OPERATOR_EXTRA_LINKS
Пример #12
0
def auth_backend_list(args):
    """Lists all API auth backend modules at the command line"""
    AirflowConsole().print_as(
        data=list(ProvidersManager().auth_backend_module_names),
        output=args.output,
        mapper=lambda x: {
            "api_auth_backand_module": x,
        },
    )
Пример #13
0
def logging_list(args):
    """Lists all log task handlers at the command line"""
    AirflowConsole().print_as(
        data=list(ProvidersManager().logging_class_names),
        output=args.output,
        mapper=lambda x: {
            "logging_class_name": x,
        },
    )
Пример #14
0
def secrets_backends_list(args):
    """Lists all secrets backends at the command line"""
    AirflowConsole().print_as(
        data=list(ProvidersManager().secrets_backend_class_names),
        output=args.output,
        mapper=lambda x: {
            "secrets_backend_class_name": x,
        },
    )
Пример #15
0
def connection_field_behaviours(args):
    """Lists field behaviours"""
    AirflowConsole().print_as(
        data=list(ProvidersManager().field_behaviours.keys()),
        output=args.output,
        mapper=lambda x: {
            "field_behaviours": x,
        },
    )
Пример #16
0
def extra_links_list(args):
    """Lists all extra links at the command line"""
    AirflowConsole().print_as(
        data=ProvidersManager().extra_links_class_names,
        output=args.output,
        mapper=lambda x: {
            "extra_link_class_name": x,
        },
    )
Пример #17
0
    def __str__(self):

        tabulate_data = [
            {
                'Provider name': provider['package-name'],
                'Version': provider['versions'][0],
            }
            for version, provider in ProvidersManager().providers.values()
        ]
        return tabulate(tabulate_data, headers='keys')
Пример #18
0
 def test_providers_are_loaded(self):
     provider_manager = ProvidersManager()
     provider_list = list(provider_manager.providers.keys())
     # No need to sort the list - it should be sorted alphabetically !
     for provider in provider_list:
         package_name = provider_manager.providers[provider][1]['package-name']
         version = provider_manager.providers[provider][0]
         assert re.search(r'[0-9]*\.[0-9]*\.[0-9]*.*', version)
         assert package_name == provider
     assert ALL_PROVIDERS == provider_list
Пример #19
0
def providers_list(args):
    """Lists all providers at the command line"""
    AirflowConsole().print_as(
        data=list(ProvidersManager().providers.values()),
        output=args.output,
        mapper=lambda x: {
            "package_name": x.data["package-name"],
            "description": _remove_rst_syntax(x.data["description"]),
            "version": x.version,
        },
    )
Пример #20
0
def hooks_list(args):
    """Lists all hooks at the command line"""
    AirflowConsole().print_as(
        data=ProvidersManager().hooks.items(),
        output=args.output,
        mapper=lambda x: {
            "connection_type": x[0],
            "class": x[1][0],
            "conn_attribute_name": x[1][1],
        },
    )
Пример #21
0
    def test_providers_are_loaded(self):
        provider_manager = ProvidersManager()
        provider_list = list(provider_manager.providers.keys())
        # No need to sort the list - it should be sorted alphabetically !
        for provider in provider_list:
            package_name = provider_manager.providers[provider][1]['package-name']
            version = provider_manager.providers[provider][0]
            self.assertRegex(version, r'[0-9]*\.[0-9]*\.[0-9]*.*')
            self.assertEqual(package_name, provider)

        self.assertEqual(ALL_PROVIDERS, provider_list)
Пример #22
0
def connection_form_widget_list(args):
    """Lists all custom connection form fields at the command line"""
    AirflowConsole().print_as(
        data=list(ProvidersManager().connection_form_widgets.items()),
        output=args.output,
        mapper=lambda x: {
            "connection_parameter_name": x[0],
            "class": x[1].hook_class_name,
            'package_name': x[1].package_name,
            'field_type': x[1].field.field_class.__name__,
        },
    )
def hooks_list(args):
    """Lists all hooks at the command line"""
    AirflowConsole().print_as(
        data=list(ProvidersManager().hooks.items()),
        output=args.output,
        mapper=lambda x: {
            "connection_type": x[0],
            "class": x[1].connection_class,
            "conn_id_attribute_name": x[1].connection_id_attribute_name,
            'package_name': x[1].package_name,
            'hook_name': x[1].hook_name,
        },
    )
Пример #24
0
    def provider_user_agent(cls) -> Optional[str]:
        """Construct User-Agent from Airflow core & provider package versions"""
        import airflow
        from airflow.providers_manager import ProvidersManager

        try:
            manager = ProvidersManager()
            provider_name = manager.hooks[cls.conn_type].package_name  # type: ignore[union-attr]
            provider = manager.providers[provider_name]
            return f'apache-airflow/{airflow.__version__} {provider_name}/{provider.version}'
        except KeyError:
            warnings.warn(f"Hook '{cls.hook_name}' info is not initialized in airflow.ProviderManager")
            return None
Пример #25
0
 def user_agent_value(self) -> str:
     manager = ProvidersManager()
     package_name = manager.hooks[
         BaseDatabricksHook.
         conn_type].package_name  # type: ignore[union-attr]
     provider = manager.providers[package_name]
     version = provider.version
     python_version = platform.python_version()
     system = platform.system().lower()
     ua_string = (
         f"databricks-aiflow/{version} _/0.0.0 python/{python_version} os/{system} "
         f"airflow/{__version__} operator/{self.caller}")
     return ua_string
Пример #26
0
    def get_hook(self):
        """Return hook based on conn_type."""
        hook_class_name, conn_id_param, package_name, hook_name = ProvidersManager(
        ).hooks.get(self.conn_type, (None, None, None, None))

        if not hook_class_name:
            raise AirflowException(f'Unknown hook type "{self.conn_type}"')
        try:
            hook_class = import_string(hook_class_name)
        except ImportError:
            warnings.warn("Could not import %s when discovering %s %s",
                          hook_class_name, hook_name, package_name)
            raise
        return hook_class(**{conn_id_param: self.conn_id})
Пример #27
0
def provider_get(args):
    """Get a provider info."""
    providers = ProvidersManager().providers
    if args.provider_name in providers:
        provider_version = providers[args.provider_name].version
        provider_info = providers[args.provider_name].provider_info
        if args.full:
            provider_info["description"] = _remove_rst_syntax(provider_info["description"])
            AirflowConsole().print_as(
                data=[provider_info],
                output=args.output,
            )
        else:
            print(f"Provider: {args.provider_name}")
            print(f"Version: {provider_version}")
    else:
        raise SystemExit(f"No such provider installed: {args.provider_name}")
Пример #28
0
def provider_get(args):
    """Get a provider info."""
    providers = ProvidersManager().providers
    if args.provider_name in providers:
        provider_version, provider_info = providers[args.provider_name]
        print("#")
        print(f"# Provider: {args.provider_name}")
        print(f"# Version: {provider_version}")
        print("#")
        if args.full:
            yaml_content = yaml.dump(provider_info)
            if should_use_colors(args):
                yaml_content = pygments.highlight(
                    code=yaml_content,
                    formatter=get_terminal_formatter(),
                    lexer=YamlLexer())
            print(yaml_content)
    else:
        raise SystemExit(f"No such provider installed: {args.provider_name}")
Пример #29
0
    def get_hook(self, *, hook_params=None):
        """Return hook based on conn_type"""
        hook = ProvidersManager().hooks.get(self.conn_type, None)

        if hook is None:
            raise AirflowException(f'Unknown hook type "{self.conn_type}"')
        try:
            hook_class = import_string(hook.hook_class_name)
        except ImportError:
            warnings.warn(
                "Could not import %s when discovering %s %s",
                hook.hook_class_name,
                hook.hook_name,
                hook.package_name,
            )
            raise
        if hook_params is None:
            hook_params = {}
        return hook_class(**{hook.connection_id_attribute_name: self.conn_id}, **hook_params)
Пример #30
0
def _backported_get_hook(connection, *, hook_params=None):
    """Return hook based on conn_type
    For supporting Airflow versions < 2.3, we backport "get_hook()" method. This should be removed
    when "apache-airflow-providers-slack" will depend on Airflow >= 2.3.
    """
    hook = ProvidersManager().hooks.get(connection.conn_type, None)

    if hook is None:
        raise AirflowException(f'Unknown hook type "{connection.conn_type}"')
    try:
        hook_class = import_string(hook.hook_class_name)
    except ImportError:
        warnings.warn(
            f"Could not import {hook.hook_class_name} when discovering {hook.hook_name} {hook.package_name}",
        )
        raise
    if hook_params is None:
        hook_params = {}
    return hook_class(
        **{hook.connection_id_attribute_name: connection.conn_id},
        **hook_params)